Skip to content

[Core Aten Op] bring aten_replication_pad3d test back#6537

Merged
ManfeiBai merged 1 commit intomasterfrom
ManfeiBai-patch-69
Feb 15, 2024
Merged

[Core Aten Op] bring aten_replication_pad3d test back#6537
ManfeiBai merged 1 commit intomasterfrom
ManfeiBai-patch-69

Conversation

@ManfeiBai
Copy link
Copy Markdown
Collaborator

@ManfeiBai ManfeiBai commented Feb 15, 2024

Fixes #5892


Testing

two tests pass directly:

# pytest test/test_core_aten_ops.py -k test_aten_replication_pad3d_0
=========================================== test session starts ===========================================
platform linux -- Python 3.10.13, pytest-8.0.0, pluggy-1.4.0
rootdir: /root/pytorch
configfile: pytest.ini
plugins: hypothesis-6.97.4
collected 499 items / 498 deselected / 1 selected                                                         

test/test_core_aten_ops.py WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1707956891.272721  694064 pjrt_api.cc:100] GetPjrtApi was found for tpu at /root/miniconda3/envs/torch310/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1707956891.272807  694064 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1707956891.272820  694064 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
.                                                                        [100%]

==================================== 1 passed, 498 deselected in 6.90s ====================================
# pytest test/test_core_aten_ops.py -k test_aten_replication_pad3d_1
=========================================== test session starts ===========================================
platform linux -- Python 3.10.13, pytest-8.0.0, pluggy-1.4.0
rootdir: /root/pytorch
configfile: pytest.ini
plugins: hypothesis-6.97.4
collected 499 items / 498 deselected / 1 selected                                                         

test/test_core_aten_ops.py WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1707956908.792722  695638 pjrt_api.cc:100] GetPjrtApi was found for tpu at /root/miniconda3/envs/torch310/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1707956908.792808  695638 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1707956908.792822  695638 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
.                                                                        [100%]

==================================== 1 passed, 498 deselected in 6.97s ====================================

@ManfeiBai ManfeiBai marked this pull request as ready for review February 15, 2024 00:36
@ManfeiBai ManfeiBai requested review from qihqi and wonjoo-wj February 15, 2024 00:37
Copy link
Copy Markdown
Collaborator

@wonjoo-wj wonjoo-wj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! LGTM pending CI.

@ManfeiBai ManfeiBai merged commit b7c760d into master Feb 15, 2024
@wonjoo-wj
Copy link
Copy Markdown
Collaborator

@ManfeiBai, looking back at this I actually didn't see a lowering for aten_replication_pad3d in torch_xla. I assume this test succeed because the op fell back to CPU?

@ManfeiBai
Copy link
Copy Markdown
Collaborator Author

@ManfeiBai, looking back at this I actually didn't see a lowering for aten_replication_pad3d in torch_xla. I assume this test succeed because the op fell back to CPU?

thanks for confirming, @wonjoolee95, let me confirm it locally too

@wonjoo-wj
Copy link
Copy Markdown
Collaborator

What you can do is just have a simple python to run that op in torch_xla. And you can print the metric check to see if the aten:: or xla:: version got called. Let me know if you need more help.

@ManfeiBai
Copy link
Copy Markdown
Collaborator Author

ManfeiBai commented Feb 15, 2024

What you can do is just have a simple python to run that op in torch_xla. And you can print the metric check to see if the aten:: or xla:: version got called. Let me know if you need more help.

one simple test locally:

# PJRT_DEVICE=TPU python
Python 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import torch_xla
>>> a = torch.randn((1, 3, 2, 10)).to(torch.float32)
>>> b = [1, 1, 1, 1, 1, 1, ]
>>> kwargs = dict()
>>> 
>>> torch.ops.aten.replication_pad3d(a, b)
tensor([[[[-0.0815, -0.0815, -0.7422,  0.5998, -0.2836,  0.1122,  0.5141,
           -0.0043,  0.7604, -0.5636, -0.5118, -0.5118],
          [-0.0815, -0.0815, -0.7422,  0.5998, -0.2836,  0.1122,  0.5141,
           -0.0043,  0.7604, -0.5636, -0.5118, -0.5118],
          [ 0.2472,  0.2472, -0.1091,  1.1298, -0.7253,  0.0215,  0.6112,
            0.8709,  1.2762, -0.4532, -1.9455, -1.9455],
          [ 0.2472,  0.2472, -0.1091,  1.1298, -0.7253,  0.0215,  0.6112,
            0.8709,  1.2762, -0.4532, -1.9455, -1.9455]],

         [[-0.0815, -0.0815, -0.7422,  0.5998, -0.2836,  0.1122,  0.5141,
           -0.0043,  0.7604, -0.5636, -0.5118, -0.5118],
          [-0.0815, -0.0815, -0.7422,  0.5998, -0.2836,  0.1122,  0.5141,
           -0.0043,  0.7604, -0.5636, -0.5118, -0.5118],
          [ 0.2472,  0.2472, -0.1091,  1.1298, -0.7253,  0.0215,  0.6112,
            0.8709,  1.2762, -0.4532, -1.9455, -1.9455],
          [ 0.2472,  0.2472, -0.1091,  1.1298, -0.7253,  0.0215,  0.6112,
            0.8709,  1.2762, -0.4532, -1.9455, -1.9455]],

         [[ 1.0400,  1.0400, -1.7993,  1.6323,  1.0308, -0.3612,  0.1286,
           -0.1385,  2.5149, -0.7805, -1.4634, -1.4634],
          [ 1.0400,  1.0400, -1.7993,  1.6323,  1.0308, -0.3612,  0.1286,
           -0.1385,  2.5149, -0.7805, -1.4634, -1.4634],
          [-0.7135, -0.7135,  0.2299,  0.6553,  0.8867, -0.4578, -0.4595,
            0.3352, -0.1528, -0.1920,  0.3373,  0.3373],
          [-0.7135, -0.7135,  0.2299,  0.6553,  0.8867, -0.4578, -0.4595,
            0.3352, -0.1528, -0.1920,  0.3373,  0.3373]],

         [[ 0.8732,  0.8732,  0.5287,  0.5531, -0.7906,  0.2093, -0.5855,
           -0.4209,  0.3002,  2.6657,  0.3340,  0.3340],
          [ 0.8732,  0.8732,  0.5287,  0.5531, -0.7906,  0.2093, -0.5855,
           -0.4209,  0.3002,  2.6657,  0.3340,  0.3340],
          [ 0.7902,  0.7902,  0.4101,  0.8244, -0.7611,  0.0275,  0.6254,
            0.0372,  0.2531,  0.1178,  2.4389,  2.4389],
          [ 0.7902,  0.7902,  0.4101,  0.8244, -0.7611,  0.0275,  0.6254,
            0.0372,  0.2531,  0.1178,  2.4389,  2.4389]],

         [[ 0.8732,  0.8732,  0.5287,  0.5531, -0.7906,  0.2093, -0.5855,
           -0.4209,  0.3002,  2.6657,  0.3340,  0.3340],
          [ 0.8732,  0.8732,  0.5287,  0.5531, -0.7906,  0.2093, -0.5855,
           -0.4209,  0.3002,  2.6657,  0.3340,  0.3340],
          [ 0.7902,  0.7902,  0.4101,  0.8244, -0.7611,  0.0275,  0.6254,
            0.0372,  0.2531,  0.1178,  2.4389,  2.4389],
          [ 0.7902,  0.7902,  0.4101,  0.8244, -0.7611,  0.0275,  0.6254,
            0.0372,  0.2531,  0.1178,  2.4389,  2.4389]]]])
>>> 
>>> import torch_xla.core.xla_model as xm
>>> c = a.to(xm.xla_device())
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1708039355.708649  753016 pjrt_api.cc:100] GetPjrtApi was found for tpu at /root/miniconda3/envs/torch310/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1708039355.708730  753016 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1708039355.708754  753016 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
>>> d = b.to(xm.xla_device())
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: 'list' object has no attribute 'to'
>>> torch.ops.aten.replication_pad3d(c, b)
tensor([[[[-0.0815, -0.0815, -0.7422,  0.5998, -0.2836,  0.1122,  0.5141,
           -0.0043,  0.7604, -0.5636, -0.5118, -0.5118],
          [-0.0815, -0.0815, -0.7422,  0.5998, -0.2836,  0.1122,  0.5141,
           -0.0043,  0.7604, -0.5636, -0.5118, -0.5118],
          [ 0.2472,  0.2472, -0.1091,  1.1298, -0.7253,  0.0215,  0.6112,
            0.8709,  1.2762, -0.4532, -1.9455, -1.9455],
          [ 0.2472,  0.2472, -0.1091,  1.1298, -0.7253,  0.0215,  0.6112,
            0.8709,  1.2762, -0.4532, -1.9455, -1.9455]],

         [[-0.0815, -0.0815, -0.7422,  0.5998, -0.2836,  0.1122,  0.5141,
           -0.0043,  0.7604, -0.5636, -0.5118, -0.5118],
          [-0.0815, -0.0815, -0.7422,  0.5998, -0.2836,  0.1122,  0.5141,
           -0.0043,  0.7604, -0.5636, -0.5118, -0.5118],
          [ 0.2472,  0.2472, -0.1091,  1.1298, -0.7253,  0.0215,  0.6112,
            0.8709,  1.2762, -0.4532, -1.9455, -1.9455],
          [ 0.2472,  0.2472, -0.1091,  1.1298, -0.7253,  0.0215,  0.6112,
            0.8709,  1.2762, -0.4532, -1.9455, -1.9455]],

         [[ 1.0400,  1.0400, -1.7993,  1.6323,  1.0308, -0.3612,  0.1286,
           -0.1385,  2.5149, -0.7805, -1.4634, -1.4634],
          [ 1.0400,  1.0400, -1.7993,  1.6323,  1.0308, -0.3612,  0.1286,
           -0.1385,  2.5149, -0.7805, -1.4634, -1.4634],
          [-0.7135, -0.7135,  0.2299,  0.6553,  0.8867, -0.4578, -0.4595,
            0.3352, -0.1528, -0.1920,  0.3373,  0.3373],
          [-0.7135, -0.7135,  0.2299,  0.6553,  0.8867, -0.4578, -0.4595,
            0.3352, -0.1528, -0.1920,  0.3373,  0.3373]],

         [[ 0.8732,  0.8732,  0.5287,  0.5531, -0.7906,  0.2093, -0.5855,
           -0.4209,  0.3002,  2.6657,  0.3340,  0.3340],
          [ 0.8732,  0.8732,  0.5287,  0.5531, -0.7906,  0.2093, -0.5855,
           -0.4209,  0.3002,  2.6657,  0.3340,  0.3340],
          [ 0.7902,  0.7902,  0.4101,  0.8244, -0.7611,  0.0275,  0.6254,
            0.0372,  0.2531,  0.1178,  2.4389,  2.4389],
          [ 0.7902,  0.7902,  0.4101,  0.8244, -0.7611,  0.0275,  0.6254,
            0.0372,  0.2531,  0.1178,  2.4389,  2.4389]],

         [[ 0.8732,  0.8732,  0.5287,  0.5531, -0.7906,  0.2093, -0.5855,
           -0.4209,  0.3002,  2.6657,  0.3340,  0.3340],
          [ 0.8732,  0.8732,  0.5287,  0.5531, -0.7906,  0.2093, -0.5855,
           -0.4209,  0.3002,  2.6657,  0.3340,  0.3340],
          [ 0.7902,  0.7902,  0.4101,  0.8244, -0.7611,  0.0275,  0.6254,
            0.0372,  0.2531,  0.1178,  2.4389,  2.4389],
          [ 0.7902,  0.7902,  0.4101,  0.8244, -0.7611,  0.0275,  0.6254,
            0.0372,  0.2531,  0.1178,  2.4389,  2.4389]]]], device='xla:0')
>>> 

it looks like the second generated result is a tensor on XLA device too, does that means the op not fell back to CPU?
btw, do we have guidence of how to print metric for that test case?

@wonjoo-wj
Copy link
Copy Markdown
Collaborator

You can import import torch_xla.debug.metrics as met and do print(met.metrics_report()). Can you paste the output to the metric print?

@ManfeiBai
Copy link
Copy Markdown
Collaborator Author

print(met.metrics_report())

thanks, pasted the output in https://gist.github.com/ManfeiBai/36799aa983d518a2233337dc701d294b

@ManfeiBai
Copy link
Copy Markdown
Collaborator Author

thanks, dumped the similar example code of 2d and HLO of 2d has dumped logic, this 3d example need to have same logic of op lowering like 2d too

@wonjoo-wj
Copy link
Copy Markdown
Collaborator

Yep, so it needs an actual lowering. We already lower the reflection pad 2d op so I think the logic may be somewhat similar. Let me know if you have any questions. Meanwhile, I'll revert this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Core ATen Opset] Lower aten_replication_pad3d

2 participants