TensorDict is a dictionary-like class that inherits properties from tensors,
such as indexing, shape operations, casting to device or storage and many more.
The code-base consists of two main components: TensorDict,
a specialized dictionary for PyTorch tensors, and tensorclass,
a dataclass for tensors.
from tensordict import TensorDict
data = TensorDict(
obs=torch.randn(128, 84),
action=torch.randn(128, 4),
reward=torch.randn(128, 1),
batch_size=[128],
)
data_gpu = data.to("cuda") # all tensors move together
sub = data_gpu[:64] # all tensors are sliced
stacked = torch.stack([data, data]) # works like a tensorKey Features | Examples | Installation | Ecosystem | Citation | License
TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, dispatching them on the leaves for you.
- Composability:
TensorDictgeneralizestorch.Tensoroperations to collections of tensors. [tutorial] - Speed: asynchronous transfer to device, fast node-to-node communication through
consolidate, compatible withtorch.compile. [tutorial] - Shape operations: indexing, slicing, concatenation, reshaping -- everything you can do with a tensor. [tutorial]
- Distributed / multiprocessed: distribute TensorDict instances across workers, devices and machines. [doc]
- Serialization and memory-mapping for efficient checkpointing. [doc]
- Functional programming and compatibility with
torch.vmap. [tutorial] - Nesting: nest TensorDict instances to create hierarchical structures. [tutorial]
- Lazy preallocation: preallocate memory without initializing tensors. [tutorial]
@tensorclass: a specialized dataclass fortorch.Tensor. [tutorial]
Check our Getting Started guide for a full overview of TensorDict's features.
Working with groups of tensors is common in ML. Without a shared structure, every operation must be repeated for each tensor:
# Without TensorDict
obs = obs.to("cuda")
action = action.to("cuda")
reward = reward.to("cuda")
next_obs = next_obs.to("cuda")
obs_batch = obs[:32]
action_batch = action[:32]
reward_batch = reward[:32]
next_obs_batch = next_obs[:32]With TensorDict, all of that collapses to:
# With TensorDict
data = data.to("cuda")
data_batch = data[:32]This holds for any operation: reshape, unsqueeze, permute, to, indexing,
torch.stack, torch.cat, and many more.
Using TensorDict primitives, most supervised training loops can be rewritten in a generic way:
for i, data in enumerate(dataset):
data = model(data)
loss = loss_module(data)
loss.backward()
optimizer.step()
optimizer.zero_grad()Each step of the training loop -- data loading, model prediction, loss computation -- can be swapped independently without touching the rest. The same loop works across classification, segmentation, RL, and more.
By default, device transfers are asynchronous and synchronized only when needed:
td_cuda = TensorDict(**dict_of_tensors, device="cuda")
td_cpu = td_cuda.to("cpu")
td_cpu = td_cuda.to("cpu", non_blocking=False) # force synchronousUsing TensorDict you can code the Adam optimizer as you would for a single tensor and apply it to a collection of parameters. On CUDA, these operations use fused kernels:
class Adam:
def __init__(self, weights: TensorDict, alpha: float=1e-3,
beta1: float=0.9, beta2: float=0.999,
eps: float = 1e-6):
weights = weights.lock_()
self.weights = weights
self.t = 0
self._mu = weights.data.clone()
self._sigma = weights.data.mul(0.0)
self.beta1 = beta1
self.beta2 = beta2
self.alpha = alpha
self.eps = eps
def step(self):
self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
self.t += 1
mu = self._mu.div_(1-self.beta1**self.t)
sigma = self._sigma.div_(1 - self.beta2 ** self.t)
self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))TensorDict is used across a range of domains:
| Domain | Projects |
|---|---|
| Reinforcement Learning | TorchRL (PyTorch), DreamerV3-torch, Dreamer4, SkyRL |
| LLM Post-Training | verl, ROLL (Alibaba), LMFlow, LoongFlow (Baidu) |
| Robotics & Simulation | MuJoCo Playground (Google DeepMind), ProtoMotions (NVIDIA), holosoma (Amazon) |
| Physics & Scientific ML | PhysicsNeMo (NVIDIA) |
| Genomics | Medaka (Oxford Nanopore) |
With pip:
pip install tensordictFor the latest features:
pip install tensordict-nightlyWith conda:
conda install -c conda-forge tensordictWith uv + PyTorch nightlies:
If you're using a PyTorch nightly, install tensordict with --no-deps to prevent
uv from re-resolving torch from PyPI:
uv pip install -e . --no-depsOr explicitly point uv at the PyTorch nightly wheel index:
uv pip install -e . --prerelease=allow -f "https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html"If you're using TensorDict, please refer to this BibTeX entry to cite this work:
@misc{bou2023torchrl,
title={TorchRL: A data-driven decision-making library for PyTorch},
author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
year={2023},
eprint={2306.00577},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
TensorDict is licensed under the MIT License. See LICENSE for details.