Skip to content

Commit

Permalink
add: automated xlmr model calibration (#112)
Browse files Browse the repository at this point in the history
* update: fixed dependency requirements

* add: automated xlmr model calibration - using temperature scaling

Signed-off-by: Biswaroop Bhattacharjee <[email protected]>
  • Loading branch information
biswaroop1547 authored Feb 21, 2022
1 parent 6ae0d8d commit ff25008
Show file tree
Hide file tree
Showing 11 changed files with 571 additions and 63 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
Expand Down
6 changes: 6 additions & 0 deletions dialogy/constants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ class SIGNAL:
NGRAM_RANGE = "ngram_range"
GRIDSEARCH_WORKERS = -1 # -1 means use all available cores

#Calibration Constants
MODEL_CALIBRATION = "model_calibration"
TS_PARAMETER = "ts_parameter"
CALIBRATION_CONFIG_FILE = "calibration_config.json"
TEMPERATURE = "temperature"

# CLI Commands
TRAIN = "train"
TEST = "test"
Expand Down
16 changes: 14 additions & 2 deletions dialogy/plugins/text/classification/xlmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import dialogy.constants as const
from dialogy.base import Guard, Input, Output, Plugin
from dialogy.types import Intent
from dialogy.utils import load_file, logger, save_file
from dialogy.utils import load_file, logger, save_file, read_from_json


class XLMRMultiClass(Plugin):
Expand Down Expand Up @@ -69,6 +69,12 @@ def __init__(
self.labelencoder_file_path = os.path.join(
self.model_dir, const.LABELENCODER_FILE
)
self.ts_parameter: float = read_from_json(
[const.TS_PARAMETER],
model_dir,
const.CALIBRATION_CONFIG_FILE
).get(const.TS_PARAMETER) or 1.0

self.threshold = threshold
self.skip_labels = set(skip_labels or set())
self.purpose = purpose
Expand All @@ -79,7 +85,7 @@ def __init__(
or const.PRODUCTION not in args_map
):
raise ValueError(
f"Attempting to set invalid {args_map=}. "
f"Attempting to set invalid {args_map}. "
"It is missing some of {const.TRAIN}, {const.TEST}, {const.PRODUCTION} in configs."
)
self.args_map = args_map
Expand Down Expand Up @@ -115,6 +121,7 @@ def init_model(self, label_count: Optional[int] = None) -> None:
if self.args_map and self.purpose in self.args_map
else {}
)
self.use_calibration = args.get(const.MODEL_CALIBRATION)
try:
self.model = self.classifier(
const.XLMR_MODEL,
Expand Down Expand Up @@ -176,6 +183,8 @@ def inference(self, texts: Optional[List[str]]) -> List[Intent]:
if not predictions:
return [fallback_output]


logits = logits / self.ts_parameter
confidence_scores = [np.exp(logit) / sum(np.exp(logit)) for logit in logits]
intents_confidence_order = np.argsort(confidence_scores)[0][::-1]
predicted_intents = self.labelencoder.inverse_transform(
Expand All @@ -185,6 +194,9 @@ def inference(self, texts: Optional[List[str]]) -> List[Intent]:
confidence_scores[0][idx] for idx in intents_confidence_order
]

if self.use_calibration:
ordered_confidence_scores = [logits[0][idx] for idx in np.argsort(logits)[0][::-1]] # ordered logits for calibration

return [
Intent(name=intent, score=round(score, self.round)).add_parser(
self.__class__.__name__
Expand Down
9 changes: 8 additions & 1 deletion dialogy/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,15 @@
make_unix_ts,
unix_ts_to_datetime,
)
from dialogy.utils.file_handler import create_timestamps_path, load_file, save_file
from dialogy.utils.file_handler import (
create_timestamps_path,
load_file,
save_file,
read_from_json,
save_to_json
)
from dialogy.utils.logger import logger
from dialogy.utils.misc import traverse_dict, validate_type
from dialogy.utils.naive_lang_detect import lang_detect_from_text
from dialogy.utils.normalize_utterance import is_utterance, normalize
from dialogy.utils.temperature_scaling import fit_ts_parameter, save_reliability_graph
27 changes: 26 additions & 1 deletion dialogy/utils/file_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import json
from json.decoder import JSONDecodeError
from datetime import datetime
from typing import Any, Optional
from typing import Any, Dict, Optional, List


def load_file(
Expand Down Expand Up @@ -61,6 +63,29 @@ def save_file(
with open(file_path, mode, encoding=encoding, newline=newline) as file:
_ = file.write(content) if not writer else writer(content, file)

def read_from_json(params: List[str], dir_path: str, file_name: str) -> Dict[str, Any]:
full_path = os.path.join(dir_path, file_name)
req_config = {}
if os.path.exists(full_path):
with open(full_path, "r") as json_file:
config_ = json.load(json_file)
req_config = {param:config_.get(param) for param in params}
return req_config

def save_to_json(params: Dict[str, Any], dir_path: str, file_name: str) -> None:
full_path = os.path.join(dir_path, file_name)
existing_config = {}
if os.path.exists(full_path):
try:
with open(full_path, "r") as json_file:
existing_config = json.load(json_file)
for key, val in params.items():
existing_config[key] = val
except JSONDecodeError:
print(f"Failed to load json file {full_path}, writing on newly created file.")
params = existing_config or params
with open(full_path, "w") as json_file:
json.dump(params, json_file, indent = 1, ensure_ascii = False)

def create_timestamps_path(
directory: str,
Expand Down
117 changes: 117 additions & 0 deletions dialogy/utils/temperature_scaling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from tqdm import tqdm
import os
from typing import Any, Dict, Optional, List, Tuple, Union
import numpy.typing as npt
from torch import Tensor

import torch
import torch.nn as nn
import torch.optim as optim

import dialogy.constants as const
from dialogy.utils import logger


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def calc_bins(preds: npt.NDArray[np.float64], labels_oneh: npt.NDArray[np.float64]) -> Any:
# Assign each prediction to a bin
num_bins = 10
bins = np.linspace(0.1, 1, num_bins)
binned = np.digitize(preds, bins)

# Save the accuracy, confidence and size of each bin
bin_accs = np.zeros(num_bins)
bin_confs = np.zeros(num_bins)
bin_sizes = np.zeros(num_bins)

for bin in range(num_bins):
bin_sizes[bin] = len(preds[binned == bin])
if bin_sizes[bin] > 0:
bin_accs[bin] = (labels_oneh[binned==bin]).sum() / bin_sizes[bin]
bin_confs[bin] = (preds[binned==bin]).sum() / bin_sizes[bin]

return bins, binned, bin_accs, bin_confs, bin_sizes


def get_metrics(preds: npt.NDArray[np.float64], labels_oneh: npt.NDArray[np.float64]) -> Tuple[float, float]:
ECE = 0
MCE = 0
bins, _, bin_accs, bin_confs, bin_sizes = calc_bins(preds, labels_oneh)

for i in range(len(bins)):
abs_conf_dif = abs(bin_accs[i] - bin_confs[i])
ECE += (bin_sizes[i] / sum(bin_sizes)) * abs_conf_dif
MCE = max(MCE, abs_conf_dif)

return ECE, MCE


def save_reliability_graph(preds: npt.NDArray[np.float64], labels_oneh: npt.NDArray[np.float64], dir_path: str, prefix: str) -> None:
ECE, MCE = get_metrics(preds, labels_oneh)
bins, _, bin_accs, _, _ = calc_bins(preds, labels_oneh)

fig = plt.figure(figsize=(8, 8))
ax = fig.gca()

# x/y limits
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1)

# x/y labels
plt.xlabel('Confidence')
plt.ylabel('Accuracy')

# Create grid
ax.set_axisbelow(True)
ax.grid(color='gray', linestyle='dashed')

# Error bars
plt.bar(bins, bins, width=0.1, alpha=0.3, edgecolor='black', color='r', hatch='\\')

# Draw bars and identity line
plt.bar(bins, bin_accs, width=0.1, alpha=1, edgecolor='black', color='b')
plt.plot([0,1],[0,1], '--', color='gray', linewidth=2)

# Equally spaced axes
plt.gca().set_aspect('equal', adjustable='box')

# ECE and MCE legend
ECE_patch = mpatches.Patch(color='green', label='ECE = {:.2f}%'.format(ECE*100))
MCE_patch = mpatches.Patch(color='red', label='MCE = {:.2f}%'.format(MCE*100))
plt.legend(handles=[ECE_patch, MCE_patch])

plt.savefig(os.path.join(dir_path, f'{prefix}_reliability_graph.png'), bbox_inches='tight')


def T_scaling(logits: Tensor, temperature: Tensor) -> Tensor:
return torch.div(logits, temperature)


def fit_ts_parameter(
logits_list: npt.NDArray[np.float64],
labels_list: npt.NDArray[np.int64],
lr: float = 0.001,
max_iter: int =10000
) -> float:
logits_tensor = torch.from_numpy(logits_list).to(DEVICE)
labels_tensor = torch.from_numpy(labels_list).to(DEVICE)
temperature = nn.Parameter(torch.ones(1).to(DEVICE))
criterion = nn.CrossEntropyLoss()
optimizer = optim.LBFGS(
[temperature],
lr=lr,
max_iter=max_iter,
line_search_fn='strong_wolfe'
)

def _eval() -> Any:
loss = criterion(T_scaling(logits_tensor, temperature), labels_tensor)
loss.backward()
return loss

optimizer.step(_eval)
return round(temperature.item(), 4)
Loading

0 comments on commit ff25008

Please sign in to comment.