Skip to content

Commit

Permalink
修复Trainer里check_code函数忽略pin_memory参数导致的内存bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ouyhlan committed Nov 29, 2021
1 parent 3cb01d1 commit cac1331
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions fastNLP/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
check_batch_size = max(len(self.model.device_ids), check_batch_size)
_check_code(dataset=train_data, model=self.model, losser=losser, forward_func=self._forward_func, metrics=metrics,
dev_data=dev_dataset, metric_key=self.metric_key, check_level=check_code_level,
batch_size=check_batch_size)
batch_size=check_batch_size, pin_memory=self.pin_memory)

self.train_data = train_data
self.dev_data = dev_data # If None, No validation.
Expand Down Expand Up @@ -950,7 +950,7 @@ def _get_value_info(_dict):
return strs


def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAULT_CHECK_BATCH_SIZE,
def _check_code(dataset, model, losser, metrics, forward_func, pin_memory, batch_size=DEFAULT_CHECK_BATCH_SIZE,
dev_data=None, metric_key=None, check_level=0):
# check get_loss 方法
model_device = _get_model_device(model=model)
Expand Down Expand Up @@ -1010,7 +1010,7 @@ def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAUL

if dev_data is not None:
tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
batch_size=batch_size, verbose=-1, use_tqdm=False)
batch_size=batch_size, verbose=-1, use_tqdm=False, pin_memory=pin_memory)
evaluate_results = tester.test()
_check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics)

Expand Down

0 comments on commit cac1331

Please sign in to comment.