|
42 | 42 | }; |
43 | 43 |
|
44 | 44 | typedef MPSGraphTensor* (^PoolingOpBlock2D)(PoolingCachedGraph2D&, MPSGraphPooling2DOpDescriptor*); |
45 | | -#define PoolingOpFn2D(graph, desc) MPSGraphTensor*(mps::PoolingCachedGraph2D & graph, MPSGraphPooling2DOpDescriptor * desc) |
| 45 | +#define PoolingOpFn2D(graph, desc) \ |
| 46 | + MPSGraphTensor*(mps::PoolingCachedGraph2D & graph, MPSGraphPooling2DOpDescriptor * desc) |
46 | 47 |
|
47 | 48 | typedef MPSGraphTensor* (^PoolingOpBlock3D)(PoolingCachedGraph3D&, MPSGraphPooling4DOpDescriptor*); |
48 | | -#define PoolingOpFn3D(graph, desc) MPSGraphTensor*(mps::PoolingCachedGraph3D & graph, MPSGraphPooling4DOpDescriptor * desc) |
| 49 | +#define PoolingOpFn3D(graph, desc) \ |
| 50 | + MPSGraphTensor*(mps::PoolingCachedGraph3D & graph, MPSGraphPooling4DOpDescriptor * desc) |
49 | 51 |
|
50 | 52 | // Pooling ops (1D/2D forward and backward Max and Average pooling) |
51 | 53 | static void pool2d_template(const Tensor& input, |
@@ -445,16 +447,27 @@ static void pool3d_template(const Tensor& input, |
445 | 447 |
|
446 | 448 | pool3d_shape_check(input, |
447 | 449 | nInputPlane, |
448 | | - kT, kH, kW, |
449 | | - dT, dH, dW, |
450 | | - padT, padH, padW, |
451 | | - dilationT, dilationH, dilationW, |
452 | | - inputDepth, inputHeight, inputWidth, |
453 | | - outputDepth, outputHeight, outputWidth, |
| 450 | + kT, |
| 451 | + kH, |
| 452 | + kW, |
| 453 | + dT, |
| 454 | + dH, |
| 455 | + dW, |
| 456 | + padT, |
| 457 | + padH, |
| 458 | + padW, |
| 459 | + dilationT, |
| 460 | + dilationH, |
| 461 | + dilationW, |
| 462 | + inputDepth, |
| 463 | + inputHeight, |
| 464 | + inputWidth, |
| 465 | + outputDepth, |
| 466 | + outputHeight, |
| 467 | + outputWidth, |
454 | 468 | op_name.data(), |
455 | 469 | true); |
456 | 470 |
|
457 | | - |
458 | 471 | auto output_memory_format = output.suggest_memory_format(); |
459 | 472 | // the output and indices are 'empty', so we could avoid unnecessary gatherView on empty tensors |
460 | 473 | // by simply restriding them (instead of calling the costly Contiguous()). |
@@ -491,15 +504,21 @@ static void pool3d_template(const Tensor& input, |
491 | 504 | MPSShape* gradOutputShape = is_backward_pass ? getMPSShape(grad_output, memory_format) : nullptr; |
492 | 505 |
|
493 | 506 | auto cachedGraph = LookUpOrCreateCachedGraph<PoolingCachedGraph3D>(key, [&](auto* mpsGraph, auto* newCachedGraph) { |
494 | | - MPSGraphPooling4DOpDescriptor* desc = [MPSGraphPooling4DOpDescriptor |
495 | | - descriptorWithKernelSizes:@[ @1, @(kT), @(kW), @(kH) ] |
496 | | - strides:@[ @1, @(dT), @(dW), @(dH) ] |
497 | | - dilationRates:@[ @1, @(dilationT), @(dilationW), @(dilationH) ] |
498 | | - paddingValues:@[ @0, @0, |
499 | | - @(padT), ceil_mode ? @(padT * dT) : @(padT), |
500 | | - @(padW), ceil_mode ? @(padW * dW) : @(padW), |
501 | | - @(padH), ceil_mode ? @(padH * dH) : @(padH) ] |
502 | | - paddingStyle:MPSGraphPaddingStyleExplicit]; |
| 507 | + MPSGraphPooling4DOpDescriptor* desc = |
| 508 | + [MPSGraphPooling4DOpDescriptor descriptorWithKernelSizes:@[ @1, @(kT), @(kW), @(kH) ] |
| 509 | + strides:@[ @1, @(dT), @(dW), @(dH) ] |
| 510 | + dilationRates:@[ @1, @(dilationT), @(dilationW), @(dilationH) ] |
| 511 | + paddingValues:@[ |
| 512 | + @0, |
| 513 | + @0, |
| 514 | + @(padT), |
| 515 | + ceil_mode ? @(padT * dT) : @(padT), |
| 516 | + @(padW), |
| 517 | + ceil_mode ? @(padW * dW) : @(padW), |
| 518 | + @(padH), |
| 519 | + ceil_mode ? @(padH * dH) : @(padH) |
| 520 | + ] |
| 521 | + paddingStyle:MPSGraphPaddingStyleExplicit]; |
503 | 522 |
|
504 | 523 | desc.ceilMode = (padT == 0 && padW == 0 && padH == 0) ? ceil_mode : false; |
505 | 524 | if (has_indices) { |
@@ -633,16 +652,12 @@ Tensor mps_max_pool3d(const Tensor& input, |
633 | 652 | Tensor output = at::empty({0}, input.options(), MemoryFormat::Contiguous); |
634 | 653 | mps::PoolingOpBlock3D pooling_op_block = ^PoolingOpFn3D(cachedGraph, desc) { |
635 | 654 | MPSGraph* mpsGraph = cachedGraph.graph(); |
636 | | - cachedGraph.expandedInputTensor = [mpsGraph expandDimsOfTensor:cachedGraph.inputTensor |
637 | | - axis:2 |
638 | | - name:nil]; |
| 655 | + cachedGraph.expandedInputTensor = [mpsGraph expandDimsOfTensor:cachedGraph.inputTensor axis:2 name:nil]; |
639 | 656 |
|
640 | 657 | cachedGraph.pooledTensor = [mpsGraph maxPooling4DWithSourceTensor:cachedGraph.expandedInputTensor |
641 | | - descriptor:desc |
642 | | - name:nil]; |
643 | | - return [mpsGraph squeezeTensor:cachedGraph.pooledTensor |
644 | | - axis:2 |
645 | | - name:nil]; |
| 658 | + descriptor:desc |
| 659 | + name:nil]; |
| 660 | + return [mpsGraph squeezeTensor:cachedGraph.pooledTensor axis:2 name:nil]; |
646 | 661 | }; |
647 | 662 | mps::pool3d_template(input, |
648 | 663 | output, |
@@ -671,21 +686,15 @@ Tensor mps_max_pool3d_backward(const Tensor& grad_output, |
671 | 686 | Tensor grad_input = at::empty(input.sizes(), input.options(), MemoryFormat::Contiguous); |
672 | 687 | mps::PoolingOpBlock3D pooling_op_block = ^PoolingOpFn3D(cachedGraph, desc) { |
673 | 688 | MPSGraph* mpsGraph = cachedGraph.graph(); |
674 | | - cachedGraph.expandedInputTensor = [mpsGraph expandDimsOfTensor:cachedGraph.inputTensor |
675 | | - axis:2 |
676 | | - name:nil]; |
| 689 | + cachedGraph.expandedInputTensor = [mpsGraph expandDimsOfTensor:cachedGraph.inputTensor axis:2 name:nil]; |
677 | 690 |
|
678 | | - cachedGraph.expandedGradOutputTensor = [mpsGraph expandDimsOfTensor:cachedGraph.gradOutputTensor |
679 | | - axis:2 |
680 | | - name:nil]; |
| 691 | + cachedGraph.expandedGradOutputTensor = [mpsGraph expandDimsOfTensor:cachedGraph.gradOutputTensor axis:2 name:nil]; |
681 | 692 |
|
682 | 693 | cachedGraph.pooledTensor = [mpsGraph maxPooling4DGradientWithGradientTensor:cachedGraph.expandedGradOutputTensor |
683 | | - sourceTensor:cachedGraph.expandedInputTensor |
684 | | - descriptor:desc |
685 | | - name:nil]; |
686 | | - return [mpsGraph squeezeTensor:cachedGraph.pooledTensor |
687 | | - axis:2 |
688 | | - name:nil]; |
| 694 | + sourceTensor:cachedGraph.expandedInputTensor |
| 695 | + descriptor:desc |
| 696 | + name:nil]; |
| 697 | + return [mpsGraph squeezeTensor:cachedGraph.pooledTensor axis:2 name:nil]; |
689 | 698 | }; |
690 | 699 | mps::pool3d_template(input, |
691 | 700 | grad_input, |
|
0 commit comments