Skip to content

Enable eager execution#3306

Merged
JackCaoG merged 10 commits intopytorch:masterfrom
aws-rhsoln:eager_execution
Feb 4, 2022
Merged

Enable eager execution#3306
JackCaoG merged 10 commits intopytorch:masterfrom
aws-rhsoln:eager_execution

Conversation

@aws-rhsoln
Copy link
Copy Markdown
Contributor

ML scientists during their initial model development phase tend to use tools like PDB to debug their models. They would step through their models and print arbitary tensors to check their values and see if they look correct. Currently, if the user wants to debug issues in their model by printing output of intermediate layers or use debugger like pdb to step and investigate the tensors, the users incur an expensive intermediate graph compilation and execution. Adding the ability to run operations eagerly allows users to debug operations before running their expensive training jobs. Consider the example

import pdb
pdb.set_trace()
dev = xm.xla_device()
linear1 = torch.nn.Linear(10,30).to(dev)
linear2 = torch.nn.Linear(30,20).to(dev)
linear3 = torch.nn.Linear(20,10).to(dev)
linear4 = torch.nn.Linear(10,1).to(dev)
inp = torch.randn(2,10).to(dev)
output1 = linear1(inp)
output2 = linear2(output1)
PDB ---> xm.mark_step() # Users puts a xm.mark_step in pdb to investigate tensors 
output3 = linear3(output2)
output4 = linear4(output3)
xm.mark_step()
print(output4.to('cpu'))
print(output3.to('cpu')) 

In the above example, there are two graphs, both having a graphsize of 2. In larger models these intermediate graphsizes can be large and incur high compile and execution times. Also, in the next iteration if the user puts the mark_step at different location it would result in a different graph resulting in a cache miss. This would force the user to have fixed breakpoints, thereby putting a constraint. If each op is compiled and executed independently, the position of breakpoint doesn't matter anymore and there would be fixed number of compilations and executions. Moreover, with per op execution, the chances of hitting the cache increases as the layers in large models keep on repeating, thereyby avoiding a graph compile. With the xrt_server preserving the compilation cache across training runs, compiling per op can result in reusable cache, thereby further reducing the compilation time during debug.

Pitch

We would like users to be able to enable eager execution by calling an api. This api needs to be called before any xla_tensor is created. The api is similar to what tensorflow had in 1.15 for enabling eager execution.

import torch_xla.core.xla_model as xm
import os

dev = xm.xla_device()
# Enabling the eager exection. This API needs to be called before
# any XLA tensor is created. The rest of the script can remain 
# unchanged.
xm.enable_eager_execution() 
linear1 = torch.nn.Linear(10,30).to(dev)
linear2 = torch.nn.Linear(30,20).to(dev)
linear3 = torch.nn.Linear(20,10).to(dev)
linear4 = torch.nn.Linear(10,1).to(dev)
inp = torch.randn(2,10).to(dev)
output1 = linear1(inp)
output2 = linear2(output1)
PDB ---> output1.to('cpu')
output3 = linear3(output2)
output4 = linear4(output3)
PDB ---> output4.to('cpu')
xm.mark_step()

The number of graph compilation and executions remain the same and does not depend on where the mark_step is inserted.

Note: Doing per-op execution would result in higher e2e execution time when compared to lazy mode, however for initial development when the number of tensor prints by the user is going to be high, not doing intermediate graph compile should offset this time.

@JackCaoG
Copy link
Copy Markdown
Collaborator

Thanks for contributing. I didn't look at the change in detail yet but what you are trying to achieved is very similar to OpByOp execution model we already support. Take a look at https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#environment-variables.

What OpByOp does is to skip the optimization when doing the compilation and execute the program in a op by op manner. Compilation time will be very fast but execution will be slower.

@aws-rhsoln
Copy link
Copy Markdown
Contributor Author

aws-rhsoln commented Jan 19, 2022

Thanks for contributing. I didn't look at the change in detail yet but what you are trying to achieved is very similar to OpByOp execution model we already support. Take a look at https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#environment-variables.

What OpByOp does is to skip the optimization when doing the compilation and execute the program in a op by op manner. Compilation time will be very fast but execution will be slower.

Yeah it is similar, except that this would help users execute eagerly which they are used to doing in frameworks like PT. The issue with OpByOp mode is that it would break the lazily collected graph into ops when the user wants to print. This would result in re-executions in cases like these:

import torch_xla.core.xla_model as xm
import os

os.environ["XLA_GET_TENSORS_OPBYOP"] = "1"
os.environ["XLA_SYNC_TENSORS_OPBYOP"] = "1"
dev = xm.xla_device()
linear1 = torch.nn.Linear(10,30).to(dev)
linear2 = torch.nn.Linear(30,20).to(dev)
linear3 = torch.nn.Linear(20,10).to(dev)
linear4 = torch.nn.Linear(10,1).to(dev)
inp = torch.randn(2,10).to(dev)
output1 = linear1(inp)
output2 = linear2(output1)
output3 = linear3(output2)
output4 = linear4(output3)
PDB ---> output4.to('cpu')
PDB ---> output3.to('cpu')

It would cut 2 graphs, one having 8 nodes (each linear layer creats 2 nodes) and other having 6 nodes and each graph would then be executed op by op. This would result in duplicate executions.

@JackCaoG
Copy link
Copy Markdown
Collaborator

I felt like what you want is essentially least amount of wait time when user want to inspect a tensor value. Can we just spawn a sub process that call xm.mark_step() every 0.x seconds and set the execution mode to OpByOp. This way it will actively execute the pending graph.

In LTC @wconstab has plan to make Lazy a mode and eager would just means native pytorch eager mode. I don't think this pr fit our long term vision and it adds overhead in our critical path (tensor creation). I am also a bit unsure about the how much user experience improvement this pr will bring. PyTorch/XLA program can emit slightly different result if the kernel is not fused and running program in this eager mode will for sure damaged the performance and memory usage by a lot.

Comment thread torch_xla/csrc/tensor.cpp Outdated
Comment thread torch_xla/csrc/tensor.cpp Outdated
Comment thread torch_xla/csrc/tensor.h Outdated
@wconstab
Copy link
Copy Markdown
Collaborator

I think this approach would work fine for lazy-tensor-core too. We could discuss on the lazy_tensor design doc what kind of user api (env var or otherwise) we want to shoot for. Env var seems simplest for starters for torch-xla.

@aws-rhsoln aws-rhsoln requested a review from JackCaoG January 26, 2022 21:23
@JackCaoG
Copy link
Copy Markdown
Collaborator

JackCaoG commented Feb 1, 2022

@aws-rhsoln There seems to be still some kind of merge conflict

@aws-rhsoln
Copy link
Copy Markdown
Contributor Author

@aws-rhsoln There seems to be still some kind of merge conflict

Will try to resolve it and send in a revision

@aws-rhsoln
Copy link
Copy Markdown
Contributor Author

Ran the BERT-large model for 10 iterations. Here are the metrics:

EAGER DEBUG MODE:
Metric: XrtCompile
  TotalSamples: 358
  Accumulator: 28s857ms591.389us
  Mean: 078ms811.708us
  StdDev: 420ms665.013us
  Rate: 3.69262 / second
  Percentiles: 25%=022ms367.483us; 50%=029ms283.217us; 80%=079ms168.842us; 90%=089ms117.083us; 95%=094ms097.250us; 99%=541ms433.761us
Metric: XrtExecute
  TotalSamples: 68993
  Accumulator: 03m10s751ms891.251us
  Mean: 004ms876.244us
  StdDev: 111ms658.587us
  Rate: 200.191 / second
  Percentiles: 25%=193.963us; 50%=214.762us; 80%=261.520us; 90%=548.940us; 95%=001ms035.784us; 99%=004ms343.308us
  
  LAZY MODE:
Metric: XrtCompile
  TotalSamples: 2
  Accumulator: 16m56s565ms618.863us
  Mean: 08m58s782ms309.432us
  StdDev: 19s235ms806.571us
  Rate: 0.00393881 / second
  Percentiles: 25%=08m39s548ms502.861us; 50%=08m17s017ms116.002us; 80%=08m17s017ms116.002us; 90%=08m17s017ms116.002us; 95%=08m17s017ms116.002us; 99%=08m17s017ms116.002us
Metric: XrtExecute
  TotalSamples: 10
  Accumulator: 02m32s168ms546.572us
  Mean: 09s217ms754.657us
  StdDev: 356ms037.205us
  Rate: 0.0171977 / second
  Percentiles: 25%=09s051ms592.705us; 50%=09s182ms245.882us; 80%=10s600ms077.524us; 90%=10s037ms299.425us; 95%=10s037ms299.425us; 99%=10s037ms299.425us

For 2-layer BERT:

EAGER DEBUG MODE:
Metric: XrtCompile
  TotalSamples: 292
  Accumulator: 11s296ms214.594us
  Mean: 039ms685.666us
  StdDev: 053ms055.274us
  Rate: 9.62879 / second
  Percentiles: 25%=021ms109.436us; 50%=027ms160.091us; 80%=039ms982.515us; 90%=051ms787.316us; 95%=086ms160.113us; 99%=355ms206.725us
Metric: XrtExecute
  TotalSamples: 8537
  Accumulator: 39s930ms102.810us
  Mean: 005ms704.853us
  StdDev: 042ms841.875us
  Rate: 164.159 / second
  Percentiles: 25%=220.880us; 50%=306.633us; 80%=002ms632.986us; 90%=007ms759.590us; 95%=011ms172.809us; 99%=047ms932.984us

LAZY MODE:
Metric: XrtCompile
  TotalSamples: 2
  Accumulator: 19s739ms186.742us
  Mean: 09s370ms593.371us
  StdDev: 371ms599.849us
  Rate: 0.175261 / second
  Percentiles: 25%=09s999ms993.522us; 50%=10s740ms193.220us; 80%=10s740ms193.220us; 90%=10s740ms193.220us; 95%=10s740ms193.220us; 99%=10s740ms193.220us
Metric: XrtExecute
  TotalSamples: 10
  Accumulator: 13s060ms102.048us
  Mean: 01s306ms010.205us
  StdDev: 061ms283.828us
  Rate: 0.453409 / second
  Percentiles: 25%=01s272ms396.594us; 50%=01s282ms046.295us; 80%=01s342ms692.091us; 90%=01s473ms744.644us; 95%=01s473ms744.644us; 99%=01s473ms744.644us

Copy link
Copy Markdown
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly lgtm, minor nits

Comment thread test/run_tests.sh Outdated
Comment thread test/test_operations.py Outdated
Comment thread torch_xla/csrc/tensor.h Outdated
@aws-rhsoln aws-rhsoln requested a review from JackCaoG February 3, 2022 21:32
Comment thread torch_xla/csrc/tensor.cpp
return DeviceContextArena::Get()->GetRunningSeed(device);
}

bool XLATensor::UseEagerDebugMode() {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm.. this doesn't needs to be a class method. I don't want to block you because of this but maybe I will fix it later.

Copy link
Copy Markdown
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @aws-rhsoln !

@JackCaoG JackCaoG merged commit cb9b09a into pytorch:master Feb 4, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants