Skip to content

Have RandomState instances unique to each batch index, can be used by e.g. random augmentation#4230

Closed
grafi-tt wants to merge 4 commits intochainer:masterfrom
grafi-tt:randomstate-per-batch-idx
Closed

Have RandomState instances unique to each batch index, can be used by e.g. random augmentation#4230
grafi-tt wants to merge 4 commits intochainer:masterfrom
grafi-tt:randomstate-per-batch-idx

Conversation

@grafi-tt
Copy link
Copy Markdown
Contributor

I've made iterators store numpy.random.RandomState instances those are bound to batch indices. They can be retrieved by chainer.iteartors.get_random_state().

The implementation is trivial for SerialIterator. For MultithreadIterator, a set of the backup states (implemented as tuples) is required to support reset and serialize correctly. For MultiprocessIterator, shared memory is necessary.

This feature is very useful when you use MultiprocessIterator. To perform random augmentation correctly with this iterator class, you need to set the random seed at the fetch of first batch, as the global random state remains the same on every forked processes.

Making the random seed deterministic is rather hard, because for each part of a batch to be fetched, assignment of a process that does fetch is completely non-deterministic. You may mitigate the situation by reseeding the random state at every fetch, but it hurts the long periodicity of Mersenne Twistter.

overhead

If chainer.iteartors.get_random_state() is not called, there is almost no overhead. If called, overhead for RandomState.get_state() and/or RandomState.set_state() are incurred; but a benchmark suggests it's still negligible.

I performed the benchmark bellow on my Linux desktop machine, with i5-3550 (4 cores) processor. It just iterates over MNIST dataset with random flip augmentation. The size of batches is very small compared to most real workload, so the overhead of iteration is rate-limiting. Even on such a extreme setting, the measured performance impact is less than 10%.

import time

from chainer.datasets import get_mnist, TransformDataset
from chainer.iterators import (get_random_state,
                               SerialIterator,
                               MultiprocessIterator,
                               MultithreadIterator)
import numpy


def random_flip_local(img):
    random = get_random_state()
    if random.uniform() >= 0.5:
        img = numpy.fliplr(img)
    return img


def random_flip_global(img):
    random = numpy.random
    if random.uniform() >= 0.5:
        img = numpy.fliplr(img)
    return img


N_PROC = 4
COUNT = 10000


def main():
    for it_cls in (SerialIterator, MultithreadIterator, MultiprocessIterator):
        for aug_meth in (random_flip_local, random_flip_global):
            dataset, _ = get_mnist(withlabel=False, ndim=2)
            dataset = TransformDataset(dataset, aug_meth)
            options = {}
            t1 = time.perf_counter()
            if it_cls == MultithreadIterator:
                options['n_threads'] = N_PROC
            if it_cls == MultiprocessIterator:
                options['n_processes'] = N_PROC
            it = it_cls(dataset, N_PROC, **options)
            for _ in range(COUNT):
                it.next()
            t2 = time.perf_counter()
            print("{:20} {:18} {:6.2f}us".format(it_cls.__name__, aug_meth.__name__,
                                            (t2 - t1) / COUNT * 1000 * 1000))


if __name__ == '__main__':
    main()

The result is:

SerialIterator       random_flip_local   52.33us
SerialIterator       random_flip_global  50.77us
MultithreadIterator  random_flip_local  315.76us
MultithreadIterator  random_flip_global 284.20us
MultiprocessIterator random_flip_local  527.36us
MultiprocessIterator random_flip_global 504.12us

As I've thought this issue is more important, so I made it prior to #3754.

@grafi-tt grafi-tt force-pushed the randomstate-per-batch-idx branch from 25110c9 to ca5387c Compare January 24, 2018 22:25
@delta2323
Copy link
Copy Markdown
Member

Thank you for sending the PR! I think this PR consists of several parts.

  • Add iterators.get_random_state
  • Use get_random_state in SerialIterator
  • Use get_random_state in MultiProcessIterator

Could you consider to split the PR into these parts and send them as PRs respectively?

@Crissman
Copy link
Copy Markdown
Member

Crissman commented Mar 5, 2018

@grafi-tt Is it possible to split this PR as mentioned by @delta2323?

Thanks!

@grafi-tt
Copy link
Copy Markdown
Contributor Author

grafi-tt commented Mar 6, 2018

@Crissman @delta2323 Sorry for late reply! To be precise, this PR consists of 5 parts:

  1. implement iterators.get_random_state that can be called during iteration by user, and iterators.random_state.set_random_state that should be called by iterator implementation,
  2. Call set_random_state appropriately from SerialIterator,
  3. Call set_random_state appropriately from MultithreadIterator,
  4. Call set_random_state appropriately from MultiprocessIterator,
  5. Write document about get_random_state.

Though I can reorganize the changes to those 5 commits soon, I'm wondering how to send those PRs because the changes 2, 3, 4 and 5 depend on the change 1. Should I send a PR for 1 first, and send the other PRs after it is merged?

@delta2323
Copy link
Copy Markdown
Member

delta2323 commented Mar 6, 2018

  1. implement iterators.get_random_state that can be called during iteration by user, and iterators.random_state.set_random_state that should be called by iterator implementation,

Looking at this sentence solely, we might be able to work get_random_state and set_random_state in parallel. Isn't it possible?

@delta2323
Copy link
Copy Markdown
Member

One idea would to make a branch that implements 1 and 5 and a PR from the branch first. Then, create branches for 2, 3, and 4 respectively from the branch and send PRs (one for each branch). We'll review the PR for 1 and 5 first. It is also OK to separate 1 and 5, too.

@grafi-tt
Copy link
Copy Markdown
Contributor Author

grafi-tt commented Mar 8, 2018

@delta2323 Thank you! I created PR #4448 which contains 1 and 5.

@delta2323
Copy link
Copy Markdown
Member

Thank you! I'll take a look.

@delta2323 delta2323 added the st:blocked-by-another-pr State indicating that another ticket is preventing this ticket from being closed/merged. label Mar 28, 2018
@grafi-tt
Copy link
Copy Markdown
Contributor Author

#4448

@grafi-tt grafi-tt closed this May 28, 2018
@okuta okuta added this to the Closed issues and PRs milestone Jun 18, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

st:blocked-by-another-pr State indicating that another ticket is preventing this ticket from being closed/merged.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants