Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port to pytorch 1.x #41

Merged
merged 2 commits into from
Apr 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions dnc/dnc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .util import *
from .memory import *

from torch.nn.init import orthogonal, xavier_uniform
from torch.nn.init import orthogonal_, xavier_uniform_


class DNC(nn.Module):
Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(

# final output layer
self.output = nn.Linear(self.nn_output_size, self.input_size)
orthogonal(self.output.weight)
orthogonal_(self.output.weight)

if self.gpu_id != -1:
[x.cuda(self.gpu_id) for x in self.rnns]
Expand All @@ -131,7 +131,7 @@ def _init_hidden(self, hx, batch_size, reset_experience):
# initialize hidden state of the controller RNN
if chx is None:
h = cuda(T.zeros(self.num_hidden_layers, batch_size, self.output_size), gpu_id=self.gpu_id)
xavier_uniform(h)
xavier_uniform_(h)

chx = [ (h, h) if self.rnn_type.lower() == 'lstm' else h for x in range(self.num_layers)]

Expand Down
28 changes: 14 additions & 14 deletions dnc/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,45 +214,45 @@ def forward(self, ξ, hidden):

if self.independent_linears:
# r read keys (b * r * w)
read_keys = F.tanh(self.read_keys_transform(ξ).view(b, r, w))
read_keys = T.tanh(self.read_keys_transform(ξ).view(b, r, w))
# r read strengths (b * r)
read_strengths = F.softplus(self.read_strengths_transform(ξ).view(b, r))
# write key (b * 1 * w)
write_key = F.tanh(self.write_key_transform(ξ).view(b, 1, w))
write_key = T.tanh(self.write_key_transform(ξ).view(b, 1, w))
# write strength (b * 1)
write_strength = F.softplus(self.write_strength_transform(ξ).view(b, 1))
# erase vector (b * 1 * w)
erase_vector = F.sigmoid(self.erase_vector_transform(ξ).view(b, 1, w))
erase_vector = T.sigmoid(self.erase_vector_transform(ξ).view(b, 1, w))
# write vector (b * 1 * w)
write_vector = F.tanh(self.write_vector_transform(ξ).view(b, 1, w))
write_vector = T.tanh(self.write_vector_transform(ξ).view(b, 1, w))
# r free gates (b * r)
free_gates = F.sigmoid(self.free_gates_transform(ξ).view(b, r))
free_gates = T.sigmoid(self.free_gates_transform(ξ).view(b, r))
# allocation gate (b * 1)
allocation_gate = F.sigmoid(self.allocation_gate_transform(ξ).view(b, 1))
allocation_gate = T.sigmoid(self.allocation_gate_transform(ξ).view(b, 1))
# write gate (b * 1)
write_gate = F.sigmoid(self.write_gate_transform(ξ).view(b, 1))
write_gate = T.sigmoid(self.write_gate_transform(ξ).view(b, 1))
# read modes (b * r * 3)
read_modes = σ(self.read_modes_transform(ξ).view(b, r, 3), 1)
else:
ξ = self.interface_weights(ξ)
# r read keys (b * w * r)
read_keys = F.tanh(ξ[:, :r * w].contiguous().view(b, r, w))
read_keys = T.tanh(ξ[:, :r * w].contiguous().view(b, r, w))
# r read strengths (b * r)
read_strengths = F.softplus(ξ[:, r * w:r * w + r].contiguous().view(b, r))
# write key (b * w * 1)
write_key = F.tanh(ξ[:, r * w + r:r * w + r + w].contiguous().view(b, 1, w))
write_key = T.tanh(ξ[:, r * w + r:r * w + r + w].contiguous().view(b, 1, w))
# write strength (b * 1)
write_strength = F.softplus(ξ[:, r * w + r + w].contiguous().view(b, 1))
# erase vector (b * w)
erase_vector = F.sigmoid(ξ[:, r * w + r + w + 1: r * w + r + 2 * w + 1].contiguous().view(b, 1, w))
erase_vector = T.sigmoid(ξ[:, r * w + r + w + 1: r * w + r + 2 * w + 1].contiguous().view(b, 1, w))
# write vector (b * w)
write_vector = F.tanh(ξ[:, r * w + r + 2 * w + 1: r * w + r + 3 * w + 1].contiguous().view(b, 1, w))
write_vector = T.tanh(ξ[:, r * w + r + 2 * w + 1: r * w + r + 3 * w + 1].contiguous().view(b, 1, w))
# r free gates (b * r)
free_gates = F.sigmoid(ξ[:, r * w + r + 3 * w + 1: r * w + 2 * r + 3 * w + 1].contiguous().view(b, r))
free_gates = T.sigmoid(ξ[:, r * w + r + 3 * w + 1: r * w + 2 * r + 3 * w + 1].contiguous().view(b, r))
# allocation gate (b * 1)
allocation_gate = F.sigmoid(ξ[:, r * w + 2 * r + 3 * w + 1].contiguous().unsqueeze(1).view(b, 1))
allocation_gate = T.sigmoid(ξ[:, r * w + 2 * r + 3 * w + 1].contiguous().unsqueeze(1).view(b, 1))
# write gate (b * 1)
write_gate = F.sigmoid(ξ[:, r * w + 2 * r + 3 * w + 2].contiguous()).unsqueeze(1).view(b, 1)
write_gate = T.sigmoid(ξ[:, r * w + 2 * r + 3 * w + 2].contiguous()).unsqueeze(1).view(b, 1)
# read modes (b * 3*r)
read_modes = σ(ξ[:, r * w + 2 * r + 3 * w + 3: r * w + 5 * r + 3 * w + 3].contiguous().view(b, r, 3), 1)

Expand Down
2 changes: 1 addition & 1 deletion dnc/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.nn.utils.rnn import pad_packed_sequence as pad
from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import PackedSequence
from torch.nn.init import orthogonal, xavier_uniform
from torch.nn.init import orthogonal_, xavier_uniform_

from .util import *
from .sparse_memory import SparseMemory
Expand Down
2 changes: 1 addition & 1 deletion dnc/sdnc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.nn.utils.rnn import pad_packed_sequence as pad
from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import PackedSequence
from torch.nn.init import orthogonal, xavier_uniform
from torch.nn.init import orthogonal_, xavier_uniform_

from .util import *
from .sparse_temporal_memory import SparseTemporalMemory
Expand Down
18 changes: 9 additions & 9 deletions dnc/sparse_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,17 @@ def __init__(
self.write_vector_transform = nn.Linear(self.input_size, w)
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c)
self.write_gate_transform = nn.Linear(self.input_size, 1)
T.nn.init.orthogonal(self.read_query_transform.weight)
T.nn.init.orthogonal(self.write_vector_transform.weight)
T.nn.init.orthogonal(self.interpolation_gate_transform.weight)
T.nn.init.orthogonal(self.write_gate_transform.weight)
T.nn.init.orthogonal_(self.read_query_transform.weight)
T.nn.init.orthogonal_(self.write_vector_transform.weight)
T.nn.init.orthogonal_(self.interpolation_gate_transform.weight)
T.nn.init.orthogonal_(self.write_gate_transform.weight)
else:
self.interface_size = (r * w) + w + self.c + 1
if self.gpu_id != -1:
self.interface_weights = nn.Linear(self.input_size, self.interface_size).cuda()
else:
self.interface_weights = nn.Linear(self.input_size, self.interface_size)
T.nn.init.orthogonal(self.interface_weights.weight)
T.nn.init.orthogonal_(self.interface_weights.weight)

self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
self.δ = 0.005 # minimum usage
Expand Down Expand Up @@ -299,19 +299,19 @@ def forward(self, ξ, hidden):
# write key (b * 1 * w)
write_vector = self.write_vector_transform(ξ).view(b, 1, w)
# write vector (b * 1 * r)
interpolation_gate = F.sigmoid(self.interpolation_gate_transform(ξ)).view(b, c)
interpolation_gate = T.sigmoid(self.interpolation_gate_transform(ξ)).view(b, c)
# write gate (b * 1)
write_gate = F.sigmoid(self.write_gate_transform(ξ).view(b, 1))
write_gate = T.sigmoid(self.write_gate_transform(ξ).view(b, 1))
else:
ξ = self.interface_weights(ξ)
# r read keys (b * r * w)
read_query = ξ[:, :r * w].contiguous().view(b, r, w)
# write key (b * 1 * w)
write_vector = ξ[:, r * w: r * w + w].contiguous().view(b, 1, w)
# write vector (b * 1 * r)
interpolation_gate = F.sigmoid(ξ[:, r * w + w: r * w + w + c]).contiguous().view(b, c)
interpolation_gate = T.sigmoid(ξ[:, r * w + w: r * w + w + c]).contiguous().view(b, c)
# write gate (b * 1)
write_gate = F.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)
write_gate = T.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)

