Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New layer architecture #159

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open

New layer architecture #159

wants to merge 5 commits into from

Conversation

hweom
Copy link
Contributor

@hweom hweom commented Mar 12, 2022

New layer architecture prototype

  • 🧭 Architecture

Relates to #155 .

Changes proposed by this PR:

  1. Static network graph is separated from invocation context.
    a) Static graph captures layers, connections between them and shapes of the units of data.
    b) Invocation context specifies the batch size and stores all data associated with an invocation (data, gradients).
  2. Batch size is now explicit in the context instead of being implicitly extracted by layers from incoming data.
  3. Separation into Layer and ILayer is now gone, everything is now handled in layer implementations (with "leaf" layers focusing on data manipulations while container layers focusing on network composition).

Notes to reviewer:

This is still a very early prototype not intended for merging:

  1. Solver architecture not changed and just crudely hacked to support new network architecture.
  2. Shared weights not supported.
  3. Serialization not supported.
  4. Mnist example compiles and runs but doesn't converge (there is a bug somewhere, i'm sure).

A good order for exploring this PR is starting at comments in net/mod.rs, net/layer.rs, net/descriptor.rs and net/context.rs.

1. Static network graph is separated from invocation context.
   a) Static graph captures layers, connections between them
      and shapes of the units of data.
   b) Invocation context specifies the batch size and stores
      all data associated with an invocation (data, gradients).
2. Batch size is now explicit in the context instead of being
   implicitly extracted by layers from incoming data.
3. Separation into Layer and ILayer is now gone, everything is
   now handled in layer implementations (with "leaf" layers focusing
   on data manipulations while container layers focusing on
   network composition).

This is still a very early prototype not intended for mergin:
1. Solver architecture not changed and just crudely hacked to support
   new network architecture.
2. Shared weights not supported.
3. Serialization not supported.
@drahnr
Copy link
Member

drahnr commented Mar 21, 2022

I did an initial pass, it simplifies things on the user end, but what I see as plus, on the other hand side, it removes the ability to mix execution backends iiuc? I'll do another pass soon.

@hweom
Copy link
Contributor Author

hweom commented Mar 23, 2022

it removes the ability to mix execution backends

You mean use different backends for net vs loss in Solver? Yeah, this was a shortcut that I did. All the changes to the solver part of the framework were more or less minimally invasive hacks. This was so I could use the buffer holding the network output directly as an input to the loss layer.

In principle, we can keep the ability to have different backed for loss layer through either of these approaches:

  • Do not store the Backend in Context and pass it as a separate arg to Layer::compute_output() and compute_gradients(). (In reality, there is little reason to store it there, except for parameter passing convenience.) Then we can have a single Context holding all the buffers but 2 Backends.
  • Have 2 separate Context instances (with separate Backends): one for net and another for loss. This would require somehow copying the network output data from one context to another: either via copying the buffer data bytes or (more intelligently) just injecting a buffer reference.

Or did you mean mixing backends in different invocations of the network? I think already nothing precludes that, as layers don't store Backend object internally.

@drahnr
Copy link
Member

drahnr commented Apr 13, 2022

Changes proposed by this PR:

1. Static network graph is separated from invocation context.
   a) Static graph captures layers, connections between them and shapes of the units of data.
   b) Invocation context specifies the batch size and stores all data associated with an invocation (data, gradients).

As mentioned earlier, this makes passing different execution contexts more difficult from what I can see API wise, since the creation of the layers then would have to hold an Rc<B> to the backend.

Storing the associated data as part of the descriptor is not something that seems idiomatic. The descriptor becomes the owner of the actual learned weights iiuc.

A plus of this is, that all layers now have to use the same storage and not be backend specific and also allow things to extend more quickly to other serialization formats.

One use case that must be supported, is to load external network definitions that only share the same input and output dimensions. This allows to i.e. hotswap networks during runtime.

2. Batch size is now explicit in the context instead of being implicitly extracted by layers from incoming data.

I think this is the biggest gain in the new architecture.

3. Separation into Layer and ILayer is now gone, everything is now handled in layer implementations (with "leaf" layers focusing    on data manipulations while container layers focusing on network composition).

👍


This is the first pass, it generally looks very promising, I have to give it another pass in hopefully less than 24d from now 😅

@hweom
Copy link
Contributor Author

hweom commented Apr 14, 2022

As mentioned earlier, this makes passing different execution contexts more difficult from what I can see API wise, since the creation of the layers then would have to hold an Rc<B> to the backend.

I think I misunderstood your earlier. Are you saying that we can't create the network using backend B1 and then pass to it a Context which uses a different backend B2? The same limitation exists now too, though?

Spoiler: I'm toying with an idea of separating backend from context:

