Skip to content

Commit b54bb7b

Browse files
davmrebrianwa84
authored andcommitted
Replace deprecation of parameter_properties inheritance with a warning.
This also removes a spurious deprecation warning that appears when `JointDistribution`s are used in public colabs. (JDs have a multi-level hierarchy in which no class defines its own _parameter_properties, but since they all inherit the base class def that raises `NotImplementedError`, there's no inheritance problem). Justification for this change: there are legitimate 'quick-and-dirty' uses of parameter_properties inheritance, even if we wouldn't do it in TFP. For example, somewhere in the DeepMind silo is a Normal distribution that takes a log_scale rather than a scale: class NormalWithLogScale(tfd.Normal): def __init__(self, loc, log_scale): super().__init__(loc=loc, scale=tf.exp(log_scale)) Ignoring whether this is the best way to accomplish any particular goal, from a general Pythonic standpoint one might expect that this class would at least be basically functional. It may not support batch slicing or AutoCompositeTensor, but one should at least be able to call sample and log_prob, which means it should at least define properties like batch_shape. But now that batch shape depends on parameter_properties (as of cl/373590501), breaking parameter_properties inheritance means breaking batch_shape inheritance. To allow quick subclasses like this, I propose we simply warn when an inherited `parameter_properties` is called. In this example, the batch shape would be computed (correctly) using the base Normal parameters, just as if an explicit batch_shape method had been inherited. For full functionality including batch slicing, CompositeTensor, etc., a subclass would need to both (a) set self.parameters = dict(locals()) in its own constructor, and (b) define its own _parameter_properties. I've tried to articulate these requirements in the warning message. PiperOrigin-RevId: 374554087
1 parent c0e7515 commit b54bb7b

2 files changed

Lines changed: 61 additions & 39 deletions

File tree

tensorflow_probability/python/distributions/distribution.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import contextlib
2424
import functools
2525
import inspect
26+
import logging
2627
import types
2728

2829
from absl import logging
@@ -276,40 +277,61 @@ def __new__(mcs, classname, baseclasses, attrs):
276277
return super(_DistributionMeta, mcs).__new__(
277278
mcs, classname, baseclasses, attrs)
278279

279-
# Subclasses shouldn't inherit their parents' `_parameter_properties`,
280-
# since (in general) they'll have different parameters. Exceptions (for
281-
# convenience) are:
280+
# Warn when a subclass inherits `_parameter_properties` from its parent
281+
# (this is unsafe, since the subclass will in general have different
282+
# parameters). Exceptions are:
282283
# - Subclasses that don't define their own `__init__` (handled above by
283284
# the short-circuit when `default_init is None`).
284285
# - Subclasses that define a passthrough `__init__(self, *args, **kwargs)`.
285-
# - Direct children of `Distribution`, since the inherited method just
286-
# raises a NotImplementedError.
286+
# pylint: disable=protected-access
287287
init_argspec = tf_inspect.getfullargspec(default_init)
288288
if ('_parameter_properties' not in attrs
289-
and base != Distribution
290289
# Passthrough exception: may only take `self` and at least one of
291290
# `*args` and `**kwargs`.
292291
and (len(init_argspec.args) > 1
293292
or not (init_argspec.varargs or init_argspec.varkw))):
294-
# TODO(b/183457779) remove warning and raise `NotImplementedError`.
295-
attrs['_parameter_properties'] = deprecation.deprecated(
296-
date='2021-07-01',
297-
instructions="""
298-
Calling `_parameter_properties` on subclass {classname} that redefines the
299-
parent ({basename}) `__init__` is unsafe and will raise an error in the future.
300-
Please implement an explicit `_parameter_properties` for the subclass. If the
301-
subclass `__init__` takes the same parameters as the parent, you may use the
302-
placeholder implementation:
303293

304-
@classmethod
305-
def _parameter_properties(cls, dtype, num_classes=None):
306-
return {basename}._parameter_properties(
307-
dtype=dtype, num_classes=num_classes)
294+
@functools.wraps(base._parameter_properties)
295+
def wrapped_properties(*args, **kwargs): # pylint: disable=missing-docstring
296+
"""Wrapper to warn if `parameter_properties` is inherited."""
297+
properties = base._parameter_properties(*args, **kwargs)
298+
# Warn *after* calling the base method, so that we don't bother warning
299+
# if it just raised NotImplementedError anyway.
300+
logging.warning("""
301+
Distribution subclass %s inherits `_parameter_properties from its parent (%s)
302+
while also redefining `__init__`. The inherited annotations cover the following
303+
parameters: %s. It is likely that these do not match the subclass parameters.
304+
This may lead to errors when computing batch shapes, slicing into batch
305+
dimensions, calling `.copy()`, flattening the distribution as a CompositeTensor
306+
(e.g., when it is passed or returned from a `tf.function`), and possibly other
307+
cases. The recommended pattern for distribution subclasses is to define a new
308+
`_parameter_properties` method with the subclass parameters, and to store the
309+
corresponding parameter values as `self._parameters` in `__init__`, after
310+
calling the superclass constructor:
311+
312+
```
313+
class MySubclass(tfd.SomeDistribution):
314+
315+
def __init__(self, param_a, param_b):
316+
parameters = dict(locals())
317+
# ... do subclass initialization ...
318+
super(MySubclass, self).__init__(**base_class_params)
319+
# Ensure that the subclass (not base class) parameters are stored.
320+
self._parameters = parameters
321+
322+
def _parameter_properties(self, dtype, num_classes=None):
323+
return dict(
324+
# Annotations may optionally specify properties, such as `event_ndims`,
325+
# `default_constraining_bijector_fn`, `specifies_shape`, etc.; see
326+
# the `ParameterProperties` documentation for details.
327+
param_a=tfp.util.ParameterProperties(),
328+
param_b=tfp.util.ParameterProperties())
329+
```
330+
""", classname, base.__name__, str(properties.keys()))
331+
return properties
308332

309-
""".format(classname=classname,
310-
basename=base.__name__))(base._parameter_properties)
333+
attrs['_parameter_properties'] = wrapped_properties
311334

312-
# pylint: disable=protected-access
313335
# For a comparison of different methods for wrapping functions, see:
314336
# https://hynek.me/articles/decorators/
315337
@decorator.decorator
@@ -657,10 +679,7 @@ def _composite_tensor_shape_params(self):
657679
@classmethod
658680
def _parameter_properties(cls, dtype, num_classes=None):
659681
raise NotImplementedError(
660-
'_parameter_properties` is not implemented: {}. '
661-
'Note that subclasses that redefine `__init__` are not assumed to '
662-
'share parameters with their parent class and must provide a separate '
663-
'implementation.'.format(cls.__name__))
682+
'_parameter_properties` is not implemented: {}.'.format(cls.__name__))
664683

