diff --git a/CHANGELOG.md b/CHANGELOG.md index fe3b9ff9..2b7703f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,11 @@ Changelog ========= +Version 0.73.0 +-------------- +* added reverse and scale arguments to target variable +* also, the data store can now be csv + Version 0.72.0 -------------- * worked over explore value counts section diff --git a/nkululeko/constants.py b/nkululeko/constants.py index 77eb5ec1..624927a5 100644 --- a/nkululeko/constants.py +++ b/nkululeko/constants.py @@ -1,2 +1,2 @@ -VERSION="0.72.0" +VERSION="0.73.0" SAMPLING_RATE = 16000 diff --git a/nkululeko/data/dataset.py b/nkululeko/data/dataset.py index 2d40e03f..f8d979e2 100644 --- a/nkululeko/data/dataset.py +++ b/nkululeko/data/dataset.py @@ -3,14 +3,14 @@ import os import os.path from random import sample +import numpy as np +import pandas as pd import audformat +from audformat.utils import duration + import nkululeko.filter_data as filter import nkululeko.glob_conf as glob_conf - -# import audb -import pandas as pd -from audformat.utils import duration from nkululeko.filter_data import DataFilter from nkululeko.plots import Plots from nkululeko.reporting.report_item import ReportItem @@ -97,11 +97,12 @@ def load(self): """Load the dataframe with files, speakers and task labels""" # store the dataframe store = self.util.get_path("store") - store_file = f"{store}{self.name}.pkl" + store_file = f"{store}{self.name}" + store_format = self.util.config_val("FEATS", "store_format", "pkl") self.root = self._load_db() if not self.start_fresh and os.path.isfile(store_file): self.util.debug(f"{self.name}: reusing previously stored file {store_file}") - self.df = pd.read_pickle(store_file) + self.df = self.util.get_store(store_file, store_format) self.is_labeled = self.target in self.df self.got_gender = "gender" in self.df self.got_age = "age" in self.df @@ -175,6 +176,9 @@ def prepare(self): # ensure segmented index self.df = self.util.make_segmented_index(self.df) self.util.copy_flags(self, self.df) + # check the type of numeric targets + if not self.util.exp_is_classification(): + self.df[self.target] = self.df[self.target].astype(float) # add duration if "duration" not in self.df: start = self.df.index.get_level_values(1) @@ -202,10 +206,36 @@ def prepare(self): # we might need to append the database name to all speakers in case other databases have the same speaker names self.df.speaker = self.df.speaker.apply(lambda x: self.name + x) + # check if the target variable should be reversed + def reverse_array(d): + d = np.array(d) + max = d.max() + res = [] + for n in d: + res.append(abs(n - max)) + return res + + reverse = self.util.config_val_data(self.name, "reverse", False) + if reverse: + self.util.debug("reversing target numbers") + self.df[self.target] = reverse_array(self.df[self.target].values) + + # check if the target variable should be scaled (z-transformed) + scale = self.util.config_val_data(self.name, "scale", False) + if scale: + from sklearn.preprocessing import StandardScaler + + self.util.debug("scaling target variable to normal distribution") + scaler = StandardScaler() + self.df[self.target] = scaler.fit_transform( + self.df[self.target].values.reshape(-1, 1) + ) + # store the dataframe store = self.util.get_path("store") - store_file = f"{store}{self.name}.pkl" - self.df.to_pickle(store_file) + store_format = self.util.config_val("FEATS", "store_format", "pkl") + store_file = f"{store}{self.name}" + self.util.write_store(self.df, store_file, store_format) def _get_df_for_lists(self, db, df_files): is_labeled, got_speaker, got_gender, got_age = ( diff --git a/nkululeko/experiment.py b/nkululeko/experiment.py index e99501b2..f1618870 100644 --- a/nkululeko/experiment.py +++ b/nkululeko/experiment.py @@ -557,18 +557,18 @@ def analyse_features(self, needs_feats): ) def _check_scale(self): - scale = self.util.config_val("FEATS", "scale", False) + scale_feats = self.util.config_val("FEATS", "scale", False) # print the scale - self.util.debug(f"scaler: {scale}") - if scale: - self.scaler = Scaler( + self.util.debug(f"scaler: {scale_feats}") + if scale_feats: + self.scaler_feats = Scaler( self.df_train, self.df_test, self.feats_train, self.feats_test, - scale, + scale_feats, ) - self.feats_train, self.feats_test = self.scaler.scale() + self.feats_train, self.feats_test = self.scaler_feats.scale() def init_runmanager(self): """Initialize the manager object for the runs.""" diff --git a/nkululeko/reporter.py b/nkululeko/reporter.py index 6eaf6f22..741c6537 100644 --- a/nkululeko/reporter.py +++ b/nkululeko/reporter.py @@ -5,9 +5,15 @@ import matplotlib.pyplot as plt import numpy as np from scipy.stats import pearsonr -from sklearn.metrics import (ConfusionMatrixDisplay, accuracy_score, - classification_report, confusion_matrix, - mean_squared_error, r2_score, recall_score) +from sklearn.metrics import ( + ConfusionMatrixDisplay, + accuracy_score, + classification_report, + confusion_matrix, + mean_squared_error, + r2_score, + recall_score, +) from sklearn.utils import resample import nkululeko.glob_conf as glob_conf @@ -160,6 +166,8 @@ def _plot_confmat(self, truths, preds, plot_name, epoch): ) res_dir = self.util.get_path("res_dir") + uar = int(uar * 1000) / 1000.0 + acc = int(acc * 1000) / 1000.0 rpt = f"epoch: {epoch}, UAR: {uar}, ACC: {acc}" # print(rpt) self.util.debug(rpt) diff --git a/nkululeko/scaler.py b/nkululeko/scaler.py index 2825b52a..02dc6454 100644 --- a/nkululeko/scaler.py +++ b/nkululeko/scaler.py @@ -8,7 +8,7 @@ class Scaler: """ - class to normalize speech parameters + class to normalize speech features """ def __init__( @@ -68,13 +68,9 @@ def scale_df(self, df): return df def speaker_scale(self): - self.feats_train = self.speaker_scale_df( - self.data_train, self.feats_train - ) + self.feats_train = self.speaker_scale_df(self.data_train, self.feats_train) if self.feats_test is not None: - self.feats_test = self.speaker_scale_df( - self.data_test, self.feats_test - ) + self.feats_test = self.speaker_scale_df(self.data_test, self.feats_test) return [self.feats_train, self.feats_test] def speaker_scale_df(self, df, feats_df):