Skip to content

[Training] Adam Optimizer#1970

Merged
wschin merged 29 commits intoonnx:masterfrom
wschin:adam
Apr 13, 2020
Merged

[Training] Adam Optimizer#1970
wschin merged 29 commits intoonnx:masterfrom
wschin:adam

Conversation

@wschin
Copy link
Copy Markdown
Collaborator

@wschin wschin commented Apr 27, 2019

PR #2314 is a single place for reviewing the whole training story.

A common signature shared by Pytroch's and TF's Adam.

Design verification script (it shows that the proposed Adam signature covers both of TF and Pytorch):

import itertools
import numpy as np
import torch
from torch import optim
from torch import nn
import torch.random as rnd

iteration_count = 3
n = 10 # number of features
l = 5 # number of data points (aka batch size)

X_ = torch.randn(l, n)
Y_ = torch.randn(l, 1)

def apply_adam(t, r, x, g, v, h, norm_coefficient, alpha, beta, epsilon):  # type: ignore
    # Add gradient of regularization term.
    g_regularized = norm_coefficient * x + g
    # Update momentum.
    v_new = alpha * v + (1 - alpha) * g_regularized
    # Update second-order momentum.
    h_new = beta * h + (1 - beta) * (g_regularized * g_regularized)
    # Compute element-wise square root.
    h_sqrt = np.sqrt(h_new) + epsilon
    # Adjust learning rate.
    t = t + 1
    r_adjusted = r * np.sqrt(1 - beta**t) / (1 - alpha**t)
    # Apply Adam update rule.
    x_new = x - r_adjusted * (v_new / h_sqrt)
    return x_new, v_new, h_new

def show_pytorch(lr, alpha, beta, weight_decay, epsilon):
    X = X_.clone()
    Y = Y_.clone()

    rnd.manual_seed(0)
    model = nn.Sequential(
        nn.Linear(n, 1, bias=False)
    )

    loss_fn = nn.MSELoss(reduction='sum')

    solver = optim.Adam(model.parameters(), lr=lr, betas=[alpha, beta],
            weight_decay=weight_decay, eps=epsilon)

    results = []
    for t in range(iteration_count):
        Y_pred = model(X)
        loss = loss_fn(Y_pred, Y)
        results.append(float(loss))
        model.zero_grad()
        loss.backward()
        solver.step()
    return results

def show_tensorflow(lr, alpha, beta, epsilon):
    import tensorflow as tf
    rnd.manual_seed(0)
    layer = nn.Linear(n, 1, bias=False)

    X = tf.placeholder('float', shape=[l, n])
    W = tf.Variable(torch.Tensor.numpy(layer.weight.detach()))
    Y = tf.placeholder('float', shape=[l, 1])
    Y_pred = tf.matmul(X, W, transpose_b=True)
    loss = tf.reduce_sum(tf.square(Y - Y_pred))
    optimizer = tf.train.AdamOptimizer(learning_rate=lr, beta1=alpha, beta2=beta, epsilon=epsilon)
    minimizer = optimizer.minimize(loss)

    sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)

    results = []
    for t in range(iteration_count):
        loss_value = sess.run(loss, {X: X_, Y: Y_})
        results.append(float(loss_value ))
        sess.run([minimizer], {X: X_, Y: Y_})

    return results

def show_mine(lr, alpha, beta, weight_decay, epsilon):
    X = X_.clone()
    Y = Y_.clone()

    rnd.manual_seed(0)
    model = nn.Sequential(
        nn.Linear(n, 1, bias=False)
    )

    loss_fn = nn.MSELoss(reduction='sum')

    solver = optim.Adam(model.parameters(), lr=lr, betas=[alpha, beta],
            weight_decay=weight_decay, eps=epsilon)

    results = []
    for t in range(iteration_count):
        Y_pred = model(X)
        loss = loss_fn(Y_pred, Y)
        results.append(float(loss))
        model.zero_grad()
        loss.backward()

        with torch.no_grad():
            for param in model.parameters():
                if 'exp_avg' not in solver.state[param]:
                    solver.state[param]['exp_avg'] = torch.zeros_like(param)
                if 'exp_avg_sq' not in solver.state[param]:
                    solver.state[param]['exp_avg_sq'] = torch.zeros_like(param)

                new_tensor, new_v, new_h = apply_adam(t=t,
                        r=lr, x=param.data, g=param.grad.data,
                        v=solver.state[param]['exp_avg'].data,
                        h=solver.state[param]['exp_avg_sq'].data,
                        norm_coefficient=weight_decay,
                        alpha=alpha, beta=beta, epsilon=epsilon)

                solver.state[param]['exp_avg'].data = new_v.data
                solver.state[param]['exp_avg_sq'].data = new_h.data
                param.data = new_tensor.data

    return results


