-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Description
As we know DS flattens individual tensors by optim groups so that each tensor and its attributes can't be accessed by the user once the ds engine takes over.
As @tjruwase is working on the API used by the universal checkpoint to get/set fragments of data/grads/optim states across multiple gpus, I think the next stage is to override some accessors in those flattened tensors to transparently get/set data on those.
For example weight.grad could behind the scenes get/set the sharded fragments - allowing users normal access to important data. e.g. researchers often need to access grads. optim states is another really important data that needs to be accessible. e.g. we wanted this to diagnose spikes during the training, but alas that data wasn't accessible.
And of course should look at what other accessors could be overriden and ideally the flattened tensors would appear to the user as normal tensors.
Thank you!