Skip to content

[jit] Add module attributes#17309

Closed
driazati wants to merge 34 commits intopytorch:masterfrom
driazati:attr
Closed

[jit] Add module attributes#17309
driazati wants to merge 34 commits intopytorch:masterfrom
driazati:attr

Conversation

@driazati
Copy link
Contributor

@driazati driazati commented Feb 20, 2019

Similar to nn.Parameters, this PR lets you store any IValue on a module as an attribute on a ScriptModule (only from the Python front-end currently). To mark something as an attribute, it should wrapped in jit.Attribute(value, type) (ex. self.table = torch.jit.Attribute(table, Dict[str, torch.Tensor]))

Followup Work:

  • (de)serializing for use in C++
  • change self.training to be a bool attribute instead of a buffer
  • mutable attributes
  • string frontend support
  • documentation

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Feb 20, 2019
@driazati driazati marked this pull request as ready for review February 20, 2019 21:52
@driazati
Copy link
Contributor Author

@pytorchbot rebase this please

@sidazhang
Copy link

A few questions here:

  1. What sort of types can be set as an attribute? Any IValue can be an attribute?
  2. Can the type be mutated? i.e. Mutable dict?
  3. Can you access the attribute in the C++ front end?

Basically

auto module = torch::load("./model.pt");
module.get_attribute("some_attribute_name")

@driazati
Copy link
Contributor Author

This is pretty rough right now and won't have full support for a while, but:

  1. Any IValue used in JIT (we don't plan to do some, e.g. Blob IValues)
  2. Yes (eventually)
  3. Yes (eventually)

@driazati driazati changed the title [wip][jit] attributes [jit] attributes Feb 26, 2019
@driazati driazati requested a review from zdevito February 26, 2019 22:52
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great start -- I've put some comments inline about some issues that we need to resolve about figure out the right types for attributes, and how to structure the Module class.

// to the slots where parameters are kept while also allow parameters
// to be reassigned
std::unique_ptr<at::Tensor> parameter;
struct NamedInput {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not always an input, so NamedIValue would be a better term.

auto retval = graph_->copy();
for (auto inp : member_inputs) {
inputs.push_back(*inp);
for (auto inp : member_inputs_) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if statement here is going wreck havoc later. I think inputs/outputs should be flipped to IValue.

Module()
: modules("Module"),
parameters("Parameter"),
attributes("Attributes"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be careful how we structure things here so that it mirrors how a python class works.
We probably want to drop the is_parameter flag and simply use the separate lists (attributes vs parameters) to decide whether something is a parameter. Things that were previously buffers should go in the attributes list.

return self.__getattr__('forward').graph
return Module.__getattr__(self, attr)

def _is_attribute(self, attr, value):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems brittle because it has to look through every other potential place where getattr looks to make sure it doesn't fall into one of these categories. Is there a better way to structure this?

""".format(type(v).__name__, attr, constants)))


def _get_type(value):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not a fan of adding this. It will need to be an exhaustive list, and we already have similar stuff in C++ for converting python values. It also cannot be done correctly: If a list is empty or a dict is empty we cannot infer the type. We should use the existing C++ code that figure out the type and IValue information from a python class rather than reimplement it here. We should also consider how we could inform the Module about the right type for things like empty dictionaries where the type is not inferrable.

@driazati driazati changed the title [jit] attributes [jit] Add module attributes Mar 4, 2019
@zdevito zdevito self-requested a review March 4, 2019 20:07
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good. I have a few comments that should be addressed and then it is ready to land.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In its current form, won't landing this will break serialization of buffers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was from an old commit and was fixed to use a flag

return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
} else if (NamedInput* v = module->find_attribute(field)) {
return std::make_shared<SimpleValue>(
m.get_or_add_attribute(v->type, v->slot()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_or_add_attribute should take a NamedInput* rather than both a type and a IValue*. Otherwise it is too easy to accidentally pass the wrong type.

// python method. If so return this as a python value.
py::object py_module = py::cast(module);
if (py::object attr = py::getattr(py_module, field.c_str(), py::none())) {
if (py::isinstance(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attributes should be registered during the init process. Doing it lazily here does not follow the same pattern of parameters, and it prevents a user from having an attribute exist even when a method doesn't reference it. I see trouble with this when, for instance, someone uses a non-python way of defining a Method and it can't see the attributes because they are in a Python-only state at the moment.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@houseroad
Copy link
Member

@driazati this broke pytorch-xla build. cc: @ailzhang @suo

petrex pushed a commit to petrex/pytorch that referenced this pull request Mar 7, 2019
* upstream/master: (24 commits)
  Automatic update of fbcode/onnx to 96c58ceeacf0f2b73d752e413e4fd78787a12da3 (pytorch#17676)
  Set the default ONNX opset to the latest stable opset (i.e., 9) (pytorch#17736)
  Add module attributes (pytorch#17309)
  - refactoring serialization of ONNX initializers to be name-based (pytorch#17420)
  ONNX Export for Max and Average Pooling in CEIL_MODE
  use flake8-mypy (pytorch#17721)
  use fp16<->fp32 intrinsic (pytorch#17496)
  Implement a Caffe2 standalone LSTM operator (pytorch#17726)
  caffe2:libtorch_cuda depends on caffe2:caffe2_gpu (pytorch#17729)
  add tensor and cost inference functions (pytorch#17684)
  ONNX Export Narrow op
  Keep the dim_type of hinted shape as BATCH if possible (pytorch#17734)
  fix different round behavior on CPU and GPU pytorch#16498 (pytorch#17443)
  Warn about memory overlaps on expanded tensors (pytorch#17576)
  fix exp fam. formula
  refactor caffe2 operator constructors - 10/9 (pytorch#17659)
  Improve ONNX symbolic for logsoftmax and softmax (pytorch#17672)
  Enable using CMD when building cpp extensions on Windows
  Do not rename net boundary inputs/outputs during ssaRewrite. (pytorch#17545)
  Reapply D14078519 (pytorch#17596)
  ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants