Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions third_party/xla_client/computation_client.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "tensorflow/compiler/xla/xla_client/computation_client.h"

#include <algorithm>
#include <atomic>
#include <cstdlib>
#include <fstream>
#include <map>
Expand Down Expand Up @@ -265,11 +264,6 @@ ComputationClient* ComputationClient::Get() {
return computation_client;
}

int64 ComputationClient::GetNextDataId() {
static std::atomic<int64>* id_generator = new std::atomic<int64>(1);
return id_generator->fetch_add(1);
}

metrics::Metric* ComputationClient::TransferToServerMetric() {
static metrics::Metric* metric =
new metrics::Metric("TransferToServerTime", metrics::MetricFnTime);
Expand Down
14 changes: 5 additions & 9 deletions third_party/xla_client/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,24 @@ class ComputationClient {
public:
class Data {
public:
using OpaqueHandle = int64;

Data(string device, Shape shape)
: unique_id_(GetNextDataId()),
device_(std::move(device)),
shape_(std::move(shape)) {}
: device_(std::move(device)), shape_(std::move(shape)) {}

virtual ~Data() {}

int64 unique_id() const { return unique_id_; }

const string& device() const { return device_; }

const Shape& shape() const { return shape_; }

virtual OpaqueHandle GetOpaqueHandle() = 0;

virtual void Assign(const Data& data) = 0;

virtual bool HasValue() const = 0;

private:
int64 unique_id_ = 0;
string device_;
Shape shape_;
};
Expand Down Expand Up @@ -252,9 +251,6 @@ class ComputationClient {
static ComputationClient* Get();

protected:
// Generates a new unique ID for a Data object.
static int64 GetNextDataId();

// Metrics common to all client intrfaces.
static metrics::Metric* TransferToServerMetric();
static metrics::Metric* TransferFromServerMetric();
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla_client/xrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class XrtComputationClient : public ComputationClient {

int64 get_handle() const { return handle_ptr->handle; }

OpaqueHandle GetOpaqueHandle() override { return get_handle(); }

void Assign(const Data& data) override;

bool HasValue() const override { return handle_ptr != nullptr; }
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,14 @@ class HloMetadataSetter {

xla::XlaOp LoweringContext::GetParameter(
const std::shared_ptr<xla::ComputationClient::Data>& data) {
auto it = parameters_map_.find(data.get());
xla::ComputationClient::Data::OpaqueHandle handle = data->GetOpaqueHandle();
auto it = parameters_map_.find(handle);
if (it == parameters_map_.end()) {
xla::XlaOp param =
xla::Parameter(builder(), parameters_.size(), data->shape(),
absl::StrCat("param_", parameters_.size()));
parameters_.push_back(data);
it = parameters_map_.emplace(data.get(), param).first;
it = parameters_map_.emplace(handle, param).first;
}
return it->second;
}
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/lowering_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ class LoweringContext {

xla::XlaBuilder builder_;
std::vector<xla::ComputationClient::DataPtr> parameters_;
std::unordered_map<xla::ComputationClient::Data*, xla::XlaOp> parameters_map_;
std::unordered_map<xla::ComputationClient::Data::OpaqueHandle, xla::XlaOp>
parameters_map_;
std::vector<xla::XlaOp> root_tuple_;
OutputMap<xla::XlaOp> emitted_outputs_;
Util::EmissionMap emit_status_;
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1028,12 +1028,12 @@ std::shared_ptr<XLATensor::Async> XLATensor::TryRunCachedSync(
unique_device.set((*tensors)[index].GetDevice());
}
std::vector<xla::ComputationClient::DataPtr> parameters_data;
std::unordered_set<xla::int64> data_uids;
std::unordered_set<xla::ComputationClient::Data::OpaqueHandle> data_handles;
for (auto node : ir::Util::ComputePostOrder(roots)) {
const ir::ops::DeviceData* device_data =
dynamic_cast<const ir::ops::DeviceData*>(node);
if (device_data != nullptr) {
if (data_uids.insert(device_data->data()->unique_id()).second) {
if (data_handles.insert(device_data->data()->GetOpaqueHandle()).second) {
parameters_data.push_back(device_data->data());
}
}
Expand Down Expand Up @@ -1249,6 +1249,7 @@ std::shared_ptr<XLATensor::Async> XLATensor::SyncTensorsGraphInternal(
xla::ComputationClient::Get()->Compile(std::move(instances));
std::vector<xla::ComputationClient::DataPtr> parameters_data =
lowering_ctx.GetParametersData();
XLA_CHECK_EQ(program_shape.parameters_size(), parameters_data.size());
ComputationCache::TypePtr cached_computation = GetComputationCache()->Add(
coll.hash, std::make_shared<CachedComputation>(
std::move(computations.front()), parameters_data.size()));
Expand Down