pub trait Layer<B: IBackend>: Debug {
    fn compute_output(&self, backend: &B, context: &mut Context);
}

which I think is cleaner (and Context has no use of backend internally anyway). Still the layer is locked to the backend used during creation, and I don't see an easy way around it unless we change it to something like:

pub trait Layer: Debug {
    fn compute_output(&self, backend: &dyn IBackend + LayerOps<f32>, context: &mut Context);
}

(the latter will not compile, but hopefully the idea is clear).

Storing the associated data as part of the descriptor is not something that seems idiomatic. The descriptor becomes the owner of the actual learned weights iiuc.

Well the descriptor is just a convenient way of exposing data from a Layer which the outside world needs to know about. The same can be done with trait functions (much like ILayer::learnable_weights() currently does). Descriptor helps reduce boilerplate code in layers.

The question of ownership is a bit fuzzy with Rc<RefCell<>>, but in my implementation for example Linear layer holds pointers to weights, not relying on the Descriptor:

pub struct Linear {
    // Weight (A) and bias (b) for the linear operation y = Ax + b.
    weight: Rc<RefCell<LearnableParams>>,
    bias: Rc<RefCell<LearnableParams>>,
}

One use case that must be supported, is to load external network definitions that only share the same input and output dimensions. This allows to i.e. hotswap networks during runtime.

This should be already be supported I think. At least I don't see any immediate issues.

This is the first pass, it generally looks very promising, I have to give it another pass in hopefully less than 24d from now sweat_smile

Thanks. I have some updates on my end which I hope to push in about a week. Some cleanups on network side, plus I'm looking into solvers, as I need Adam optimizer for my tasks.

1. Static network graph is separated from invocation context.
   a) Static graph captures layers, connections between them
      and shapes of the units of data.
   b) Invocation context specifies the batch size and stores
      all data associated with an invocation (data, gradients).
2. Batch size is now explicit in the context instead of being
   implicitly extracted by layers from incoming data.
3. Separation into Layer and ILayer is now gone, everything is
   now handled in layer implementations (with "leaf" layers focusing
   on data manipulations while container layers focusing on
   network composition).
4. Solvers replaced by a more linear architecture of a top-level
   Trainer and different Optimizers (although only SGD with momentum
   is currently supported since both RMSprop and Adam require
   squaring backend support).

This is still a very early prototype not intended for mergin:
1. Shared weights not supported.
2. Serialization not supported.
3. Not all layers are migrated.
@hweom
Copy link
Contributor Author

hweom commented Apr 28, 2022

OK, pushed a refreshed version. I couldn't implement Adam since it requires squaring tensors, which is not supported by the backends currently, but I've added some placeholders for it in the new train module.

Mikhail Balakhno and others added 2 commits May 7, 2022 16:09
1. Static network graph is separated from invocation context.
   a) Static graph captures layers, connections between them
      and shapes of the units of data.
   b) Invocation context specifies the batch size and stores
      all data associated with an invocation (data, gradients).
2. Batch size is now explicit in the context instead of being
   implicitly extracted by layers from incoming data.
3. Separation into Layer and ILayer is now gone, everything is
   now handled in layer implementations (with "leaf" layers focusing
   on data manipulations while container layers focusing on
   network composition).
4. Solvers replaced by a more linear architecture of a top-level
   Trainer and different Optimizers (SGD with momentum and Adam
   are currently supported).

This is still a very early prototype not intended for mergin:
1. Shared weights not supported.
2. Serialization not supported.
3. Not all layers are migrated.
@hweom
Copy link
Contributor Author

hweom commented May 7, 2022

Added Adam implementation, for now without backend support.

.as_mut_slice::<f32>();

// We can rewrite the matrix equations at the top of this file in a element-wise form:
// Mᵢ[j] = β₁Mᵢ₋₁[j] + (1-β₁)∇ᵢ[j]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♥️

@drahnr
Copy link
Member

drahnr commented May 22, 2022

Alright, this a significant chunk of work ❤️ I'd like to discuss how we can move towards filling in the missing pieces and a path to getting the adjusted arch back to master.

@hweom
Copy link
Contributor Author

hweom commented May 22, 2022

I think the remaining part should be pretty mechanical -- port other layers to the new infra, write unit tests, etc. I'm happy to do all of that, or we can split the work.

I think it's probably a good idea to commit the current work to a branch, maybe even split in several PRs, to make the review more manageable. The currently missing pieces can be committed as separate PRs into the branch. The branch will have old and new code alongside until everything is ported, after which old code will be deleted.

Do you still want to do an in-depth review of the core infra? I'd be definitely more comfortable if someone can double-check the file structure, names, etc.

@drahnr
Copy link
Member

drahnr commented May 27, 2022

I'll get to that. One thing that came to mind was, bring ready to impl auto differentiation with the new arch. The old one was a bit clunky in that regard.

@hweom
Copy link
Contributor Author

hweom commented May 31, 2022

bring ready to impl auto differentiation

Sorry, not sure what this means. Could you elaborate?

@drahnr
Copy link
Member

drahnr commented Jun 4, 2022

That was supposed to be being - what I meant was, the API should be able to represent inference and training passes both separately and in one step

@hweom
Copy link
Contributor Author

hweom commented Jun 11, 2022

Sorry, can you clarify this? I think it already does it.

Right now the API provides 2 types of abstraction: Network and Trainer (Trainer also operates on Network, but it utilizes a lower-level API of the top Layer):

  • Network is the API for using the net via Network::transform().
  • Trainer is the API for training the net via Trainer::train_minibatch(). The latter also returns the result of the forward pass.

Both APIs hide the low-level details like constructing a Context, pushing inputs into it, extracting outputs, etc.

@drahnr
Copy link
Member

drahnr commented Jun 23, 2022

I think we can move forward with this large refactor. We could have a sync call if you'd like? Sorry for the delay(s)

@hweom
Copy link
Contributor Author

hweom commented Jun 24, 2022

Sure, happy to have a call! I'm in PDT timezone, so it seems the acceptable overlapping time range is your evening and my morning. How about Jun 24, 19:00 Munich time? If it works, I can send a Google Meet invite.

@drahnr
Copy link
Member

drahnr commented Jun 24, 2022

Sure, happy to have a call! I'm in PDT timezone, so it seems the acceptable overlapping time range is your evening and my morning. How about Jun 24, 19:00 Munich time? If it works, I can send a Google Meet invite.

That'd work, please drop to [email protected] - if you get a bounce ( I hope not) it's due some email forwarding service issues, which are hopefully dealt with by now 🤞

@drahnr
Copy link
Member

drahnr commented Jun 27, 2022

Hey 👋 - I created https://github.com/spearow/juice/tree/arch-refactor where we should land the changeset first. You should also have received an invite that allows you to create branches.

@hweom
Copy link
Contributor Author

hweom commented Mar 26, 2023

How much do we want the RNN layer to be implemented in the new arch before switching to it? I'm looking into it, but it will likely require some extensive changes of the backend.

cudnnRNNForwardInference() is deprecated, and its replacement, cudnnRNNForward(), requires batch size-dependent descriptors. It's doable, but will likely take me quite some time.

As far as I can tell, the existing RNN implementation is not used anywhere. I'm not even sure it's implemented correctly.

@@ -311,9 +308,8 @@ fn run_mnist(
targets.push(label_val as usize);
}
// train the network!
let infered_out = solver.train_minibatch(inp_lock.clone(), label_lock.clone());
let mut infered = solver.train_minibatch(inp_lock.clone(), label_lock.clone());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Comment on lines +120 to +121
//pub mod layer;
//pub mod layers;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
//pub mod layer;
//pub mod layers;

Comment on lines +71 to +78
// // Gradient is calculated as 2 * (predictions - labels).
// backend.copy(&labels.borrow(), &mut input_gradient.borrow_mut());
// backend.axpby(
// &native_scalar(2.0),
// &predictions.borrow(),
// &native_scalar(-2.0),
// &mut input_gradient.borrow_mut(),
// );
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be faster :) I am not sure how NaN is treated in axpby though.

branches: Vec<LayerConfig>,
}

pub struct Fanout<B: IBackend> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit of documentation would be nice, since it'll become user visible

/// of the scenario (so the longer the agent is able to keep pole from falling, the bigger
/// overall reward it gets).
///
/// State "s" consists of [cart_pos, cart_vel, pole_angle, pole_angle_vel] variables.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// State "s" consists of [cart_pos, cart_vel, pole_angle, pole_angle_vel] variables.
/// State `s` consists of `[cart_pos, cart_vel, pole_angle, pole_angle_vel]` variables.

Comment on lines +43 to +45
if br != k || c.rows() != m || c.cols() != n {
panic!("Wrong GEMM dimensions: [{},{}]x[{},{}] -> [{},{}]", ar, ac, br, bc, c.rows(), c.cols());
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should consider making it a debug_assert! and rely on the caller.

Copy link
Member

@drahnr drahnr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have yet to do a full code review, generally, looks excellent, a few nits.

@hweom
Copy link
Contributor Author

hweom commented Mar 28, 2023

Sorry, this is an old PR, at this point superseded by all the recent ones. I used it to ask the question: #159 (comment) (my bad, probably should have asked directly).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants