-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_data.m
27 lines (26 loc) · 1.13 KB
/
load_data.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
function [train_input, train_target, valid_input, valid_target, test_input, test_target, vocab] = load_data(N)
% This method loads the training, validation and test set.
% It also divides the training set into mini-batches.
% Inputs:
% N: Mini-batch size.
% Outputs:
% train_input: An array of size D X N X M, where
% D: number of input dimensions (in this case, 3).
% N: size of each mini-batch (in this case, 100).
% M: number of minibatches.
% train_target: An array of size 1 X N X M.
% valid_input: An array of size D X number of points in the validation set.
% test: An array of size D X number of points in the test set.
% vocab: Vocabulary containing index to word mapping.
load data/data.mat;
numdims = size(data.trainData, 1);
D = numdims - 1;
M = floor(size(data.trainData, 2) / N);
train_input = reshape(data.trainData(1:D, 1:N * M), D, N, M);
train_target = reshape(data.trainData(D + 1, 1:N * M), 1, N, M);
valid_input = data.validData(1:D, :);
valid_target = data.validData(D + 1, :);
test_input = data.testData(1:D, :);
test_target = data.testData(D + 1, :);
vocab = data.vocab;
end