-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmidi_dataset.py
224 lines (159 loc) · 7.34 KB
/
midi_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
#%%
import glob
import json
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import pretty_midi
import soundfile as sf
import torch
from IPython.display import Audio, display
from PIL import Image
from tqdm import tqdm
from chiptune import chiptunes_synthesize
from piano_roll_to_pretty_midi import to_pretty_midi
from util import (crop_augment_piano_roll, crop_roll_128_to_88, get_onsets,
pad_roll_88_to_128)
def transpose(midi_data, offset):
for instrument in midi_data.instruments:
# Don't want to shift drum notes
if not instrument.is_drum:
for note in instrument.notes:
note.pitch += offset
return midi_data
def change_tempo(midi_data, factor):
# Get the length of the MIDI file
length = midi_data.get_end_time()
midi_data.adjust_times([0, length], [0, length*factor])
return midi_data
def has_simulateneous_neighbouring_notes(roll):
roll = (roll>0).astype(np.int32)
padded_roll = np.pad(roll,((1,1),(0,0)),"constant",constant_values=0)
two_note_sum = padded_roll + np.roll(padded_roll,1,0)
return (np.max(two_note_sum)>1)
class MidiDataset(torch.utils.data.Dataset):
def __init__(self,midi_filepaths=None, prepared_data_path = None, crop_size=None,downsample_factor=1, only_88_keys=True):
self.crop_size=crop_size
self.downsample_factor = downsample_factor
self.N_STEPS=64
if prepared_data_path is not None:
self.data = torch.load(prepared_data_path)
else:
self.fps = glob.glob(midi_filepaths,recursive=True)
self.data = []
for fp in tqdm(self.fps):
try:
piano_roll = self.load_piano_roll(fp)
except:
continue
if only_88_keys:
piano_roll = crop_roll_128_to_88(piano_roll)
piano_rolls = self.chunk_piano_roll(piano_roll)
if len(piano_rolls) > 0:
for piano_roll in piano_rolls:
self.data.append({"piano_roll":piano_roll,"caption":self.fp_to_caption(fp), "filepath":fp})
print("Loaded {} examples".format(len(self.data)))
# make sure all piano rolls are of size N_STEPS
self.data = [example for example in self.data if example["piano_roll"].shape[1] == self.N_STEPS]
print("Filtered to {} examples because of size".format(len(self.data)))
if self.downsample_factor != 1:
# downsample all piano rolls
for i in range(len(self.data)):
self.data[i]["piano_roll"] = self.data[i]["piano_roll"][:,::self.downsample_factor]
# remove all empty piano rolls
self.data = [ example for example in self.data if example["piano_roll"].sum() > 0]
print("Filtered to {} examples because of empty".format(len(self.data)))
# filter away rolls with less than 8 onsets
self.data = [example for example in self.data if np.sum(get_onsets(example["piano_roll"][None,...]))>7]
print("Filtered to {} examples because of too few onsets".format(len(self.data)))
# filter away rolls with neighbouring notes being played simultaneously
self.data = [example for example in self.data if not has_simulateneous_neighbouring_notes(example["piano_roll"])]
print("Filtered to {} examples because of neighbouring notes".format(len(self.data)))
new_data = []
seen = set()
# filter away non unique piano rolls
for i in range(len(self.data)):
if self.data[i]["piano_roll"].tostring() in seen:
continue
else:
new_data.append(self.data[i])
seen.add(self.data[i]["piano_roll"].tostring())
self.data = new_data
print("Filtered to {} after removing duplicates".format(len(self.data)))
def __getitem__(self, idx):
example = self.data[idx]
roll = example["piano_roll"].copy()
# random crop
if self.crop_size is not None:
roll = crop_augment_piano_roll(roll,self.crop_size)
return {"piano_roll":roll,"caption":example["caption"]}
def save_data(self, path):
torch.save(self.data,path)
def chunk_piano_roll(self,full_piano_roll):
if full_piano_roll.shape[1] < self.N_STEPS:
return []
# trim end to multiple of N_STEPS
full_piano_roll = full_piano_roll[:,:full_piano_roll.shape[1]//self.N_STEPS*self.N_STEPS]
# split into 64 step chunks
piano_rolls = np.split(full_piano_roll,full_piano_roll.shape[1]//self.N_STEPS,axis=1)
# filter away empty chunks
piano_rolls = [piano_roll for piano_roll in piano_rolls if np.sum(piano_roll) > 0]
return piano_rolls
def __len__(self):
return len(self.data)
def fp_to_caption(self,fp):
split_fp = fp.split("/")
file_name = split_fp[-1]
game_name = split_fp[-2]
title = file_name.split("-")[-1]
title = title.split(".")[0]
# remove numbers
title = "".join([i for i in title if not i.isdigit()])
caption = f"{game_name} {title}"
return caption
def load_piano_roll(self,fp):
midi = pretty_midi.PrettyMIDI(fp)
#check 4 time
assert len(midi.time_signature_changes) < 2
if len(midi.time_signature_changes) == 1:
assert midi.time_signature_changes[0].numerator == 4
beat_times=midi.get_beats()
beat_ticks = [midi.time_to_tick(time) for time in beat_times]
# get beat length
quarter_length = beat_ticks[1]-beat_ticks[0]
# check that beats are all the same length
for i in range(len(beat_ticks)-1):
assert beat_ticks[i+1]-beat_ticks[i] == quarter_length
# check that beats is a multiple of 4
assert quarter_length%4 == 0
steps_per_beat = 4
# get 16th note length
step_length = quarter_length//steps_per_beat
last_end = 0
# quantize midi to nearest 16th note
for instrument in midi.instruments:
for note in instrument.notes:
note.start = midi.tick_to_time(step_length*(midi.time_to_tick(note.start)//step_length))
note.end = midi.tick_to_time(step_length*(midi.time_to_tick(note.end)//step_length))
if note.end > last_end:
last_end = note.end
for instrument in midi.instruments:
if instrument.is_drum:
instrument.notes = []
# get time of first event
# convert to ticks
first_onset_ticks = midi.time_to_tick(midi.get_onsets()[0])
# get beat of first event
first_beat = first_onset_ticks//quarter_length
# get time of first beat
first_beat_time = midi.tick_to_time(first_beat*quarter_length)
n_beats=len(beat_times)
n_steps = n_beats*steps_per_beat
sampling_steps_per_beat = steps_per_beat
# get 16th note in seconds
sampling_step_time = midi.tick_to_time(quarter_length/sampling_steps_per_beat)
# get piano roll
piano_roll = midi.get_piano_roll(fs=1/sampling_step_time,times=np.arange(first_beat_time,last_end+midi.tick_to_time(quarter_length),sampling_step_time))
return piano_roll
# %%