Skip to content

[Dynamic Shape] More correctness guards#276

Merged
yaoyaoding merged 23 commits intohidet-org:mainfrom
Aalanli:dyn-shape-assertion
Jun 16, 2023
Merged

[Dynamic Shape] More correctness guards#276
yaoyaoding merged 23 commits intohidet-org:mainfrom
Aalanli:dyn-shape-assertion

Conversation

@Aalanli
Copy link
Copy Markdown
Contributor

@Aalanli Aalanli commented Jun 9, 2023

Add support for dynamic shape assertions in the C++ runtime.
Further, add shape and dtype check for python runtime. With this being enabled by default.

@Aalanli
Copy link
Copy Markdown
Contributor Author

Aalanli commented Jun 9, 2023

Moreover, there is a subtle bug where the symbol registry gets rewritten multiple times if there exists inputs which promises the same shape, but runtime inputs have different shapes.
Eg. hidet.symbol(['a', 'a']), but during runtime: hidet.randn([1, 2])

So added a check for that as well.

@Aalanli
Copy link
Copy Markdown
Contributor Author

Aalanli commented Jun 12, 2023

I have no idea why the gpt2 test passed before, there was a similar error with the llama implementation, where the compiled model reinterpreted an int64 tensor as an int32 tensor, producing random outputs.

Copy link
Copy Markdown
Member

@yaoyaoding yaoyaoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Aalanli.

Comment on lines +154 to +177
for i, (traced, new) in enumerate(zip(self.meta_data.inputs, inputs)):
if ir.data_type(traced.dtype) != new.dtype:
raise RuntimeError(
f"dtype mismatch at arg {i} between original: {traced.dtype} and new: {new.dtype}"
)
traced_shape = traced.shape
concrete_shape = new.shape
if len(traced_shape) != len(concrete_shape):
raise RuntimeError(
f"Rank of input {i} not equal to original. ({len(concrete_shape)} vs. {len(traced_shape)})"
)
for j, (orig_shape, new_shape) in enumerate(zip(traced_shape, concrete_shape)):
if isinstance(orig_shape, int) and orig_shape != new_shape:
raise RuntimeError(
f'shape mismatch at dimension {j}, original: \
{orig_shape} vs. new: {new_shape}'
)
elif orig_shape not in symbol_map:
symbol_map[orig_shape] = new_shape
elif symbol_map[orig_shape] != new_shape:
raise RuntimeError(
f"There exists multiple instances of the same symbol {orig_shape}\
with different values in inputs (ex: {symbol_map[orig_shape]} and {new_shape})"
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same logic as the CompiledGraph. Consider implement as a utility function and put it at the same module of TensorSignature.

Aalanli and others added 7 commits June 14, 2023 11:33
Co-authored-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Co-authored-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Co-authored-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Co-authored-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
@yaoyaoding
Copy link
Copy Markdown
Member

Thanks @Aalanli !

@yaoyaoding yaoyaoding merged commit 5b490b6 into hidet-org:main Jun 16, 2023
@Aalanli Aalanli deleted the dyn-shape-assertion branch September 27, 2023 18:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants