Creates a callable TensorFlow graph from a Python function.
tf.function( func=None, input_signature=None, autograph=True, experimental_autograph_options=None, experimental_relax_shapes=False )
function constructs a callable that executes a TensorFlow graph
tf.Graph) created by tracing the TensorFlow operations in
This allows the TensorFlow runtime to apply optimizations and exploit
parallelism in the computation defined by
def f(x, y): return tf.reduce_mean(tf.multiply(x ** 2, 3) + y) g = tf.function(f) x = tf.constant([[2.0, 3.0]]) y = tf.constant([[3.0, -2.0]]) # `f` and `g` will return the same value, but `g` will be executed as a # TensorFlow graph. assert f(x, y).numpy() == g(x, y).numpy() # Tensors and tf.Variables used by the Python function are captured in the # graph. def h(): return f(x, y) assert (h().numpy() == f(x, y).numpy()).all() # Data-dependent control flow is also captured in the graph. Supported # control flow statements include `if`, `for`, `while`, `break`, `continue`, # `return`. def g(x): if tf.reduce_sum(x) > 0: return x * x else: return -x // 2 # print and TensorFlow side effects are supported, but exercise caution when # using Python side effects like mutating objects, saving to files, etc. l =  def g(x): for i in x: print(i) # Works tf.compat.v1.assign(v, i) # Works tf.compat.v1.py_func(lambda i: l.append(i))(i) # Works l.append(i) # Caution! Doesn't work.
Note that unlike other TensorFlow operations, we don't convert python
numerical inputs to tensors. Moreover, a new graph is generated for each
distinct python numerical value, for example calling
generate two new graphs (while only one is generated if you call
g(tf.constant(3))). Therefore, python numerical
inputs should be restricted to arguments that will have few distinct values,
such as hyperparameters like the number of layers in a neural network. This
allows TensorFlow to optimize each variant of the neural network.
The Python function
func may reference stateful objects (such as
These are captured as implicit inputs to the callable returned by
c = tf.Variable(0) def f(x): c.assign_add(1) return x + tf.compat.v1.to_float(c) assert int(c) == 0 assert f(1.0) == 2.0 assert int(c) == 1 assert f(1.0) == 3.0 assert int(c) == 2
function can be applied to methods of an object. For example:
class Dense(object): def __init__(self): self.W = tf.Variable(tf.compat.v1.glorot_uniform_initializer()((10, 10))) self.b = tf.Variable(tf.zeros(10)) def compute(self, x): return tf.matmul(x, self.W) + self.b d1 = Dense() d2 = Dense() x = tf.random.uniform((10, 10)) # d1 and d2 are using distinct variables assert not (d1.compute(x).numpy() == d2.compute(x).numpy()).all()
call methods of a
tf.keras.Model subclass can be decorated with
function中恒配资网 in order to apply graph execution optimizations on it.
class MyModel(tf.keras.Model): def __init__(self, keep_probability=0.2): super(MyModel, self).__init__() self.dense1 = tf.keras.layers.Dense(4) self.dense2 = tf.keras.layers.Dense(5) self.keep_probability = keep_probability def call(self, inputs, training=True): y = self.dense2(self.dense1(inputs)) if training: return tf.nn.dropout(y, self.keep_probability) else: return y model = MyModel() model(x, training=True) # executes a graph, with dropout model(x, training=False) # executes a graph, without dropout
function instantiates a separate graph for every unique set of input
shapes and datatypes. For example, the following code snippet will result
in three distinct graphs being traced, as each input has a different
def f(x): return tf.add(x, 1.) scalar = tf.constant(1.0) vector = tf.constant([1.0, 1.0]) matrix = tf.constant([[3.0]]) f(scalar) f(vector) f(matrix)
An "input signature" can be optionally provided to
function to control
the graphs traced. The input signature specifies the shape and type of each
Tensor argument to the function using a
tf.TensorSpec object. For example,
the following code snippet ensures that a single graph is created where the
Tensor is required to be a floating point tensor with no restrictions
def f(x): return tf.add(x, 1.)
input_signature is specified, the callable will convert the inputs
to the specified TensorSpecs.
Tracing and staging
True, all Python control flow that depends on
values is staged into a TensorFlow graph. When
function is traced and control flow is not allowed to depend on data.
function only stages TensorFlow operations, all Python code that
func executes and does not depend on data will shape the construction of
For example, consider the following:
import numpy as np def add_noise(): return tf.eye(5) + np.random.randn(5, 5) traced = tf.function(add_noise)
add_noise() will return a different output every time it is invoked.
traced() will return the same value every time it is called,
since a particular random value generated by the
np.random.randn call will
be inserted in the traced/staged TensorFlow graph as a constant. In this
particular example, replacing
np.random.randn(5, 5) with
tf.random.normal((5, 5)) will result in the same behavior for
A corollary of the previous discussion on tracing is the following: If a
func has Python side-effects, then executing
times may not be semantically equivalent to executing
F = tf.function(func)
multiple times; this difference is due to the fact that
captures the subgraph of TensorFlow operations that is constructed when
is invoked to trace a graph.
The same is true if code with Python side effects is used inside control flow,
such as a loop. If your code uses side effects that are not intended to
control graph construction, wrap them inside
A single tf.function object might need to map to multiple computation graphs under the hood. This should be visible only as performance (tracing graphs has a nonzero computational and memory cost) but should not affect the correctness of the program. A traced function should return the same result as it would when run eagerly, assuming no unintended Python side-effects.
tf.function with tensor arguments of different dtypes should lead
to at least one computational graph per distinct set of dtypes. Alternatively,
always calling a
tf.function with tensor arguments of the same shapes and
dtypes and the same non-tensor arguments should not lead to additional
retracings of your function.
Other than that, TensorFlow reserves the right to retrace functions as many times as needed, to ensure that traced functions behave as they would when run eagerly and to provide the best end-to-end performance. For example, the behavior of how many traces TensorFlow will do when the function is repeatedly called with different python scalars as arguments is left undefined to allow for future optimizations.
中恒配资网To control the tracing behavior, use the following tools:
tf.functionobjects are guaranteed to not share traces; and
- specifying a signature or using concrete function objects returned from get_concrete_function() guarantees that only one function graph will be built.
func: function to be compiled. If
funcis None, returns a decorator that can be invoked with a single argument -
func. The end result is equivalent to providing all the arguments up front. In other words,
tf.function(input_signature=...)(func)is equivalent to
tf.function(func, input_signature=...). The former can be used to decorate Python functions, for example: @tf.function(input_signature=...) def foo(...): ...
input_signature: A possibly nested sequence of
tf.TensorSpecobjects specifying the shapes and dtypes of the Tensors that will be supplied to this function. If
None, a separate function is instantiated for each inferred input signature. If input_signature is specified, every input to
funcmust be a
autograph: Whether autograph should be applied on
funcbefore tracing a graph. This allows for dynamic control flow (Python if's, loops etc.) in the traced graph. See for more information.
experimental_autograph_options: Experimental knobs (in the form of a tuple of tensorflow.autograph.Feature values) to control behavior when autograph=True.
experimental_relax_shapes: When true, argument shapes may be relaxed to avoid unecessary retracing.
func is not None, returns a callable that will execute the compiled
function (and return zero or more
func is None, returns a decorator that, when invoked with a single
func中恒配资网 argument, returns a callable equivalent to the case above.
Nonenor a sequence of