Full static typing for torch.distributions#144219
Full static typing for torch.distributions#144219randolf-scholz wants to merge 27 commits intopytorch:mainfrom
torch.distributions#144219Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/144219
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 1 Unrelated FailureAs of commit 2a8e2ec with merge base aec3ef1 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label "module: typing" |
|
@pytorchbot label "module: distributions" |
|
@pytorchbot label "release notes: python_frontend" |
randolf-scholz
left a comment
There was a problem hiding this comment.
More accurate would be using some sort of polymorphic mapping like
class _KL_REGISTRY_TYPE(Protocol):
def __iter__(self) -> tuple[type[Distribution], type[Distribution]]: ...
def __getitem__(self, key: tuple[type[P], type[Q]], / ) -> _KL[P, Q]: ...
def __setitem__(self, key: tuple[type[P], type[Q]], value: _KL[P, Q], /) -> None: ...
def __delitem__(self, key: tuple[type[Distribution], type[Distribution]], /) -> None: ...
def clear(self) -> None: ...but likely this overcomplicates things unnecessarily
This comment was marked as outdated.
This comment was marked as outdated.
|
@randolf-scholz Still interested in merging this? |
|
@Skylion007 Yes, this was quite a bit of work, and it would be a shame if it goes to waste... As I wrote in my last comment and in the OP, there are a few remaining open problems that mostly stem from LSP violations. It would be good to get some feedback on these. Also, @malfet suggested in the other PR (the one adding |
fbd5731 to
1ffc7c1
Compare
|
@Skylion007 I rebased onto main and squashed some of my commits |
torch/distributions/constraints.py
Outdated
There was a problem hiding this comment.
Why not ParamSpec here? T isn't used elsewhere?
There was a problem hiding this comment.
ParamSpec does not make sense here, since this is a property-decorator, and properties are usually not supposed to take arguments. (note that T is actually the type the property gets bound to)
Skylion007
left a comment
There was a problem hiding this comment.
I have some nits, but this is definitely an improvement for what was there before.
torch/distributions/kl.py
Outdated
There was a problem hiding this comment.
We want these types to be public?
There was a problem hiding this comment.
The codebase seems to be inconsistent with respect to that, but more often than not uses private variables. Personally, I strongly prefer non-private, because it makes hints that show up for instance with pylance more readable. Moreover, when support for 3.11 is dropped and PEP 695 is used there is really no reason anymore to use an underscore prefix.
There was a problem hiding this comment.
But for the purposes of this PR, I am fine with changing it if necessary.
There was a problem hiding this comment.
We're the ones setting a lot of the standards for type naming!
I personally like the _.
But "The names, they convey nothing!" :-D
DistributionT1 = TypeVar("DistributionT1", bound=Distribution)
DistributionT2 = TypeVar("DistributionT2", bound=Distribution)
DistributionT3 = TypeVar("DistributionT3", bound=Distribution)
DistributionT4= TypeVar("DistributionT3", bound=Distribution)
and
DistributionBinaryFunc: Callable[[DistributionT1, DistributionT2], Tensor]
torch/distributions/kl.py
Outdated
There was a problem hiding this comment.
Hold on, when does P, and Q are different from P2, Q2, this feels like the PERFECT place for a static type check tha the function you are decorating matches here?
There was a problem hiding this comment.
That would be nice, but I do not see how to do it currently, I think it requires HKTs. That's because below you have some cases where the same function gets decorated multiple times like
@register_kl(Normal, Beta)
@register_kl(Normal, ContinuousBernoulli)
@register_kl(Normal, Exponential)
@register_kl(Normal, Gamma)
@register_kl(Normal, Pareto)
@register_kl(Normal, Uniform)
def _kl_normal_infinity(
p: Normal, q: Union[Beta, ContinuousBernoulli, Exponential, Gamma, Pareto, Uniform]
) -> Tensor:
return _infinite_like(p.loc)So, the fist decorator must return Callable[[Normal, Union[Beta, ContinuousBernoulli, Exponential, Gamma, Pareto, Uniform]], Tensor], and not just Callable[[Normal, Uniform], Tensor], otherwise the next decorator will cause a type error.
What would be ideal would be something like
class _KL_Decorator[P, Q](Protocol):
def __call__[P2 :> P, Q2 :> Q](self, arg: _KL[P2, Q2], /) -> _KL[P2, Q2]: ...
def register_kl(type_p: type[P], type_q: type[Q]) -> _KL_Decorator[P, Q]:So that for instance the first @register_kl(Normal, Uniform) would produce a
class _KL_Decorator[Normal, Uniform](Protocol):
def __call__[P2: Normal, Q2: Uniform](self, arg: _KL[P2, Q2], /) -> _KL[P2, Q2]: ...Because then when this gets applied to Callable[[Normal, Uniform | Pareto |...], Tensor], it gives back Callable[[Normal, Uniform | Pareto |...], Tensor].
But this requires HKTs, which are currently not available in the python typing system.
There was a problem hiding this comment.
Actually with mypy==1.15 similar issues crop up in the constraint_registry.py file, it seems one also needs to loosen the Factory type hint in a similar manner.
There was a problem hiding this comment.
You might be able to do a typevaruple for *Args here
There was a problem hiding this comment.
Possibly, but all this function does is call TensorBase.new which currently has an *args: Any overload. Really what would be needed here is the ability to reference the signature of another function, which is currently not a feature supported by the type system.
1b65b97 to
e625cb4
Compare
This comment was marked as outdated.
This comment was marked as outdated.
+ removed '## mypy: allow-untyped-defs' comments
e625cb4 to
73f160e
Compare
| cat = _Cat | ||
| stack = _Stack | ||
|
|
||
| # Type aliases. |
There was a problem hiding this comment.
These are now all reexported and demand doc strings I think?
|
@skylion: Sigh, the delta is greater than 2k lines, and this makes the "sanity check" test fail. It's generally pretty easy to split typing pull requests. Starting a full review now, don't split before that! 🙂 |
rec
left a comment
There was a problem hiding this comment.
Whew, that's a big one!
I read every line though.
Thanks for doing this all!
|
|
||
| def build_constraint( | ||
| constraint_fn: Union[C, type[C]], | ||
| args: tuple, |
There was a problem hiding this comment.
What, we can just do that as a synonym for tuple[Any, ...]? I had gotten the impression that this wouldn't work with mypy?
This is a test, is it even being type checked at all?
There was a problem hiding this comment.
Currently, the lintrunner is ignoring these, but I do check them locally because the runtime code helps debugging the annotations.
There was a problem hiding this comment.
$ mypy torch/distributions/ test/distributions/ --warn-unused-ignores
test/distributions/test_distributions.py:3677:19: error: Argument 1 to "cdf" of "TransformedDistribution" has incompatible type "float"; expected "Tensor" [arg-type]
test/distributions/test_distributions.py:3687:19: error: Argument 1 to "cdf" of "TransformedDistribution" has incompatible type "float"; expected "Tensor" [arg-type]
test/distributions/test_distributions.py:3697:19: error: Argument 1 to "cdf" of "TransformedDistribution" has incompatible type "float"; expected "Tensor" [arg-type]
test/distributions/test_distributions.py:3707:19: error: Argument 1 to "cdf" of "TransformedDistribution" has incompatible type "float"; expected "Tensor" [arg-type]
test/distributions/test_distributions.py:5225:41: error: Argument 1 to "log_prob" of "Gamma" has incompatible type "int"; expected "Tensor" [arg-type]
test/distributions/test_distributions.py:5251:40: error: Argument 1 to "log_prob" of "Gamma" has incompatible type "int"; expected "Tensor" [arg-type]
Found 6 errors in 1 file (checked 52 source files)
| """ | ||
| Creates a pair of distributions `Dist` initialized to test each element of | ||
| param with each other. | ||
| """ | ||
| params1 = [torch.tensor([p] * len(p)) for p in params] | ||
| params2 = [p.transpose(0, 1) for p in params1] | ||
| return Dist(*params1), Dist(*params2) | ||
| return Dist(*params1), Dist(*params2) # type: ignore[arg-type] |
There was a problem hiding this comment.
Wait, why does this fail? It should understand that both params1 and params2 are Sequence[Tensor] and not have an issue?
There was a problem hiding this comment.
mypy infers D as Distribution and Distribution.__init__ only expects batch_shape, event_shape and validate_args.
What we could do is make these arguments keyword-only in Distribution.__init__, then the error goes away. Probably a good idea from a design POV, but backward incompatible!
There was a problem hiding this comment.
But I think maybe this is better for a follow-up PR.
| @@ -1266,7 +1287,9 @@ def _check_forward_ad(self, fn): | |||
| torch.count_nonzero(fwAD.unpack_dual(dual_out).tangent).item(), 0 | |||
| ) | |||
|
|
|||
| def _check_log_prob(self, dist, asset_fn): | |||
| def _check_log_prob( | |||
| self, dist: Distribution, asset_fn: Callable[Concatenate[int, ...], None] | |||
There was a problem hiding this comment.
so... cool... I had no idea you could do that, it's obvious only in hindsight.
There was a problem hiding this comment.
Defining partial function signatures can be really handy, I wish this was even better supported (for instance when writing Callback Protocols)
torch/distributions/kl.py
Outdated
There was a problem hiding this comment.
We're the ones setting a lot of the standards for type naming!
I personally like the _.
But "The names, they convey nothing!" :-D
DistributionT1 = TypeVar("DistributionT1", bound=Distribution)
DistributionT2 = TypeVar("DistributionT2", bound=Distribution)
DistributionT3 = TypeVar("DistributionT3", bound=Distribution)
DistributionT4= TypeVar("DistributionT3", bound=Distribution)
and
DistributionBinaryFunc: Callable[[DistributionT1, DistributionT2], Tensor]
| """ | ||
|
|
||
| arg_constraints = {"alpha": constraints.positive, "scale": constraints.positive} | ||
| arg_constraints: ClassVar[dict[str, Constraint]] = { |
There was a problem hiding this comment.
This could conceivably a TypedDict?
There was a problem hiding this comment.
Yes, that would be a nice enhancement. Thinking about it, it should probably even be Final[ClassVar[SomeTypedDict]], but that would require running mypy with --python-version 3.13. (currently 3.11 in mypy.ini)
| raise NotImplementedError(f"{type(self)}.with_cache is not implemented") | ||
|
|
||
| def __eq__(self, other): | ||
| def __eq__(self, other: object) -> bool: |
There was a problem hiding this comment.
Any is for gradual types, really no good reason to use it here. The signature mirrors that of object.__eq__
|
@rec I implemented most of your suggestions. So, splitting it up I think it would make most sense to first make a PR for Regarding Alternatively, one could just make the classes they are pointing to public, but it's not my decision to make. |
|
Well, this was extremely educational, with at least one head-slapper. You really covered everything in your response and I also agree with your plan to split. Regarding the type-aliases in I think it's good to go! |
Fixes #144196 Part of #144219 Pull Request resolved: #154712 Approved by: https://github.com/Skylion007
Fixes pytorch#144196 Part of pytorch#144219 Pull Request resolved: pytorch#154712 Approved by: https://github.com/Skylion007
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Fixes #144196
Extends #144197 #144106 #144110
Open Problems /// LSP violations
mixture_same_family.py:cdfandlog_probviolate LSP (argument namedxinstead ofvalue).exp_family.py: LSP problem with_log_normalizer(parent class requires(*natural_params: Tensor) -> Tensor, subclasses implement(a: Tensor, b: Tensor) -> Tensor).(natural_params: Tuple[Tensor, ...]) -> Tensor. While this is BC breaking, (a) this is a private method, i.e. implementation detail, and (b) no one other than torch seems to overwrite itconstraints.py:dependent_property: mypy does not apply the same special casing to subclasses ofpropertyas it does topropertyitself, hence the need fortype: ignore[assignment]statements.relaxed_bernoulli.py,relaxed_categorical.py,logistic_normal.py,log_normal.py,kumaraswamy.py,half_cauchy.py,half_normal.py,inverse_gamma.py,gumbel.py,weibull.py.lazy_propertyindistributions/utils.constraints.pypublic interface not usable as type hints.TypeAlias-variants, but that is likely not the best solution.transforms.py:_InverseTransform.with_cacheviolates LSP.with_cacheto return_InverseTransform.test_distributions.py: One test usesDist.arg_constraints.get, hence assumesarg_constraintsis a class-attribute, but the base classDistributiondefines it as a@property.test_distributions.py: One test usesDist.support.event_dim, hence assumessupportis a class-attribute, but the base classDistributiondefines it as a@property.test_distributions.py: Multiple tests usedist.cdf(float), but the base class annotatescdf(Tensor) -> Tensor.test_distributions.py: Multiple tests usedist.log_prob(float), but the base class annotateslog_prob(Tensor) -> Tensor.Notes
__init__.py: use+=instead ofextends(ruff PYI056)binomial.py: Allowfloatarguments inprobsandlogits(gets used in tests)constraints.py: made_DependentPropertya generic class, and_DependentProperty.__call__polymorphic.constraint_registry.py: MadeConstraintRegistry.registera polymorphic method, checking that the factory is compatible with the constraint.constraint_registry.py: Needed to addtype: ignorecomments to functions that try to register multiple different constraints at once.dirichlet.py:@once_differentiableis untyped, requirestype: ignore[misc]comment.dirichlet.py:ctx: Anycould be replaced withctx: FunctionContext, however, the type lacks thesaved_tensorsattribute.distribution.py:Distribution._get_checked_instanceAccessing"__init__"on an instance is unsound, requirestype: ignorecomment.distribution.py: ChangedsupportfromOptional[Constraint]toConstraint(consistent with the existing docstring, and several functions in tests rely on this assumption)exp_family.py: small update toExponentialFamily.entropyto fix type error.independent.py: fixed type bug inIndependent.support.multivariate_normal.py: Addedtype: ignorecomments to_batch_mahalanobiscaused by1.relaxed_bernoulli.py: Allow float temperature argument (used in tests)relaxed_categorical.py: Allow float temperature argument (used in tests)transforms.py: Needed to changeComposeTransform.__init__signature to acceptSequence[Transform]rather than justlist[Transform](covariance!)transformed_distribution.py: Needed to changeTransformedDistribution.__init__signature to acceptSequence[Transform]rather than justlist[Transform](covariance!)transformed_distribution.py:TransformedDistribution.supportis problematic, because the parent class defines it as@propertybut several subclasses define it as an attribute, violating LSP.von_mises.py: fixedresulttype being initialized asfloatinstead ofTensor.von_mises.py:@torch.jit.script_if_tracingis untyped, requirestype: ignore[misc]comment.von_mises.py: Allow floatlocandscale(used in tests)cc @fritzo @neerajprad @alicanb @nikitaved @ezyang @malfet @xuzhao9 @gramster
Footnotes
torch.Sizeis not correctly typed, causingmypyto thinkSize + Sizeistuple[int, ...]instead ofSize, see https://github.com/pytorch/pytorch/issues/144218. ↩