diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index a09116552c8e..be9a4aff3c7e 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -697,6 +697,8 @@ def on_log(self, args, state, control, logs=None, **kwargs): for k, v in logs.items(): if isinstance(v, (int, float)): self.tb_writer.add_scalar(k, v, state.global_step) + elif isinstance(v, str): + self.tb_writer.add_text(k, v, state.global_step) else: logger.warning( "Trainer is attempting to log a value of " diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index ce9f2a26732c..cf9a83aa188a 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -589,11 +589,21 @@ def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: Tr class ProgressCallback(TrainerCallback): """ A [`TrainerCallback`] that displays the progress of training or evaluation. + You can modify `max_str_len` to control how long strings are truncated when logging. """ - def __init__(self): + def __init__(self, max_str_len: int = 100): + """ + Initialize the callback with optional max_str_len parameter to control string truncation length. + + Args: + max_str_len (`int`): + Maximum length of strings to display in logs. + Longer strings will be truncated with a message. + """ self.training_bar = None self.prediction_bar = None + self.max_str_len = max_str_len def on_train_begin(self, args, state, control, **kwargs): if state.is_world_process_zero: @@ -631,7 +641,13 @@ def on_log(self, args, state, control, logs=None, **kwargs): # but avoid doing any value pickling. shallow_logs = {} for k, v in logs.items(): - shallow_logs[k] = v + if isinstance(v, str) and len(v) > self.max_str_len: + shallow_logs[k] = ( + f"[String too long to display, length: {len(v)} > {self.max_str_len}. " + "Consider increasing `max_str_len` if needed.]" + ) + else: + shallow_logs[k] = v _ = shallow_logs.pop("total_flos", None) # round numbers so that it looks better in console if "epoch" in shallow_logs: