Skip to content

Commit 978ad16

Browse files
lwfacebook-github-bot
authored andcommitted
[TensorPipe] Allow passing args to agent options constructor (#37918)
Summary: Pull Request resolved: #37918 ghstack-source-id: 103569096 Test Plan: Tested top of stack Reviewed By: jiayisuse Differential Revision: D21425537 fbshipit-source-id: 2e78d700ea774944c7fd8b22e152d8e459dd422a
1 parent 4e93844 commit 978ad16

2 files changed

Lines changed: 12 additions & 1 deletion

File tree

torch/csrc/distributed/rpc/init.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,11 @@ If the future completes with an error, an exception is thrown.
448448
// Base class: torch.distributed.rpc.RpcBackendOptions.
449449
py::class_<TensorPipeRpcBackendOptions>(
450450
module, "TensorPipeRpcBackendOptions", rpcBackendOptions)
451-
.def(py::init<>())
451+
.def(
452+
py::init<std::map<std::string, worker_id_t>, float, std::string>(),
453+
py::arg("worker_name_to_id"),
454+
py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds,
455+
py::arg("init_method") = kDefaultInitMethod)
452456
.def_readwrite(
453457
"worker_name_to_id", &TensorPipeRpcBackendOptions::workerNameToId);
454458

torch/csrc/distributed/rpc/tensorpipe_agent.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ namespace distributed {
1515
namespace rpc {
1616

1717
struct TensorPipeRpcBackendOptions : public RpcBackendOptions {
18+
TensorPipeRpcBackendOptions(
19+
std::map<std::string, worker_id_t> worker_name_to_id,
20+
float rpc_timeout,
21+
std::string init_method)
22+
: RpcBackendOptions(rpc_timeout, init_method),
23+
workerNameToId(std::move(worker_name_to_id)) {}
24+
1825
std::map<std::string, worker_id_t> workerNameToId;
1926
};
2027

0 commit comments

Comments
 (0)