Skip to content

pytorch/tensordict

Docs Discord Python version GitHub license pypi version Downloads Conda (channel only)

TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or storage and many more. The code-base consists of two main components: TensorDict, a specialized dictionary for PyTorch tensors, and tensorclass, a dataclass for tensors.

from tensordict import TensorDict

data = TensorDict(
    obs=torch.randn(128, 84),
    action=torch.randn(128, 4),
    reward=torch.randn(128, 1),
    batch_size=[128],
)

data_gpu = data.to("cuda")      # all tensors move together
sub = data_gpu[:64]              # all tensors are sliced
stacked = torch.stack([data, data])  # works like a tensor

Key Features | Examples | Installation | Ecosystem | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, dispatching them on the leaves for you.

  • Composability: TensorDict generalizes torch.Tensor operations to collections of tensors. [tutorial]
  • Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile. [tutorial]
  • Shape operations: indexing, slicing, concatenation, reshaping -- everything you can do with a tensor. [tutorial]
  • Distributed / multiprocessed: distribute TensorDict instances across workers, devices and machines. [doc]
  • Serialization and memory-mapping for efficient checkpointing. [doc]
  • Functional programming and compatibility with torch.vmap. [tutorial]
  • Nesting: nest TensorDict instances to create hierarchical structures. [tutorial]
  • Lazy preallocation: preallocate memory without initializing tensors. [tutorial]
  • @tensorclass: a specialized dataclass for torch.Tensor. [tutorial]

Examples

Check our Getting Started guide for a full overview of TensorDict's features.

Before / after

Working with groups of tensors is common in ML. Without a shared structure, every operation must be repeated for each tensor:

# Without TensorDict
obs = obs.to("cuda")
action = action.to("cuda")
reward = reward.to("cuda")
next_obs = next_obs.to("cuda")

obs_batch = obs[:32]
action_batch = action[:32]
reward_batch = reward[:32]
next_obs_batch = next_obs[:32]

With TensorDict, all of that collapses to:

# With TensorDict
data = data.to("cuda")
data_batch = data[:32]

This holds for any operation: reshape, unsqueeze, permute, to, indexing, torch.stack, torch.cat, and many more.

Generic training loops

Using TensorDict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Each step of the training loop -- data loading, model prediction, loss computation -- can be swapped independently without touching the rest. The same loop works across classification, segmentation, RL, and more.

Fast copy on device

By default, device transfers are asynchronous and synchronized only when needed:

td_cuda = TensorDict(**dict_of_tensors, device="cuda")
td_cpu = td_cuda.to("cpu")
td_cpu = td_cuda.to("cpu", non_blocking=False)  # force synchronous

Coding an optimizer

Using TensorDict you can code the Adam optimizer as you would for a single tensor and apply it to a collection of parameters. On CUDA, these operations use fused kernels:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Ecosystem

TensorDict is used across a range of domains:

Domain Projects
Reinforcement Learning TorchRL (PyTorch), DreamerV3-torch, Dreamer4, SkyRL
LLM Post-Training verl, ROLL (Alibaba), LMFlow, LoongFlow (Baidu)
Robotics & Simulation MuJoCo Playground (Google DeepMind), ProtoMotions (NVIDIA), holosoma (Amazon)
Physics & Scientific ML PhysicsNeMo (NVIDIA)
Genomics Medaka (Oxford Nanopore)

Installation

With pip:

pip install tensordict

For the latest features:

pip install tensordict-nightly

With conda:

conda install -c conda-forge tensordict

With uv + PyTorch nightlies:

If you're using a PyTorch nightly, install tensordict with --no-deps to prevent uv from re-resolving torch from PyPI:

uv pip install -e . --no-deps

Or explicitly point uv at the PyTorch nightly wheel index:

uv pip install -e . --prerelease=allow -f "https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html"

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

License

TensorDict is licensed under the MIT License. See LICENSE for details.

About

TensorDict is a pytorch dedicated tensor container.

Resources

License

Code of conduct

Contributing

Stars

Watchers

Forks

Packages

 
 
 

Languages