Skip to content

[TF 2.0] allow tf.function input_signature to be specified by annotations #31579

@jeffpollock9

Description

@jeffpollock9

System information

  • TensorFlow version (you are using): 2.0.0-rc0
  • Are you willing to contribute it (Yes/No): Yes

Describe the feature and the current behavior/state.

tf.function has an argument input_signature which I have been using to try and make my code a bit safer and ensure I don't keep re-tracing functions. The input_signature specifies the tensor type for each of the function arguments. It would be much nicer (I think) to specify these types using python (>=3.5) annotations, where a suitable version of python is available. A very rough example looks like:

import tensorflow as tf


def function(fn):
    input_signature = list(fn.__annotations__.values())
    return tf.function(fn, autograph=False, input_signature=input_signature)


@function
def foo(
    x: tf.TensorSpec(shape=[None], dtype=tf.float64),
    y: tf.TensorSpec(shape=[None], dtype=tf.float64),
):
    return x + 10.0 + y


vec32 = tf.random.normal([2], dtype=tf.float32)
vec64 = tf.random.normal([2], dtype=tf.float64)


# should pass
foo(vec64, vec64)
foo(y=vec64, x=vec64)

# should fail
foo(vec32, vec64)

Which I think is nicer than the current signature:

@tf.function(
    autograph=False,
    input_signature=[
        tf.TensorSpec(shape=[None], dtype=tf.float64),
        tf.TensorSpec(shape=[None], dtype=tf.float64),
    ],
)
def foo(x, y):
    return x + 10.0 + y

I think the main benefit of the annotation approach is that the argument name and type are beside each other, and this syntax is already widely used in python.

In order to enable using annotations as the input_signature I think there should be an extra boolean argument to tf.function called e.g. use_annotation_input_signature which defaults to False.

Also note I have set autograph=False here to avoid a warning:

Cause: name 'foo_scope' is not defined

I am guessing a proper implementation inside of tf.function would not have this problem.

Will this change the current api? How?

It would add an additional argument to tf.function which at the default value would not change anything.

Who will benefit with this feature?

Anyone using python >= 3.5 who would like to specify the tensor types of their functions.

Any Other info.

None

Metadata

Metadata

Labels

comp:autographAutograph related issuesstaleThis label marks the issue/pr stale - to be closed automatically if no activitystat:contribution welcomeStatus - Contributions welcometype:featureFeature requests

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions