Skip to content

[C++ API] Protobuf serialization#11619

Closed
goldsborough wants to merge 4 commits intopytorch:masterfrom
goldsborough:cpp-serialization
Closed

[C++ API] Protobuf serialization#11619
goldsborough wants to merge 4 commits intopytorch:masterfrom
goldsborough:cpp-serialization

Conversation

@goldsborough
Copy link
Contributor

@goldsborough goldsborough commented Sep 13, 2018

This PR serves two purposes:

  1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general,
  2. Add serialization to the ONNX/PyTorch proto format.

This is currently a rough prototype I coded up today, to get quick feedback.

For this I propose the following serialization interface within the C++ API:

namespace torch { namespace serialize {
class Reader {
 public:
  virtual ~Reader() = default;
  virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0;
  virtual void finish() { }
};

class Writer {
 public:
  virtual ~Reader() = default;
  virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0;
  virtual void finish() { }
};
}} // namespace torch::serialize

There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See torch/serialize/cereal.h and torch/serialize/default.h. This abstraction and subclassing for these two allows us to:

  1. Provide a cereal-less serialization forward that we can ship and iterate on going forward,
  2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft.

The user-facing API is (conceptually):

void torch::save(const Module& module, Writer& writer);
void torch::save(const Optimizer& optimizer, Writer& writer);
void torch::read(Module& module, Reader& reader);
void torch::read(Optimizer& optimizer, Reader& reader);

with implementations for both optimizers and modules that write into the Writer and read from the Reader

@ebetica @ezyang @zdevito @dzhulgakov

@ebetica
Copy link
Contributor

ebetica commented Sep 13, 2018

I wouldn't worry about (2). How would you change this if you didn't have to support Cereal at all?

@ezyang
Copy link
Contributor

ezyang commented Sep 13, 2018

Is the problem this is solving primarily "how to serialize to both protobuf and cereal", or are there other problems?

Copy link
Collaborator

@dzhulgakov dzhulgakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks neat!

For future we should have some mechanism of encoding at least typename of the c++ module in the serialized file for reference. Since for now modules are different from jit modules I guess it's ok

P.S. reusing jit module serialization is... cute :)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@goldsborough goldsborough force-pushed the cpp-serialization branch 2 times, most recently from 36d2455 to 3f83954 Compare September 18, 2018 06:33
@goldsborough goldsborough changed the title [C++ API][Prototype] Abstract serialization and add Protobuf serialization [C++ API][Prototype] Add Protobuf serialization Sep 18, 2018
@goldsborough goldsborough changed the title [C++ API][Prototype] Add Protobuf serialization [C++ API] Protobuf serialization Sep 18, 2018
@goldsborough
Copy link
Contributor Author

goldsborough commented Sep 18, 2018

After discussing with @ebetica we found that backwards compatibility with cereal is not as important and that StarCraft would be happy to switch to protobuf. For this reason I've opted to specialize my design a little more towards our Protobuf serialization format. The new primitives are InputArchive and OutputArchive, which both wrap a ScriptModule, which we use as a vehicle for serialization. The API for e.g. InputArchive is:

class InputArchive {
 public:
  void read(const std::string& key, Tensor& tensor, bool is_buffer = false);
  void read(const std::string& key, InputArchive& archive);

  template <typename... Ts>
  void operator()(Ts&&... ts) {
    read(std::forward<Ts>(ts)...);
  }

 private:
  std::shared_ptr<jit::script::Module> module_;
};

which gets populated e.g. in nn::Module::load:

void Module::load(serialize::InputArchive& archive) {
  for (auto& parameter : parameters_) {
    archive.read(parameter.key, parameter.value);
  }
  for (auto& buffer : buffers_) {
    archive.read(buffer.key, buffer.value, /*is_buffer=*/true);
  }
  for (const auto& child : children_) {
      serialize::InputArchive child_archive;
      archive.read(child.key, child_archive);
      child.value->load(child_archive);
  }
}

This API is not intended for inheritence. The reason why is that InputArchives are recursively built (1 InputArchive = 1 ScriptModule) and have to be re-instantiated inside of save() or load(). This re-instantiation wouldn't work (nicely) if InputArchive were an abstract class.

This means we only support protobuf serialization out of the box. However, I've made sure that all parameters and buffers in modules and optimizers that one would want to store are properly exposed. This means if one wanted to, one could write a new serialization library on top of the Module base class. I believe this is more the strategy we have in Python, where we have one blessed serialization API (Pickle), but generally modules and optimizers expose enough state that anyone could dump their models into a Thrift or JSON interface.

CC @ebetica @dzhulgakov @zdevito @apaszke @ezyang

@goldsborough goldsborough force-pushed the cpp-serialization branch 3 times, most recently from 6379c68 to 7881ad9 Compare September 18, 2018 18:01
@ebetica
Copy link
Contributor

ebetica commented Sep 18, 2018

Do you think overloading torch::save for a stream is nice? It just feels more C++-y to me, and is more flexible for users.

Copy link
Contributor

@ebetica ebetica left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy with the API changes, I didn't do a deep reading of this PR.

This comment was marked as off-topic.

.gitmodules Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't leak file descriptors

@goldsborough goldsborough force-pushed the cpp-serialization branch 2 times, most recently from 7bb46aa to 03d08f6 Compare September 20, 2018 23:18
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

goldsborough has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Sep 21, 2018
Summary:
This PR serves two purposes:

1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general,
2. Add serialization to the ONNX/PyTorch proto format.

This is currently a rough prototype I coded up today, to get quick feedback.

For this I propose the following serialization interface within the C++ API:

```cpp
namespace torch { namespace serialize {
class Reader {
 public:
  virtual ~Reader() = default;
  virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0;
  virtual void finish() { }
};

class Writer {
 public:
  virtual ~Reader() = default;
  virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0;
  virtual void finish() { }
};
}} // namespace torch::serialize
```

There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to:

1. Provide a cereal-less serialization forward that we can ship and iterate on going forward,
2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft.

The user-facing API is (conceptually):

```cpp
void torch::save(const Module& module, Writer& writer);
void torch::save(const Optimizer& optimizer, Writer& writer);
void torch::read(Module& module, Reader& reader);
void torch::read(Optimizer& optimizer, Reader& reader);
```

with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader`

ebetica ezyang zdevito dzhulgakov
Pull Request resolved: pytorch/pytorch#11619

Differential Revision: D9984664

Pulled By: goldsborough

fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
iotamudelta pushed a commit to ROCm/pytorch that referenced this pull request Sep 21, 2018
Summary:
This PR serves two purposes:

1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general,
2. Add serialization to the ONNX/PyTorch proto format.

This is currently a rough prototype I coded up today, to get quick feedback.

For this I propose the following serialization interface within the C++ API:

```cpp
namespace torch { namespace serialize {
class Reader {
 public:
  virtual ~Reader() = default;
  virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0;
  virtual void finish() { }
};

class Writer {
 public:
  virtual ~Reader() = default;
  virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0;
  virtual void finish() { }
};
}} // namespace torch::serialize
```

There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to:

1. Provide a cereal-less serialization forward that we can ship and iterate on going forward,
2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft.

The user-facing API is (conceptually):

```cpp
void torch::save(const Module& module, Writer& writer);
void torch::save(const Optimizer& optimizer, Writer& writer);
void torch::read(Module& module, Reader& reader);
void torch::read(Optimizer& optimizer, Reader& reader);
```

with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader`

ebetica ezyang zdevito dzhulgakov
Pull Request resolved: pytorch#11619

Differential Revision: D9984664

Pulled By: goldsborough

fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
@ezyang ezyang added the merged label Jun 26, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants