🚀 The feature, motivation and pitch
Currently, PyTorch supports CUDA Graph features under torch.cuda namespace, providing capture, replay functionality via C++ CUDAGraph implementation. This RFC proposes to generalize the existing graph implementation into a new AcceleratorGraph structure to extend graph support to more backends. The goal is to decouple CUDA-specific logic and introduce a unified, extensible interface that enables consistent graph functionality regardless of the underlying hardware.
Alternatives
1. overview
We introduce new graph related APIs under torch.accelerator module, which provides the same functionality as the current torch.cuda interfaces. These new APIs are device-agnostic and intended as the primary entry point for graph features. Internally, the appropriate backend (e.g., CUDA, XPU) will be selected based on current device or the arguments passed by user. The existing torch.cuda APIs will continue to be supported to ensure backward compatibility. Here are new APIs usage:
# current APIs
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
# g.capture_begin(), g.capture_end()
# run model
model(input)
g.replay()
# new APIs
g = torch.accelerator.Graph(backend='cuda')
with torch.accelerator.graph(g):
# run model
model(input)
g.replay()
The new Graph design builds upon the original CUDA Graph implementation. In the current architecture, a C++ level CUDAGraph class implements core graph features provided by CUDA.
The generalization follows factory pattern, where users instantiate a generic graph object that encapsulates device-specific graph implementations. The diagram below illustrates the detailed design.
2. Python frontend
accelerator.Graph class
accelerator.Graph is a high-level interface that constructs a general _AcceleratorGraph instance.
This instance internally holds a device-specific graph implementation, selected based on the provided backend argument or inferred from the current device if unspecified.
The _AcceleratorGraph class is bound to its C++ counterpart and exposes methods such as capture_begin(), capture_end(), and replay() to the Python layer.
class Graph:
def __new__(cls, backend, keep_graph=False):
if backend:
return torch._C._AcceleratorGraph(backend.upper(), keep_graph)
else :
return torch._C._AcceleratorGraph(torch.accelerator.current_accelerator().type.upper(), keep_graph)
CUDAGraph class
a subclass of torch._C._AcceleratorGraph, inheriting all its functionality and exposing the same set of graph-related methods.
class CUDAGraph(torch.accelerator.Graph):
def __new__(cls, keep_graph=False):
return super().__new__(cls, "CUDA", keep_graph)
torch.accelerator.graph class
A context management that wraps core torch.accelerator.Graph methods.
class graph:
def __init__(graph):
self.graph_ = graph
def __enter__(self):
self.graph_.capture_begin()
def __exit__(self):
self.graph_.capture_end()
3. C++ level classes
We defined an abstract interface GraphImplInterface, which declares but does not implement graph-related methods. Each backend-specific graph implementation, such as CUDAGraph, inherits from GraphImplInterface and provides concrete implementations of these methods using the corresponding device's runtime API.
We register object creators for each backend-specific graph class in a registry table called AccGraphRegistry. This registry supports extensibility, allowing out-of-tree device backends to register their own graph creators.
The overall flow, using CUDAGraph as an example, is as follows:
AcceleratorGraph acts as the general entry point for graph functionality.
- It retrieves the appropriate graph creator from
AccGraphRegistry based on the target backend.
- The creator returns a
CUDAGraph instance, which also conforms to the GraphImplInterface.
- Through the abstract interface (
GraphImplInterface), users can invoke the backend-specific implementations (e.g., capture_begin, capture_end, replay) without directly depending on the concrete class.
GraphImplInterface class
An abstract interface that defines the core operations required for graph management in accelerator backends. Any concrete graph implementation for a specific device (such as CUDA) must inherit from this interface and provide device-specific logic for these operations. This abstraction enables a unified API for graph-based computation across different hardware accelerators.
struct GraphImplInterface {
virtual ~GraphImplInterface() = default;
virtual void capture_begin(MempoolId_t pool = {0, 0}, StreamCaptureMode mode=StreamCaptureMode::Global) = 0;
virtual void capture_end() = 0;
virtual void replay() = 0;
// ...
};
CUDAGraph class
inherits from base class GraphImplInterface. Keeps its current implementation.
struct CUDAGraph : public GraphImplInterface {
CUDAGraph(bool keep_graph=false);
~CUDAGraph();
void register_generator_state(c10::intrusive_ptr<at::CUDAGeneratorState> state);
void register_generator_state(const at::Generator& generator);
void capture_begin(
MempoolId_t pool = {0, 0},
cudaStreamCaptureMode capture_mode = cudaStreamCaptureModeGlobal);
void capture_end();
void replay();
cudaGraph_t raw_cuda_graph();
// ...
};
AcceleratorGraph class
a high-level wrapper class that provides a unified interface for graph operations on accelerator devices. It holds a std::unique_ptr<GraphImplInterface>, which points to the appropriate device-specific implementation (such as CUDAGraph).
struct AcceleratorGraph {
public:
AcceleratorGraph(const std::string& device_type, bool keep_graph = false)
: graph_(getGraph(device_type)) {}
~AcceleratorGraph() = default;
void capture_begin(MempoolId_t pool = {0, 0}, StreamCaptureMode mode=StreamCaptureMode::Global) {
graph_->capture_begin(pool, mode);
}
void capture_end() {
graph_->capture_end();
}
void replay() {
graph_->replay();
}
// ...
protected:
std::unique_ptr<GraphImplInterface> graph_;
};
Register creator
Factory registry for dynamic backend registration
TORCH_DECLARE_REGISTRY(DeviceGraphsRegistry, at::GraphImplInterface);
C10_DEFINE_REGISTRY(DeviceGraphsRegistry, at::GraphImplInterface);
// register cuda graph
C10_REGISTER_CLASS(DeviceGraphsRegistry, CUDA, at::cuda::CUDAGraph);
Additional context
1. Interface generalization
torch.cuda.is_current_stream_capturing()
This API checks whether the current stream is under capture. Since this behavior is inherently tied to the stream itself, we prefer to expose it as a method on the stream object. We will implement it under Stream scope. The new api will be torch.accelerator.current_stream.is_capturing().
torch.cuda.graph_pool_handle()
This API is a member function of Mempool class, which manages device cachingAllocator during graph capture. With the refactoring of cachingAllocator, Mempool will be abstracted to a device-agnostic component.
This API will be exposed through a unified interface torch.accelerator.graph_pool_handle().
torch.cuda.make_graphed_callables()
This API enables graph features on callables. Its current implementation based on torch.cuda.graph API. We will generalize this to torch.accelerator.make_graphed_callables() with generalized torch.accelerator.graph API.
The existing APIs torch.cuda will be retained for backward compatibility. Internally, it will delegate to the new torch.accelerator method.
cc @NmomoN @mengpenghui @fwenguang @cdzhan @1274085042 @PHLens @albanD @guangyey @EikanWang
🚀 The feature, motivation and pitch
Currently, PyTorch supports CUDA Graph features under torch.cuda namespace, providing capture, replay functionality via C++
CUDAGraphimplementation. This RFC proposes to generalize the existing graph implementation into a newAcceleratorGraphstructure to extend graph support to more backends. The goal is to decouple CUDA-specific logic and introduce a unified, extensible interface that enables consistent graph functionality regardless of the underlying hardware.Alternatives
1. overview
We introduce new graph related APIs under
torch.acceleratormodule, which provides the same functionality as the currenttorch.cudainterfaces. These new APIs are device-agnostic and intended as the primary entry point for graph features. Internally, the appropriate backend (e.g., CUDA, XPU) will be selected based on current device or the arguments passed by user. The existingtorch.cudaAPIs will continue to be supported to ensure backward compatibility. Here are new APIs usage:The new Graph design builds upon the original CUDA Graph implementation. In the current architecture, a C++ level
CUDAGraphclass implements core graph features provided by CUDA.The generalization follows factory pattern, where users instantiate a generic graph object that encapsulates device-specific graph implementations. The diagram below illustrates the detailed design.
2. Python frontend
accelerator.Graph class
accelerator.Graphis a high-level interface that constructs a general_AcceleratorGraphinstance.This instance internally holds a device-specific graph implementation, selected based on the provided
backendargument or inferred from the current device if unspecified.The
_AcceleratorGraphclass is bound to its C++ counterpart and exposes methods such ascapture_begin(),capture_end(), andreplay()to the Python layer.CUDAGraph class
a subclass of
torch._C._AcceleratorGraph, inheriting all its functionality and exposing the same set of graph-related methods.torch.accelerator.graph class
A context management that wraps core
torch.accelerator.Graphmethods.3. C++ level classes
We defined an abstract interface
GraphImplInterface, which declares but does not implement graph-related methods. Each backend-specific graph implementation, such asCUDAGraph, inherits fromGraphImplInterfaceand provides concrete implementations of these methods using the corresponding device's runtime API.We register object creators for each backend-specific graph class in a registry table called
AccGraphRegistry. This registry supports extensibility, allowing out-of-tree device backends to register their own graph creators.The overall flow, using
CUDAGraphas an example, is as follows:AcceleratorGraphacts as the general entry point for graph functionality.AccGraphRegistrybased on the target backend.CUDAGraphinstance, which also conforms to theGraphImplInterface.GraphImplInterface), users can invoke the backend-specific implementations (e.g.,capture_begin,capture_end,replay) without directly depending on the concrete class.GraphImplInterface class
An abstract interface that defines the core operations required for graph management in accelerator backends. Any concrete graph implementation for a specific device (such as
CUDA) must inherit from this interface and provide device-specific logic for these operations. This abstraction enables a unified API for graph-based computation across different hardware accelerators.CUDAGraph class
inherits from base class GraphImplInterface. Keeps its current implementation.
AcceleratorGraph class
a high-level wrapper class that provides a unified interface for graph operations on accelerator devices. It holds a
std::unique_ptr<GraphImplInterface>, which points to the appropriate device-specific implementation (such asCUDAGraph).Register creator
Factory registry for dynamic backend registration
Additional context
1. Interface generalization
torch.cuda.is_current_stream_capturing()
This API checks whether the current stream is under capture. Since this behavior is inherently tied to the stream itself, we prefer to expose it as a method on the stream object. We will implement it under Stream scope. The new api will be
torch.accelerator.current_stream.is_capturing().torch.cuda.graph_pool_handle()
This API is a member function of Mempool class, which manages device cachingAllocator during graph capture. With the refactoring of cachingAllocator, Mempool will be abstracted to a device-agnostic component.
This API will be exposed through a unified interface
torch.accelerator.graph_pool_handle().torch.cuda.make_graphed_callables()
This API enables graph features on callables. Its current implementation based on
torch.cuda.graphAPI. We will generalize this totorch.accelerator.make_graphed_callables()with generalizedtorch.accelerator.graphAPI.The existing APIs
torch.cudawill be retained for backward compatibility. Internally, it will delegate to the newtorch.acceleratormethod.cc @NmomoN @mengpenghui @fwenguang @cdzhan @1274085042 @PHLens @albanD @guangyey @EikanWang