Skip to content

Commit 2cc290c

Browse files
committed
[MPS] Fix batchnorm for mixed types
By up/down casting weights to input types Extend unittests to support float16 input Fixes #96113
1 parent 666efd8 commit 2cc290c

3 files changed

Lines changed: 108 additions & 105 deletions

File tree

aten/src/ATen/native/mps/OperationUtils.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ void resize_tensor(Tensor* output) {
404404
// this is meant to suppress the availability warning on castTensor
405405
// we pass ScalarType instead of MPSDataType to handle MPSDataTypeBoolean's availability too
406406
MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, MPSDataType toType) {
407-
if ([tensor dataType] == toType) {
407+
if (!tensor || [tensor dataType] == toType) {
408408
return tensor;
409409
}
410410
return [mpsGraph castTensor:tensor toType:toType name:@"castTensor"];

0 commit comments

Comments
 (0)