665684
@classmethod
666685
def parameter_properties(cls, dtype=tf.float32, num_classes=None):

tensorflow_probability/python/distributions/distribution_test.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import collections
2020
# Dependency imports
2121

22+
from absl import logging
2223
from absl.testing import parameterized
2324

2425
import numpy as np
@@ -29,8 +30,6 @@
2930
from tensorflow_probability.python.internal import test_util
3031

3132
from tensorflow.python.framework import test_util as tf_test_util # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
32-
from tensorflow.python.platform import test as tf_test # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
33-
from tensorflow.python.platform import tf_logging # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
3433

3534

3635
class TupleDistribution(tfd.Distribution):
@@ -617,10 +616,7 @@ def normal_differential_entropy(scale):
617616
self.evaluate(normal_differential_entropy(scale)),
618617
err=1e-5)
619618

620-
@test_util.jax_disable_test_missing_functionality('tf_logging')
621-
@tf_test.mock.patch.object(tf_logging, 'warning', autospec=True)
622-
def testParameterPropertiesNotInherited(self, mock_warning):
623-
# TODO(b/183457779) Test for NotImplementedError (rather than just warning).
619+
def testParameterPropertiesNotInherited(self):
624620

625621
# Subclasses that don't redefine __init__ can inherit properties.
626622
class NormalTrivialSubclass(tfd.Normal):
@@ -640,20 +636,27 @@ class MyDistribution(tfd.Distribution):
640636
def __init__(self, param1, param2):
641637
pass
642638

643-
NormalTrivialSubclass.parameter_properties()
644-
NormalWithPassThroughInit.parameter_properties()
645-
with self.assertRaises(NotImplementedError):
646-
MyDistribution.parameter_properties()
647-
self.assertEqual(0, mock_warning.call_count)
639+
with self.assertLogs(level=logging.WARNING) as log:
640+
NormalTrivialSubclass.parameter_properties()
641+
NormalWithPassThroughInit.parameter_properties()
642+
with self.assertRaises(NotImplementedError):
643+
MyDistribution.parameter_properties()
644+
with self.assertRaises(NotImplementedError):
645+
# Ensure that the unimplemented JD propertoes don't raise a warning.
646+
tfd.JointDistributionCoroutine.parameter_properties()
647+
logging.warning('assertLogs context requires at least one warning.')
648+
# Assert that no warnings occurred other than the dummy warning.
649+
self.assertLen(log.records, 1)
648650

649651
class NormalWithExtraParam(tfd.Normal):
650652

651653
def __init__(self, extra_param, *args, **kwargs):
652654
self._extra_param = extra_param
653655
super(NormalWithExtraParam, self).__init__(*args, **kwargs)
654656

655-
NormalWithExtraParam.parameter_properties()
656-
self.assertEqual(1, mock_warning.call_count)
657+
with self.assertLogs(level=logging.WARNING) as log:
658+
NormalWithExtraParam.parameter_properties()
659+
self.assertLen(log.records, 1)
657660

658661

659662
@test_util.test_all_tf_execution_regimes

0 commit comments

Comments
 (0)