Skip to content

Commit

Permalink
Refactor ava hook (#567)
Browse files Browse the repository at this point in the history
* resolve comments

* update changelog

* Refactor AVA Eval, only support mAP

* fix unittest

* add warning

* Refactor AVA Eval, only support mAP

* fix unittest

* add warning

* update warning info
  • Loading branch information
kennymckormick authored Jan 27, 2021
1 parent 1e90c1f commit 944c5b1
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 12 deletions.
2 changes: 1 addition & 1 deletion configs/detection/ava/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [optional arguments]
Example: test SlowOnly model on AVA and dump the result to a csv file.

```shell
python tools/test.py configs/detection/AVA/slowonly_kinetics_pretrained_r50_8x8x1_20e_ava_rgb.py checkpoints/SOME_CHECKPOINT.pth --eval bbox --out results.csv
python tools/test.py configs/detection/AVA/slowonly_kinetics_pretrained_r50_8x8x1_20e_ava_rgb.py checkpoints/SOME_CHECKPOINT.pth --eval mAP --out results.csv
```

For more details and optional arguments infos, you can refer to **Test a dataset** part in [getting_started](/docs/getting_started.md#test-a-dataset) .
4 changes: 2 additions & 2 deletions mmaction/core/evaluation/ava_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def ava_eval(result_file,
max_dets=(100, ),
verbose=True):

assert result_type in ['proposal', 'bbox']
assert result_type in ['mAP']

start = time.time()
categories, class_whitelist = read_labelmap(open(label_file))
Expand Down Expand Up @@ -213,7 +213,7 @@ def ava_eval(result_file,
ret[f'AR@{num}'] = ar[i]
return ret

if result_type == 'bbox':
if result_type == 'mAP':
pascal_evaluator = det_eval.PascalDetectionEvaluator(categories)

start = time.time()
Expand Down
6 changes: 5 additions & 1 deletion mmaction/datasets/ava_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,14 @@ def dump_results(self, results, out):

def evaluate(self,
results,
metrics=('proposal', 'bbox'),
metrics=('mAP', ),
metric_options=None,
logger=None):
# need to create a temp result file
assert len(metrics) == 1 and metrics[0] == 'mAP', (
'For evaluation on AVADataset, you need to use metrics "mAP" '
'See https://github.com/open-mmlab/mmaction2/pull/567 '
'for more info.')
time_now = datetime.now().strftime('%Y%m%d_%H%M%S')
temp_file = f'AVA_{time_now}_result.csv'
results2csv(self, results, temp_file)
Expand Down
2 changes: 0 additions & 2 deletions tests/test_data/test_datasets/test_ava_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,4 @@ def test_ava_evaluate(self):
[0.106, 0.445, 0.782, 0.673, 0.367]])
]]
res = ava_dataset.evaluate(fake_result)
assert_array_almost_equal(res['[email protected]@100'], 0.33333333)
assert_array_almost_equal(res['AR@100'], 0.15833333)
assert_array_almost_equal(res['[email protected]'], 0.027777778)
7 changes: 1 addition & 6 deletions tests/test_metrics/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,8 @@ def test_ava_detection():
result_path = osp.join(data_prefix, 'pred.csv')
label_map = osp.join(data_prefix, 'action_list.txt')

# eval proposal
detection = ava_eval(result_path, 'proposal', label_map, gt_path, None)
assert_array_almost_equal(detection['[email protected]@100'], 0.41666667)
assert_array_almost_equal(detection['AR@100'], 0.08333333)

# eval bbox
detection = ava_eval(result_path, 'bbox', label_map, gt_path, None)
detection = ava_eval(result_path, 'mAP', label_map, gt_path, None)
assert_array_almost_equal(detection['[email protected]'], 0.09385522)


Expand Down

0 comments on commit 944c5b1

Please sign in to comment.