Skip to content

Commit

Permalink
Custom baseline (#74)
Browse files Browse the repository at this point in the history
* add baseline_path

* update cli option name & plotting

* fix typo
  • Loading branch information
felixgwu authored Aug 24, 2020
1 parent d0156eb commit b15f586
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 23 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

Automatic Evaluation Metric described in the paper [BERTScore: Evaluating Text Generation with BERT](https://arxiv.org/abs/1904.09675) (ICLR 2020).
#### News:
- The option `--rescale-with-baseline` is changed to `--rescale_with_baseline` so that it is consistent with other options.
- Updated to version 0.3.5
- Being compatible with Huggingface's transformers >=v3.0.0 and minor fixes ([#58](https://github.com/Tiiiger/bert_score/pull/58), [#66](https://github.com/Tiiiger/bert_score/pull/66), [#68](https://github.com/Tiiiger/bert_score/pull/68))
- Several improvements related to efficency ([#67](https://github.com/Tiiiger/bert_score/pull/67), [#69](https://github.com/Tiiiger/bert_score/pull/69))
Expand Down Expand Up @@ -116,7 +117,7 @@ where "roberta-large_L17_no-idf_version=0.3.0(hug_trans=2.3.0)" is the hash code
Starting from version 0.3.0, we support rescaling the scores with baseline scores

```sh
bert-score -r example/refs.txt -c example/hyps.txt --lang en --rescale-with-baseline
bert-score -r example/refs.txt -c example/hyps.txt --lang en --rescale_with_baseline
```
You will get:

Expand Down
15 changes: 11 additions & 4 deletions bert_score/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def score(
lang=None,
return_hash=False,
rescale_with_baseline=False,
baseline_path=None,
):
"""
BERTScore metric.
Expand All @@ -64,6 +65,7 @@ def score(
specified when `rescale_with_baseline` is True.
- :param: `return_hash` (bool): return hash code of the setting
- :param: `rescale_with_baseline` (bool): rescale bertscore with pre-computed baseline
- :param: `baseline_path` (str): customized baseline file
Return:
- :param: `(P, R, F)`: each is of shape (N); N = number of input
Expand Down Expand Up @@ -145,8 +147,10 @@ def score(
max_preds.append(all_preds[beg:end].max(dim=0)[0])
all_preds = torch.stack(max_preds, dim=0)

use_custom_baseline = baseline_path is not None
if rescale_with_baseline:
baseline_path = os.path.join(os.path.dirname(__file__), f"rescale_baseline/{lang}/{model_type}.tsv")
if baseline_path is None:
baseline_path = os.path.join(os.path.dirname(__file__), f"rescale_baseline/{lang}/{model_type}.tsv")
if os.path.isfile(baseline_path):
if not all_layers:
baselines = torch.from_numpy(pd.read_csv(baseline_path).iloc[num_layers].to_numpy())[1:].float()
Expand All @@ -164,13 +168,15 @@ def score(
print(f"done in {time_diff:.2f} seconds, {len(refs) / time_diff:.2f} sentences/sec")

if return_hash:
return tuple([out, get_hash(model_type, num_layers, idf, rescale_with_baseline)])
return tuple([out, get_hash(model_type, num_layers, idf, rescale_with_baseline,
use_custom_baseline=use_custom_baseline)])

return out


def plot_example(
candidate, reference, model_type=None, num_layers=None, lang=None, rescale_with_baseline=False, fname=""
candidate, reference, model_type=None, num_layers=None, lang=None, rescale_with_baseline=False,
baseline_path=None, fname="",
):
"""
BERTScore metric.
Expand Down Expand Up @@ -234,7 +240,8 @@ def plot_example(
sim = sim[1:-1, 1:-1]

if rescale_with_baseline:
baseline_path = os.path.join(os.path.dirname(__file__), f"rescale_baseline/{lang}/{model_type}.tsv")
if baseline_path is None:
baseline_path = os.path.join(os.path.dirname(__file__), f"rescale_baseline/{lang}/{model_type}.tsv")
if os.path.isfile(baseline_path):
baselines = torch.from_numpy(pd.read_csv(baseline_path).iloc[num_layers].to_numpy())[1:].float()
sim = (sim - baselines[2].item()) / (1 - baselines[2].item())
Expand Down
37 changes: 24 additions & 13 deletions bert_score/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
device=None,
lang=None,
rescale_with_baseline=False,
baseline_path=None,
):
"""
Args:
Expand All @@ -51,8 +52,8 @@ def __init__(
- :param: `num_layers` (int): the layer of representation to use.
default using the number of layer tuned on WMT16 correlation data
- :param: `verbose` (bool): turn on intermediate status update
- :param: `idf` (dict): use idf weighting, can also be a precomputed idf_dict
- :param: `idf_sents` (List of str): use idf weighting, can also be a precomputed idf_dict
- :param: `idf` (bool): a booling to specify whether to use idf or not (this should be True even if `idf_sents` is given)
- :param: `idf_sents` (List of str): list of sentences used to compute the idf weights
- :param: `device` (str): on which the contextual embedding model will be allocated on.
If this argument is None, the model lives on cuda:0 if cuda is available.
- :param: `batch_size` (int): bert score processing batch size
Expand All @@ -62,6 +63,7 @@ def __init__(
specified when `rescale_with_baseline` is True.
- :param: `return_hash` (bool): return hash code of the setting
- :param: `rescale_with_baseline` (bool): rescale bertscore with pre-computed baseline
- :param: `baseline_path` (str): customized baseline file
"""

assert lang is not None or model_type is not None, "Either lang or model_type should be specified"
Expand Down Expand Up @@ -106,6 +108,12 @@ def __init__(
if idf_sents is not None:
self.compute_idf(idf_sents)

self._baseline_vals = None
self.baseline_path = baseline_path
self.use_custom_baseline = self.baseline_path is not None
if self.baseline_path is None:
self.baseline_path = os.path.join(os.path.dirname(__file__), f"rescale_baseline/{self.lang}/{self.model_type}.tsv")

@property
def lang(self):
return self._lang
Expand All @@ -128,22 +136,25 @@ def rescale_with_baseline(self):

@property
def baseline_vals(self):
baseline_path = os.path.join(os.path.dirname(__file__), f"rescale_baseline/{self.lang}/{self.model_type}.tsv")
if os.path.isfile(baseline_path):
if not self.all_layers:
baseline_vals = torch.from_numpy(pd.read_csv(baseline_path).iloc[self.num_layers].to_numpy())[
1:
].float()
if self._baseline_vals is None:
if os.path.isfile(self.baseline_path):
if not self.all_layers:
self._baseline_vals = torch.from_numpy(
pd.read_csv(self.baseline_path).iloc[self.num_layers].to_numpy()
)[1:].float()
else:
self._baseline_vals = torch.from_numpy(
pd.read_csv(self.baseline_path).to_numpy()
)[:, 1:].unsqueeze(1).float()
else:
baseline_vals = torch.from_numpy(pd.read_csv(baseline_path).to_numpy())[:, 1:].unsqueeze(1).float()
else:
raise ValueError(f"Baseline not Found for {self.model_type} on {self.lang} at {baseline_path}")
raise ValueError(
f"Baseline not Found for {self.model_type} on {self.lang} at {self.baseline_path}")

return baseline_vals
return self._baseline_vals

@property
def hash(self):
return get_hash(self.model_type, self.num_layers, self.idf, self.rescale_with_baseline)
return get_hash(self.model_type, self.num_layers, self.idf, self.rescale_with_baseline, self.use_custom_baseline)

def compute_idf(self, sents):
"""
Expand Down
7 changes: 5 additions & 2 deletions bert_score/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,12 +437,15 @@ def length_to_mask(lens):
return preds


def get_hash(model, num_layers, idf, rescale_with_baseline):
def get_hash(model, num_layers, idf, rescale_with_baseline, use_custom_baseline):
msg = "{}_L{}{}_version={}(hug_trans={})".format(
model, num_layers, "_idf" if idf else "_no-idf", __version__, trans_version
)
if rescale_with_baseline:
msg += "-rescaled"
if use_custom_baseline:
msg += "-custom-rescaled"
else:
msg += "-rescaled"
return msg


Expand Down
4 changes: 3 additions & 1 deletion bert_score_cli/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ def main():
parser.add_argument("--nthreads", type=int, default=4, help="number of cpu workers (default: 4)")
parser.add_argument("--idf", action="store_true", help="BERT Score with IDF scaling")
parser.add_argument(
"--rescale-with-baseline", action="store_true", help="Rescaling the numerical score with precomputed baselines"
"--rescale_with_baseline", action="store_true", help="Rescaling the numerical score with precomputed baselines"
)
parser.add_argument("--baseline_path", default=None, type=str, help="path of custom baseline csv file")
parser.add_argument("-s", "--seg_level", action="store_true", help="show individual score of each pair")
parser.add_argument("-v", "--verbose", action="store_true", help="increase output verbosity")
parser.add_argument("-r", "--ref", type=str, nargs="+", required=True, help="reference file path(s) or a string")
Expand Down Expand Up @@ -65,6 +66,7 @@ def main():
lang=args.lang,
return_hash=True,
rescale_with_baseline=args.rescale_with_baseline,
baseline_path=args.baseline_path,
)
avg_scores = [s.mean(dim=0) for s in all_preds]
P = avg_scores[0].cpu().item()
Expand Down
4 changes: 3 additions & 1 deletion bert_score_cli/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ def main():
parser.add_argument("-c", "--cand", type=str, required=True, help="candidate sentence")
parser.add_argument("-f", "--file", type=str, default="visualize.png", help="name of file to save output matrix in")
parser.add_argument(
"--rescale-with-baseline", action="store_true", help="Rescaling the numerical score with precomputed baselines"
"--rescale_with_baseline", action="store_true", help="Rescaling the numerical score with precomputed baselines"
)
parser.add_argument("--baseline_path", default=None, type=str, help="path of custom baseline csv file")

args = parser.parse_args()

Expand All @@ -33,6 +34,7 @@ def main():
num_layers=args.num_layers,
fname=args.file,
rescale_with_baseline=args.rescale_with_baseline,
baseline_path=args.baseline_path,
)


Expand Down
2 changes: 1 addition & 1 deletion journal/rescale_baseline.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ out = bert_score.score(
and for the command-line version:
```bash
bert-score -r example/refs.txt -c example/hyps.txt \
--lang en --rescale-with-baseline
--lang en --rescale_with_baseline
```


Expand Down

0 comments on commit b15f586

Please sign in to comment.