Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 24, 2023
1 parent d524e9e commit 70e3747
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,10 @@ def main(cfg) -> None:
model.freeze()
config = OmegaConf.to_container(cfg.inference, resolve=True)
model.set_inference_config(config)

if hasattr(cfg.model.data, 'test_ds'):
trainer.test(model)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _process_example(self, example):
context = example[self.context_key]
output = example[self.label_key]
reference = example[self.reference_key]

if self.prompt_template is not None:
assert '{input}' in self.prompt_template
assert '{output}' in self.prompt_template
Expand Down Expand Up @@ -304,4 +304,4 @@ def collate_fn(self, batch):
'reference_texts': reference_texts,
}

return processed_batch
return processed_batch
Original file line number Diff line number Diff line change
Expand Up @@ -358,20 +358,20 @@ def inference_step(self, dataloader_iter, batch_idx, mode, dataloader_idx=0):
input_text = batch.pop('context_texts')
labels_text = batch.pop('reference_texts')
loss = super().validation_step(itertools.chain([batch]), batch_idx)

data_cfg = self.cfg.data.validation_ds if mode == 'validation' else self.cfg.data.test_ds
if self.get_inference_config() is not None:
self._inference_config['add_BOS'] = data_cfg.add_bos
self._inference_config['tokens_to_generate'] = data_cfg.tokens_to_generate

output = self.predict_step(batch, batch_idx, dataloader_idx)
preds_text = [s[len(i):] for i, s in zip(input_text, output['sentences'])]
preds_text = [s[len(i) :] for i, s in zip(input_text, output['sentences'])]

return {
'loss': loss,
'preds': preds_text, # [str]
'labels': labels_text, # [[str]]
'inputs': input_text, # [str]
'preds': preds_text, # [str]
'labels': labels_text, # [[str]]
'inputs': input_text, # [str]
}

def inference_epoch_end(self, outputs, mode, data_cfg):
Expand All @@ -384,29 +384,22 @@ def inference_epoch_end(self, outputs, mode, data_cfg):

averaged_loss = []
averaged_metric = []

# Log metrics for each provided validation/test dataset.
for dataloader_idx, output in enumerate(outputs):
loss = super().validation_epoch_end([x['loss'] for x in output])
loss_log_key = self._determine_log_key(data_cfg, dataloader_idx, "loss", mode)
self.log(loss_log_key, loss)
averaged_loss.append(loss)

# Gather the outputs object from all data parallel ranks since we are using the DistributedSampler which splits data across DDP ranks.
gathered_outputs = [None for _ in range(parallel_state.get_data_parallel_world_size())]
torch.distributed.all_gather_object(
gathered_outputs,
[
{
'preds': x['preds'],
'labels': x['labels'],
'inputs': x['inputs'],
}
for x in output
],
[{'preds': x['preds'], 'labels': x['labels'], 'inputs': x['inputs'],} for x in output],
group=parallel_state.get_data_parallel_group(),
)

# Remove duplicate examples due to distributed sampler.
inp_label_set = set()
deduplicated_outputs = {
Expand All @@ -416,24 +409,22 @@ def inference_epoch_end(self, outputs, mode, data_cfg):
}
for rank in range(0, parallel_state.get_data_parallel_world_size()):
for batch in gathered_outputs[rank]:
for pred, label, input in zip(
batch['preds'], batch['labels'], batch['inputs']
):
for pred, label, input in zip(batch['preds'], batch['labels'], batch['inputs']):
key = input + ' '.join(label)
if key not in inp_label_set:
inp_label_set.add(key)
deduplicated_outputs['preds'].append(pred)
deduplicated_outputs['labels'].append(label)
deduplicated_outputs['inputs'].append(input)
deduplicated_outputs['inputs'].append(input)

# Compute metric score
metric_name = self.val_metric_name if mode == 'validation' else self.test_metric_name
metric_log_key = self._determine_log_key(data_cfg, dataloader_idx, metric_name, mode)
metric_fn = self.val_metric[dataloader_idx] if mode == 'validation' else self.test_metric[dataloader_idx]
metric_result = metric_fn(deduplicated_outputs['preds'], deduplicated_outputs['labels'])

if metric_name == 'rouge':
for k,v in metric_result.items():
for k, v in metric_result.items():
if 'fmeasure' in k:
self.log(metric_log_key + f'_{k}', v.item(), sync_dist=True)
logging.info(f"{mode} {metric_name} {k}: {v.item()}")
Expand All @@ -443,8 +434,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg):
logging.info(f"{mode} {metric_name}: {metric_result.item()}")

averaged_metric.append(metric_result)



# Write predictions to file
if self.global_rank == 0 and data_cfg.get("write_predictions_to_file", False):
logging.info(f"Total deduplicated inference data size: {len(deduplicated_outputs['inputs'])}")
Expand All @@ -455,8 +445,10 @@ def inference_epoch_end(self, outputs, mode, data_cfg):
f"Cannot write predictions to file when output_file_path_prefix is not set or present in the yaml config file."
)
filename_log_key = self._determine_log_key(data_cfg, dataloader_idx, None, mode)
self.write_predictions_to_file(deduplicated_outputs, f"{data_cfg.output_file_path_prefix}_{filename_log_key}")

self.write_predictions_to_file(
deduplicated_outputs, f"{data_cfg.output_file_path_prefix}_{filename_log_key}"
)

torch.distributed.barrier(group=parallel_state.get_data_parallel_group())

# Logging of the averaged metrics:
Expand Down

0 comments on commit 70e3747

Please sign in to comment.