Skip to content

ENH: ml_dtypes.bfloat16 support#9494

Merged
leofang merged 37 commits intocupy:mainfrom
seberg:wip-ml_dtypes
Feb 2, 2026
Merged

ENH: ml_dtypes.bfloat16 support#9494
leofang merged 37 commits intocupy:mainfrom
seberg:wip-ml_dtypes

Conversation

@seberg
Copy link
Member

@seberg seberg commented Nov 21, 2025

Add optional support for bfloat16 via ml_dtypes. The mechanism for this is:

  • We the dtype coming from ml_dtypes and if that can be imported, add it to our C type maps. (The corresponding scalars are supported now via an earlier refactor.)
  • We add a new header and type. This header is only included if bfloat16 is used in any of the kernel arguments.
  • Add a *bf16_loop() helper, which expands to the loop if ml_dtypes can be found, otherwise the loop is omitted. (This should be fine, in part because there will always be an "e" loop in front, so we don't have to worry much about weirder promotion corner cases.)
  • Attempts to add tests in CI (not sure I got this right).

I think this should be considered experimental, in the sense that corner cases/differences probably exist and it may not always be obvious what the computation type should be for bfloat16 input (or e.g. half input is explicitly using float32 and it is missing for now).

@seberg seberg added the cat:enhancement Improvements to existing features label Nov 21, 2025
@asi1024 asi1024 added cat:feature New features/APIs prio:high and removed cat:enhancement Improvements to existing features labels Nov 28, 2025
@leofang
Copy link
Member

leofang commented Dec 22, 2025

FYI #9503 is merged

@leofang leofang added this to the v14 milestone Dec 22, 2025
@mergify

This comment was marked as outdated.

@leofang leofang modified the milestones: v14, v14.0.0, v15 Dec 24, 2025
@leofang leofang added the to-be-backported Pull-requests to be backported to stable branch label Dec 29, 2025
@seberg

This comment was marked as outdated.

@seberg seberg changed the title PoC/WIP: ml_dtypes (bfloat16) support ENH: ml_dtypes (bfloat16) support Jan 14, 2026
@seberg
Copy link
Member Author

seberg commented Jan 14, 2026

This is now in a state where a look make sense. Not small anymore, but also not that big.

One larger change is that there is now a bfloat16.cuh and it is only included if bfloat16 is involved in the kernel, that is rather than _get_typename() to get the C type, we now need _get_typename_and_preamble() to also know that we need to include bfloat16.cuh.

Some open points/homework for me:

  • New headers need some thoughts (I don't like the way float16 <-> bfloat16 casts work but it may be OK)
  • A few small bugs (some matmul code fails).
  • Need to fix fusion support (if not too complicated), including special value support (e.g. inf).
  • It may be nice to see if we can avoid rebuilding caches too much.

EDIT: I won't claim there won't be holes, but I think I covered all of those points 🎉. The one thing would probably still do is rename param_preambles to type_headers because it is just much clearer (and the name preambles is used elsewhere.

@seberg seberg changed the title ENH: ml_dtypes (bfloat16) support ENH: ml_dtypes.bfloat16 support Jan 15, 2026
@seberg seberg marked this pull request as ready for review January 15, 2026 19:14
@seberg seberg requested a review from a team as a code owner January 15, 2026 19:14
@seberg
Copy link
Member Author

seberg commented Jan 16, 2026

/test mini

leofang
leofang previously approved these changes Jan 17, 2026
Copy link
Member

@leofang leofang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Awesome progress, Sebastian!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can wait, but it'd be nice later to add no_bfloat16=False to for_all_dtypes, for_float_dtypes, ... in cupy/testing/_loops.py to get it auto-tested

@leofang
Copy link
Member

leofang commented Jan 17, 2026

/test mini

@seberg
Copy link
Member Author

seberg commented Jan 17, 2026

One other thing to confirm here: I am not yet sure the bfloat16 tests are running in any of the CI jobs.

('F->f', 'out0 = arg(in0) * (180.0 / M_PI)'),
('D->d', 'out0 = arg(in0) * (180.0 / M_PI)')),
'out0 = in0 >= 0 ? 0 : 180.0',
'out0 = in0 >= decltype(in0){0} ? 0 : 180.0',
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I need to undo (or change this), but it needs some header fix as well (it thinks the __half comparison may also match).

FWIW, I will also check other CI testing bfloat16 though, I suspect I added ml_dtypes to the matrix but have to add it to actually be installed maybe?

Copy link
Member Author

@seberg seberg Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the end, I made this two lines, but really just used decltype(+in0){0}. The + seemed like the easiest way to fix the fusion, which uses e.g. double &in0 and we need to get rid of the &.

I couldn't think of a way to keep this code unchanged that would avoid the operator to the header, but I didn't want to add a full family of them...

(As an aside: type_in0_raw mechanism isn't compatible with fusion, maybe fusion could also define it. Although I like the decltype at least if it wasn't for the +in0 hack.)

@seberg
Copy link
Member Author

seberg commented Jan 28, 2026

Want to see if this works (or if I might have to adapt the minimum version in the tests).

/test cuda122,linux

@leofang
Copy link
Member

leofang commented Jan 28, 2026

/test cuda122,linux

1 similar comment
@leofang
Copy link
Member

leofang commented Jan 30, 2026

/test cuda122,linux

@leofang
Copy link
Member

leofang commented Jan 30, 2026

/test mini

1 similar comment
@seberg
Copy link
Member Author

seberg commented Jan 31, 2026

/test mini

@leofang
Copy link
Member

leofang commented Feb 1, 2026

/test windows,cuda120

@leofang
Copy link
Member

leofang commented Feb 1, 2026

/test windows,cuda129

}
return float16(ret_raw_);
}
#endif // #ifdef __HIPCC_RTC__
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment only, no need to address in this PR: Shouldn't this be __HIPCC__?

@leofang
Copy link
Member

leofang commented Feb 2, 2026

/test force-skip

@github-actions
Copy link

github-actions bot commented Feb 2, 2026

The following tests were force-skipped:

@leofang leofang merged commit 513b954 into cupy:main Feb 2, 2026
57 of 58 checks passed
chainer-ci pushed a commit to chainer-ci/cupy that referenced this pull request Feb 2, 2026
ENH: `ml_dtypes.bfloat16` support
@asi1024
Copy link
Member

asi1024 commented Feb 2, 2026

LGTM! Thanks!

@leofang leofang linked an issue Feb 2, 2026 that may be closed by this pull request
@leofang leofang mentioned this pull request Feb 2, 2026
@seberg seberg deleted the wip-ml_dtypes branch February 2, 2026 07:29
@leofang leofang modified the milestones: v15, v15.0.0a1 Mar 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

blocking Issue/pull-request is mandatory for the upcoming release cat:feature New features/APIs prio:high to-be-backported Pull-requests to be backported to stable branch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support for bfloat16

3 participants