Skip to content

Parser expand#1754

Merged
jjsjann123 merged 31 commits intodevelfrom
parser_expand
Jun 8, 2022
Merged

Parser expand#1754
jjsjann123 merged 31 commits intodevelfrom
parser_expand

Conversation

@jjsjann123
Copy link
Copy Markdown
Collaborator

@jjsjann123 jjsjann123 commented Jun 7, 2022

  1. Updated parser with expand/expand_as support. -> only fuse functional expand(_as) to avoid aliasing, handled the same way as with our other alias_copy.
  2. updated expand(_as) in arith.cpp to allow rank broadcast on inner dimensions. (matching pytorch API).
  3. python tests

@jjsjann123 jjsjann123 marked this pull request as ready for review June 7, 2022 09:21
Copy link
Copy Markdown
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving the C++ side.

Comment thread torch/csrc/jit/codegen/cuda/arith.cpp Outdated
Comment on lines +1068 to +1077
// broadcast inner on inp to match rank with other.
if (inp_domain.size() < other_domain.size()) {
const int num_bcast =
static_cast<int>(other_domain.size() - inp_domain.size());
std::vector<bool> inner_bcast_dims(other_domain.size(), false);
std::fill(
inner_bcast_dims.begin(), inner_bcast_dims.begin() + num_bcast, true);
inp = broadcast(inp, inner_bcast_dims);
inp_domain = TensorDomain::noReductions(inp->getMaybeRFactorDomain());
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please factor this part out so that it can be used from both expand and expand_as?

@rdspring1
Copy link
Copy Markdown
Collaborator

LGTM

Comment thread torch/csrc/jit/codegen/cuda/arith.cpp Outdated
Comment on lines +989 to +990
std::fill(
inner_bcast_dims.begin(), inner_bcast_dims.begin() + num_bcast, true);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This inserts broadcast axes to outside of inp_domain, i.e., left of the input domains. Is it what you intend to do?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is. See note in PR (could use better English 😛 ): updated expand(_as) in arith.cpp to allow rank broadcast on inner dimensions. (matching pytorch API).

In [1]: import torch

In [2]: x = torch.randn(5, 5)

In [3]: x.expand((2, 5, 5)).shape
Out[3]: torch.Size([2, 5, 5])

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can back this out into integration code if we think it's a crazy API tho.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any problem, so no need to back it out.

@jjsjann123 jjsjann123 merged commit b93a147 into devel Jun 8, 2022
@jjsjann123 jjsjann123 deleted the parser_expand branch June 8, 2022 07:22
shmsong pushed a commit to shmsong/pytorch that referenced this pull request Jul 24, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- TransformPropagator refactor: switched to Dijkstra instead of exhaustive enumeration on all possible paths to reduce compilation time on transform propagation;
- Indexing refactor: remove reference tensor creation in all tensor indexing logic (csarofeen#1690)
- (more) generic grouped grid reduction kernel;
- Minor parser/fuser patches:
  1. zero-dim tensor reduction support
  3. no-op binary removal within fused graph
  4. expand supported in fusion

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
a054b3e Refactor TransormPropagator to allow specifying a position and propagating to part of the DAG (csarofeen#1775)
d67e1cd Indexing refactor stage 1: remove reference tensor creation in all tensor indexing logic (csarofeen#1690)
1b65299 Issue 1770 (csarofeen#1774)
35b0427 Avoid compilation errors like below: (csarofeen#1773)
452c773 Ignore reductions of zero-dim tensors per PyTorch conventions (csarofeen#1771)
31d6c56 TransformPropagator refactor (csarofeen#1769)
570c5a8 Merge pull request csarofeen#1767 from csarofeen/upstream_merge_0621
9d6c3d8 merging upstream 61305cd
0ed815f New TransformPropagator algorithm (csarofeen#1763)
6c19520 no-op binary removal (csarofeen#1764)
ec7fa41 Proper propagation of IterType (csarofeen#1762)
b263562 Fix dimensionality check (csarofeen#1759)
2d6343f More generic grouped grid reduction kernel (csarofeen#1740)
64e2b56 [nvfuser] prevent spamming warning message (pytorch#77777) (csarofeen#1758)
0c43162 [nvFuser] Improving bitwise ops support (pytorch#77158) (csarofeen#1757)
b93a147 Parser expand (csarofeen#1754)
```

RUN_TORCHBENCH: nvfuser
Pull Request resolved: pytorch#80355
Approved by: https://github.com/davidberard98
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants