TensorFlow tf.function retracing warning


在自搭架構的時候常常會出現:

warning "5 out of the last 5 calls to <function XXX> triggered tf.function retracing


原因歸納有以下幾種:

1. 在Loop裡面重複搭建模型,造成每loop一次,AutoGraph都會在session的cache創一個graph

2. Input內容沒轉為tf.Tensor類別,每一次input就算shape相同也會重新創建一個graph

3. Input shape不一致:多半發生在shape[0],只要input shape不是同數量的batch都會引發retrace。


針對上述問題有幾種解法可參考:


問題1.:把tf.function包在class類別裡,

例如 - 

class pred_class():

    def __init__(self, model):

        self.model = model

    @tf.function(experimental_relax_shapes = True)

    def __call__(self, x):

        v = self.model(x)

        return v 

在loop外創建class便可減少retrace problem


問題2.:強制轉換input = tf.convert_to_tensor(x.astype('float32'))


問題3.:調用 experimental_relax_shapes = True。

但要注意的是experimental_relax_shapes會在第三次不同shape才會啟用,參見

https://github.com/tensorflow/tensorflow/issues/35303#issuecomment-569057679

The logic that experimental_relax_shapes uses is a bit convoluted and hard to follow, but roughly it will only attempt to relax shapes on the third call. In general, if you enable it you should expect the function to be still called a few times times and the shapes of tensors to be anything.

例如 -

fn(tf.constant([2]))

Output:

(1,)

<tf.Tensor: shape=(1,), dtype=int32, numpy=array([4], dtype=int32)>

------------

fn(tf.constant([2, 3]))

Output:

(2,)

<tf.Tensor: shape=(2,), dtype=int32, numpy=array([4, 9], dtype=int32)>

------------

fn(tf.constant([2, 3, 4]))

Output:

(None,)

<tf.Tensor: shape=(3,), dtype=int32, numpy=array([ 4, 9, 16], dtype=int32)>

------------

fn(tf.constant([2, 3, 4, 5]))

Output:<tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 4, 9, 16, 25], dtype=int32)>

留言