Skip to content

Resolve changes in tensor/variable API with PyTorch master#824

Merged
fritzo merged 5 commits intopyro-ppl:devfrom
neerajprad:pytorch-master
Feb 27, 2018
Merged

Resolve changes in tensor/variable API with PyTorch master#824
fritzo merged 5 commits intopyro-ppl:devfrom
neerajprad:pytorch-master

Conversation

@neerajprad
Copy link
Copy Markdown
Member

@neerajprad neerajprad commented Feb 27, 2018

This resolves a few minor issues so as to keep in sync with PyTorch master.

  • tensor[0] returns a Variable scalar instead of a python numeric type. We need to use .item() to convert a PyTorch scalar into a python numeric type.
  • .data returns an instance of type Variable and not torch.Tensor. Instead of using .data.cpu().numpy(), we are now using .detach().cpu().numpy().

NOTE: Tests will fail until we update the PyTorch wheels for travis. The core changes are really small, so would suggest reviewing only once the tests pass.

Only partially resolves #815. We will need more renaming changes later.

@neerajprad
Copy link
Copy Markdown
Member Author

PyTorch wheels are updated; running tests against the updated wheel now

Copy link
Copy Markdown
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this, Neeraj!

My comments are honest questions, and I'm fine if they result in no changes, only edifying answers 🙂 .

Comment thread examples/air/viz.py
# misleading, as it incorrectly suggests objects occlude one
# another.
clipped = np.clip(imgarr.data.cpu().numpy(), 0, 1)
clipped = np.clip(imgarr.detach().cpu().numpy(), 0, 1)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it safe to replace .detach().cpu().numpy() with .cpu().numpy() everywhere?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I started with. If the node has requires_grad=True, it will throw an exception when we convert it to numpy without detaching.

logits = Variable(logits)
ix = dist.Categorical(logits=logits).sample()
return traces[ix.data[0]]
return traces[ix]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not ix.item() here? (This may not be perfectly tested)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because a scalar is a valid index value. We can change it ix.item() but it's not necessary.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! It will be so much easier to write mixture models now.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely; scalars will simplify a number of clunky looking stuff.

Comment thread tests/distributions/test_rejector.py Outdated
(cost + cost.detach() * dist.score_parts(z)[1]).backward()
mean_alpha_grad = alphas.grad.data.mean()
mean_beta_grad = betas.grad.data.mean()
mean_alpha_grad = alphas.grad.data.mean().item()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not simply alphas.grad.mean().item()?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, in this case we can just use that. Will change here and other patterns that I see.

@fritzo
Copy link
Copy Markdown
Member

fritzo commented Feb 27, 2018

@neerajprad has #780 been merged into this branch? I believe that introduces some new helpers like MultiViewTensor that will need to be updated.

@fritzo
Copy link
Copy Markdown
Member

fritzo commented Feb 27, 2018

Update README.md to recommend installing PyTorch commit 05269b5?

@neerajprad
Copy link
Copy Markdown
Member Author

Update README.md to recommend installing PyTorch commit 05269b5?

Eeks..I always forget that. Thanks for the reminder!

@neerajprad
Copy link
Copy Markdown
Member Author

@neerajprad has #780 been merged into this branch? I believe that introduces some new helpers like MultiViewTensor that will need to be updated.

Thanks for the heads up. Will take a look and update.

@neerajprad
Copy link
Copy Markdown
Member Author

Ready to merge, unless there are any further pending comments.

@fritzo fritzo merged commit a72f9ac into pyro-ppl:dev Feb 27, 2018
@fehiepsi
Copy link
Copy Markdown
Member

fehiepsi commented Feb 27, 2018

@neerajprad I would like to know about the changes on this. As a summary, we will:

  • Use .item() to get Python value of a scalar.
  • Use .tolist() to convert a tensor to Python list.
  • Use .data to make requires_grad=False. Its difference to .detach() is that: it still uses the same memory as the host tensor.
  • Use .detach() to separate a tensor from the graph and returns a non-grad tensor. It works like a combination of .data and .clone (.clone alone on a has-grad tensor still requires grad).
  • Use .numpy() on a non-grad non-cuda tensor to convert it to numpy array. They use the same memory.
  • No need to use pyro's utils ng_zeros, ng_ones. Use torch.ones, torch.zeros instead.
  • Use torch.tensor(..., dtype=..., requires_grad=...) to create a tensor with the corresponding type, grad. The previous ways torch.FloatTensor(...), torch.cuda.DoubleTensor(...) also work.
  • b = torch.tensor(...).type_as(a) is the same as b = torch.tensor(..., dtype=a.dtype).
  • .sum() always return a scalar tensor.

If the points 3 and 4 are correct, then we should not change a.data.cpu().numpy() to a.detach().cpu().numpy(): no need to make a clone. One of the point to use .detach() is: we don't want numpy array use the same memory as a cpu tensor. However, in many cases of this pull request, we don't have to worry about the memory issue.

@neerajprad
Copy link
Copy Markdown
Member Author

Thanks for the summary, @fehiepsi.

Use .tolist() to convert a tensor to Python list.

Interesting, didn't know about this.

Use .detach() to separate a tensor from the graph and returns a non-grad tensor. It works like a combination of .data and .clone (.clone alone on a has-grad tensor still requires grad).

Is this new behavior, or planned for the future? .detach() used to share the same underlying data, and still does on the PyTorch branch that I am using. This is also what PyTorch recommends in the error that it throws (RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.)

Use torch.tensor(..., dtype=..., requires_grad=...) to create a tensor with the corresponding type, grad. The previous ways torch.FloatTensor(...), torch.cuda.DoubleTensor(...) also work.

Still waiting on these changes in PyTorch master. We will use torch.tensor everywhere once the changes are merged.

@fehiepsi
Copy link
Copy Markdown
Member

@neerajprad That it my mistake. :( I will ask on PyTorch's slack what is the difference between x.data and x.detach() and get back to you. They seem identical now. For torch.tensor, it is already committed to Pytorch master. :)

@fehiepsi
Copy link
Copy Markdown
Member

fehiepsi commented Feb 27, 2018

@neerajprad As answered by Adam, they are very similar except that x.detach() has some additional checks for in-place operators (I am not clear how these checks work on a detached Variable though). So in my opinion, they can be used interchangeably. :) Btw, with torch.tensor, I think that the interface of pytorch 0.4 is ready to use now (bug fixes and enhancement are still on-going). I expect there will not be many changes affect Pyro (except things related to Distribution, or bugs).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support PyTorch after Variable+Tensor were merged

3 participants