Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add multi-pass inference code #4

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@ training:

hydra:
run:
dir: logs
dir: logs

run_additional_inference: True # Set to false to not run multipass inference.
6 changes: 6 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ def validate_or_test(opt, model, partition, epoch=None):
scalar_outputs = model.forward_downstream_classification_model(
inputs, labels
)

if opt.run_additional_inference:
scalar_outputs = model.forward_downstream_multi_pass(
inputs, labels, scalar_outputs=scalar_outputs
)

test_results = utils.log_results(
test_results, scalar_outputs, num_steps_per_epoch
)
Expand Down
21 changes: 16 additions & 5 deletions src/ff_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ def __init__(self, opt, partition, num_classes=10):
self.uniform_label = torch.ones(self.num_classes) / self.num_classes

def __getitem__(self, index):
pos_sample, neg_sample, neutral_sample, class_label = self._generate_sample(
pos_sample, neg_sample, neutral_sample, all_sample, class_label = self._generate_sample(
index
)

inputs = {
"pos_images": pos_sample,
"neg_images": neg_sample,
"neutral_sample": neutral_sample,
"all_sample": all_sample
}
labels = {"class_labels": class_label}
return inputs, labels
Expand All @@ -32,7 +33,7 @@ def _get_pos_sample(self, sample, class_label):
torch.tensor(class_label), num_classes=self.num_classes
)
pos_sample = sample.clone()
pos_sample[:, 0, : self.num_classes] = one_hot_label
pos_sample[0, 0, : self.num_classes] = one_hot_label
return pos_sample

def _get_neg_sample(self, sample, class_label):
Expand All @@ -44,17 +45,27 @@ def _get_neg_sample(self, sample, class_label):
torch.tensor(wrong_class_label), num_classes=self.num_classes
)
neg_sample = sample.clone()
neg_sample[:, 0, : self.num_classes] = one_hot_label
neg_sample[0, 0, : self.num_classes] = one_hot_label
return neg_sample

def _get_neutral_sample(self, z):
z[:, 0, : self.num_classes] = self.uniform_label
z[0, 0, : self.num_classes] = self.uniform_label
return z

def _get_all_sample(self, sample):
all_samples = torch.zeros((self.num_classes, sample.shape[0], sample.shape[1], sample.shape[2]))
for i in range(self.num_classes):
all_samples[i, :, :, :] = sample.clone()
one_hot_label = torch.nn.functional.one_hot(
torch.tensor(i), num_classes=self.num_classes)
all_samples[i, 0, 0, : self.num_classes] = one_hot_label.clone()
return all_samples

def _generate_sample(self, index):
# Get MNIST sample.
sample, class_label = self.mnist[index]
pos_sample = self._get_pos_sample(sample, class_label)
neg_sample = self._get_neg_sample(sample, class_label)
neutral_sample = self._get_neutral_sample(sample)
return pos_sample, neg_sample, neutral_sample, class_label
all_sample = self._get_all_sample(sample)
return pos_sample, neg_sample, neutral_sample, all_sample, class_label
44 changes: 44 additions & 0 deletions src/ff_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,50 @@ def forward(self, inputs, labels):
inputs, labels, scalar_outputs=scalar_outputs
)

if self.opt.run_additional_inference:
scalar_outputs = self.forward_downstream_multi_pass(
inputs, labels, scalar_outputs=scalar_outputs
)

return scalar_outputs

def forward_downstream_multi_pass(
self, inputs, labels, scalar_outputs=None,
):
if scalar_outputs is None:
scalar_outputs = {
"Loss": torch.zeros(1, device=self.opt.device),
}

z_all = inputs["all_sample"]
z_all = z_all.reshape(z_all.shape[0], z_all.shape[1], -1)
ssq_all = []
for class_num in range(z_all.shape[1]):
z = z_all[:, class_num, :]
z = self._layer_norm(z)
input_classification_model = []

with torch.no_grad():
for idx, layer in enumerate(self.model):
z = layer(z)
z = self.act_fn.apply(z)
z_unnorm = z.clone()
z = self._layer_norm(z)

if idx >= 1:
# print(z.shape)
input_classification_model.append(z_unnorm)

input_classification_model = torch.concat(input_classification_model, dim=-1)
ssq = torch.sum(input_classification_model ** 2, dim=-1)
ssq_all.append(ssq)
ssq_all = torch.stack(ssq_all, dim=-1)

classification_accuracy = utils.get_accuracy(
self.opt, ssq_all.data, labels["class_labels"]
)

scalar_outputs["multi_pass_classification_accuracy"] = classification_accuracy
return scalar_outputs

def forward_downstream_classification_model(
Expand Down