Skip to content

Commit

Permalink
Fixing a couple of bugs with the distributed evaluator.
Browse files Browse the repository at this point in the history
  • Loading branch information
searchivarius committed Feb 4, 2025
1 parent f871c85 commit 33b506e
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions scripts/train_nn/eval_model_distr.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

from flexneuart.eval import FAKE_DOC_ID, METRIC_LIST, get_eval_results

from mtasklite import delayed_init
from mtasklite import delayed_init, ExceptionBehaviour
from mtasklite.processes import pqdm

from time import time
Expand Down Expand Up @@ -95,7 +95,8 @@
parser.add_argument('--eval_metric', choices=METRIC_LIST, default=METRIC_LIST[0],
help='Metric list: ' + ','.join(METRIC_LIST),
metavar='eval metric')
parser.add_argument('--device_qty', required=True, help='A number of the CUDA devices to use')
parser.add_argument('--device_qty', type=int, required=True, help='A number of the CUDA devices to use')
parser.add_argument('--override_max_doc_len', type=int, default=None)

parser.add_argument(f'--{IGNORE_MISS}',
help='ignore queries missing from the run file or vice versa',
Expand Down Expand Up @@ -128,6 +129,7 @@
query_field = args.index_field

device_qty=args.device_qty
print(f'Number of devices: {device_qty}')

max_query_val = args.max_num_query

Expand Down Expand Up @@ -185,11 +187,13 @@

@delayed_init
class Worker:
def __init__(self, device_name):
def __init__(self, device_name, override_max_doc_len):
print('Loading model from:', fname, ' to device', device_name)
model_holder = ModelSerializer.load_all(fname)
self.max_doc_len = model_holder.max_doc_len
self.max_query_len = model_holder.max_query_len
if override_max_doc_len is not None:
self.max_doc_len = args.override_max_doc_len
self.model = model_holder.model
self.model.to(device_name)
self.device_name = device_name
Expand All @@ -200,11 +204,13 @@ def __init__(self, device_name):
def __call__(self, query_id):
query_dict = {query_id: query_dict_all[query_id]}
dataset = query_dict, data_dict
orig_run = {query_id: valid_run_head[query_id]}
return run_model(self.model,
device_name=self.device_name,
batch_size=args.batch_size, amp=args.amp,
max_query_len=self.max_query_len, max_doc_len=self.max_doc_len,
dataset=dataset, orig_run=valid_run_head,
dataset=dataset, orig_run=orig_run,
use_progress_bar=False,
cand_score_weight=args.cand_score_weight,
desc='validating the run')
min_top_orig_score = {}
Expand Down Expand Up @@ -261,9 +267,9 @@ def __call__(self, query_id):
start_val_time = time()
rerank_run_head = {}

for run_dict in pqdm(list(query_dict_all),
[Worker(device_name) for device_name in get_device_name_arr(device_qty)],
device_qty):
for run_dict in pqdm(list(query_dict_all.keys()),
[Worker(device_name, args.override_max_doc_len) for device_name in get_device_name_arr(device_qty)],
exception_behaviour=ExceptionBehaviour.IMMEDIATE):
for k, v in run_dict.items():
rerank_run_head[k] = v

Expand Down

0 comments on commit 33b506e

Please sign in to comment.