# Compare TF Adam and its ONNX counterpart.
lr_pool = [0.1, 0.001]
alpha_pool = [0.89, 0.8, 0.01]
beta_pool = [0.99]
weight_decay_pool = [0]
epsilon_pool = [1e-7, 1e-1]
for lr_, alpha_, beta_, weight_decay_, epsilon_ in itertools.product(lr_pool, alpha_pool, beta_pool, weight_decay_pool, epsilon_pool):
    tf_result = show_tensorflow(lr=lr_, alpha=alpha_, beta=beta_, epsilon=epsilon_)
    mine_result = show_mine(lr=lr_, alpha=alpha_, beta=beta_, weight_decay=weight_decay_, epsilon=epsilon_)
    assert np.allclose(tf_result, mine_result)

# Compare Pytorch Adam and its ONNX counterpart.
lr_pool = [0.1, 0.001]
alpha_pool = [0.89, 0.8, 0.01]
beta_pool = [0.99, 0.2, 0.5]
weight_decay_pool = [0, 0.01, 0.1, 1]
epsilon_pool = [1e-7, 1e-1]
for lr_, alpha_, beta_, weight_decay_, epsilon_ in itertools.product(lr_pool, alpha_pool, beta_pool, weight_decay_pool, epsilon_pool):
    torch_result = show_pytorch(lr=lr_, alpha=alpha_, beta=beta_, weight_decay=weight_decay_, epsilon=epsilon_)
    mine_result = show_mine(lr=lr_, alpha=alpha_, beta=beta_, weight_decay=weight_decay_, epsilon=epsilon_)
    assert np.allclose(torch_result, mine_result)

Update docs
@wschin wschin requested a review from a team as a code owner April 27, 2019 21:57
@wschin wschin requested a review from a team as a code owner June 4, 2019 22:16
@wschin wschin changed the title [WIP] Adam Optimizer [Training] Adam Optimizer Jun 5, 2019
@CLAassistant
Copy link
Copy Markdown

CLAassistant commented Jul 24, 2019

CLA assistant check
All committers have signed the CLA.

@prasanthpul prasanthpul added this to the 1.7 milestone Aug 20, 2019
@postrational postrational added topic: operator Issues related to ONNX operators topic: training Issues related to ONNX training labels Aug 23, 2019
Comment thread onnx/defs/operator_sets.h Outdated
@wschin wschin removed this from the 1.7 milestone Feb 26, 2020
Copy link
Copy Markdown
Contributor

@postrational postrational left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, but some artifacts made it in. Please clean up the artifacts so the PR can be approved.

Comment thread onnx/backend/test/data/node/test_add/model.onnx Outdated
Comment thread onnx/backend/test/data/node/test_add_bcast/model.onnx
@sveta-levitan
Copy link
Copy Markdown
Contributor

sveta-levitan commented Mar 3, 2020

@wschin Wei-Sheng, please respond to Michal's comments above, plus his comments on Gitter in Operators chat room. Thank you!

@chinhuang007
Copy link
Copy Markdown
Contributor

@wschin Can you please take a look and move this forward? During the TSC meeting today, the members suggest to have at least one optimizer PR merged for release 1.7. Please let us know if more time is needed.

@wschin
Copy link
Copy Markdown
Collaborator Author

wschin commented Mar 6, 2020

@sveta-levitan , @postrational , @chinhuang007, PR is update-to-date again. Please take a look. Thank you.

Comment thread docs/Changelog.md
Comment thread onnx/defs/tensor/defs.cc
@wschin wschin merged commit f89f387 into onnx:master Apr 13, 2020
@chinhuang007 chinhuang007 added this to the 1.7 milestone May 8, 2020
jcwchen pushed a commit to jcwchen/onnx that referenced this pull request Sep 23, 2020
* Adam draft

Update docs

* Empty shape inference

* Add shape inference

* Add tests

* Fix a test

* Sync with recent master changes

* Sync with master

* Make flake8 happy

* Sync docs

* sync doc

* Revert side-effect changes

* sync with master

* Add 1 back

* Polish doc

* Finish the design of Adam spec

* formatting
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

topic: operator Issues related to ONNX operators topic: training Issues related to ONNX training

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants