Skip to content

Implement StudentT ops#2233

Open
fehiepsi wants to merge 14 commits intopyro-ppl:devfrom
fehiepsi:ops-student
Open

Implement StudentT ops#2233
fehiepsi wants to merge 14 commits intopyro-ppl:devfrom
fehiepsi:ops-student

Conversation

@fehiepsi
Copy link
Copy Markdown
Member

@fehiepsi fehiepsi commented Dec 23, 2019

The implementation will assume that df > 2 throughout.

Right now, there is nothing implemented yet. I'll add more details below. The math is mostly corresponding to GammaGaussian op.

Reference

Tasks

  • Implement StudentT op, except the sum method.
  • Consider to implement the "approximated" sum method <- I'm not sure if this is necessary. Edit: this is necessary, and is the most tricky part.
  • Implement the "approximated" tensordot
  • Add test to verify that this match the sequential approximation mechanism in the above reference
  • Add tests similar to Gaussian/GammaGaussian ops.

@fehiepsi fehiepsi added the WIP label Dec 23, 2019
@fritzo
Copy link
Copy Markdown
Member

fritzo commented Dec 23, 2019

What is required to implement df < 2? In particular many real-world time series exhibit df around 1.7-1.8.

@fehiepsi
Copy link
Copy Markdown
Member Author

fehiepsi commented Dec 23, 2019

@fritzo That condition is only used for moment-matching, where we need to match covariance of two student-ts. For example, given a, b, we want to find c such that St(b, 0, c * Id) ~ St(a, 0, Id). I believe that we can relax that condition as long as we have a better approximation.

In https://arxiv.org/abs/1703.02428 section 4.3 and https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8822764 , the authors use KL for such approximation, hence df > 2 is not required. However, using KL requires numerical methods to minimizing KL(St(b, 0, c * Id), St(a, 0, Id)), which is non-trivial. :(

@fehiepsi
Copy link
Copy Markdown
Member Author

fehiepsi commented Dec 23, 2019

I think that we can leverage tail behaviors of St(b, 0, c * Id) and St(a, 0, Id) to find c such that both distributions have the same 95% confidence interval (or maybe have the same icdf(0.95))... instead of using moment matching or kl as above. But I haven't come up with a formula yet.

@fritzo
Copy link
Copy Markdown
Member

fritzo commented Dec 23, 2019

When df<2 the df-power moments still exist, so we could try to match those at least for df >= 1

mu = E[x]
sigma = (E|x-mu|^df)^(1/df)

@fehiepsi
Copy link
Copy Markdown
Member Author

Good idea!! I'll try to find the formula and match one of the moments < df.

@fehiepsi
Copy link
Copy Markdown
Member Author

This new paper https://arxiv.org/abs/1912.01607 computes general moments for Student's t-distribution. There are many choices for moment order, especially for multidimensional cases. I'll try to pick up a simple one from them.

identity = torch.eye(self.loc.size(-1), device=self.loc.device, dtype=self.loc.dtype)
scale_inv = identity.triangular_solve(self.scale_tril, upper=False).solution.transpose(-1, -2)
return torch.matmul(scale_inv.transpose(-1, -2), scale_inv)
return torch.cholesky_solve(identity, self.scale_tril)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is much more stable, which helps many tests.

Copy link
Copy Markdown
Member Author

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fritzo The PR is almost ready. All methods are tested but I still need to add some more tests to verify one-step-ahead prediction and measurement update agree with intermediate approximations in Section 4.1 of https://arxiv.org/abs/1703.02428.

In MVT, the log_prob has the term log((x - m)P(x - m) + ...). I have tried various parameterization with df, info_vec/loc, precision but none of them are able to represent a StudentT with a rank-deficient precision matrix. The closest one I come up with is (df, info_vec, precision) but log_density and condition require Cholesky of P, which is not available with rank-deficient precision. It turns out that the parameterization (alpha, beta, info_vec, precision) in GammaGaussian can solve the issue. With that parameterization, all methods except .__add__ are pretty straightforward and inherited well from GammaGaussian, so I think you can quickly go over them (modulo design recommendations).

