[halide-backend] Dimension-based indexing#129026
[halide-backend] Dimension-based indexing#129026jansel wants to merge 16 commits intogh/jansel/354/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/129026
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ddfa273 with merge base bc8883a ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torch/_inductor/codegen/halide.py
Outdated
|
|
||
| def __init__(self, expr, size, stride): | ||
| super().__init__() | ||
| if V.graph.sizevars.statically_known_leq(stride, 0): |
There was a problem hiding this comment.
hmm, when do we get a negative stride?
| ) | ||
| eq = V.graph.sizevars.statically_known_equals | ||
| lt = V.graph.sizevars.statically_known_lt | ||
| size_hint = functools.partial(V.graph.sizevars.size_hint, fallback=inf) |
There was a problem hiding this comment.
Should we use an integer rather than a float for the fallback value?
There was a problem hiding this comment.
I want it to go last and there is no such thing as a max int in python
There was a problem hiding this comment.
For our purposes, int64 max would work right ?
There was a problem hiding this comment.
Yeah I suppose, seems like a style preference.
| try: | ||
| code.writeline( | ||
| f"{arg.name}.dim({i}).set_stride({int(dim.stride)})" | ||
| ) | ||
| except TypeError: | ||
| pass # not integer | ||
| try: | ||
| code.writeline( | ||
| f"{arg.name}.dim({i}).set_extent({int(dim.size)})" | ||
| ) | ||
| except TypeError: | ||
| pass # not integer |
There was a problem hiding this comment.
Could query is_integer to avoid the try/except
There was a problem hiding this comment.
It might be a regular int (not sympy)
| ) | ||
| eq = V.graph.sizevars.statically_known_equals | ||
| lt = V.graph.sizevars.statically_known_lt | ||
| size_hint = functools.partial(V.graph.sizevars.size_hint, fallback=inf) |
There was a problem hiding this comment.
For our purposes, int64 max would work right ?
| line = f"{var}[{index_str},]" # trailing comma workaround for https://github.com/halide/Halide/issues/8299 | ||
| dtype = V.graph.get_dtype(name) | ||
| if dtype in (torch.float16, torch.bfloat16): | ||
| dtype = torch.float32 |
There was a problem hiding this comment.
nit: factor out to dtype_to_compute_dtype similar to triton codegen ?
| all_used_symbols.update(super().prepare_indexing(index).free_symbols) | ||
|
|
||
| had_fallback = False | ||
| for tree in reversed(self.range_trees): |
There was a problem hiding this comment.
nit: maybe factor out to helper function
Pull Request resolved: #127506 Approved by: https://github.com/shunting314, https://github.com/eellison ghstack dependencies: #126417, #129025, #129026
Requires halide/Halide#8255 Pull Request resolved: #129036 Approved by: https://github.com/shunting314, https://github.com/eellison ghstack dependencies: #126417, #129025, #129026, #127506
In theory Halide doesn't need the split reduction stuff we do for Triton since it can generate multiple kernels. Pull Request resolved: #129320 Approved by: https://github.com/shunting314, https://github.com/eellison ghstack dependencies: #126417, #129025, #129026, #127506, #129036
Stack from ghstack (oldest at bottom):
Prior to this the generated Halide code was a rather literal translation of the Triton code, with XBLOCK/YBLOCK/RBLOCK and 1D inputs. Halide prefers dimensions, and this 1D index triggers a lot of bugs and perf issues. This PR infers dimensions and changes the indexing in the generated code.
Before
After
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang