Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model-parallel MNIST example #98

Merged
merged 53 commits into from
Aug 4, 2017

Conversation

levelfour
Copy link
Contributor

@levelfour levelfour commented Jul 10, 2017

This PR adds some user-friendly interfaces to implement model-parallel neural nets, with model-parallel MNIST example.

New Features

MultiNodeChainGroup (chainermn.MultiNodeChainGroup)

This variant of chainer.ChainList represents multiple connected component of the entire computational graph.
In case of multi-node computation, computational graphs often become non-connected.
Each of connected component is represented by chainer.Chain, and MultiNodeChainGroup combines them.

new utility function chainermn.functions.pseudo_connect

This function is used when we want "delegate_variable" to imitate the other variable.
This kind of pathological motivation occurs in multi node environment.
Please see the document of this function for detail.

empty dataset (chainermn.datasets.get_empty_dataset)

It is used for the model with no actual inputs, rather receiving inputs from the other machine.

@levelfour levelfour changed the title [WIP] Add model-parallel MNIST example Add model-parallel MNIST example Jul 28, 2017
README.md Outdated
@@ -38,7 +38,7 @@ Please refer to the [installation guide](https://chainermn.readthedocs.io/en/lat
You can invoke MNIST example with four workers by the following command:

```
mpiexec -n 4 python examples/mnist/train_mnist.py
mpiexec -n 4 python examples/mnist/train_mnist_data_parallel.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid we should not change this line.
Data parallel model is an advanced feature and users should first see the basic example of MNIST.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it's confusing if train_mnist.py and train_mnist_model_parallel.py exist at the same time, so renamed the original train_mnist.py to train_mnist_data_parallel.py, in order to explicitly assert this example is for data parallel.

import chainermn


class SimpleModelInst(chainer.Chain):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we think of a better name? Inst sounds like an instance of a class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thank you.

#!/usr/bin/env python
# coding: utf-8

import argparse
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to PEP8,

Imports should be grouped in the following order:

standard library imports
related third party imports
local application/library specific imports
You should put a blank line between each group of imports.

https://www.python.org/dev/peps/pep-0008/#imports

So we need a blank line after argparse.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am not sure if we need a blank line between Chainer and ChainerMN imports.

@@ -0,0 +1,29 @@
import numpy as np
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unittest should be the first because it's a standard library, while numpy is a third party lib.

@@ -6,6 +6,7 @@
import chainer.testing.attr
import chainermn
import chainermn.functions
import copy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pep8.

err = model()
err.backward()

def test_cross_model(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

crossing_model ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

if backward_pointer is not None and _x.creator is not None:
_x.creator.rank = -1

x = _x if x is None else x + _x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is reasonable design to sum up received arrays when receiving from multiple workers. I think it would be more natural and useful to hand it to f as different parameters.

import chainer


def get_empty_dataset(dataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about changing its name to create_empty_dataset, which is consistent to create_multi_node_optimizer and create_multi_node_evaluator?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

import chainermn.functions.point_to_point_communication


class MultiNodeChainGroup(chainer.ChainList):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any reasons for the name? I don't have strong opinion, but, how about changing its name to MultiNodeChainList, which is consistent to chainer.ChainList.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

import chainermn


class SimpleModelInst(chainer.Chain):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't understand what Inst does mean. How about SimpleModelSub or something like that?

self.add_link(MLP0b(comm), rank_in=1, rank_out=None)


class MLP1inst(chainer.Chain):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto for inst

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

``chainermn.functions.send()``.

Args:
dataset(chainer.datasets.TupleDataset): Dataset to convert.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataset does not need to be TupleDataset. Chainer accepts many kinds of datasets.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

~chainer.datasets.TransformDataset:
Dataset consists of only patterns in the original one.
"""
return chainer.datasets.TransformDataset(dataset, lambda data: ())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably just [()] * len(dataset) is enough? (TransformDataset preserves dataset, so it will consume unnecessary memory.)



class PseudoConnect(chainer.Function):
"""Connect a variable with delegating variable."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Connect two variables with a delegating variable"
or
"Connect a variable to a delegating variable" ?

"""Receive elements from target process.

This function returns data received from target process. If ``backward()``
is invoked, it will try to send gradients to the target process.

.. note::
If you define non-connected computational graph on one machine,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

machine -> process

@@ -100,3 +102,31 @@ def test_communication(self):
err = chainermn.functions.send(
y, self.communicator, self.rank_send)
err.backward()

def test_retain(self):
if self.communicator.rank == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this test work if more than 2 processes invoked?
It should be skipped if communicator.size > 2 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test also works on more than 2 processes. It emulates test_cycle_model in test_link.py. FYI, we run test on 3 processes in the latest commit.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK!

@iwiwi
Copy link
Contributor

iwiwi commented Aug 4, 2017

LGTM

@keisukefukuda
Copy link
Member

LGTM!

@keisukefukuda keisukefukuda merged commit 1fa4021 into chainer:master Aug 4, 2017
@levelfour levelfour deleted the model-parallel-mnist branch August 5, 2017 17:13
@iwiwi iwiwi added this to the v1.0.0 milestone Aug 31, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants