Skip to content

[Training] SG with Momentum Optimizer#1959

Merged
wschin merged 25 commits intoonnx:masterfrom
wschin:sg
Mar 11, 2020
Merged

[Training] SG with Momentum Optimizer#1959
wschin merged 25 commits intoonnx:masterfrom
wschin:sg

Conversation

@wschin
Copy link
Copy Markdown
Collaborator

@wschin wschin commented Apr 23, 2019

PR #2314 is a single place for reviewing the whole training story.

This PR is similar to #1955 but adds another signature for stochastic gradient method with momentum. It covers Pytorch's SGD and Tensorflow's MomentomOptimizer. Note that SGD is never a descent algorithm so I call it SG in my PR's title.

The script below is used to verify if the design can cover both of TF and Pytorch cases.

import itertools
import numpy as np
import torch
from torch import optim
from torch import nn
import torch.random as rnd

iteration_count = 3
n = 10 # number of features
l = 5 # number of data points (aka batch size)

X_ = torch.randn(l, n)
Y_ = torch.randn(l, 1)

# An example implementation of ONNX Momentum when "nesterov" attribute is false.
def apply_momentum(t, r, x, g, v, norm_coefficient, alpha_, beta_):  # type: ignore
    # Add gradient of regularization term.
    g_regularized = norm_coefficient * x + g
    # Coefficient of gradient should be 1 at the first iteration.
    beta_adjusted = beta_ if t > 0 else 1
    # Update momentum.
    v_new = alpha_ * v + beta_adjusted * g_regularized
    # Apply SG with momentum update rule.
    x_new = x - r * v_new
    return x_new, v_new

# An example implementation of ONNX Momentum when "nesterov" attribute is true.
def apply_nesterov(t, r, x, g, v, norm_coefficient, alpha_, beta_):  # type: ignore
    # Add gradient of regularization term.
    g_regularized = norm_coefficient * x + g
    # Coefficient of gradient should be 1 at the first iteration.
    beta_adjusted = beta_ if t > 0 else 1
    v_new = alpha_ * v + beta_adjusted * g_regularized
    # Apply Nesterov with momentum update rule.
    x_new = x - r * (g_regularized + alpha_ * v_new)
    return x_new, v_new

def show_pytorch(lr, alpha, beta, weight_decay, nesterov):
    X = X_.clone()
    Y = Y_.clone()

    rnd.manual_seed(0)
    model = nn.Sequential(
        nn.Linear(n, 1, bias=False)
    )

    loss_fn = nn.MSELoss(reduction='sum')

    solver = optim.SGD(model.parameters(), lr=lr, momentum=alpha, dampening=1-beta,
            weight_decay=weight_decay, nesterov=nesterov)

    results = []
    for t in range(iteration_count):
        Y_pred = model(X)
        loss = loss_fn(Y_pred, Y)
        results.append(float(loss))
        model.zero_grad()
        loss.backward()
        solver.step()
    return results

def show_tensorflow(lr, alpha, nesterov):
    import tensorflow as tf
    rnd.manual_seed(0)
    layer = nn.Linear(n, 1, bias=False)

    X = tf.placeholder('float', shape=[l, n])
    W = tf.Variable(torch.Tensor.numpy(layer.weight.detach()))
    Y = tf.placeholder('float', shape=[l, 1])
    Y_pred = tf.matmul(X, W, transpose_b=True)
    loss = tf.reduce_sum(tf.square(Y - Y_pred))
    optimizer = tf.train.MomentumOptimizer(learning_rate=lr, momentum=alpha, use_nesterov=nesterov)
    minimizer = optimizer.minimize(loss)

    sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)

    results = []
    for t in range(iteration_count):
        loss_value = sess.run(loss, {X: X_, Y: Y_})
        results.append(float(loss_value ))
        sess.run([minimizer], {X: X_, Y: Y_})

    return results


def show_mine(lr, alpha, beta, weight_decay, nesterov):
    X = X_.clone()
    Y = Y_.clone()

    rnd.manual_seed(0)
    model = nn.Sequential(
        nn.Linear(n, 1, bias=False)
    )

    loss_fn = nn.MSELoss(reduction='sum')

    solver = optim.SGD(model.parameters(), lr=lr, momentum=alpha, dampening=1-beta,
            weight_decay=weight_decay, nesterov=nesterov)

    results = []
    for t in range(iteration_count):
        Y_pred = model(X)
        loss = loss_fn(Y_pred, Y)
        results.append(float(loss))
        model.zero_grad()
        loss.backward()

        with torch.no_grad():
            for param in model.parameters():
                if 'momentum_buffer' not in solver.state[param]:
                    solver.state[param]['momentum_buffer'] = torch.zeros_like(param)
                
                if (nesterov):
                    new_tensor, new_state = apply_nesterov(t=t,
                            r=lr, x=param.data, g=param.grad.data,
                            v=solver.state[param]['momentum_buffer'].data,
                            norm_coefficient=weight_decay,
                            alpha_=alpha, beta_=beta)
                else:
                    new_tensor, new_state = apply_momentum(t=t,
                            r=lr, x=param.data, g=param.grad.data,
                            v=solver.state[param]['momentum_buffer'].data,
                            norm_coefficient=weight_decay,
                            alpha_=alpha, beta_=beta)
                solver.state[param]['momentum_buffer'].data = new_state.data
                param.data = new_tensor.data

    return results


# Compare TF SG and its ONNX counterpart.
lr_pool = [0.1, 0.001]
alpha_pool = [1.2, 0.8, 0.01]
beta_pool = [1]
weight_decay_pool = [0]
nesterov_pool = [True, False]

for lr_, alpha_, beta_, weight_decay_, nesterov_ in itertools.product(lr_pool, alpha_pool, beta_pool,
                                                                      weight_decay_pool, nesterov_pool):
    tf_result = show_tensorflow(lr=lr_, alpha=alpha_, nesterov=nesterov_)
    mine_result = show_mine(lr=lr_, alpha=alpha_, beta=beta_, weight_decay=weight_decay_, nesterov=nesterov_)
    assert np.allclose(tf_result, mine_result)


# Compare Pytorch SG with momentum and its ONNX counterpart.
lr_pool = [0.1, 0.001]
alpha_pool = [1.2, 0.8, 0.01]
beta_pool = [1, 0.2, 0.5]
weight_decay_pool = [0, 0.01, 0.1, 1]
nesterov_pool = [False]
for lr_, alpha_, beta_, weight_decay_, nesterov_ in itertools.product(lr_pool, alpha_pool, beta_pool,
                                                                      weight_decay_pool, nesterov_pool):
    torch_result = show_pytorch(lr=lr_, alpha=alpha_, beta=beta_, weight_decay=weight_decay_, nesterov=nesterov_)
    mine_result = show_mine(lr=lr_, alpha=alpha_, beta=beta_, weight_decay=weight_decay_, nesterov=nesterov_)
    assert np.allclose(torch_result, mine_result)


# Compare Pytorch SG with Nesterov momentum and its ONNX counterpart.
lr_pool = [0.1, 0.001]
alpha_pool = [1.2, 0.8, 0.01]
beta_pool = [1]
weight_decay_pool = [0, 0.01, 0.1, 1]
nesterov_pool = [True]
for lr_, alpha_, beta_, weight_decay_, nesterov_ in itertools.product(lr_pool, alpha_pool, beta_pool,
                                                                      weight_decay_pool, nesterov_pool):
    torch_result = show_pytorch(lr=lr_, alpha=alpha_, beta=beta_, weight_decay=weight_decay_, nesterov=nesterov_)
    mine_result = show_mine(lr=lr_, alpha=alpha_, beta=beta_, weight_decay=weight_decay_, nesterov=nesterov_)
    assert np.allclose(torch_result, mine_result)

@wschin wschin requested a review from a team as a code owner April 23, 2019 05:31
Fix

Update other docs
@gramalingam
Copy link
Copy Markdown
Contributor

I don't think these ops are control-flow ops … so they should probably go into a more appropriate defs file, instead of the control-flow defs file.

Comment thread onnx/defs/controlflow/defs.cc Outdated
Comment thread onnx/defs/controlflow/defs.cc Outdated
@wschin wschin requested a review from a team as a code owner May 26, 2019 23:43
@wschin wschin changed the title [WIP] SG with Momentum Optimizer [Training] SG with Momentum Optimizer May 27, 2019
@wschin
Copy link
Copy Markdown
Collaborator Author

wschin commented May 27, 2019

@gramalingam, I will move this file to the right place after #1955 is merged.

@wschin wschin mentioned this pull request May 31, 2019
@ebarsoum ebarsoum added the topic: operator Issues related to ONNX operators label Jul 9, 2019
@CLAassistant
Copy link
Copy Markdown

CLAassistant commented Jul 24, 2019

CLA assistant check
All committers have signed the CLA.

@prasanthpul prasanthpul added this to the 1.7 milestone Aug 20, 2019
@postrational postrational added the topic: training Issues related to ONNX training label Aug 23, 2019
@wschin wschin removed this from the 1.7 milestone Feb 26, 2020
@wschin wschin merged commit c2fefcb into onnx:master Mar 11, 2020
@wschin wschin deleted the sg branch March 11, 2020 23:56
linkerzhang added a commit that referenced this pull request Mar 31, 2020
* Fix Greater/LessOrEqual function definition (#2645)

* Fix Greater/LessOrEqual function definition

* Update test data

Co-authored-by: Ke Zhang <kezhan@microsoft.com>

* Suppress a warning in unsqueeze (#2637)

I keep getting this warning when building PyTorch:

```
In file included from
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/tensor/utils.h:6,
                 from
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/tensor/defs.cc:4:
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/tensor/defs.cc: In
lambda function:
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/tensor/defs.cc:1414:22:
warning: unnecessary parentheses in declaration of �i�
[-Wparentheses]
           for (size_t(i) = 0; i < axes.size(); ++i) {
                      ^
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/schema.h:959:12:
note: in definition of macro �ONNX_OPERATOR_SET_SCHEMA_EX�
     return impl.SetName(#name)
\
            ^~~~
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/tensor/defs.cc:1369:1:
note: in expansion of macro �ONNX_OPERATOR_SET_SCHEMA�
 ONNX_OPERATOR_SET_SCHEMA(
```

This commit should fix it and modernize the code a bit.

Co-authored-by: Ke Zhang <kezhan@microsoft.com>

* [Training] Add Adagrad optimizer operator (#1955)

* Adagrad draft

* MIMO

* Support multiple tensors to be optimized

* Address comments

* Move optimizers to a new place

Remove copied

Add momentum

Save

Remove momentum

Fix

Move constants to attributes

* Fix build

* Add shape test

Add two node tests

Update test coverage

* Fix shape inf

* Fix shape inf

* fix shape inf

* Format

* Add function type

* Merge lines

* Format

* Fix version number

* Update op version in model files

* Fix a test function and update related test files

* Update onnx/backend/test/case/node/adagrad.py

* Remove unused file

* sync docs

* Fix shape test

* sync doc

* sync with master

* Update onnx/defs/training/defs.cc

Co-Authored-By: Michał Karzyński <postrational@users.noreply.github.com>

* sync doc

* address comments

* address a minor comment

* Polish one line

Co-authored-by: Michał Karzyński <postrational@users.noreply.github.com>

* [Training] SG with Momentum Optimizer (#1959)

* SG with Momentum

* Registrate Op

Fix

Update other docs

* Add shape inference code and polish definition

* Update docs

* Add test cases and fix several bugs

* Remove accidently added copy

* Alpha -> alpha & Beta -> beta

* Clarify an attribute

* Fix an attribute

* Fix bug

* Fix missing attributes

* sync doc

* Remove unused domain

* sync with master

Co-authored-by: Chin Huang <chhuang@us.ibm.com>

* Change type of label tensor to int32/int64 in SoftmaxCrossEntropyLoss spec. (#2667)

* Update Pow input types in Opset 12 (#2666)

* Update Pow input types in Opset 12

* gen doc and tests

* remove uints and 8 bit ints

* add tests

* remove uint int x tets

* Adding CI for ONNX Debug mode (Linux, OSX) (#2651)

* adding an osx build, linux build, with and without onnx_ml for debug mode

* test debug mode with ONNX_ML=1

* Rename OPTIONAL to OPTIONAL_VALUE (#2682)

Co-authored-by: G. Ramalingam <grama@microsoft.com>

* Update Batchnorm test (#2674)

* Update Batchnorm test

* relax shape inference on scalar

* Remove unnecessary copies and std::move (#2684)

* Update sequence test case so input is not scalar and splits are specified (#2675)

* Update sequence test case to input is not scalar and splits are specified

* Add spaces to make the checker happy

* Use cmake GNUInstallDirs (#2661)

https://cmake.org/cmake/help/latest/module/GNUInstallDirs.html
this make allow install the libraries (and headers) in different location than `lib` (Gentoo uses lib64 for 64-bits libs)
also change the .cmake files for avoid conclicts if build both 32-bis and 64-bits (avoids conflict/overwrite files)

Co-authored-by: Ke Zhang <kezhan@microsoft.com>

* Add 'ignore_index' input in the spec for SoftmaxCrossEntropyLoss and NLLLoss. (#2680)

* Add 'ignore_index' input in the spec for SoftmaxCrossEntropyLoss and NLLLoss.

* Add tests.

* build break.

* build break.

* clean up.

* build break.

* Change ignore_index to attribute.

* Change ignore_index to attribute.

* PR feedback.

* PR feedback.

* Make ignore_index optional in NLLLoss.

* Build break.

* remove trailing spaces to fix build break.

* Build break.

* Update spec doc.

* Fix NLLLoss function definition to fix test: test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum_ignore_index_expanded

* PR feedback.

* Fix test for softmax cross entropy loss to exclude ignored_index'ed weights from the sum of weights.

* Build break.

* Reduce binary size of libraries consuming ONNX (part 1/2) (#2643)

* Change the return type for the zipmap operator to match the description in the spec.

* Reduce binary size of libraries consuming ONNX (part 1/2)

* Fix build error

* Replace separate Get*Doc() functions with easy macro for greater convenience

* Add one more macro for complicated operator doc documentation.

Co-authored-by: Ke Zhang <kezhan@microsoft.com>

* Update pybind (#2340) (#2688)

* Change version number for release verification

Change version number for release verification

Co-authored-by: Takeshi Watanabe <take-cheeze@users.noreply.github.com>
Co-authored-by: Ke Zhang <kezhan@microsoft.com>
Co-authored-by: Hong Xu <hong@topbug.net>
Co-authored-by: Wei-Sheng Chin <wschin@outlook.com>
Co-authored-by: Michał Karzyński <postrational@users.noreply.github.com>
Co-authored-by: M. Zeeshan Siddiqui <mzs@microsoft.com>
Co-authored-by: Lara Haidar <haidar.lara@gmail.com>
Co-authored-by: Vinitra Swamy <vinitras@gmail.com>
Co-authored-by: Changming Sun <chasun@microsoft.com>
Co-authored-by: G. Ramalingam <grama@microsoft.com>
Co-authored-by: Changming Sun <me@sunchangming.com>
Co-authored-by: Scott McKay <skottmckay@gmail.com>
Co-authored-by: Gustavo Alvarez <462213+sl1pkn07@users.noreply.github.com>
Co-authored-by: Pranav Sharma <prs@microsoft.com>
jcwchen pushed a commit to jcwchen/onnx that referenced this pull request Sep 23, 2020
* SG with Momentum

* Registrate Op

Fix

Update other docs

* Add shape inference code and polish definition

* Update docs

* Add test cases and fix several bugs

* Remove accidently added copy

* Alpha -> alpha & Beta -> beta

* Clarify an attribute

* Fix an attribute

* Fix bug

* Fix missing attributes

* sync doc

* Remove unused domain

* sync with master

Co-authored-by: Chin Huang <chhuang@us.ibm.com>
@jcwchen jcwchen mentioned this pull request Jan 13, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

topic: operator Issues related to ONNX operators topic: training Issues related to ONNX training

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants