Skip to content

xb.call_jax always assumes the traced function makes use of all arguments #8794

@tengyifei

Description

@tengyifei

If the jax function doesn't use one of the arguments (common in gradient calculations), then the HLO parameter count and function parameters will mismatch.

We need a way to get jax to tell us which input argument maps to which HLO parameter.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions