Skip to content

Apply precision config env vars in the root process.#6152

Merged
golechwierowicz merged 3 commits intomasterfrom
olechwierowicz/apply-env-vars-in-root-process
Dec 15, 2023
Merged

Apply precision config env vars in the root process.#6152
golechwierowicz merged 3 commits intomasterfrom
olechwierowicz/apply-env-vars-in-root-process

Conversation

@golechwierowicz
Copy link
Copy Markdown
Collaborator

After some changes to the main branch, os.environ was not sufficient to pick up new env vars in the subprocess.
In this PR we apply a necessary workaround in the root process which launches subprocess per each experiment. New flags are passed via process_env var.
I tried to keep the experiment_runner.py as clean as possible, and abstracted the new env vars via apply_env method of {Benchmark, TorchBench}Model.

Tested with
PJRT_DEVICE=CUDA python3 new_xla/benchmarks/experiment_runner.py --dynamo=openxla --xla=PJRT --test=eval --filter='hf_Bert$|BERT_pytorch$' --suite-name=torchbench --accelerator=cuda --progress-bar --output-dirname=/tmp/output --repeat=3 --print-subprocess --no-resume --profile-cuda --profile-cuda-dump=/tmp/dumpz --profile-cuda-cpu-collect

For hf_Bert the fp16 casting works, and it shows in the gemm kernels.

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                    triton_gemm_dot_918         0.00%       0.000us         0.00%       0.000us       0.000us     436.000us        17.83%     436.000us      36.333us            12
                                     triton_gemm_dot_75         0.00%       0.000us         0.00%       0.000us       0.000us     299.000us        12.23%     299.000us      27.182us            11
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us     288.000us        11.78%     288.000us      24.000us            12
                                    triton_gemm_dot_128         0.00%       0.000us         0.00%       0.000us       0.000us     256.000us        10.47%     256.000us     256.000us             1
                                    triton_gemm_dot_872         0.00%       0.000us         0.00%       0.000us       0.000us     144.000us         5.89%     144.000us      12.000us            12
ampere_fp16_s16816gemm_fp16_128x64_ldg8_f2f_stages_6...         0.00%       0.000us         0.00%       0.000us       0.000us     143.000us         5.85%     143.000us      11.000us            13
                                    triton_gemm_dot_845         0.00%       0.000us         0.00%       0.000us       0.000us     131.000us         5.36%     131.000us      10.917us            12
                                             fusion_189         0.00%       0.000us         0.00%       0.000us       0.000us     119.000us         4.87%     119.000us       9.917us            12
                                             fusion_232         0.00%       0.000us         0.00%       0.000us       0.000us     104.000us         4.25%     104.000us       8.667us            12
                                             fusion_233         0.00%       0.000us         0.00%       0.000us       0.000us      97.000us         3.97%      97.000us       8.083us            12
                                               fusion_1         0.00%       0.000us         0.00%       0.000us       0.000us      93.000us         3.80%      93.000us      93.000us             1
                                             fusion_187         0.00%       0.000us         0.00%       0.000us       0.000us      85.000us         3.48%      85.000us       3.542us            24
                                             fusion_184         0.00%       0.000us         0.00%       0.000us       0.000us      72.000us         2.94%      72.000us       6.000us            12
void splitKreduce_kernel<32, 16, int, __half, __half...         0.00%       0.000us         0.00%       0.000us       0.000us      48.000us         1.96%      48.000us       4.000us            12
                                             fusion_173         0.00%       0.000us         0.00%       0.000us       0.000us      43.000us         1.76%      43.000us       2.048us            21
                                     triton_gemm_dot_73         0.00%       0.000us         0.00%       0.000us       0.000us      27.000us         1.10%      27.000us      27.000us             1
                                             fusion_230         0.00%       0.000us         0.00%       0.000us       0.000us      12.000us         0.49%      12.000us       1.000us            12
                                             fusion_229         0.00%       0.000us         0.00%       0.000us       0.000us      12.000us         0.49%      12.000us       1.000us            12
                                             fusion_231         0.00%       0.000us         0.00%       0.000us       0.000us      12.000us         0.49%      12.000us       1.000us            12
                                             fusion_194         0.00%       0.000us         0.00%       0.000us       0.000us       4.000us         0.16%       4.000us       4.000us             1
                                               fusion_4         0.00%       0.000us         0.00%       0.000us       0.000us       4.000us         0.16%       4.000us       4.000us             1
                                              fusion_48         0.00%       0.000us         0.00%       0.000us       0.000us       3.000us         0.12%       3.000us       3.000us             1
                                               fusion_2         0.00%       0.000us         0.00%       0.000us       0.000us       3.000us         0.12%       3.000us       3.000us             1
                                                 fusion         0.00%       0.000us         0.00%       0.000us       0.000us       2.000us         0.08%       2.000us       2.000us             1
                                             fusion_192         0.00%       0.000us         0.00%       0.000us       0.000us       2.000us         0.08%       2.000us       2.000us             1
                                             fusion_185         0.00%       0.000us         0.00%       0.000us       0.000us       2.000us         0.08%       2.000us       2.000us             1
                                             fusion_180         0.00%       0.000us         0.00%       0.000us       0.000us       2.000us         0.08%       2.000us       2.000us             1
                         Memcpy HtoD (Pinned -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.04%       1.000us       1.000us             1
                       Memcpy DtoH (Device -> Pageable)         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.04%       1.000us       1.000us             1
                               TorchDynamo Cache Lookup         4.33%     179.000us         4.33%     179.000us      35.800us       0.000us         0.00%       0.000us       0.000us             5
                                  Torch-Compiled Region        77.26%       3.194ms       305.35%      12.623ms       2.525ms       0.000us         0.00%       0.000us       0.000us             5
                                            aten::slice         0.70%      29.000us         1.57%      65.000us      32.500us       0.000us         0.00%       0.000us       0.000us             2
                                       aten::as_strided         0.02%       1.000us         0.02%       1.000us       1.000us       0.000us         0.00%       0.000us       0.000us             1
                                       aten::slice_copy         0.68%      28.000us         0.68%      28.000us      28.000us       0.000us         0.00%       0.000us       0.000us             1
                                            aten::empty         0.12%       5.000us         0.12%       5.000us       1.667us       0.000us         0.00%       0.000us       0.000us             3
                                               aten::to         0.10%       4.000us         8.37%     346.000us     115.333us       0.000us         0.00%       0.000us       0.000us             3
                                         aten::_to_copy         7.57%     313.000us         8.27%     342.000us     171.000us       0.000us         0.00%       0.000us       0.000us             2
                                            aten::copy_         0.27%      11.000us         0.27%      11.000us      11.000us       0.000us         0.00%       0.000us       0.000us             1
                                       aten::lift_fresh         0.07%       3.000us         0.07%       3.000us       3.000us       0.000us         0.00%       0.000us       0.000us             1
                                            aten::index         1.38%      57.000us         1.45%      60.000us      60.000us       0.000us         0.00%       0.000us       0.000us             1
                                    aten::scalar_tensor         0.07%       3.000us         0.07%       3.000us       3.000us       0.000us         0.00%       0.000us       0.000us             1
                                               aten::eq         0.31%      13.000us         0.31%      13.000us      13.000us       0.000us         0.00%       0.000us       0.000us             1
                                              aten::any         0.17%       7.000us         0.17%       7.000us       7.000us       0.000us         0.00%       0.000us       0.000us             1
                                             aten::item         0.07%       3.000us         6.56%     271.000us     271.000us       0.000us         0.00%       0.000us       0.000us             1
                              aten::_local_scalar_dense         0.63%      26.000us         6.48%     268.000us     268.000us       0.000us         0.00%       0.000us       0.000us             1
                                          aten::_to_cpu         5.49%     227.000us         5.85%     242.000us     242.000us       0.000us         0.00%       0.000us       0.000us             1
                                         cuLaunchKernel         0.34%      14.000us         0.34%      14.000us      14.000us       0.000us         0.00%       0.000us       0.000us             1
                                  cudaDeviceSynchronize         0.41%      17.000us         0.41%      17.000us      17.000us       0.000us         0.00%       0.000us       0.000us             1
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 4.134ms
Self CUDA time total: 2.445ms

for BERT_pytorch which does not have fp16 as default we do nothing

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
void cutlass::Kernel<cutlass_80_tensorop_s1688gemm_1...         0.00%       0.000us         0.00%       0.000us       0.000us       4.878ms        37.48%       4.878ms     203.250us            24
                                     triton_gemm_dot_73         0.00%       0.000us         0.00%       0.000us       0.000us       2.272ms        17.46%       2.272ms     189.333us            12
void cutlass::Kernel<cutlass_80_tensorop_s1688gemm_6...         0.00%       0.000us         0.00%       0.000us       0.000us     939.000us         7.21%     939.000us      39.125us            24
                                             fusion_201         0.00%       0.000us         0.00%       0.000us       0.000us     895.000us         6.88%     895.000us      74.583us            12
void cutlass::Kernel<cutlass_80_tensorop_s1688gemm_2...         0.00%       0.000us         0.00%       0.000us       0.000us     775.000us         5.95%     775.000us      64.583us            12
                                             fusion_208         0.00%       0.000us         0.00%       0.000us       0.000us     401.000us         3.08%     401.000us      33.417us            12
                                             fusion_266         0.00%       0.000us         0.00%       0.000us       0.000us     396.000us         3.04%     396.000us      36.000us            11
                                             fusion_221         0.00%       0.000us         0.00%       0.000us       0.000us     276.000us         2.12%     276.000us      23.000us            12
                                             fusion_197         0.00%       0.000us         0.00%       0.000us       0.000us     264.000us         2.03%     264.000us      24.000us            11
                                             fusion_222         0.00%       0.000us         0.00%       0.000us       0.000us     240.000us         1.84%     240.000us      20.000us            12
                                             fusion_212         0.00%       0.000us         0.00%       0.000us       0.000us     231.000us         1.77%     231.000us      19.250us            12
                                             fusion_207         0.00%       0.000us         0.00%       0.000us       0.000us     228.000us         1.75%     228.000us      19.000us            12
                                             fusion_211         0.00%       0.000us         0.00%       0.000us       0.000us     228.000us         1.75%     228.000us      19.000us            12
                                             fusion_206         0.00%       0.000us         0.00%       0.000us       0.000us     224.000us         1.72%     224.000us      18.667us            12
                                             fusion_223         0.00%       0.000us         0.00%       0.000us       0.000us     220.000us         1.69%     220.000us      20.000us            11
                                             fusion_198         0.00%       0.000us         0.00%       0.000us       0.000us     209.000us         1.61%     209.000us      19.000us            11
                                             fusion_202         0.00%       0.000us         0.00%       0.000us       0.000us     172.000us         1.32%     172.000us      14.333us            12
                                             fusion_203         0.00%       0.000us         0.00%       0.000us       0.000us     105.000us         0.81%     105.000us       8.750us            12
                                              fusion_24         0.00%       0.000us         0.00%       0.000us       0.000us      26.000us         0.20%      26.000us      26.000us             1
                                             fusion_267         0.00%       0.000us         0.00%       0.000us       0.000us      14.000us         0.11%      14.000us      14.000us             1
                                             fusion_213         0.00%       0.000us         0.00%       0.000us       0.000us      10.000us         0.08%      10.000us      10.000us             1
                                             fusion_220         0.00%       0.000us         0.00%       0.000us       0.000us       5.000us         0.04%       5.000us       5.000us             1
                                             fusion_214         0.00%       0.000us         0.00%       0.000us       0.000us       5.000us         0.04%       5.000us       5.000us             1
                                             fusion_216         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.01%       1.000us       1.000us             1
                                             fusion_219         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.01%       1.000us       1.000us             1
                               TorchDynamo Cache Lookup        21.59%     177.000us        21.59%     177.000us     177.000us       0.000us         0.00%       0.000us       0.000us             1
                                  Torch-Compiled Region        77.32%     634.000us        77.32%     634.000us     634.000us       0.000us         0.00%       0.000us       0.000us             1
                                  cudaDeviceSynchronize         1.10%       9.000us         1.10%       9.000us       9.000us       0.000us         0.00%       0.000us       0.000us             1
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 820.000us
Self CUDA time total: 13.015ms

Copy link
Copy Markdown
Collaborator

@frgossen frgossen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One comment.

Comment thread benchmarks/benchmark_model.py Outdated
Comment thread benchmarks/torchbench_model.py Outdated
raise ValueError(f"Unknown precision: {precision}")
return precision_flag

def apply_env(self, process_env):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to address in this PR:
Since this is a workaround, wdyt about putting it behind an optional flag?
This helps comparing XLA against inductor but it somewhat obfuscates comparing PyTorch+Inductor against PyTorch/XLA+XLA.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I can do it in the next PR, but I think this flag should be enabled by default. We need to compare models with the same precision ops in the end, otherwise comparisons are kind of meaningless.

@golechwierowicz golechwierowicz force-pushed the olechwierowicz/apply-env-vars-in-root-process branch from 06cb2d6 to 65da637 Compare December 15, 2023 10:46
@golechwierowicz golechwierowicz merged commit dfcf306 into master Dec 15, 2023
cota added a commit to cota/pytorch-xla that referenced this pull request Dec 18, 2023
In "dfcf306e7 Apply precision config env vars in the root process.
(pytorch#6152)" we started running load_benchmark() from experiment_runner's
main process. Unfortunately, load_benchmark() for
pytorch_CycleGAN_and_pix2pix seems to exit the calling process when
using XLA. This results in experiment_runner exiting prematurely.

Work around this issue by adding pytorch_CycleGAN_and_pix2pix to
the deny list, so that experiment_runner does not die early.
Note that pytorch_CycleGAN_and_pix2pix was not running successfully
for XLA before the aforementioned dfcf306 commit, so skipping it
does not reduce coverage.
cota added a commit to cota/pytorch-xla that referenced this pull request Dec 18, 2023
In "dfcf306e7 Apply precision config env vars in the root process.
(pytorch#6152)" we started running load_benchmark() from experiment_runner's
main process. Unfortunately, load_benchmark() for
pytorch_CycleGAN_and_pix2pix seems to exit the calling process when
using XLA. This results in experiment_runner exiting prematurely.

Work around this issue by adding pytorch_CycleGAN_and_pix2pix to
the deny list, so that experiment_runner does not die early.
Note that pytorch_CycleGAN_and_pix2pix was not running successfully
for XLA before the aforementioned dfcf306 commit, so skipping it
does not reduce coverage.
cota added a commit to cota/pytorch-xla that referenced this pull request Dec 18, 2023
In `dfcf306e7 Apply precision config env vars in the root process. (pytorch#6152)`
we started running load_benchmark() from experiment_runner's
main process. Unfortunately, load_benchmark() for
pytorch_CycleGAN_and_pix2pix seems to exit the calling process when
using XLA. This results in experiment_runner exiting prematurely.

Work around this issue by adding pytorch_CycleGAN_and_pix2pix to
the deny list, so that experiment_runner does not die early.
Note that pytorch_CycleGAN_and_pix2pix was not running successfully
for XLA before the aforementioned dfcf306 commit, so skipping it
does not reduce coverage.
cota added a commit to cota/pytorch-xla that referenced this pull request Dec 19, 2023
In `dfcf306e7 Apply precision config env vars in the root process. (pytorch#6152)`
we started running load_benchmark() from experiment_runner's
main process. Unfortunately, load_benchmark() for
some models do exit the calling process when using XLA.
This results in experiment_runner exiting prematurely.

Work around this issue by adding these models to the deny list,
so that experiment_runner does not die early.
Note that these models were not running successfully under XLA
before the aforementioned dfcf306 commit, so skipping them
does not reduce coverage.
cota added a commit to cota/pytorch-xla that referenced this pull request Dec 19, 2023
In `dfcf306e7 Apply precision config env vars in the root process. (pytorch#6152)`
we started running load_benchmark() from experiment_runner's
main process. Unfortunately, load_benchmark() for
some models does exit the calling process when using XLA.
This results in experiment_runner exiting prematurely.

Work around this issue by adding these models to the deny list,
so that experiment_runner does not die early.
Note that these models were not running successfully under XLA
before the aforementioned dfcf306 commit, so skipping them
does not reduce coverage.
cota added a commit to cota/pytorch-xla that referenced this pull request Dec 19, 2023
In `dfcf306e7 Apply precision config env vars in the root process. (pytorch#6152)`
we started running load_benchmark() from experiment_runner's
main process. Unfortunately, load_benchmark() for
some models does exit the calling process when using XLA.
This results in experiment_runner exiting prematurely.

Work around this issue by adding these models to the deny list,
so that experiment_runner does not die early.
Note that these models were not running successfully under XLA
before the aforementioned dfcf306 commit, so skipping them
does not reduce coverage.
cota added a commit to cota/pytorch-xla that referenced this pull request Dec 19, 2023
In `dfcf306e7 Apply precision config env vars in the root process. (pytorch#6152)`
we started running load_benchmark() from experiment_runner's
main process. Unfortunately, load_benchmark() for
some models does exit the calling process when using XLA.
This results in experiment_runner exiting prematurely.

Work around this issue by adding these models to the deny list,
so that experiment_runner does not die early.
Note that these models were not running successfully under XLA
before the aforementioned dfcf306 commit, so skipping them
does not reduce coverage.
cota added a commit to cota/pytorch-xla that referenced this pull request Dec 19, 2023
In `dfcf306e7 Apply precision config env vars in the root process. (pytorch#6152)`
we started running load_benchmark() from experiment_runner's
main process. Unfortunately, load_benchmark() for
some models does exit the calling process, which results
in experiment_runner exiting prematurely.

Work around this issue by adding these models to the deny list,
so that experiment_runner does not die early.
cota added a commit to cota/pytorch-xla that referenced this pull request Dec 19, 2023
In "dfcf306e7 Apply precision config env vars in the root process. (pytorch#6152)"
we started running load_benchmark() from experiment_runner's
main process. Unfortunately, load_benchmark() for
some models does exit the calling process, which results
in experiment_runner exiting prematurely.

Work around this issue by adding these models to the deny list,
so that experiment_runner does not die early.
cota added a commit that referenced this pull request Dec 20, 2023
In "dfcf306e7 Apply precision config env vars in the root process. (#6152)"
we started running load_benchmark() from experiment_runner's
main process. Unfortunately, load_benchmark() for
some models does exit the calling process, which results
in experiment_runner exiting prematurely.

Work around this issue by adding these models to the deny list,
so that experiment_runner does not die early.
mbzomowski pushed a commit to mbzomowski-test-org/xla that referenced this pull request Jan 3, 2024
…ch#6199)

In "dfcf306e7 Apply precision config env vars in the root process. (pytorch#6152)"
we started running load_benchmark() from experiment_runner's
main process. Unfortunately, load_benchmark() for
some models does exit the calling process, which results
in experiment_runner exiting prematurely.

Work around this issue by adding these models to the deny list,
so that experiment_runner does not die early.
golechwierowicz added a commit that referenced this pull request Jan 12, 2024
After some changes to the main branch, os.environ was not sufficient to pick up new env vars in the subprocess.
In this PR we apply a necessary workaround in the root process which launches subprocess per each experiment. New flags are passed via process_env var.
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
In "dfcf306e7 Apply precision config env vars in the root process. (#6152)"
we started running load_benchmark() from experiment_runner's
main process. Unfortunately, load_benchmark() for
some models does exit the calling process, which results
in experiment_runner exiting prematurely.

Work around this issue by adding these models to the deny list,
so that experiment_runner does not die early.
@zpcore zpcore self-requested a review January 24, 2024 08:31
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
After some changes to the main branch, os.environ was not sufficient to pick up new env vars in the subprocess.
In this PR we apply a necessary workaround in the root process which launches subprocess per each experiment. New flags are passed via process_env var.
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
In "dfcf306e7 Apply precision config env vars in the root process. (#6152)"
we started running load_benchmark() from experiment_runner's
main process. Unfortunately, load_benchmark() for
some models does exit the calling process, which results
in experiment_runner exiting prematurely.

Work around this issue by adding these models to the deny list,
so that experiment_runner does not die early.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants