Skip to content

[Dynamo] module tests + operator support#148

Merged
yaoyaoding merged 1 commit intohidet-org:mainfrom
AndreSlavescu:main
Apr 26, 2023
Merged

[Dynamo] module tests + operator support#148
yaoyaoding merged 1 commit intohidet-org:mainfrom
AndreSlavescu:main

Conversation

@AndreSlavescu
Copy link
Copy Markdown
Contributor

@AndreSlavescu AndreSlavescu commented Mar 25, 2023

  • non-linear activation operators + tests
  • convolution operators + tests
  • conditional operator tests

Copy link
Copy Markdown
Member

@yaoyaoding yaoyaoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @AndreSlavescu,

I left some comments.

One thing that might need your help. Could you help me to validate that if the torch.compile will dispatch the model to the backend when there is only one operator in the module? Say, we directly optimize

torch.compile(torch.nn.Conv2d(...), backend='hidet')

by using this api: https://docs.hidet.org/stable/gallery/tutorials/optimize-pytorch-model.html#print-the-input-graph

If it is true, we might need another way to write the check_module function in our test script to make sure that we really tested the mapping.

Copy link
Copy Markdown
Contributor

@xinli-git xinli-git left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR.

A general comment on testing strategies is that you should look for code coverage. I think you have covered most of the things in your tests, but 1 exception is perhaps conv operators.

There are three notable convolution cases for as far as I have heard

  • common: 3x3 (or 5x5) kernel sizes, stride 1 or 2, and groups 1, the things you often see in resnet
  • grouped conv: similar to common cases but with groups > 1, the things you often see in resnext
  • depthwise seperable conv: a conv where # of groups == # of input channels, followed by a conv where kernel size is 1x1, something like this: https://github.com/seungjunlee96/Depthwise-Separable-Convolution_Pytorch

I think these would be a good stress test on the correctness for those index calculations. :)

@yaoyaoding
Copy link
Copy Markdown
Member

Hi @AndreSlavescu, let me know by replying this PR if your PR is ready to review. Thanks!

@AndreSlavescu
Copy link
Copy Markdown
Contributor Author

Hi @AndreSlavescu, let me know by replying this PR if your PR is ready to review. Thanks!

Hi @yaoyaoding , PR is ready for review.

Copy link
Copy Markdown
Member

@yaoyaoding yaoyaoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @AndreSlavescu, over all looks good. But there are still some minor issues.

@pytest.mark.parametrize('groups', [1])
@pytest.mark.parametrize('dtype', [torch.float32])
def test_conv1d_transpose(in_shape, w_shape, stride, padding, output_padding, groups, dtype):
check_module(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add

cudnn.allow_tf32 = False

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make it consistent, could you add the pairs of

cudnn.allow_tf32 = False 

and

cudnn.allow_tf32 = True 

to all conv/conv_transpose test?

@AndreSlavescu AndreSlavescu force-pushed the main branch 3 times, most recently from b7944b3 to cb2fc6b Compare April 25, 2023 19:42
@yaoyaoding yaoyaoding changed the base branch from main to test April 26, 2023 16:57
@yaoyaoding yaoyaoding changed the base branch from test to main April 26, 2023 16:57
@yaoyaoding
Copy link
Copy Markdown
Member

Thanks @AndreSlavescu! This PR looks good to me now.

@yaoyaoding yaoyaoding merged commit 0f8d3fa into hidet-org:main Apr 26, 2023
KTong821 pushed a commit to KTong821/hidet that referenced this pull request Apr 24, 2024
…idet-org#148)

**Overview** 
Specialize function `Constant._binary()` for compilation speedup

**Compilation time improvement results** 
matmul_f16 with `max_parallel_jobs=1`
Before: 2m 11.2s
After: 2m 4.4s
Speedup: 5.5%

**Additional test**
matmul_f16 has 177 candidates. I checked that all of them remained the same(no functional changes)
vadiklyutiy added a commit that referenced this pull request Jul 22, 2024
)

**Overview** 
Specialize function `Constant._binary()` for compilation speedup

**Compilation time improvement results** 
matmul_f16 with `max_parallel_jobs=1`
Before: 2m 11.2s
After: 2m 4.4s
Speedup: 5.5%

**Additional test**
matmul_f16 has 177 candidates. I checked that all of them remained the same(no functional changes)
vadiklyutiy added a commit that referenced this pull request Jul 23, 2024
)

**Overview** 
Specialize function `Constant._binary()` for compilation speedup

**Compilation time improvement results** 
matmul_f16 with `max_parallel_jobs=1`
Before: 2m 11.2s
After: 2m 4.4s
Speedup: 5.5%

**Additional test**
matmul_f16 has 177 candidates. I checked that all of them remained the same(no functional changes)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants