Skip to content

Commit 22fd2a7

Browse files
liuk22facebook-github-bot
authored andcommitted
[PyTorch] Add Vulkan support and tests for at::select.int operator, 4 dim/rank tensor case (#96228)
Summary: Pull Request resolved: #96228 Currently, selection along a dimension/rank is only supported for 3D/rank tensors in PyTorch Vulkan. This adds support for 4D/rank tensors at selection along batch, channel (depth), height, and width. Additionally: - The existing implementations have been name-refactored to reflect whether they operate on 3d or 4d tensors. - The params buffer for all select operations now use `ivec2` or `ivec4` only, for memory alignment safety. Test Plan: **Internal:** 1. `buck run --target-platforms ovr_config//platform/macos:arm64-fbsource //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1` on Apple M1 MacBook 2. Confirm all tests pass with no regression, and the directly affected tests `select_4d_*`, refactored `select_3d_`, pass 3. Test output P636928908, in particular: ``` [...bunch of other tests...] [ RUN ] VulkanAPITest.select_3d_depth_small [ OK ] VulkanAPITest.select_3d_depth_small (1 ms) [ RUN ] VulkanAPITest.select_3d_depth_medium [ OK ] VulkanAPITest.select_3d_depth_medium (0 ms) [ RUN ] VulkanAPITest.select_3d_depth_large [ OK ] VulkanAPITest.select_3d_depth_large (1 ms) [ RUN ] VulkanAPITest.select_3d_height_small [ OK ] VulkanAPITest.select_3d_height_small (0 ms) [ RUN ] VulkanAPITest.select_3d_height_medium [ OK ] VulkanAPITest.select_3d_height_medium (0 ms) [ RUN ] VulkanAPITest.select_3d_height_medium1 [ OK ] VulkanAPITest.select_3d_height_medium1 (0 ms) [ RUN ] VulkanAPITest.select_3d_height_medium2 [ OK ] VulkanAPITest.select_3d_height_medium2 (0 ms) [ RUN ] VulkanAPITest.select_3d_height_large [ OK ] VulkanAPITest.select_3d_height_large (1 ms) [ RUN ] VulkanAPITest.select_3d_width_small [ OK ] VulkanAPITest.select_3d_width_small (0 ms) [ RUN ] VulkanAPITest.select_3d_width_medium [ OK ] VulkanAPITest.select_3d_width_medium (0 ms) [ RUN ] VulkanAPITest.select_3d_width_medium2 [ OK ] VulkanAPITest.select_3d_width_medium2 (0 ms) [ RUN ] VulkanAPITest.select_3d_width_large [ OK ] VulkanAPITest.select_3d_width_large (1 ms) [ RUN ] VulkanAPITest.select_4d_batch_small [ OK ] VulkanAPITest.select_4d_batch_small (0 ms) [ RUN ] VulkanAPITest.select_4d_batch_medium [ OK ] VulkanAPITest.select_4d_batch_medium (0 ms) [ RUN ] VulkanAPITest.select_4d_batch_large [ OK ] VulkanAPITest.select_4d_batch_large (1 ms) [ RUN ] VulkanAPITest.select_4d_depth_small [ OK ] VulkanAPITest.select_4d_depth_small (1 ms) [ RUN ] VulkanAPITest.select_4d_depth_medium [ OK ] VulkanAPITest.select_4d_depth_medium (0 ms) [ RUN ] VulkanAPITest.select_4d_depth_large [ OK ] VulkanAPITest.select_4d_depth_large (1 ms) [ RUN ] VulkanAPITest.select_4d_height_small [ OK ] VulkanAPITest.select_4d_height_small (0 ms) [ RUN ] VulkanAPITest.select_4d_height_medium [ OK ] VulkanAPITest.select_4d_height_medium (0 ms) [ RUN ] VulkanAPITest.select_4d_height_large [ OK ] VulkanAPITest.select_4d_height_large (1 ms) [ RUN ] VulkanAPITest.select_4d_width_small [ OK ] VulkanAPITest.select_4d_width_small (0 ms) [ RUN ] VulkanAPITest.select_4d_width_medium [ OK ] VulkanAPITest.select_4d_width_medium (0 ms) [ RUN ] VulkanAPITest.select_4d_width_large [ OK ] VulkanAPITest.select_4d_width_large (1 ms) [...bunch of other tests...] [ FAILED ] 7 tests, listed below: [ FAILED ] VulkanAPITest.cat_dim1_singledepth_success [ FAILED ] VulkanAPITest.gru_success [ FAILED ] VulkanAPITest.gru_mclareninputs_success [ FAILED ] VulkanAPITest.gru_prepack_success [ FAILED ] VulkanAPITest.lstm_success [ FAILED ] VulkanAPITest.lstm_mclareninputs_success [ FAILED ] VulkanAPITest.lstm_prepack_success ``` Reviewed By: SS-JIA Differential Revision: D42623181 fbshipit-source-id: 5b42fe7f2ceb3d4d3dddd7a7389ccc343320da7d
1 parent b3a0798 commit 22fd2a7

10 files changed

Lines changed: 549 additions & 69 deletions
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#version 450 core
2+
#define PRECISION $precision
3+
#define FORMAT $format
4+
5+
layout(std430) buffer;
6+
7+
/* Qualifiers: layout - storage - precision - memory */
8+
9+
/*
10+
* Output Image
11+
*/
12+
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
13+
14+
/*
15+
* Input Buffer
16+
*/
17+
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
18+
19+
/*
20+
* Params Buffer
21+
*/
22+
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
23+
// batch_info.x: number of texels per batch
24+
// batch_info.y: index along batch dim to select
25+
ivec2 batch_info;
26+
}
27+
uBlock;
28+
29+
/*
30+
* Local Work Group Size
31+
*/
32+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
33+
34+
void main() {
35+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
36+
const uint src_pos_z = (uBlock.batch_info.y * uBlock.batch_info.x) + pos.z;
37+
imageStore(
38+
uOutput, pos, texelFetch(uInput, ivec3(pos.x, pos.y, src_pos_z), 0));
39+
}

aten/src/ATen/native/vulkan/glsl/select_depth.glsl

Lines changed: 0 additions & 31 deletions
This file was deleted.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#version 450 core
2+
#define PRECISION $precision
3+
#define FORMAT $format
4+
5+
layout(std430) buffer;
6+
7+
/* Qualifiers: layout - storage - precision - memory */
8+
9+
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
10+
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
11+
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
12+
// depth_info.x: output texture x extent
13+
// depth_info.y: output texture y extent
14+
// depth_info.z: output texture z extent
15+
// depth_info.w: output texture w extent
16+
ivec4 depth_info;
17+
}
18+
uBlock;
19+
20+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
21+
22+
void main() {
23+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
24+
25+
if (all(lessThan(pos, uBlock.depth_info.xyz))) {
26+
const int tex = uBlock.depth_info.w / 4;
27+
const int ind = uBlock.depth_info.w % 4;
28+
const float v = texelFetch(uInput, ivec3(pos.x, pos.y, tex), 0)[ind];
29+
30+
imageStore(uOutput, ivec3(pos.x, pos.y, 0), vec4(v, 0, 0, 0));
31+
}
32+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#version 450 core
2+
#define PRECISION $precision
3+
#define FORMAT $format
4+
5+
layout(std430) buffer;
6+
7+
/* Qualifiers: layout - storage - precision - memory */
8+
9+
/*
10+
* Output Image
11+
*/
12+
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
13+
14+
/*
15+
* Input Buffer
16+
*/
17+
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
18+
19+
/*
20+
* Params Buffer
21+
*/
22+
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
23+
// depth_info.x: number of batches
24+
// depth_info.y: number of texels per batch
25+
// depth_info.z: index along channel dim to select
26+
// depth_info.w: zero pad for alignment
27+
ivec4 depth_info;
28+
}
29+
uBlock;
30+
31+
/*
32+
* Local Work Group Size
33+
*/
34+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
35+
36+
void main() {
37+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
38+
// read in the same channel from 4 separate batches
39+
vec4 out_texel = vec4(0, 0, 0, 0);
40+
for (int k = 0; k < 4; k++) {
41+
if ((k + pos.z * 4) >=
42+
uBlock.depth_info.x) { // < 4 batches for this texel, exit early
43+
break;
44+
}
45+
const uint src_pos_z = (4 * uBlock.depth_info.y * pos.z) +
46+
(k * uBlock.depth_info.y) + (uBlock.depth_info.z / 4);
47+
const uint src_pos_t = uBlock.depth_info.z % 4;
48+
out_texel[k] =
49+
texelFetch(uInput, ivec3(pos.x, pos.y, src_pos_z), 0)[src_pos_t];
50+
}
51+
52+
imageStore(uOutput, pos, out_texel);
53+
}

aten/src/ATen/native/vulkan/glsl/select_height.glsl renamed to aten/src/ATen/native/vulkan/glsl/select_height_3d.glsl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@ layout(std430) buffer;
99
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
1010
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
1111
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
12-
ivec3 size;
13-
int index;
14-
} uBlock;
12+
// height_info.x: output texture x extent
13+
// height_info.y: output texture y extent
14+
// height_info.z: output texture z extent
15+
// height_info.w: output texture w extent
16+
ivec4 height_info;
17+
}
18+
uBlock;
1519

1620
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
1721

@@ -21,7 +25,7 @@ void main() {
2125
// w
2226
const int src_x = pos.x;
2327
// h
24-
const int src_y = uBlock.index;
28+
const int src_y = uBlock.height_info.w;
2529
// c
2630
const int src_z = pos.y;
2731

@@ -31,7 +35,7 @@ void main() {
3135
ivec3 new_pos = ivec3(pos.x, pos.y * 4 + i, 0);
3236

3337
// When the C-channel exceeds original block size, exit early
34-
if (new_pos.y >= uBlock.size.y) {
38+
if (new_pos.y >= uBlock.height_info.y) {
3539
return;
3640
}
3741

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#version 450 core
2+
#define PRECISION $precision
3+
#define FORMAT $format
4+
5+
layout(std430) buffer;
6+
7+
/* Qualifiers: layout - storage - precision - memory */
8+
9+
/*
10+
* Output Image
11+
*/
12+
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
13+
14+
/*
15+
* Input Buffer
16+
*/
17+
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
18+
19+
/*
20+
* Params Buffer
21+
*/
22+
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
23+
// height_info.x: number of batches
24+
// height_info.y: number of texels per batch
25+
// height_info.z: index along height dim to select
26+
// height_info.w: zero pad for alignment
27+
ivec4 height_info;
28+
}
29+
uBlock;
30+
31+
/*
32+
* Local Work Group Size
33+
*/
34+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
35+
36+
void main() {
37+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
38+
vec4 out_texel = vec4(0, 0, 0, 0);
39+
// read in the same channel from 4 separate batches
40+
for (int k = 0; k < 4; k++) {
41+
if ((k + pos.z * 4) >=
42+
uBlock.height_info.x) { // < 4 batches for this texel, exit early
43+
break;
44+
}
45+
const uint src_pos_z = (pos.z * uBlock.height_info.y * 4) +
46+
k * uBlock.height_info.y + (pos.y / 4);
47+
out_texel[k] = texelFetch(
48+
uInput, ivec3(pos.x, uBlock.height_info.z, src_pos_z), 0)[pos.y % 4];
49+
}
50+
imageStore(uOutput, pos, out_texel);
51+
}

aten/src/ATen/native/vulkan/glsl/select_width.glsl renamed to aten/src/ATen/native/vulkan/glsl/select_width_3d.glsl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,21 @@ layout(std430) buffer;
99
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
1010
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
1111
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
12-
ivec3 size;
13-
int index;
14-
} uBlock;
12+
// width_info.x: output texture x extent
13+
// width_info.y: output texture y extent
14+
// width_info.z: output texture z extent
15+
// width_info.w: output texture w extent
16+
ivec4 width_info;
17+
}
18+
uBlock;
1519

1620
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
1721

1822
void main() {
1923
const ivec3 pos = ivec3(gl_GlobalInvocationID);
2024

2125
// w
22-
const int src_x = uBlock.index;
26+
const int src_x = uBlock.width_info.w;
2327
// h
2428
const int src_y = pos.x;
2529
// c
@@ -31,7 +35,7 @@ void main() {
3135
ivec3 new_pos = ivec3(pos.x, pos.y * 4 + i, 0);
3236

3337
// When the C-channel exceeds original block size, exit early
34-
if (new_pos.y >= uBlock.size.y) {
38+
if (new_pos.y >= uBlock.width_info.y) {
3539
return;
3640
}
3741

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#version 450 core
2+
#define PRECISION $precision
3+
#define FORMAT $format
4+
5+
layout(std430) buffer;
6+
7+
/* Qualifiers: layout - storage - precision - memory */
8+
9+
/*
10+
* Output Image
11+
*/
12+
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
13+
14+
/*
15+
* Input Buffer
16+
*/
17+
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
18+
19+
/*
20+
* Params Buffer
21+
*/
22+
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
23+
// width_info.x: number of batches
24+
// width_info.y: number of texels per batch
25+
// width_info.z: index along width dim to select
26+
// width_info.w: zero pad for alignment
27+
ivec4 width_info;
28+
}
29+
uBlock;
30+
31+
/*
32+
* Local Work Group Size
33+
*/
34+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
35+
36+
void main() {
37+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
38+
vec4 out_texel = vec4(0, 0, 0, 0);
39+
// read in the same channel from 4 separate batches
40+
for (int k = 0; k < 4; k++) {
41+
if ((k + pos.z * 4) >=
42+
uBlock.width_info.x) { // < 4 batches for this texel, exit early
43+
break;
44+
}
45+
const uint src_pos_z = (pos.z * uBlock.width_info.y * 4) +
46+
k * uBlock.width_info.y + (pos.y / 4);
47+
out_texel[k] = texelFetch(
48+
uInput, ivec3(uBlock.width_info.z, pos.x, src_pos_z), 0)[pos.y % 4];
49+
}
50+
imageStore(uOutput, pos, out_texel);
51+
}

0 commit comments

Comments
 (0)