Skip to content

Commit 351390b

Browse files
committed
apply lint diff
1 parent f9a03ef commit 351390b

1 file changed

Lines changed: 47 additions & 38 deletions

File tree

aten/src/ATen/native/mps/operations/Pooling.mm

Lines changed: 47 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,12 @@
4242
};
4343

4444
typedef MPSGraphTensor* (^PoolingOpBlock2D)(PoolingCachedGraph2D&, MPSGraphPooling2DOpDescriptor*);
45-
#define PoolingOpFn2D(graph, desc) MPSGraphTensor*(mps::PoolingCachedGraph2D & graph, MPSGraphPooling2DOpDescriptor * desc)
45+
#define PoolingOpFn2D(graph, desc) \
46+
MPSGraphTensor*(mps::PoolingCachedGraph2D & graph, MPSGraphPooling2DOpDescriptor * desc)
4647

4748
typedef MPSGraphTensor* (^PoolingOpBlock3D)(PoolingCachedGraph3D&, MPSGraphPooling4DOpDescriptor*);
48-
#define PoolingOpFn3D(graph, desc) MPSGraphTensor*(mps::PoolingCachedGraph3D & graph, MPSGraphPooling4DOpDescriptor * desc)
49+
#define PoolingOpFn3D(graph, desc) \
50+
MPSGraphTensor*(mps::PoolingCachedGraph3D & graph, MPSGraphPooling4DOpDescriptor * desc)
4951

