Adding function to convert Module to channels last#28991
Adding function to convert Module to channels last#28991VitalyFedyunin wants to merge 24 commits intogh/VitalyFedyunin/24/basefrom
Conversation
[ghstack-poisoned]
…nnels last" [ghstack-poisoned]
[ghstack-poisoned]
…ls last" [ghstack-poisoned]
…o channels last" [ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
…last" [ghstack-poisoned]
Differential Revision: [D18430810](https://our.internmc.facebook.com/intern/diff/D18430810) [ghstack-poisoned]
… channels last" Differential Revision: [D18430810](https://our.internmc.facebook.com/intern/diff/D18430810) [ghstack-poisoned]
Differential Revision: [D18430810](https://our.internmc.facebook.com/intern/diff/D18430810) [ghstack-poisoned]
Differential Revision: [D18430810](https://our.internmc.facebook.com/intern/diff/D18430810) [ghstack-poisoned]
Differential Revision: [D18430810](https://our.internmc.facebook.com/intern/diff/D18430810) [ghstack-poisoned]
… channels last" Differential Revision: [D18430810](https://our.internmc.facebook.com/intern/diff/D18430810) [ghstack-poisoned]
Differential Revision: [D18430810](https://our.internmc.facebook.com/intern/diff/D18430810) [ghstack-poisoned]
…odule to channels last" Differential Revision: [D18430810](https://our.internmc.facebook.com/intern/diff/D18430810) [ghstack-poisoned]
…s last" Differential Revision: [D18430810](https://our.internmc.facebook.com/intern/diff/D18430810) [ghstack-poisoned]
Differential Revision: [D18430810](https://our.internmc.facebook.com/intern/diff/D18430810) [ghstack-poisoned]
CircleCI build failures summaryAs of commit ea3ed3d:
Detailed failure analysisOne may explore the probable reasons each build failed interactively on the Dr. CI website. 3 upstream failures recognized by patterns:These builds matched patterns, but were probably caused by upstream breakages:
This comment was automatically generated by Dr. CI. Please report bugs/suggestions on the GitHub issue tracker. This comment has been revised 5 times. |
|
|
||
| def convert(t): | ||
| if convert_to_format is not None and t.dim() == 4: | ||
| return t.to(device, dtype if t.is_floating_point() else None, non_blocking, memory_format=convert_to_format) |
There was a problem hiding this comment.
this doesn't match the documentation (which says the only case for to with memory-format is the 1-arg case).
There was a problem hiding this comment.
Pardon, but I read 'This can be called as' as: here examples of calls, but they are not limited to this options.
There was a problem hiding this comment.
that's not my reading of it, although I can see why you read it that way.
In particular, before memory_format was introduced, it corresponded exactly to the function signatures:
pytorch/torch/csrc/autograd/utils/python_arg_parsing.h
Lines 15 to 17 in 66f2bba
(remove copy because it's not supported and remove memory_format because we are considering the case before memory_format was introduced).
So, the introduction of memory_format changed this from "these are the supported signatures" to "these are some examples of supported signatures". I think the former is more useful and we should change it back.
There was a problem hiding this comment.
And actually, the example you added (.. function:: to(memory_format=torch.channels_last)) doesn't work because the parsing code hasn't been updated, right?
There was a problem hiding this comment.
IMO we should do the following:
- Add a memory_format only-overload to the python parsing.
- List the valid calls, i.e.:
.. function:: to(device=None, dtype=None, non_blocking=False, memory_format=None)
.. function:: to(dtype, non_blocking=False, memory_format=None)
.. function:: to(tensor, non_blocking=False, memory_format=None)
.. function:: to(memory_format)
(or similar).
There was a problem hiding this comment.
I had parsing updated https://github.com/pytorch/pytorch/pull/28991/files#diff-
7a05cfd8eb442889dffd6c3d2e4d0ddcR24
Will update inline and html docs as follow-up PR
| the floating point parameters and buffers in this module | ||
| tensor (torch.Tensor): Tensor whose dtype and device are the desired | ||
| dtype and device for all parameters and buffers in this module | ||
| memory_format (:class:`torch.memory_format`): the desired memory |
There was a problem hiding this comment.
are you planning to have a reference section for memory_format that you can point to? This description isn't full enough for the long term. (e.g. why only 4D parameters/buffers? -- it's not clear)
There was a problem hiding this comment.
Yes as soon as we land new defaults for .clone .to *_like ops I will work on updating docs.
|
@VitalyFedyunin merged this pull request in 66f2bba. |
ghstack-source-id: f650f3b Pull Request resolved: pytorch/pytorch#28991
Summary: Pull Request resolved: pytorch#28991 Test Plan: Imported from OSS Differential Revision: D18430810 Pulled By: VitalyFedyunin fbshipit-source-id: 0693d4e31fc6f9831722c29fc83517f16ddfc028
|
This PR adds an API I think this is making an assumption that "any 4D parameters in the module needs a conversion to NHWC layout, if user wants to use nvidia's NHWC kernels". From what I can see this assumption is limiting in many ways:
It may be better if such an API:
|
|
@ppwwyyxx , I totally agree. |
Stack from ghstack:
zeros,ones,full#31131 Add memory_format support tozeros,ones,full(still need tests)Differential Revision: D18430810