Skip to content

Commit

Permalink
overhaul boxplots once more
Browse files Browse the repository at this point in the history
  • Loading branch information
Flux9665 committed Oct 7, 2024
1 parent 38ca97c commit 335a8ba
Showing 1 changed file with 24 additions and 25 deletions.
49 changes: 24 additions & 25 deletions Preprocessing/multilinguality/eval_lang_emb_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import torch
from huggingface_hub import hf_hub_download

# matplotlib.rcParams['mathtext.fontset'] = 'stix'
# matplotlib.rcParams['font.family'] = 'STIXGeneral'
matplotlib.rcParams['mathtext.fontset'] = 'stix'
matplotlib.rcParams['font.family'] = 'STIXGeneral'
matplotlib.rcParams['font.size'] = 7
import matplotlib.pyplot as plt
from Utility.utils import load_json_from_path


def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embeddings, weighted_avg=False, min_n_langs=5, max_n_langs=30, threshold_percentile=95, loss_fn="MSE"):
df = pd.read_csv(csv_path, sep="|")

Expand All @@ -23,7 +24,7 @@ def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embe

features_per_closest_lang = 2
# for combined, df has up to 5 features (if containing individual distances) per closest lang + 1 target lang column
if "combined_dist_0" in df.columns:
if "combined_dist_0" in df.columns:
if "map_dist_0" in df.columns:
features_per_closest_lang += 1
if "asp_dist_0" in df.columns:
Expand Down Expand Up @@ -77,7 +78,7 @@ def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embe
lang_emb = language_embeddings[iso_lookup[-1][lang]]
avg_emb += lang_emb
normalization_factor = len(langs)
avg_emb /= normalization_factor # normalize
avg_emb /= normalization_factor # normalize
current_loss = loss_fn(avg_emb, y).item()
all_losses.append(current_loss)

Expand Down Expand Up @@ -111,44 +112,42 @@ def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embe
os.makedirs(OUT_DIR, exist_ok=True)

fig, ax = plt.subplots(figsize=(6, 4))
plt.ylabel(f"{args.loss_fn} between Approximated and Real")
plt.ylabel(args.loss_fn)
for i, csv_path in enumerate(csv_paths):
print(f"csv_path: {os.path.basename(csv_path)}")
for condition in weighted:
losses = compute_loss_for_approximated_embeddings(csv_path,
iso_lookup,
lang_embs,
condition,
min_n_langs=args.min_n_langs,
max_n_langs=args.max_n_langs,
threshold_percentile=args.threshold_percentile,
loss_fn=args.loss_fn)
losses = compute_loss_for_approximated_embeddings(csv_path,
iso_lookup,
lang_embs,
condition,
min_n_langs=args.min_n_langs,
max_n_langs=args.max_n_langs,
threshold_percentile=args.threshold_percentile,
loss_fn=args.loss_fn)
print(f"weighted average: {condition} | mean loss: {np.mean(losses)}")
losses_of_multiple_datasets.append(losses)

bp_dict = ax.boxplot(losses_of_multiple_datasets,
labels =[
"Random Neighbors",
"Nearest according \nto inverse ASPF",
"Nearest according \nto Map Distance",
"Nearest according \nto Tree Distance",
"Nearest according \nto Learned Distance",
"Actual Nearest\n(Oracle)",
],
labels=["Random",
"Inverse ASP",
"Map Distance",
"Tree Distance",
"Learned Distance",
"Oracle"],
patch_artist=True,
boxprops=dict(facecolor = "lightblue",
boxprops=dict(facecolor="lightblue",
),
showfliers=False,
showfliers=False,
widths=0.55
)
)
# major ticks every 0.1, minor ticks every 0.05, between 0.0 and 0.6
major_ticks = np.arange(0, 1.0, 0.1)
minor_ticks = np.arange(0, 1.0, 0.05)
ax.set_yticks(major_ticks)
ax.set_yticks(minor_ticks, minor=True)
# horizontal grid lines for minor and major ticks
ax.grid(which='both', linestyle='-', color='lightgray', linewidth=0.3, axis='y')
plt.title(f"Using between {args.min_n_langs} and {args.max_n_langs} Nearest Neighbors to approximate an unseen Embedding")
# plt.title(f"Using between {args.min_n_langs} and {args.max_n_langs} Nearest Neighbors to approximate an unseen Embedding")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
Expand Down

0 comments on commit 335a8ba

Please sign in to comment.