Skip to content

Commit cdcc411

Browse files
committed
Compile on MacOS 12
1 parent 67260e3 commit cdcc411

2 files changed

Lines changed: 58 additions & 0 deletions

File tree

aten/src/ATen/native/mps/MPSGraphVenturaOps.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,35 @@
22
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
33

44
// TODO: Remove me when moved to MacOS 13
5+
#if !defined(__MAC_13_2) && \
6+
(!defined(MAC_OS_X_VERSION_13_2) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_2))
7+
8+
@interface MPSGraphConvolution3DOpDescriptor : NSObject<NSCopying>
9+
10+
@property (readwrite, nonatomic) NSUInteger strideInX;
11+
@property (readwrite, nonatomic) NSUInteger strideInY;
12+
@property (readwrite, nonatomic) NSUInteger strideInZ;
13+
@property (readwrite, nonatomic) NSUInteger dilationRateInX;
14+
@property (readwrite, nonatomic) NSUInteger dilationRateInY;
15+
@property (readwrite, nonatomic) NSUInteger dilationRateInZ;
16+
17+
@property (readwrite, nonatomic) NSUInteger paddingLeft;
18+
@property (readwrite, nonatomic) NSUInteger paddingRight;
19+
@property (readwrite, nonatomic) NSUInteger paddingTop;
20+
@property (readwrite, nonatomic) NSUInteger paddingBottom;
21+
@property (readwrite, nonatomic) NSUInteger paddingFront;
22+
@property (readwrite, nonatomic) NSUInteger paddingBack;
23+
24+
@property (readwrite, nonatomic) MPSGraphPaddingStyle paddingStyle;
25+
@property (readwrite, nonatomic) MPSGraphTensorNamedDataLayout dataLayout;
26+
@property (readwrite, nonatomic) MPSGraphTensorNamedDataLayout weightsLayout;
27+
28+
@property (readwrite, nonatomic) NSUInteger groups;
29+
30+
@end
31+
32+
#endif
33+
534
@interface MPSGraph (VenturaOps)
635

736
#if !defined(__MAC_13_0) && \
@@ -18,6 +47,23 @@ typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode)
1847
};
1948
#endif
2049

50+
- (MPSGraphTensor * _Nonnull) convolution3DWithSourceTensor:(MPSGraphTensor * _Nonnull) source
51+
weightsTensor:(MPSGraphTensor * _Nonnull) weights
52+
descriptor:(MPSGraphConvolution3DOpDescriptor * _Nonnull) descriptor
53+
name:(NSString * _Nullable) name;
54+
55+
- (MPSGraphTensor * _Nonnull) convolution3DDataGradientWithIncomingGradientTensor:(MPSGraphTensor * _Nonnull) incomingGradient
56+
weightsTensor:(MPSGraphTensor * _Nonnull) weights
57+
outputShape:(MPSShape * _Nonnull) outputShape
58+
forwardConvolutionDescriptor:(MPSGraphConvolution3DOpDescriptor * _Nonnull) forwardConvolutionDescriptor
59+
name:(NSString * _Nullable) name;
60+
61+
- (MPSGraphTensor * _Nonnull) convolution3DWeightsGradientWithIncomingGradientTensor:(MPSGraphTensor * _Nonnull) incomingGradient
62+
sourceTensor:(MPSGraphTensor * _Nonnull) source
63+
outputShape:(MPSShape * _Nonnull) outputShape
64+
forwardConvolutionDescriptor:(MPSGraphConvolution3DOpDescriptor * _Nonnull) forwardConvolutionDescriptor
65+
name:(NSString * _Nullable) name;
66+
2167
- (MPSGraphTensor * _Nonnull)cumulativeSumWithTensor:(MPSGraphTensor * _Nonnull)tensor
2268
axis:(NSInteger)axis
2369
name:(NSString * _Nullable)name;

aten/src/ATen/native/mps/operations/Convolution.mm

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
// Copyright © 2022 Apple Inc.
22
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
33
#include <ATen/native/ConvUtils.h>
4+
#include <ATen/native/mps/MPSGraphVenturaOps.h>
45
#include <ATen/native/mps/OperationUtils.h>
56
#include <ATen/ops/_mps_convolution_native.h>
67
#include <ATen/ops/_mps_convolution_transpose_native.h>
78
#include <ATen/ops/mps_convolution_backward_native.h>
89
#include <ATen/ops/mps_convolution_transpose_backward_native.h>
910

11+
#if !defined(__MAC_13_2) && (!defined(MAC_OS_X_VERSION_13_2) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_2))
12+
13+
@implementation MPSGraphConvolution3DOpDescriptor
14+
- (nonnull id)copyWithZone:(nullable NSZone*)zone {
15+
return self;
16+
}
17+
18+
@end
19+
20+
#endif
21+
1022
namespace at::native {
1123

1224
// Create 3D convolution descriptor

0 commit comments

Comments
 (0)