From 7e141b582fe611efb6d5ff5c8c487bcd1e197cc1 Mon Sep 17 00:00:00 2001 From: Victor Shepardson Date: Wed, 30 Mar 2022 11:35:25 +0000 Subject: [PATCH] model comments, docstrings --- notepredictor/notepredictor/model.py | 39 +++++++++++++++++++++------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/notepredictor/notepredictor/model.py b/notepredictor/notepredictor/model.py index 2a5d68e..57e607c 100644 --- a/notepredictor/notepredictor/model.py +++ b/notepredictor/notepredictor/model.py @@ -81,7 +81,19 @@ def forward(self, x): # return self.net(x) class ModalityTransformer(nn.Module): - """Model joint distribution of modalities autoregressively with random permutations""" + """ + Model joint distribution of note modalities (e.g. pitch, time, velocity). + + This is an autoregressive Transformer model for the *internal* structure of notes. + It is *not* autoregressive in time, but in modality. + At training time, it executes in parallel over all timesteps and modalities, with + time dependencies provided via the RNN backbone. + + At sampling time it is called serially, one modality at a time, + repeatedly at each time step. + + Inspired by XLNet: http://arxiv.org/abs/1906.08237 + """ def __init__(self, input_size, hidden_size, heads=4, layers=1): super().__init__() self.net = nn.TransformerDecoder( @@ -95,13 +107,11 @@ def forward(self, ctx, h_ctx, h_tgt): ctx: list of Tensor[batch x time x input_size], length note_dim-1 these are the embedded ground truth values h_ctx: Tensor[batch x time x input_size] - (need something to attend to when ctx is empty) + projection of RNN state (need something to attend to when ctx is empty) h_tgt: list of Tensor[batch x time x input_size], length note_dim - these are projections of the RNN state + these are projections of the RNN state for each target, + which the Transformer will map to distribution parameters. """ - # h_tgt = list(h_tgt) - # ctx = list(ctx) - # explicitly broadcast h_ctx, *ctx = torch.broadcast_tensors(h_ctx, *ctx) h_ctx, *h_tgt = torch.broadcast_tensors(h_ctx, *h_tgt) @@ -122,6 +132,7 @@ def forward(self, ctx, h_ctx, h_tgt): # generate a mask # this is both the target and memory mask + # masking is such that each target can only depend on "previous" context n = len(h_tgt) mask = ~tgt.new_ones((n,n), dtype=bool).tril() @@ -254,7 +265,7 @@ def embeddings(self): def forward(self, pitches, times, velocities, validation=False): """ - teacher-forced probabilistic loss and diagnostics for training + teacher-forced probabilistic loss and diagnostics for training. Args: pitches: LongTensor[batch, time] @@ -263,12 +274,14 @@ def forward(self, pitches, times, velocities, validation=False): """ batch_size, batch_len = pitches.shape + # embed data to input vectors pitch_emb = self.pitch_emb(pitches) # batch, time, emb_size time_emb = self.time_emb(times) # batch, time, emb_size vel_emb = self.vel_emb(velocities) # batch, time, emb_size embs = (pitch_emb, time_emb, vel_emb) + # feed to RNN backbone x = torch.cat(embs, -1)[:,:-1] # skip last time position ## broadcast initial state to batch size initial_state = tuple( @@ -276,20 +289,26 @@ def forward(self, pitches, times, velocities, validation=False): for t in self.initial_state) h, _ = self.rnn(x, initial_state) #batch, time, hidden_size - # fit all note factorizations at once. + # fit all note factorizations (e.g. pitch->time->vel vs vel->time->pitch) # TODO: perm each batch item independently? + # get a random ordering for note modalities: perm = torch.randperm(self.note_dim) + # chunk RNN state into Transformer inputs hs = list(self.h_proj(h).chunk(self.note_dim+1, -1)) h_ctx = hs[0] h_tgt = [hs[i+1] for i in perm] + # embed ground truth values for teacher-forcing embs = [embs[i][:,1:] for i in perm[:-1]] + # run through Transformer to conditional hidden states mode_hs = self.xformer(embs, h_ctx, h_tgt) + # permute back to canonical order mode_hs = [mode_hs[i] for i in perm.argsort()] + # final projections to raw distribution parameters pitch_params, time_params, vel_params = [ proj(h) for proj,h in zip(self.projections, mode_hs)] - # get likelihoods + # get likelihoods of data for each modality pitch_logits = F.log_softmax(pitch_params, -1) pitch_targets = pitches[:,1:,None] #batch, time, 1 pitch_log_probs = pitch_logits.gather(-1, pitch_targets)[...,0] @@ -309,6 +328,8 @@ def forward(self, pitches, times, velocities, validation=False): **{'time_'+k:v for k,v in time_result.items()}, **{'velocity_'+k:v for k,v in vel_result.items()} } + # this just computes some extra diagnostics which are inconvenient to do in the + # training script. should be turned off during training for performance. if validation: with torch.no_grad(): r['time_acc_30ms'] = (