self.timestep += 1
hidden = self.write(interpolation_gate, write_vector, write_gate, hidden)
Expand Down
18 changes: 9 additions & 9 deletions dnc/sparse_temporal_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ def __init__(
self.write_vector_transform = nn.Linear(self.input_size, w)
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c)
self.write_gate_transform = nn.Linear(self.input_size, 1)
T.nn.init.orthogonal(self.read_query_transform.weight)
T.nn.init.orthogonal(self.write_vector_transform.weight)
T.nn.init.orthogonal(self.interpolation_gate_transform.weight)
T.nn.init.orthogonal(self.write_gate_transform.weight)
T.nn.init.orthogonal_(self.read_query_transform.weight)
T.nn.init.orthogonal_(self.write_vector_transform.weight)
T.nn.init.orthogonal_(self.interpolation_gate_transform.weight)
T.nn.init.orthogonal_(self.write_gate_transform.weight)
else:
self.interface_size = (r * w) + w + self.c + 1
self.interface_weights = nn.Linear(self.input_size, self.interface_size)
T.nn.init.orthogonal(self.interface_weights.weight)
T.nn.init.orthogonal_(self.interface_weights.weight)

self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
self.δ = 0.005 # minimum usage
Expand Down Expand Up @@ -358,19 +358,19 @@ def forward(self, ξ, hidden):
# write key (b * 1 * w)
write_vector = self.write_vector_transform(ξ).view(b, 1, w)
# write vector (b * 1 * r)
interpolation_gate = F.sigmoid(self.interpolation_gate_transform(ξ)).view(b, c)
interpolation_gate = T.sigmoid(self.interpolation_gate_transform(ξ)).view(b, c)
# write gate (b * 1)
write_gate = F.sigmoid(self.write_gate_transform(ξ).view(b, 1))
write_gate = T.sigmoid(self.write_gate_transform(ξ).view(b, 1))
else:
ξ = self.interface_weights(ξ)
# r read keys (b * r * w)
read_query = ξ[:, :r * w].contiguous().view(b, r, w)
# write key (b * 1 * w)
write_vector = ξ[:, r * w: r * w + w].contiguous().view(b, 1, w)
# write vector (b * 1 * r)
interpolation_gate = F.sigmoid(ξ[:, r * w + w: r * w + w + c]).contiguous().view(b, c)
interpolation_gate = T.sigmoid(ξ[:, r * w + w: r * w + w + c]).contiguous().view(b, c)
# write gate (b * 1)
write_gate = F.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)
write_gate = T.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)

self.timestep += 1
hidden = self.write(interpolation_gate, write_vector, write_gate, hidden)
Expand Down
31 changes: 20 additions & 11 deletions dnc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch.nn as nn
import torch as T
import torch.nn.functional as F
from torch.autograd import Variable as var
import numpy as np
import torch
from torch.autograd import Variable
Expand All @@ -24,24 +23,37 @@ def recursiveTrace(obj):


def cuda(x, grad=False, gpu_id=-1):
x = x.float() if T.is_tensor(x) else x
if gpu_id == -1:
return var(x, requires_grad=grad)
t = T.FloatTensor(x)
t.requires_grad=grad
return t
else:
return var(x.pin_memory(), requires_grad=grad).cuda(gpu_id, async=True)
t = T.FloatTensor(x.pin_memory()).cuda(gpu_id, async=True)
t.requires_grad=grad
return t


def cudavec(x, grad=False, gpu_id=-1):
if gpu_id == -1:
return var(T.from_numpy(x), requires_grad=grad)
t = T.Tensor(T.from_numpy(x))
t.requires_grad = grad
return t
else:
return var(T.from_numpy(x).pin_memory(), requires_grad=grad).cuda(gpu_id, async=True)
t = T.Tensor(T.from_numpy(x).pin_memory()).cuda(gpu_id, async=True)
t.requires_grad = grad
return t


def cudalong(x, grad=False, gpu_id=-1):
if gpu_id == -1:
return var(T.from_numpy(x.astype(np.long)), requires_grad=grad)
t = T.LongTensor(T.from_numpy(x.astype(np.long)))
t.requires_grad = grad
return t
else:
return var(T.from_numpy(x.astype(np.long)).pin_memory(), requires_grad=grad).cuda(gpu_id, async=True)
t = T.LongTensor(T.from_numpy(x.astype(np.long)).pin_memory()).cuda(gpu_id, async=True)
t.requires_grad = grad
return t


def θ(a, b, dimA=2, dimB=2, normBy=2):
Expand Down Expand Up @@ -89,10 +101,7 @@ def σ(input, axis=1):
trans_size = trans_input.size()

input_2d = trans_input.contiguous().view(-1, trans_size[-1])
if '0.3' in T.__version__:
soft_max_2d = F.softmax(input_2d, -1)
else:
soft_max_2d = F.softmax(input_2d)
soft_max_2d = F.softmax(input_2d, -1)
soft_max_nd = soft_max_2d.view(*trans_size)
return soft_max_nd.transpose(axis, len(input_size) - 1)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
setup(
name='dnc',

version='0.0.9',
version='0.1.0',

description='Differentiable Neural Computer, for Pytorch',
long_description=long_description,
Expand Down
4 changes: 2 additions & 2 deletions tasks/adding_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.nn.functional as F
import torch.optim as optim

from torch.nn.utils import clip_grad_norm
from torch.nn.utils import clip_grad_norm_

from dnc.dnc import DNC
from dnc.sdnc import SDNC
Expand Down Expand Up @@ -219,7 +219,7 @@ def cross_entropy(prediction, target):

loss.backward()

T.nn.utils.clip_grad_norm(rnn.parameters(), args.clip)
T.nn.utils.clip_grad_norm_(rnn.parameters(), args.clip)
optimizer.step()
loss_value = loss.data[0]

Expand Down
4 changes: 2 additions & 2 deletions tasks/argmax_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.nn.functional as F
import torch.optim as optim

from torch.nn.utils import clip_grad_norm
from torch.nn.utils import clip_grad_norm_

from dnc.dnc import DNC
from dnc.sdnc import SDNC
Expand Down Expand Up @@ -225,7 +225,7 @@ def generate_data(length, size):

loss.backward()

T.nn.utils.clip_grad_norm(rnn.parameters(), args.clip)
T.nn.utils.clip_grad_norm_(rnn.parameters(), args.clip)
optimizer.step()
loss_value = loss.data[0]

Expand Down
4 changes: 2 additions & 2 deletions tasks/copy_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.nn.functional as F
import torch.optim as optim

from torch.nn.utils import clip_grad_norm
from torch.nn.utils import clip_grad_norm_

from dnc.dnc import DNC
from dnc.sdnc import SDNC
Expand Down Expand Up @@ -212,7 +212,7 @@ def criterion(predictions, targets):

loss.backward()

T.nn.utils.clip_grad_norm(rnn.parameters(), args.clip)
T.nn.utils.clip_grad_norm_(rnn.parameters(), args.clip)
optimizer.step()
loss_value = loss.data[0]

Expand Down
8 changes: 4 additions & 4 deletions test/test_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch as T
from torch.autograd import Variable as var
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm
from torch.nn.utils import clip_grad_norm_
import torch.optim as optim
import numpy as np

Expand Down Expand Up @@ -71,7 +71,7 @@ def test_rnn_1():
loss = criterion((output), target_output)
loss.backward()

T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
optimizer.step()

assert target_output.size() == T.Size([21, 10, 100])
Expand Down Expand Up @@ -127,7 +127,7 @@ def test_rnn_n():
loss = criterion((output), target_output)
loss.backward()

T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
optimizer.step()

assert target_output.size() == T.Size([27, 10, 100])
Expand Down Expand Up @@ -188,7 +188,7 @@ def test_rnn_no_memory_pass():
loss = criterion((output), target_output)
loss.backward()

T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
optimizer.step()

assert target_output.size() == T.Size([27, 10, 100])
Expand Down
2 changes: 1 addition & 1 deletion test/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch as T
from torch.autograd import Variable as var
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm
from torch.nn.utils import clip_grad_norm_
import torch.optim as optim
import numpy as np

Expand Down
Loading