From 0409350bab67de6617e6d71a6c5bf9e1d75bf785 Mon Sep 17 00:00:00 2001 From: Gourav Chowdhary Date: Tue, 19 Jan 2021 16:49:37 +0530 Subject: [PATCH] Update demo.py If the probability of any label is less than 0.5 than we can show both the best and the second best prediction using above code. --- demo.py | 57 +++++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 49 insertions(+), 8 deletions(-) diff --git a/demo.py b/demo.py index 5314b4e8c9..b5bc9c20c0 100755 --- a/demo.py +++ b/demo.py @@ -62,32 +62,73 @@ def demo(opt): else: preds = model(image, text_for_pred, is_train=False) - # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) + print(preds_str) + + preds_prob = F.softmax(preds, dim=2) + preds_max_prob, preds_index_prob = preds_prob.max(dim=2) + print(preds_max_prob) + print(preds_max_prob.type) + index=np.zeros((11,26)) + p=0 + k=0 + for i in preds_max_prob: + counter=0 + for j in i: + if j<0.5: + index[p][k]=counter + k+=1 + counter+=1 + k=0 + p+=1 + + p=0 + k=0 + + for i in index: + for j in i: + if index[p][k]!= 0: + z=preds_prob[p][k] + q=z.numpy() + r=q.tolist() + sorted_index=[r.index(x) for x in sorted(r)[:37]] + second_best_index=sorted_index[36] + preds_index_prob[p][k]=second_best_index + k+=1 + k=0 + p+=1 + + + preds_str_second_best = converter.decode(preds_index_prob, length_for_pred) + print(preds_str_second_best) + log = open(f'./log_demo_result.txt', 'a') - dashed_line = '-' * 80 - head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' + dashed_line = '-' * 140 + head = f'{"image_path":25s}\t{"Best predicted_labels":25s}\t{"Second best predicted labels":25s}\tconfidence score' print(f'{dashed_line}\n{head}\n{dashed_line}') log.write(f'{dashed_line}\n{head}\n{dashed_line}\n') - preds_prob = F.softmax(preds, dim=2) - preds_max_prob, _ = preds_prob.max(dim=2) - for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): + + for img_name, pred, pred_second, pred_max_prob in zip(image_path_list, preds_str, preds_str_second_best, preds_max_prob): if 'Attn' in opt.Prediction: pred_EOS = pred.find('[s]') pred = pred[:pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] + pred_EOS_second = pred_second.find('[s]') + pred_second = pred_second[:pred_EOS_second] # prune after "end of sentence" token ([s]) + pred_max_prob_second = pred_max_prob[:pred_EOS_second] + # calculate confidence score (= multiply of pred_max_prob) confidence_score = pred_max_prob.cumprod(dim=0)[-1] - print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}') - log.write(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n') + print(f'{img_name:25s}\t{pred:25s}\t{pred_second:25s}\t{confidence_score:0.4f}') + log.write(f'{img_name:25s}\t{pred:25s}\t{pred_second:25s}\t{confidence_score:0.4f}\n') log.close()