5052
// Pooling ops (1D/2D forward and backward Max and Average pooling)
5153
static void pool2d_template(const Tensor& input,
@@ -445,16 +447,27 @@ static void pool3d_template(const Tensor& input,
445447

446448
pool3d_shape_check(input,
447449
nInputPlane,
448-
kT, kH, kW,
449-
dT, dH, dW,
450-
padT, padH, padW,
451-
dilationT, dilationH, dilationW,
452-
inputDepth, inputHeight, inputWidth,
453-
outputDepth, outputHeight, outputWidth,
450+
kT,
451+
kH,
452+
kW,
453+
dT,
454+
dH,
455+
dW,
456+
padT,
457+
padH,
458+
padW,
459+
dilationT,
460+
dilationH,
461+
dilationW,
462+
inputDepth,
463+
inputHeight,
464+
inputWidth,
465+
outputDepth,
466+
outputHeight,
467+
outputWidth,
454468
op_name.data(),
455469
true);
456470

457-
458471
auto output_memory_format = output.suggest_memory_format();
459472
// the output and indices are 'empty', so we could avoid unnecessary gatherView on empty tensors
460473
// by simply restriding them (instead of calling the costly Contiguous()).
@@ -491,15 +504,21 @@ static void pool3d_template(const Tensor& input,
491504
MPSShape* gradOutputShape = is_backward_pass ? getMPSShape(grad_output, memory_format) : nullptr;
492505

493506
auto cachedGraph = LookUpOrCreateCachedGraph<PoolingCachedGraph3D>(key, [&](auto* mpsGraph, auto* newCachedGraph) {
494-
MPSGraphPooling4DOpDescriptor* desc = [MPSGraphPooling4DOpDescriptor
495-
descriptorWithKernelSizes:@[ @1, @(kT), @(kW), @(kH) ]
496-
strides:@[ @1, @(dT), @(dW), @(dH) ]
497-
dilationRates:@[ @1, @(dilationT), @(dilationW), @(dilationH) ]
498-
paddingValues:@[ @0, @0,
499-
@(padT), ceil_mode ? @(padT * dT) : @(padT),
500-
@(padW), ceil_mode ? @(padW * dW) : @(padW),
501-
@(padH), ceil_mode ? @(padH * dH) : @(padH) ]
502-
paddingStyle:MPSGraphPaddingStyleExplicit];
507+
MPSGraphPooling4DOpDescriptor* desc =
508+
[MPSGraphPooling4DOpDescriptor descriptorWithKernelSizes:@[ @1, @(kT), @(kW), @(kH) ]
509+
strides:@[ @1, @(dT), @(dW), @(dH) ]
510+
dilationRates:@[ @1, @(dilationT), @(dilationW), @(dilationH) ]
511+
paddingValues:@[
512+
@0,
513+
@0,
514+
@(padT),
515+
ceil_mode ? @(padT * dT) : @(padT),
516+
@(padW),
517+
ceil_mode ? @(padW * dW) : @(padW),
518+
@(padH),
519+
ceil_mode ? @(padH * dH) : @(padH)
520+
]
521+
paddingStyle:MPSGraphPaddingStyleExplicit];
503522

504523
desc.ceilMode = (padT == 0 && padW == 0 && padH == 0) ? ceil_mode : false;
505524
if (has_indices) {
@@ -633,16 +652,12 @@ Tensor mps_max_pool3d(const Tensor& input,
633652
Tensor output = at::empty({0}, input.options(), MemoryFormat::Contiguous);
634653
mps::PoolingOpBlock3D pooling_op_block = ^PoolingOpFn3D(cachedGraph, desc) {
635654
MPSGraph* mpsGraph = cachedGraph.graph();
636-
cachedGraph.expandedInputTensor = [mpsGraph expandDimsOfTensor:cachedGraph.inputTensor
637-
axis:2
638-
name:nil];
655+
cachedGraph.expandedInputTensor = [mpsGraph expandDimsOfTensor:cachedGraph.inputTensor axis:2 name:nil];
639656

640657
cachedGraph.pooledTensor = [mpsGraph maxPooling4DWithSourceTensor:cachedGraph.expandedInputTensor
641-
descriptor:desc
642-
name:nil];
643-
return [mpsGraph squeezeTensor:cachedGraph.pooledTensor
644-
axis:2
645-
name:nil];
658+
descriptor:desc
659+
name:nil];
660+
return [mpsGraph squeezeTensor:cachedGraph.pooledTensor axis:2 name:nil];
646661
};
647662
mps::pool3d_template(input,
648663
output,
@@ -671,21 +686,15 @@ Tensor mps_max_pool3d_backward(const Tensor& grad_output,
671686
Tensor grad_input = at::empty(input.sizes(), input.options(), MemoryFormat::Contiguous);
672687
mps::PoolingOpBlock3D pooling_op_block = ^PoolingOpFn3D(cachedGraph, desc) {
673688
MPSGraph* mpsGraph = cachedGraph.graph();
674-
cachedGraph.expandedInputTensor = [mpsGraph expandDimsOfTensor:cachedGraph.inputTensor
675-
axis:2
676-
name:nil];
689+
cachedGraph.expandedInputTensor = [mpsGraph expandDimsOfTensor:cachedGraph.inputTensor axis:2 name:nil];
677690

678-
cachedGraph.expandedGradOutputTensor = [mpsGraph expandDimsOfTensor:cachedGraph.gradOutputTensor
679-
axis:2
680-
name:nil];
691+
cachedGraph.expandedGradOutputTensor = [mpsGraph expandDimsOfTensor:cachedGraph.gradOutputTensor axis:2 name:nil];
681692

682693
cachedGraph.pooledTensor = [mpsGraph maxPooling4DGradientWithGradientTensor:cachedGraph.expandedGradOutputTensor
683-
sourceTensor:cachedGraph.expandedInputTensor
684-
descriptor:desc
685-
name:nil];
686-
return [mpsGraph squeezeTensor:cachedGraph.pooledTensor
687-
axis:2
688-
name:nil];
694+
sourceTensor:cachedGraph.expandedInputTensor
695+
descriptor:desc
696+
name:nil];
697+
return [mpsGraph squeezeTensor:cachedGraph.pooledTensor axis:2 name:nil];
689698
};
690699
mps::pool3d_template(input,
691700
grad_input,

0 commit comments

Comments
 (0)