Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add text support to the Trainer's TensorBoard integration #34418

Merged
merged 8 commits into from
Nov 4, 2024
2 changes: 2 additions & 0 deletions src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,8 @@ def on_log(self, args, state, control, logs=None, **kwargs):
for k, v in logs.items():
if isinstance(v, (int, float)):
self.tb_writer.add_scalar(k, v, state.global_step)
elif isinstance(v, str):
self.tb_writer.add_text(k, v, state.global_step)
else:
logger.warning(
"Trainer is attempting to log a value of "
Expand Down
20 changes: 18 additions & 2 deletions src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,11 +589,21 @@ def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: Tr
class ProgressCallback(TrainerCallback):
"""
A [`TrainerCallback`] that displays the progress of training or evaluation.
You can modify `max_str_len` to control how long strings are truncated when logging.
"""

def __init__(self):
def __init__(self, max_str_len: int = 100):
"""
Initialize the callback with optional max_str_len parameter to control string truncation length.

Args:
max_str_len (`int`):
Maximum length of strings to display in logs.
Longer strings will be truncated with a message.
"""
self.training_bar = None
self.prediction_bar = None
self.max_str_len = max_str_len

def on_train_begin(self, args, state, control, **kwargs):
if state.is_world_process_zero:
Expand Down Expand Up @@ -631,7 +641,13 @@ def on_log(self, args, state, control, logs=None, **kwargs):
# but avoid doing any value pickling.
shallow_logs = {}
for k, v in logs.items():
shallow_logs[k] = v
if isinstance(v, str) and len(v) > self.max_str_len:
shallow_logs[k] = (
f"[String too long to display, length: {len(v)} > {self.max_str_len}. "
"Consider increasing `max_str_len` if needed.]"
)
else:
shallow_logs[k] = v
_ = shallow_logs.pop("total_flos", None)
# round numbers so that it looks better in console
if "epoch" in shallow_logs:
Expand Down
Loading