Add xla.step context manager#7068
Conversation
xla.step context managerxla.step context manager
|
|
||
| # Create a DataLoader | ||
| dataset = TensorDataset(input_data, target_data) | ||
| loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
There was a problem hiding this comment.
does DataLoader don't take device as argument?
There was a problem hiding this comment.
No. In normal PyTorch, you have to move the data with tensor.to: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#training-on-gpu
|
|
||
|
|
||
| @contextlib.contextmanager | ||
| def step(): |
There was a problem hiding this comment.
the reason I find step can be a bit confusing is that we don't call mark_step upon entering the step.
with xla.step():
y = x + z
y += 1
step as a context kind of suggest execution will only cover what happened in side the context manger but that's actually not the case.
There was a problem hiding this comment.
I agree. This should either print a warning if there are pending operations, or just mark_step twice. What do you think is better?
There was a problem hiding this comment.
let's try mark_step twice and benchmark it with one of the examples on resneto50 with fakedata.
There was a problem hiding this comment.
I'm going to hold off on modifying the examples until we're running tests on them. Here's my patch:
--- a/examples/train_resnet_base.py
+++ b/examples/train_resnet_base.py
@@ -45,15 +45,16 @@ class TrainResNetBase():
self.model.train()
loader = itertools.islice(loader, self.num_steps)
for step, (data, target) in enumerate(loader):
- self.optimizer.zero_grad()
- output = self.model(data)
- loss = self.loss_fn(output, target)
- loss.backward()
- self.run_optimizer()
+ with torch_xla.step():
+ self.optimizer.zero_grad()
+ output = self.model(data)
+ loss = self.loss_fn(output, target)
+ loss.backward()
+ self.run_optimizer()
+
tracker.add(self.batch_size)
if step % 10 == 0:
- xm.add_step_closure(
- self._train_update, args=(step, loss, tracker, epoch))
+ self._train_update(step, loss, tracker, epoch)Before:
epoch: 1, step: 290, loss: 6.608619213104248, rate: 1747.0911849087843
epoch: 1, step: 290, loss: 6.606635570526123, rate: 1747.0763868012214
epoch: 1, step: 290, loss: 6.618781566619873, rate: 1747.2648104487325
epoch: 1, step: 290, loss: 6.605813980102539, rate: 1746.9924093597208
After:
epoch: 1, step: 290, loss: 6.603261947631836, rate: 1752.4689284654187
epoch: 1, step: 290, loss: 6.607376575469971, rate: 1752.4377415557715
epoch: 1, step: 290, loss: 6.611710071563721, rate: 1752.2556378789855
epoch: 1, step: 290, loss: 6.638012886047363, rate: 1752.400066823619
See #6751
syncis that exceptions are handled sanely.xla.step. RemoveParallelLoaderbecause it mostly does not make a difference for MP, and we should keep our starting point as simple as possible.