Skip to content

[benchmarks] Set autocast kwargs only if AMP.#6612

Merged
ysiraichi merged 1 commit intomasterfrom
ysiraichi/fix-benchmark-amp
Feb 26, 2024
Merged

[benchmarks] Set autocast kwargs only if AMP.#6612
ysiraichi merged 1 commit intomasterfrom
ysiraichi/fix-benchmark-amp

Conversation

@ysiraichi
Copy link
Copy Markdown
Collaborator

This PR sets the dtype kwarg only if AMP is supposed to be used. Previously, leaving the dtype kwarg made it so nullcontext complained about it. This would crash every inference benchmark.

cc @miladm

@frgossen
Copy link
Copy Markdown
Collaborator

Can you give a bit more detail on why this is needed?

@ysiraichi
Copy link
Copy Markdown
Collaborator Author

ysiraichi commented Feb 26, 2024

Otherwise, we are calling nullcontext(dtype=...), here:

def train(self, inputs, collect_full_output=False):
self._optimizer_zero_grad()
with self.autocast(**self.autocast_kwargs):
pred = self.module(*inputs)
loss = self.compute_loss(pred)

Which ends up raising an error.

Copy link
Copy Markdown
Collaborator

@frgossen frgossen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@ysiraichi ysiraichi merged commit f65bae4 into master Feb 26, 2024
@vanbasten23
Copy link
Copy Markdown
Collaborator

Otherwise, we are calling nullcontext(dtype=...), here:

def train(self, inputs, collect_full_output=False):
self._optimizer_zero_grad()
with self.autocast(**self.autocast_kwargs):
pred = self.module(*inputs)
loss = self.compute_loss(pred)

Which ends up raising an error.

Not sure if I follow. With the change, for amp case, we'll call nullcontext(dtype=...) with non-empy kwargs.., wouldn't nullcontext complain too?

@ysiraichi
Copy link
Copy Markdown
Collaborator Author

As far as I understand it, for AMP case, self.autocast will be either torch.amp.autocast or torch.cuda.autocast. Am I missing something?

amithrm pushed a commit to amithrm/xla that referenced this pull request Mar 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants