Skip to content

Commit

Permalink
make VoidType always return SAME; update tie weights test
Browse files Browse the repository at this point in the history
Signed-off-by: Jason <[email protected]>
  • Loading branch information
blisc committed Feb 13, 2020
1 parent 9fc14d7 commit a71da02
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 24 deletions.
5 changes: 2 additions & 3 deletions nemo/backends/pytorch/common/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch.nn as nn

from nemo.backends.pytorch.nm import NonTrainableNM, TrainableNM
from nemo.core import NeuralModule
from nemo.core.neural_types import *


Expand All @@ -21,14 +20,14 @@ def input_ports(self):
"""Returns definitions of module input ports.
"""
# return {"input_seq": NeuralType({0: AxisType(TimeTag), 1: AxisType(BatchTag)})}
return {"input_seq": NeuralModule(ChannelType(), ('T', 'B'))}
return {"input_seq": NeuralType(('B', 'T'))}

@property
def output_ports(self):
"""Returns definitions of module output ports.
"""
# return {"outputs": NeuralType({0: AxisType(TimeTag), 1: AxisType(BatchTag), 2: AxisType(ChannelTag),})}
return {"outputs": NeuralType(('T', 'B', 'D'), ChannelType())}
return {"outputs": NeuralType(('B', 'T', 'C'))}

def __init__(self, voc_size, hidden_size, dropout=0.0):
super().__init__()
Expand Down
12 changes: 11 additions & 1 deletion nemo/core/neural_types/neural_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
'NeuralTypeError',
'NeuralPortNameMismatchError',
'NeuralPortNmTensorMismatchError',
'NeuralPortNmTensorMismatchError',
'CanNotInferResultNeuralType',
]
import uuid
Expand All @@ -46,6 +45,15 @@ class NeuralType(object):
type can be optional.
"""

def __str__(self):
return (
f"axes: {[(c.kind, c.size, c.is_list) for c in self.axes]}\n"
f"elements_type: {self.elements_type.__class__.__name__}"
)
# return f"axes: {self.axes}" # " elements_type: {self.elements_type}"
# return f" elements_type: {self.elements_type.__class__.__name__}"
# return "help"

def __init__(self, axes: Optional[Tuple] = None, elements_type: ElementType = VoidType(), optional=False):
if not isinstance(elements_type, ElementType):
raise ValueError(
Expand Down Expand Up @@ -87,6 +95,8 @@ def compare(self, second) -> NeuralTypeComparisonResult:

dimensions_pass = NeuralType.__compare_axes(axes_a, axes_b)
element_comparison_result = self.elements_type.compare(second.elements_type)
if isinstance(second.elements_type, VoidType):
element_comparison_result = NeuralTypeComparisonResult.SAME

# SAME DIMS
if dimensions_pass == 0:
Expand Down
86 changes: 66 additions & 20 deletions tests/core/test_weight_share.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@

import nemo
import nemo.collections.asr as nemo_asr
from nemo.collections.nlp.nm.trainables.common import TokenClassifier
from nemo.collections.nlp.nm.losses import PaddedSmoothedCrossEntropyLossNM
from nemo.core import WeightShareTransform
from nemo.core.neural_types import *
from tests.common_setup import NeMoUnitTest
from nemo.backends.pytorch.nm import DataLayerNM

logging = nemo.logging

Expand Down Expand Up @@ -136,26 +139,69 @@ def test_TaylorNet_get_weights(self):
# tn2.fc1.bias.data = torch.tensor([0.1])
# self.assertTrue(self.__check_if_weights_are_equal(tn1.get_weights(), tn2.get_weights()))

# def test_tie_weights2(self):
# voc_size = 3
# dim = 2
# embd = nemo.backends.pytorch.common.other.SequenceEmbedding(voc_size=voc_size, hidden_size=dim)
# proj = nemo.backends.pytorch.common.other.SequenceProjection(from_dim=dim, to_dim=voc_size)
# embd.tie_weights_with(
# proj,
# weight_names=["embedding.weight"],
# name2name_and_transform={"embedding.weight": ("projection.weight", WeightShareTransform.SAME,)},
# )
# self.assertTrue(
# np.array_equal(embd.embedding.weight.detach().numpy(), proj.projection.weight.detach().numpy(),)
# )
# was = embd.embedding.weight.detach().numpy()
# embd.embedding.weight.data = torch.tensor(np.random.randint(0, 10, (3, 2)) * 1.0)
# after = embd.embedding.weight.detach().numpy()
# self.assertTrue(
# np.array_equal(embd.embedding.weight.detach().numpy(), proj.projection.weight.detach().numpy(),)
# )
# self.assertFalse(np.array_equal(was, after))
def test_tie_weights(self):
class DummyDataLayer(DataLayerNM):
def __init__(self, vocab_size):
super().__init__()
self.vocab_size = vocab_size

class DummyDS(torch.utils.data.Dataset):
def __init__(self, vocab_size):
super().__init__()

def __getitem__(self, index):
model_inputs = torch.randint(high=vocab_size, size=[10])
model_outputs = torch.randint(high=vocab_size, size=[10])
return (model_inputs, model_outputs)

def __len__(self):
return 10

self._dataset = DummyDS(vocab_size)

@property
def output_ports(self):
return {
"model_inputs": NeuralType(('B', 'T')),
"model_outputs": NeuralType(('B', 'T')),
}

def __len__(self):
return len(self._dataset)

@property
def dataset(self):
return self._dataset

def data_iterator(self):
pass

voc_size = 10
dim = 10
embd = nemo.backends.pytorch.common.other.SequenceEmbedding(voc_size=voc_size, hidden_size=dim)
proj = TokenClassifier(hidden_size=dim, num_classes=voc_size)
data = DummyDataLayer(voc_size)
loss = PaddedSmoothedCrossEntropyLossNM(0)
embd.tie_weights_with(
proj,
weight_names=["embedding.weight"],
name2name_and_transform={"embedding.weight": ("mlp.layer2.weight", WeightShareTransform.SAME)},
)
self.assertTrue(
np.array_equal(embd.embedding.weight.detach().cpu().numpy(), proj.mlp.layer2.weight.detach().cpu().numpy())
)
_in, _out = data()
pred = embd(input_seq=_in)
pred = proj(hidden_states=pred)
loss_t = loss(target_ids=_in, logits=pred)

self.nf.train(
[loss_t], optimizer="sgd", optimization_params={"max_steps": 5, "lr": 0.0003},
)

self.assertTrue(
np.array_equal(embd.embedding.weight.detach().cpu().numpy(), proj.mlp.layer2.weight.detach().cpu().numpy())
)

def test_set_weights(self):
voc_size = 3
Expand Down

0 comments on commit a71da02

Please sign in to comment.