Skip to content

Enable cross entropy loss for xla autocast with FP32 precision#7992

Merged
JackCaoG merged 3 commits intor2.1_aws_neuronfrom
autocast_bf16_neuron
Sep 26, 2024
Merged

Enable cross entropy loss for xla autocast with FP32 precision#7992
JackCaoG merged 3 commits intor2.1_aws_neuronfrom
autocast_bf16_neuron

Conversation

@avizon-aws
Copy link
Copy Markdown
Collaborator

There are many operators in XLA autocast that have been commented, but these operators are casted in the GPU, in order to maintain consistency, we need to support these operators as well. For cross_entropy_loss, it is currently commented in the xla autocast, so there will be no casting occuring, and it will execute based on its input’s dtype.

The output type is bf16, which is expected because linear layer is specified in xla autocast. loss dtype is fp32, which is correct, but there’s a catch, there was no autocasting done for the crossEntropyLoss, the reason the dtype is FP32 is because of the target’s dtype, which is FP32. There is a multiplication which happens in crossentropyloss between the generated output and the target, all the exponentiation/log etc. is done in BF16, but only because of the final multiplication, we get the result in FP32, because it casts to the higher precision (FP32). This is not the expected behavior, all the exponentiation/logs i.e. all ops related to crossentropyloss should be executed in FP32, the reason it is not happening is because crossentropyloss is not specified in xla autocast. This finding is based after detailed analysis of the HLO outputs which is attached below.

Before this change:
Exp1.

