[Frontend] Dynamic shape fx trace#294
Conversation
|
Essentially there was a bug with the norm op. It tries to achieve polymorphism (in f32 vs f16) with class overloading (the fp16 task subclassed the fp32 task), but this results in incorrect behaviour when combined with the automatic mixed precision pass, as the Op was originally in fp32, which gets reforwarded in the pass, but the implement_cuda schedule template still assumes that the input is in fp32. This results in an array of fp16 inputs reinterpreted as fp32; pointers in c++ silently cast. I think there are two ways to achieve type polymorphism in schedule templates right now.
|
|
Hi @Aalanli, The second method is our current design. We have a some base operator (matmul, conv2d) that supports arbitrary data types and use auto scheduler to schedule. These base operators will be resove to specialized ones with specialized template, and we should check the special condition in the task definition (like in this case, we should assert the input dtype is fp16). |
yaoyaoding
left a comment
There was a problem hiding this comment.
Thanks @Aalanli. Overall looks good to me!
@xinli-git could you also have a look at this PR (especially about the normalization part).
| # unfortunately, when dynamic=True in torch.compile, there may exist other non-tensor parameters | ||
| # in example inputs |
There was a problem hiding this comment.
For those dynamic shape, I am wondering if these scalar parameters are act as the shape of the input tensors. If that's the case, we can ignore those scalar parameters.
Say a torch model gives us
sample_inputs = [tensor(['m', 'n'], 'm', 'n']
We can declare the symbol variable for 'm' and 'n' (when we define the symbol tensor) and ignore the 'm' and 'n' scalar parameters.
| @register_function(operator.iadd) | ||
| def iadd(x: Tensor, y: Tensor): | ||
| return ops.add(x, y) | ||
| return x + y |
There was a problem hiding this comment.
So the x and y could be DynInt?
To be more specific, the hidet task and their schedule template should make sure: the schedule template strictly implements what the computation defines. We can take both ways you mentioned. For example, our |
xinli-git
left a comment
There was a problem hiding this comment.
Thanks for the changes in normalize. In principle, this is the right approach. I left two implementations initially so I could add vector load for the fp16 case in the future.
but now that there is the vector data type that Yaoyao has recently introduced, keeping op and op_fp16 in a single place is the right way to go, and I intend to do the same for reduce op
| check_module(model, [x], atol=1e-2, rtol=1e-2) | ||
| model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=True).cuda().eval() | ||
| x = torch.randn(*shape).cuda() | ||
| check_module(model, [x], atol=1e-2, rtol=1e-2, dynamic=dynamic) |
There was a problem hiding this comment.
Have we been using the CPU path before this change?
| x: Tensor = op.inputs[0] | ||
| if not is_contiguous_norm(dims, len(x.shape)): | ||
| return None | ||
| if x.dtype != dtypes.float16 or prod([x.shape[dd] for dd in dims]) % 2 != 0: |
There was a problem hiding this comment.
removing this is safe for now, but we might need to think about how to handle it when we decide to use 2xfp16 types and the norm size is odd.
| @@ -32,15 +29,6 @@ class NormalizeResolveRule(ResolveRule): | |||
| 2) resolve_generic: Default case, return the output of the regular f32 reduce schedule. | |||
There was a problem hiding this comment.
remove the resolve_fp16 comment above
|
@yaoyaoding the part about normalize is fine as long as the current CI can pass. thanks for the notification :) |
|
Reading this again I think the problem is that Is my understanding correct that we should not sub-class operators? We should either write them as seperate classes or have a generate Operator / Task that works for all input data types? Basically this line is causing the problem: https://github.com/hidet-org/hidet/blob/main/python/hidet/graph/operator.py#L166 ? |
That line does not have problem. The reforward will create the task again based on the new inputs and parameters. The problem is that the task did not check the data type. If the task only support one data type, it should explicitly assert that its input has that data type. It it accepts the inputs, then its implement function SHOULD support that. We can sub-class operator like the ElementwiseBinaryOp, UnaryElementwiseOp, etc. |
|
The key convention here is: keep the task computation definition and the implement function consistent. |
yaoyaoding
left a comment
There was a problem hiding this comment.
I am still not sure what the extra scalar parameters are, let's figure them out before merge this PR.
|
Thanks @Aalanli ! |
…. ) (#294) [Ir][Primitives] add vectorized conversion instructions [Ir][CuTe] add reduce primitives in cute (#295) [Ir][CuTe] add mma primitives (#296) [Ir][CuTe] add other primitives in cute (#297) [Transforms][CuTe] add instruction selection pass (#298) [Transforms][CuTe] add resolve bank conflict pass (#299) [Transforms][CuTe] add resolve auto keywords pass (#300) [Transforms][CuTe] add shared memory allocation pass (#301) [Transforms][CuTe] add vectorize elementwise operation pass (#302) [Transforms][CuTe] add analysis pass (#303) [Transforms][CuTe] add canonicalization pass (#304) [Transforms][CuTe] add deadcode elimination pass (#305) [Transforms][CuTe] refactor cute lowering pass (#306) [Graph][Ops] matmul cute (#307) [Ir] cute miscs (#308) [Tests] cute tests (#309) [Chore] fix ci (#313) --------- Co-authored-by: xiaocenxiaocen <xiao.zhang@centml.ai>
…. ) (#294) [Ir][Primitives] add vectorized conversion instructions [Ir][CuTe] add reduce primitives in cute (#295) [Ir][CuTe] add mma primitives (#296) [Ir][CuTe] add other primitives in cute (#297) [Transforms][CuTe] add instruction selection pass (#298) [Transforms][CuTe] add resolve bank conflict pass (#299) [Transforms][CuTe] add resolve auto keywords pass (#300) [Transforms][CuTe] add shared memory allocation pass (#301) [Transforms][CuTe] add vectorize elementwise operation pass (#302) [Transforms][CuTe] add analysis pass (#303) [Transforms][CuTe] add canonicalization pass (#304) [Transforms][CuTe] add deadcode elimination pass (#305) [Transforms][CuTe] refactor cute lowering pass (#306) [Graph][Ops] matmul cute (#307) [Ir] cute miscs (#308) [Tests] cute tests (#309) [Chore] fix ci (#313) --------- Co-authored-by: xiaocenxiaocen <xiao.zhang@centml.ai>
…. ) (#294) [Ir][Primitives] add vectorized conversion instructions [Ir][CuTe] add reduce primitives in cute (#295) [Ir][CuTe] add mma primitives (#296) [Ir][CuTe] add other primitives in cute (#297) [Transforms][CuTe] add instruction selection pass (#298) [Transforms][CuTe] add resolve bank conflict pass (#299) [Transforms][CuTe] add resolve auto keywords pass (#300) [Transforms][CuTe] add shared memory allocation pass (#301) [Transforms][CuTe] add vectorize elementwise operation pass (#302) [Transforms][CuTe] add analysis pass (#303) [Transforms][CuTe] add canonicalization pass (#304) [Transforms][CuTe] add deadcode elimination pass (#305) [Transforms][CuTe] refactor cute lowering pass (#306) [Graph][Ops] matmul cute (#307) [Ir] cute miscs (#308) [Tests] cute tests (#309) [Chore] fix ci (#313) --------- Co-authored-by: xiaocenxiaocen <xiao.zhang@centml.ai>
enable the option torch.compile(..., dynamic=True)