Skip to content

Add xla.step context manager#7068

Merged
will-cromar merged 7 commits intomasterfrom
wcromar/xla-step
May 17, 2024
Merged

Add xla.step context manager#7068
will-cromar merged 7 commits intomasterfrom
wcromar/xla-step

Conversation

@will-cromar
Copy link
Copy Markdown
Collaborator

@will-cromar will-cromar commented May 15, 2024

See #6751

  • This implementation is intentionally minimal to start with. The main improvement compared to sync is that exceptions are handled sanely.
  • Update README example to use xla.step. Remove ParallelLoader because it mostly does not make a difference for MP, and we should keep our starting point as simple as possible.

@will-cromar will-cromar changed the title [WIP] Add xla.step context manager Add xla.step context manager May 16, 2024
@will-cromar will-cromar requested a review from JackCaoG May 16, 2024 20:09
@will-cromar will-cromar marked this pull request as ready for review May 16, 2024 20:09
Comment thread test/test_devices.py

# Create a DataLoader
dataset = TensorDataset(input_data, target_data)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does DataLoader don't take device as argument?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. In normal PyTorch, you have to move the data with tensor.to: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#training-on-gpu

Comment thread torch_xla/torch_xla.py


@contextlib.contextmanager
def step():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the reason I find step can be a bit confusing is that we don't call mark_step upon entering the step.

with xla.step():
  y = x + z

y += 1

step as a context kind of suggest execution will only cover what happened in side the context manger but that's actually not the case.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. This should either print a warning if there are pending operations, or just mark_step twice. What do you think is better?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's try mark_step twice and benchmark it with one of the examples on resneto50 with fakedata.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to hold off on modifying the examples until we're running tests on them. Here's my patch:

--- a/examples/train_resnet_base.py
+++ b/examples/train_resnet_base.py
@@ -45,15 +45,16 @@ class TrainResNetBase():
     self.model.train()
     loader = itertools.islice(loader, self.num_steps)
     for step, (data, target) in enumerate(loader):
-      self.optimizer.zero_grad()
-      output = self.model(data)
-      loss = self.loss_fn(output, target)
-      loss.backward()
-      self.run_optimizer()
+      with torch_xla.step():
+        self.optimizer.zero_grad()
+        output = self.model(data)
+        loss = self.loss_fn(output, target)
+        loss.backward()
+        self.run_optimizer()
+
       tracker.add(self.batch_size)
       if step % 10 == 0:
-        xm.add_step_closure(
-            self._train_update, args=(step, loss, tracker, epoch))
+        self._train_update(step, loss, tracker, epoch)

Before:

epoch: 1, step: 290, loss: 6.608619213104248, rate: 1747.0911849087843
epoch: 1, step: 290, loss: 6.606635570526123, rate: 1747.0763868012214
epoch: 1, step: 290, loss: 6.618781566619873, rate: 1747.2648104487325
epoch: 1, step: 290, loss: 6.605813980102539, rate: 1746.9924093597208

After:

epoch: 1, step: 290, loss: 6.603261947631836, rate: 1752.4689284654187
epoch: 1, step: 290, loss: 6.607376575469971, rate: 1752.4377415557715
epoch: 1, step: 290, loss: 6.611710071563721, rate: 1752.2556378789855
epoch: 1, step: 290, loss: 6.638012886047363, rate: 1752.400066823619

@will-cromar will-cromar merged commit 3c59087 into master May 17, 2024
zpcore pushed a commit that referenced this pull request May 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants