Conversation
| optimizer.step() | ||
| return loss | ||
|
|
||
| self.compiled_step_fn = torch_xla.experimental.compile(self.step_fn) |
There was a problem hiding this comment.
re: naming, I don't think compiled_step_fn makes sense here. It makes it look like this function is imperatively compiled when you call compile, when what's really happening is that you will trace and JIT-compile the function when it's called later. jit and trace already mean something else in this ecosystem though...
There was a problem hiding this comment.
well that's the trade of I made when I try to align the api with the upstream, I think it creates more confusion to called it jited_step_fn. I will just rename it to step_fn
There was a problem hiding this comment.
Do you want to just use compile as a decorator here if you're following the upstream convention?
There was a problem hiding this comment.
seem like upstream support both, I can support both too https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html
No description provided.