The most difficult one is .__add__, which involves many tricky computations. There are 4 conditions to perform add op:

  • 2 variables are uncorrelated. For Gaussian, uncorrelation implies independent, but this is not true for StudentT. We will do approximation here:
    St(x; df, m, P)St(y; df, n, Q) = St([x, y]; df, [m, n], [[P, 0], [0, Q]])
    To keep track of uncorrelated relationship, I introduced a "mask" property, which keep track of non-degenerate variables. It turns out that this is useful for moment matching too, where we only need to match the moment of non-degenerate terms!
  • 2 variables have the same degree of freedom. If this condition does not happen, we will choose a common dof (I follow the reference to choose dof to be the min dof of two distributions - to preserve the tail) and leverage moment matching to approximate a StudentT with a new one having new dof. The math for moment matching is complicated, so I tried to sketch the idea in the comment section of _moment_matching function.
  • The degree of freedom is greater than (absolute central) moment matching order. By default, I use order=1, so all dofs must be greater than 1.
  • The mixing priors Gamma(concentration, rate) should be the same. By choosing the common df, we already have the same concentration. So we must make them have the same rate. Changing the rate requires us to change the scale of info_vec and precision, which will introduce additional non-trivial codes. So in moment matching, I also force the output to have concentration = rate.

I know that the implementation is tricky to able to appreciate and might include some bugs due to my wrong understanding. Please let me know if you find any point which is unclear so we can discuss more on them.

Comment thread pyro/ops/studentt.py Outdated
Comment thread pyro/ops/studentt.py
Comment thread pyro/ops/studentt.py
@fehiepsi
Copy link
Copy Markdown
Member Author

fehiepsi commented Jan 3, 2020

@fritzo Unfortunately, the implementation is only useful for sequential filtering. :( The reason is the .condition op does not asscociate with .add, that is x + xy.condition(obs) != (x.event_pad(right=...) + xy).condition(y) (on the left hand side, condition will increase dof then add will decrease it to preserve tail; on the right hand side, we will always have dof increased). A StudentTHMM can be implemented using StudentT op but parallel scan will be invalid. Do you think that this op is useful if we can't do parallel scan?


@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str)
@pytest.mark.parametrize("dim", [1, 2, 3])
def test_studentt_multiple_representation(batch_shape, dim):
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fritzo this test verifies multiple representations of a StudentT

@fritzo
Copy link
Copy Markdown
Member

fritzo commented Jan 4, 2020

Do you think that this op is useful if we can't do parallel scan?

Well the main purpose of Gaussian and GammaGaussian was to implement parallel-scan algorithms, so maybe StudentT ops are not worth implementing.

What do you think about an alternative approach to StudentTHMM inference:
Following StableHMM, we could implement a StudentTHMM with .rsample() but no .log_prob() method, and implement auxiliary variable reparameterizers StudentTReparam and StudentTHMMReparam using scale-mixtures of Normal and GaussiannHMM, respectively. During inference we would learn/sample (completely in parallel) the auxiliary Gamma-distributed "outlyingness" variables for each time step and observation step; then conditioned on "outlyingness" we could parallel-scan evaluate GaussianHMM.log_prob() or filter. Inference would then be exact using SVI, HMC, NUTS, etc. and would work for both latent and observed HMM observations. I think we would want to implement the MultivariateStudentT versions separately.

@fehiepsi
Copy link
Copy Markdown
Member Author

fehiepsi commented Jan 5, 2020

re not worth implementing: Agreed! I'll move some cleanups into a separate PR and leave the implementation here in case it is useful for benchmarking in the future.

re alternative approach to StudentTHMM inference: It is reasonable and should be easier to have than StableHMM. I think that it is also more useful than GammaGaussian, whose posteriors of the last state is more Gaussian-like for long time series with high dimensional outputs. I'll follow this approach instead.

@fritzo
Copy link
Copy Markdown
Member

fritzo commented Jan 5, 2020

I'll follow this [StudentTHMM] approach instead.

Great! I think the code should be straight forward. I see the biggest open research question being:
"What kind of amortized guides are best for estimating the latent scale parameters"? I think we can share guide patterns across StableHMM and StudentTHMM.

@fehiepsi
Copy link
Copy Markdown
Member Author

fehiepsi commented Jan 5, 2020

I get your point. I'll think more about that question...

@matthewfieger
Copy link
Copy Markdown

The reference for this pull request is interesting. So is there a way to fit a t smoother with Pyro? I am coming from the PyFlux library where, for example, I can fit a dynamic linear model with student’s t noise. PyFlux is no longer being actively developed. I would like to switch to Pyro to fit these types of models if possible.

@fritzo
Copy link
Copy Markdown
Member

fritzo commented Jan 13, 2020

Hi @matthewfieger , one thing we are working on is a StudentTHMM together with an auxiliary variable reparameterizer StudentTReparam similar to StableHMMReparam. This would allow StudentT processes to be represented as conditional Gaussian processes, permitting smoothing via SVI or HMC. This approach would be more accurate and more expensive than the moment-matching approach proposed in this PR.

EDIT here is the alternative line of work: #2254

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants