Skip to content

Commit 673ff89

Browse files
emilyfertigbrianwa84
authored andcommitted
Patch Numpy and JAX backends before converting DeferredTensor and TransformedVariable to CompositeTensor.
PiperOrigin-RevId: 373906496
1 parent e8d0b53 commit 673ff89

9 files changed

Lines changed: 54 additions & 12 deletions

File tree

tensorflow_probability/python/internal/backend/jax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ FILENAMES = [
5353
"private",
5454
"random_generators",
5555
"raw_ops",
56+
"resource_variable_ops",
5657
"sets_lib",
5758
"sparse_lib",
5859
"tensor_array_ops",

tensorflow_probability/python/internal/backend/numpy/BUILD

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ py_library(
5252
":private",
5353
":random_generators",
5454
":raw_ops",
55+
":resource_variable_ops",
5556
":sets_lib",
5657
":sparse_lib",
5758
":static_rewrites",
@@ -332,6 +333,11 @@ py_library(
332333
],
333334
)
334335

336+
py_library(
337+
name = "resource_variable_ops",
338+
srcs = ["resource_variable_ops.py"],
339+
)
340+
335341
py_library(
336342
name = "sets_lib",
337343
srcs = ["sets_lib.py"],

tensorflow_probability/python/internal/backend/numpy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from tensorflow_probability.python.internal.backend.numpy.numpy_array import * # pylint: disable=wildcard-import
4747
from tensorflow_probability.python.internal.backend.numpy.numpy_math import * # pylint: disable=wildcard-import
4848
from tensorflow_probability.python.internal.backend.numpy.ops import * # pylint: disable=wildcard-import
49+
from tensorflow_probability.python.internal.backend.numpy.type_spec import BatchableTypeSpec
50+
from tensorflow_probability.python.internal.backend.numpy.type_spec import TypeSpec
4951

5052

5153
Assert = debugging.Assert

tensorflow_probability/python/internal/backend/numpy/ops.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,7 @@
6161
'Module',
6262
'Tensor',
6363
'TensorSpec',
64-
'TypeSpec',
6564
'Variable',
66-
'VariableSpec',
6765
# 'gradients',
6866
]
6967

@@ -696,15 +694,10 @@ class Tensor(six.with_metaclass(_TensorMeta)):
696694

697695

698696
class TensorSpec(object):
699-
pass
700697

701-
702-
class TypeSpec(object):
703-
pass
704-
705-
706-
class VariableSpec(object):
707-
pass
698+
def __init__(self, *args, **kwargs):
699+
del args, kwargs
700+
self.dtype = None
708701

709702

710703
class Module(object):
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright 2021 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
"""Numpy stub for `resource_variable_ops`."""
16+
17+
__all__ = [
18+
'VariableSpec',
19+
]
20+
21+
22+
class VariableSpec(object):
23+
24+
def __init__(self, *args, **kwargs):
25+
del args, kwargs
26+
self.dtype = None

tensorflow_probability/python/internal/backend/numpy/tf_inspect.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
# Although `inspect` is different between Python 2 and 3, we should only ever
2828
# be using Python 3's inspect because JAX is Python 3 only and if TF is present
2929
# we will use `tf_inspect` which is compatible with both Python 2 and 3.
30+
Parameter = inspect.Parameter
3031
getfullargspec = inspect.getfullargspec
3132
getcallargs = inspect.getcallargs
3233
getframeinfo = inspect.getframeinfo
@@ -46,4 +47,5 @@
4647
ismethod = inspect.ismethod
4748
ismodule = inspect.ismodule
4849
isroutine = inspect.isroutine
50+
signature = inspect.signature
4951
stack = inspect.stack

tensorflow_probability/python/internal/backend/numpy/type_spec.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
__all__ = [
1818
'lookup',
19-
'register'
19+
'register',
20+
'BatchableTypeSpec',
21+
'TypeSpec',
2022
]
2123

2224

@@ -30,3 +32,11 @@ def decorator_fn(cls):
3032
def lookup(_):
3133
# Raise ValueError instead of NotImplementedError to conform to TF.
3234
raise ValueError('`TypeSpec`s are not registered in Numpy/JAX.')
35+
36+
37+
class TypeSpec(object):
38+
pass
39+
40+
41+
class BatchableTypeSpec(TypeSpec):
42+
pass

tensorflow_probability/python/internal/backend/numpy/v2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
from tensorflow_probability.python.internal.backend.numpy.numpy_math import * # pylint: disable=wildcard-import
5252
from tensorflow_probability.python.internal.backend.numpy.ops import * # pylint: disable=wildcard-import
5353
from tensorflow_probability.python.internal.backend.numpy.tensor_array_ops import TensorArray
54+
from tensorflow_probability.python.internal.backend.numpy.type_spec import BatchableTypeSpec
55+
from tensorflow_probability.python.internal.backend.numpy.type_spec import TypeSpec
5456
# pylint: enable=unused-import
5557

5658

tensorflow_probability/substrates/meta/rewrite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
('from tensorflow.python.ops import '
6262
'resource_variable_ops'):
6363
('from tensorflow_probability.python.internal.backend.numpy '
64-
'import ops'),
64+
'import resource_variable_ops'),
6565
'from tensorflow.python.util import':
6666
'from tensorflow_probability.python.internal.backend.numpy import',
6767
'from tensorflow.python.util.all_util':

0 commit comments

Comments
 (0)