Skip to content

support dynamism on add, mul#6443

Merged
lsy323 merged 11 commits intomasterfrom
lsiyuan/aten-dynamism-test
Feb 9, 2024
Merged

support dynamism on add, mul#6443
lsy323 merged 11 commits intomasterfrom
lsiyuan/aten-dynamism-test

Conversation

@lsy323
Copy link
Copy Markdown
Collaborator

@lsy323 lsy323 commented Feb 1, 2024

  • Add unbounded dynamism test for some aten ops, those ops are used in ViT model. Let's add more as we work on other models.

Unsupported ops

add
conv
gelu
native_layer_norm
select
slice
softmax
  • Also add support of dynamism for add.

Before the change, the unbounded dynamism cannot propagate because the constant scalar is broadcasted to the same shape as the input tensor. Then during implicit broadcasting, we have ? + concrete_dim => concrete_dim

  • Add some missing tests to CI scripts

cc @sdasgup3

@lsy323 lsy323 requested a review from qihqi February 1, 2024 06:26
@lsy323 lsy323 changed the title add unbounded dynamism test for some aten ops add unbounded dynamism test for some aten ops, support add Feb 6, 2024
@lsy323 lsy323 changed the title add unbounded dynamism test for some aten ops, support add add unbounded dynamism test for some aten ops, support dynamism on add Feb 6, 2024
Siyuan Liu added 6 commits February 6, 2024 21:42
(cherry picked from commit f55abc88ae361e89da675a1aa1e4a19e7a5c762a)
(cherry picked from commit 30abe2be43defc25db8954c525d34f7f3de35292)
@lsy323 lsy323 force-pushed the lsiyuan/aten-dynamism-test branch from 3e4db72 to 92a6e00 Compare February 6, 2024 21:42
Siyuan Liu added 4 commits February 6, 2024 21:43
(cherry picked from commit 8526b2091ffafccf6972ecba3c111d1b0869621e)
@lsy323
Copy link
Copy Markdown
Collaborator Author

lsy323 commented Feb 7, 2024

HLO changed for spmd basic sharding test test_mark_sharding_ir. The graph becomes more concise after the lowering for add/mul is updated. Semantic and sharding annotation remains the same. Only the HLO op name changed.
cc @yeounoh @alanwaketan

From

ENTRY %IrToHlo.17 (p0.9: f32[1,128], p1.11: f32[1,128]) -> (f32[1,128]) {
  %p1.11 = f32[1,128]{1,0} parameter(1), sharding={replicated}
  %p0.9 = f32[1,128]{1,0} parameter(0), sharding={replicated}
  %constant.4 = f32[] constant(1)
  %reshape.5 = f32[1,1]{1,0} reshape(f32[] %constant.4)
  %broadcast.6 = f32[1,1]{1,0} broadcast(f32[1,1]{1,0} %reshape.5), dimensions={0,1}
  %reshape.7 = f32[1]{0} reshape(f32[1,1]{1,0} %broadcast.6)
  %broadcast.8 = f32[1,128]{1,0} broadcast(f32[1]{0} %reshape.7), dimensions={0}
  %multiply.10 = f32[1,128]{1,0} multiply(f32[1,128]{1,0} %p0.9, f32[1,128]{1,0} %broadcast.8)
  %add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %p1.11, f32[1,128]{1,0} %multiply.10)
  %custom-call.13 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.12), custom_call_target="Sharding", sharding={replicated}
  %constant.2 = f32[] constant(0)
  %constant.1 = f32[] constant(1)
  %multiply.3 = f32[] multiply(f32[] %constant.2, f32[] %constant.1)
  %broadcast.14 = f32[1,128]{1,0} broadcast(f32[] %multiply.3), dimensions={}
  %add.15 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.13, f32[1,128]{1,0} %broadcast.14)
  ROOT %tuple.16 = (f32[1,128]{1,0}) tuple(f32[1,128]{1,0} %add.15)
}

To

ENTRY %IrToHlo.14 (p0.5: f32[1,128], p1.8: f32[1,128]) -> (f32[1,128]) {
  %p1.8 = f32[1,128]{1,0} parameter(1), sharding={replicated}
  %p0.5 = f32[1,128]{1,0} parameter(0), sharding={replicated}
  %constant.4 = f32[] constant(1)
  %broadcast.6 = f32[1,128]{1,0} broadcast(f32[] %constant.4), dimensions={}
  %multiply.7 = f32[1,128]{1,0} multiply(f32[1,128]{1,0} %p0.5, f32[1,128]{1,0} %broadcast.6)
  %add.9 = f32[1,128]{1,0} add(f32[1,128]{1,0} %p1.8, f32[1,128]{1,0} %multiply.7)
  %custom-call.10 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.9), custom_call_target="Sharding", sharding={replicated}
  %constant.2 = f32[] constant(0)
  %constant.1 = f32[] constant(1)
  %multiply.3 = f32[] multiply(f32[] %constant.2, f32[] %constant.1)
  %broadcast.11 = f32[1,128]{1,0} broadcast(f32[] %multiply.3), dimensions={}
  %add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.10, f32[1,128]{1,0} %broadcast.11)
  ROOT %tuple.13 = (f32[1,128]{1,0}) tuple(f32[1,128]{1,0} %add.12)
}

@lsy323 lsy323 merged commit 8d91ff5 into master Feb 9, 2024
@alanwaketan
Copy link
Copy Markdown
Collaborator

HLO changed for spmd basic sharding test test_mark_sharding_ir. The graph becomes more concise after the lowering for add/mul is updated. Semantic and sharding annotation remains the same. Only the HLO op name changed. cc @yeounoh @alanwaketan

From

ENTRY %IrToHlo.17 (p0.9: f32[1,128], p1.11: f32[1,128]) -> (f32[1,128]) {
  %p1.11 = f32[1,128]{1,0} parameter(1), sharding={replicated}
  %p0.9 = f32[1,128]{1,0} parameter(0), sharding={replicated}
  %constant.4 = f32[] constant(1)
  %reshape.5 = f32[1,1]{1,0} reshape(f32[] %constant.4)
  %broadcast.6 = f32[1,1]{1,0} broadcast(f32[1,1]{1,0} %reshape.5), dimensions={0,1}
  %reshape.7 = f32[1]{0} reshape(f32[1,1]{1,0} %broadcast.6)
  %broadcast.8 = f32[1,128]{1,0} broadcast(f32[1]{0} %reshape.7), dimensions={0}
  %multiply.10 = f32[1,128]{1,0} multiply(f32[1,128]{1,0} %p0.9, f32[1,128]{1,0} %broadcast.8)
  %add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %p1.11, f32[1,128]{1,0} %multiply.10)
  %custom-call.13 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.12), custom_call_target="Sharding", sharding={replicated}
  %constant.2 = f32[] constant(0)
  %constant.1 = f32[] constant(1)
  %multiply.3 = f32[] multiply(f32[] %constant.2, f32[] %constant.1)
  %broadcast.14 = f32[1,128]{1,0} broadcast(f32[] %multiply.3), dimensions={}
  %add.15 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.13, f32[1,128]{1,0} %broadcast.14)
  ROOT %tuple.16 = (f32[1,128]{1,0}) tuple(f32[1,128]{1,0} %add.15)
}

To

ENTRY %IrToHlo.14 (p0.5: f32[1,128], p1.8: f32[1,128]) -> (f32[1,128]) {
  %p1.8 = f32[1,128]{1,0} parameter(1), sharding={replicated}
  %p0.5 = f32[1,128]{1,0} parameter(0), sharding={replicated}
  %constant.4 = f32[] constant(1)
  %broadcast.6 = f32[1,128]{1,0} broadcast(f32[] %constant.4), dimensions={}
  %multiply.7 = f32[1,128]{1,0} multiply(f32[1,128]{1,0} %p0.5, f32[1,128]{1,0} %broadcast.6)
  %add.9 = f32[1,128]{1,0} add(f32[1,128]{1,0} %p1.8, f32[1,128]{1,0} %multiply.7)
  %custom-call.10 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.9), custom_call_target="Sharding", sharding={replicated}
  %constant.2 = f32[] constant(0)
  %constant.1 = f32[] constant(1)
  %multiply.3 = f32[] multiply(f32[] %constant.2, f32[] %constant.1)
  %broadcast.11 = f32[1,128]{1,0} broadcast(f32[] %multiply.3), dimensions={}
  %add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.10, f32[1,128]{1,0} %broadcast.11)
  ROOT %tuple.13 = (f32[1,128]{1,0}) tuple(f32[1,128]{1,0} %add.12)
}

Thanks for the heads up, and it LGTM.

amithrm pushed a commit to amithrm/xla that referenced this pull request Mar 1, 2024
@lsy323 lsy323 deleted the lsiyuan/aten-dynamism-test branch March 4, 2024 19:13
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
@lsy323 lsy323 changed the title add unbounded dynamism test for some aten ops, support dynamism on add support dynamism on add, mul Aug 30, 2024
@miladm miladm added the dynamism Dynamic Shape Features label Sep 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dynamism Dynamic Shape Features

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants