在自搭架構的時候常常會出現:
warning "5 out of the last 5 calls to <function XXX> triggered tf.function retracing
原因歸納有以下幾種:
1. 在Loop裡面重複搭建模型,造成每loop一次,AutoGraph都會在session的cache創一個graph
2. Input內容沒轉為tf.Tensor類別,每一次input就算shape相同也會重新創建一個graph
3. Input shape不一致:多半發生在shape[0],只要input shape不是同數量的batch都會引發retrace。
針對上述問題有幾種解法可參考:
問題1.:把tf.function包在class類別裡,
例如 -
class pred_class():
def __init__(self, model):
self.model = model
@tf.function(experimental_relax_shapes = True)
def __call__(self, x):
v = self.model(x)
return v
在loop外創建class便可減少retrace problem
問題2.:強制轉換input = tf.convert_to_tensor(x.astype('float32'))
問題3.:調用 experimental_relax_shapes = True。
但要注意的是experimental_relax_shapes會在第三次不同shape才會啟用,參見
https://github.com/tensorflow/tensorflow/issues/35303#issuecomment-569057679
例如 -
fn(tf.constant([2]))
Output:
(1,)
<tf.Tensor: shape=(1,), dtype=int32, numpy=array([4], dtype=int32)>
------------
fn(tf.constant([2, 3]))
Output:
(2,)
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([4, 9], dtype=int32)>
------------
Output:
(None,)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([ 4, 9, 16], dtype=int32)>
------------
Output:<tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 4, 9, 16, 25], dtype=int32)>
留言
張貼留言