device = 'xla' # Get the XLA device (e.g., TPU or GPU)
model = torch.nn.Linear(10, 10).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
data = torch.randn(16, 10).to(torch.bfloat16).(to(device)
target = torch.randn(16, 10).to(device)
print(device, torch.__version__)
for epoch in range(1):
    optimizer.zero_grad()
    # debugpy.breakpoint() 
    with torch.autocast('xla'):
        output = model(data)
        loss = torch.nn.CrossEntropyLoss()(output, target)
        print(output.dtype, loss.dtype, target.dtype)
    # loss.backward()
    optimizer.step()
    print(f"Epoch {epoch}, Loss: {loss.item()}")

HLO:

ENTRY %SyncTensorsGraph.62 (p0.1: f32[], p1.2: f32[16,10], p2.3: f32[10], p3.12: f32[10,10], p4.22: f32[16,10]) -> (f32[]) {
  %p4.22 = f32[16,10]{1,0} parameter(4), frontend_attributes={neff_input_names="input4"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %convert.23 = bf16[16,10]{1,0} convert(f32[16,10]{1,0} %p4.22), metadata={op_type="xla__cast" op_name="xla__cast"}
  %p3.12 = f32[10,10]{1,0} parameter(3), frontend_attributes={neff_input_names="input3"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %custom-call.2 = f32[10,10]{1,0} custom-call(f32[10,10]{1,0} %p3.12), custom_call_target="AwsNeuronTransferWithStaticRing", api_version=API_VERSION_UNSPECIFIED, metadata={op_type="xla___op_TransferWithStaticRingTransfer" op_name="xla___op_TransferWithStaticRingTransfer"}
  %convert.20 = bf16[10,10]{1,0} convert(f32[10,10]{1,0} %custom-call.2), metadata={op_type="xla__cast" op_name="xla__cast"}
  %transpose.21 = bf16[10,10]{0,1} transpose(bf16[10,10]{1,0} %convert.20), dimensions={1,0}, metadata={op_type="aten__permute" op_name="aten__permute"}
  %dot.24 = bf16[16,10]{1,0} dot(bf16[16,10]{1,0} %convert.23, bf16[10,10]{0,1} %transpose.21), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %p2.3 = f32[10]{0} parameter(2), frontend_attributes={neff_input_names="input2"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %custom-call.3 = f32[10]{0} custom-call(f32[10]{0} %p2.3), custom_call_target="AwsNeuronTransferWithStaticRing", api_version=API_VERSION_UNSPECIFIED, metadata={op_type="xla___op_TransferWithStaticRingTransfer" op_name="xla___op_TransferWithStaticRingTransfer"}
  %convert.11 = bf16[10]{0} convert(f32[10]{0} %custom-call.3), metadata={op_type="xla__cast" op_name="xla__cast"}
  %broadcast.28 = bf16[16,10]{1,0} broadcast(bf16[10]{0} %convert.11), dimensions={1}, metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %add.29 = bf16[16,10]{1,0} add(bf16[16,10]{1,0} %dot.24, bf16[16,10]{1,0} %broadcast.28), metadata={op_type="aten__addmm" op_name="aten__addmm"}
  **%constant.32 = bf16[] constant(-inf), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %reduce.37 = bf16[16]{0} reduce(bf16[16,10]{1,0} %add.29, bf16[] %constant.32), dimensions={1}, to_apply=%MaxComputation.33, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %broadcast.38 = bf16[16,10]{1,0} broadcast(bf16[16]{0} %reduce.37), dimensions={0}, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %subtract.39 = bf16[16,10]{1,0} subtract(bf16[16,10]{1,0} %add.29, bf16[16,10]{1,0} %broadcast.38), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %exponential.40 = bf16[16,10]{1,0} exponential(bf16[16,10]{1,0} %subtract.39), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %constant.41 = bf16[] constant(0), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %reduce.46 = bf16[16]{0} reduce(bf16[16,10]{1,0} %exponential.40, bf16[] %constant.41), dimensions={1}, to_apply=%AddComputation.42, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %log.47 = bf16[16]{0} log(bf16[16]{0} %reduce.46), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %broadcast.48 = bf16[16,10]{1,0} broadcast(bf16[16]{0} %log.47), dimensions={0}, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %subtract.49 = bf16[16,10]{1,0} subtract(bf16[16,10]{1,0} %subtract.39, bf16[16,10]{1,0} %broadcast.48), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %convert.50 = f32[16,10]{1,0} convert(bf16[16,10]{1,0} %subtract.49), metadata={op_type="aten__mul" op_name="aten__mul"}
  %p1.2 = f32[16,10]{1,0} parameter(1), frontend_attributes={neff_input_names="input1"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %multiply.51 = f32[16,10]{1,0} multiply(f32[16,10]{1,0} %convert.50, f32[16,10]{1,0} %p1.2), metadata={op_type="aten__mul" op_name="aten__mul"}**
  %constant.52 = f32[] constant(0), metadata={op_type="aten__sum" op_name="aten__sum"}
  %reduce.58 = f32[] reduce(f32[16,10]{1,0} %multiply.51, f32[] %constant.52), dimensions={0,1}, to_apply=%AddComputation.54, metadata={op_type="aten__sum" op_name="aten__sum"}
  %negate.59 = f32[] negate(f32[] %reduce.58), metadata={op_type="aten__neg" op_name="aten__neg"}
  %p0.1 = f32[] parameter(0), frontend_attributes={neff_input_names="input0"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %divide.60 = f32[] divide(f32[] %negate.59, f32[] %p0.1), metadata={op_type="aten__div" op_name="aten__div"}
  ROOT %tuple.61 = (f32[]) tuple(f32[] %divide.60), frontend_attributes={neff_output_names="output0"}
}

Exp2.
The target dtype if also bf16 in this case. This experiment was done to prove that the dtype of the target was the true cause of the FP32 output as shown below.

Code change from previous experiment:
target = torch.randn(16, 10).to(torch.bfloat16).to(device)

HLO

ENTRY %SyncTensorsGraph.61 (p0.1: bf16[], p1.2: bf16[16,10], p2.3: f32[10], p3.12: f32[10,10], p4.22: f32[16,10]) -> (bf16[]) {
  %p4.22 = f32[16,10]{1,0} parameter(4), frontend_attributes={neff_input_names="input4"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %convert.23 = bf16[16,10]{1,0} convert(f32[16,10]{1,0} %p4.22), metadata={op_type="xla__cast" op_name="xla__cast"}
  %p3.12 = f32[10,10]{1,0} parameter(3), frontend_attributes={neff_input_names="input3"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %custom-call.2 = f32[10,10]{1,0} custom-call(f32[10,10]{1,0} %p3.12), custom_call_target="AwsNeuronTransferWithStaticRing", api_version=API_VERSION_UNSPECIFIED, metadata={op_type="xla___op_TransferWithStaticRingTransfer" op_name="xla___op_TransferWithStaticRingTransfer"}
  %convert.20 = bf16[10,10]{1,0} convert(f32[10,10]{1,0} %custom-call.2), metadata={op_type="xla__cast" op_name="xla__cast"}
  %transpose.21 = bf16[10,10]{0,1} transpose(bf16[10,10]{1,0} %convert.20), dimensions={1,0}, metadata={op_type="aten__permute" op_name="aten__permute"}
  %dot.24 = bf16[16,10]{1,0} dot(bf16[16,10]{1,0} %convert.23, bf16[10,10]{0,1} %transpose.21), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %p2.3 = f32[10]{0} parameter(2), frontend_attributes={neff_input_names="input2"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %custom-call.3 = f32[10]{0} custom-call(f32[10]{0} %p2.3), custom_call_target="AwsNeuronTransferWithStaticRing", api_version=API_VERSION_UNSPECIFIED, metadata={op_type="xla___op_TransferWithStaticRingTransfer" op_name="xla___op_TransferWithStaticRingTransfer"}
  %convert.11 = bf16[10]{0} convert(f32[10]{0} %custom-call.3), metadata={op_type="xla__cast" op_name="xla__cast"}
  %broadcast.28 = bf16[16,10]{1,0} broadcast(bf16[10]{0} %convert.11), dimensions={1}, metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %add.29 = bf16[16,10]{1,0} add(bf16[16,10]{1,0} %dot.24, bf16[16,10]{1,0} %broadcast.28), metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %constant.32 = bf16[] constant(-inf), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %reduce.37 = bf16[16]{0} reduce(bf16[16,10]{1,0} %add.29, bf16[] %constant.32), dimensions={1}, to_apply=%MaxComputation.33, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %broadcast.38 = bf16[16,10]{1,0} broadcast(bf16[16]{0} %reduce.37), dimensions={0}, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %subtract.39 = bf16[16,10]{1,0} subtract(bf16[16,10]{1,0} %add.29, bf16[16,10]{1,0} %broadcast.38), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %exponential.40 = bf16[16,10]{1,0} exponential(bf16[16,10]{1,0} %subtract.39), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %constant.41 = bf16[] constant(0), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %reduce.46 = bf16[16]{0} reduce(bf16[16,10]{1,0} %exponential.40, bf16[] %constant.41), dimensions={1}, to_apply=%AddComputation.42, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %log.47 = bf16[16]{0} log(bf16[16]{0} %reduce.46), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %broadcast.48 = bf16[16,10]{1,0} broadcast(bf16[16]{0} %log.47), dimensions={0}, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %subtract.49 = bf16[16,10]{1,0} subtract(bf16[16,10]{1,0} %subtract.39, bf16[16,10]{1,0} %broadcast.48), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %p1.2 = bf16[16,10]{1,0} parameter(1), frontend_attributes={neff_input_names="input1"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %multiply.50 = bf16[16,10]{1,0} multiply(bf16[16,10]{1,0} %subtract.49, bf16[16,10]{1,0} %p1.2), metadata={op_type="aten__mul" op_name="aten__mul"}
  %constant.51 = bf16[] constant(0), metadata={op_type="aten__sum" op_name="aten__sum"}
  %reduce.57 = bf16[] reduce(bf16[16,10]{1,0} %multiply.50, bf16[] %constant.51), dimensions={0,1}, to_apply=%AddComputation.53, metadata={op_type="aten__sum" op_name="aten__sum"}
  %negate.58 = bf16[] negate(bf16[] %reduce.57), metadata={op_type="aten__neg" op_name="aten__neg"}
  %p0.1 = bf16[] parameter(0), frontend_attributes={neff_input_names="input0"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %divide.59 = bf16[] divide(bf16[] %negate.58, bf16[] %p0.1), metadata={op_type="aten__div" op_name="aten__div"}
  ROOT %tuple.60 = (bf16[]) tuple(bf16[] %divide.59), frontend_attributes={neff_output_names="output0"}
}

After uncommenting the CrossEntropyLoss in the XLA autocast as done in this PR:

Exp3:
The input and target are in FP32, so the output of the linear layer will be in BF16, and then it should be upcasted to FP32 for the Crossentropyloss as seen in the HLO.

  %p4.8 = f32[16,10]{1,0} parameter(4), frontend_attributes={neff_input_names="input4"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %convert.9 = bf16[16,10]{1,0} convert(f32[16,10]{1,0} %p4.8), metadata={op_type="xla__cast" op_name="xla__cast"}
  %p3.5 = f32[10,10]{1,0} parameter(3), frontend_attributes={neff_input_names="input3"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %convert.6 = bf16[10,10]{1,0} convert(f32[10,10]{1,0} %p3.5), metadata={op_type="xla__cast" op_name="xla__cast"}
  %transpose.7 = bf16[10,10]{0,1} transpose(bf16[10,10]{1,0} %convert.6), dimensions={1,0}, metadata={op_type="aten__permute" op_name="aten__permute"}
  %dot.10 = bf16[16,10]{1,0} dot(bf16[16,10]{1,0} %convert.9, bf16[10,10]{0,1} %transpose.7), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %p2.3 = f32[10]{0} parameter(2), frontend_attributes={neff_input_names="input2"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %convert.4 = bf16[10]{0} convert(f32[10]{0} %p2.3), metadata={op_type="xla__cast" op_name="xla__cast"}
  %broadcast.14 = bf16[16,10]{1,0} broadcast(bf16[10]{0} %convert.4), dimensions={1}, metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %add.15 = bf16[16,10]{1,0} add(bf16[16,10]{1,0} %dot.10, bf16[16,10]{1,0} %broadcast.14), metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %convert.16 = f32[16,10]{1,0} convert(bf16[16,10]{1,0} %add.15), metadata={op_type="xla__cast" op_name="xla__cast"}
  %constant.17 = f32[] constant(-inf), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %reduce.22 = f32[16]{0} reduce(f32[16,10]{1,0} %convert.16, f32[] %constant.17), dimensions={1}, to_apply=%MaxComputation.18, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %broadcast.23 = f32[16,10]{1,0} broadcast(f32[16]{0} %reduce.22), dimensions={0}, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %subtract.24 = f32[16,10]{1,0} subtract(f32[16,10]{1,0} %convert.16, f32[16,10]{1,0} %broadcast.23), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %exponential.25 = f32[16,10]{1,0} exponential(f32[16,10]{1,0} %subtract.24), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %constant.26 = f32[] constant(0), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %reduce.31 = f32[16]{0} reduce(f32[16,10]{1,0} %exponential.25, f32[] %constant.26), dimensions={1}, to_apply=%AddComputation.27, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %log.32 = f32[16]{0} log(f32[16]{0} %reduce.31), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %broadcast.33 = f32[16,10]{1,0} broadcast(f32[16]{0} %log.32), dimensions={0}, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %subtract.34 = f32[16,10]{1,0} subtract(f32[16,10]{1,0} %subtract.24, f32[16,10]{1,0} %broadcast.33), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %p1.2 = f32[16,10]{1,0} parameter(1), frontend_attributes={neff_input_names="input1"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %multiply.35 = f32[16,10]{1,0} multiply(f32[16,10]{1,0} %subtract.34, f32[16,10]{1,0} %p1.2), metadata={op_type="aten__mul" op_name="aten__mul"}
  %reduce.42 = f32[] reduce(f32[16,10]{1,0} %multiply.35, f32[] %constant.26), dimensions={0,1}, to_apply=%AddComputation.38, metadata={op_type="aten__sum" op_name="aten__sum"}
  %negate.43 = f32[] negate(f32[] %reduce.42), metadata={op_type="aten__neg" op_name="aten__neg"}
  %p0.1 = f32[] parameter(0), frontend_attributes={neff_input_names="input0"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %divide.44 = f32[] divide(f32[] %negate.43, f32[] %p0.1), metadata={op_type="aten__div" op_name="aten__div"}
  ROOT %tuple.45 = (f32[]) tuple(f32[] %divide.44), frontend_attributes={neff_output_names="output0"}
}

The executions of the operators are as expected.

@JackCaoG JackCaoG merged commit e438a5b into r2.1_aws_neuron Sep 26, 2024
@JackCaoG JackCaoG deleted the autocast_bf16_neuron branch September 26, 2024 16:11
@jeffhataws
Copy link
Copy Markdown
Collaborator

@avizon-aws Need cherry-pick to master

@jeffhataws
Copy link
Copy Markdown
Collaborator

Also request cherry-pick to 2.5 here #7977 (comment)

@avizon-aws
Copy link
Copy Markdown
Collaborator Author

Created PR to master: #8094

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants