Skip to content

Add an option for JittableModule to dedup parameters.#8965

Merged
qihqi merged 5 commits intomasterfrom
hanq_handle_duped_param
Apr 16, 2025
Merged

Add an option for JittableModule to dedup parameters.#8965
qihqi merged 5 commits intomasterfrom
hanq_handle_duped_param

Conversation

@qihqi
Copy link
Copy Markdown
Collaborator

@qihqi qihqi commented Apr 11, 2025

Add an option for JittableModule to dedup parameters.

This is needed when we want to use the donate_argnums feature of jax_jit. Because, if duplicated parameters are passed into the callable with donated argnums, then, it will show an argument is donated twice and error out.

@qihqi qihqi force-pushed the hanq_handle_duped_param branch 2 times, most recently from 036d540 to 2f8de95 Compare April 11, 2025 03:43
@qihqi qihqi marked this pull request as ready for review April 11, 2025 03:46
@qihqi qihqi requested review from tengyifei and yaoshiang April 11, 2025 03:46
Copy link
Copy Markdown
Collaborator

@yaoshiang yaoshiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool stuff here, nice to see the donate args in action.

Comment thread torchax/test/test_interop.py
Comment thread torchax/test/test_interop.py Outdated
Comment thread torchax/test/test_interop.py
Comment thread torchax/test/test_interop.py Outdated
Comment thread torchax/test/test_interop.py
Comment thread torchax/test/test_interop.py
Comment thread torchax/torchax/interop.py
Comment thread torchax/torchax/ops/jaten.py
@qihqi qihqi changed the title Update to pytorch 2.6 Add an option for JittableModule to dedup parameters. Apr 11, 2025
@qihqi qihqi force-pushed the hanq_handle_duped_param branch from dfc4ca3 to 811eeef Compare April 11, 2025 22:54
@qihqi qihqi requested a review from yaoshiang April 11, 2025 22:54
Copy link
Copy Markdown
Collaborator

@yaoshiang yaoshiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approved, with optional nits and optional comment on design.

Comment thread torchax/torchax/ops/jaten.py
Comment thread torchax/test/test_interop.py Outdated
Comment thread torchax/test/test_interop.py Outdated
Comment thread torchax/test/test_interop.py
@qihqi qihqi force-pushed the hanq_handle_duped_param branch 2 times, most recently from 713c10b to af84080 Compare April 15, 2025 19:05
@qihqi qihqi force-pushed the hanq_handle_duped_param branch from af84080 to 25e81e8 Compare April 16, 2025 17:54
@qihqi qihqi merged commit cce7cc0 into master Apr 16, 2025
24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants