Skip to content

Commit

Permalink
Add mlflow support
Browse files Browse the repository at this point in the history
  • Loading branch information
belerico committed Dec 19, 2023
1 parent feb00aa commit 375840d
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion sheeprl/algos/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}")

# Transform the data into PyTorch Tensors
local_data = rb.to_tensor(dtype=None, device=device)
local_data = rb.to_tensor(dtype=None, device=device, from_numpy=cfg.buffer.from_numpy)

# Estimate returns with GAE (https://arxiv.org/abs/1506.02438)
with torch.no_grad():
Expand Down Expand Up @@ -348,3 +348,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
envs.close()
if fabric.is_global_zero:
test(agent.module, fabric, cfg, log_dir)

if not cfg.model_manager.disabled and fabric.is_global_zero:
from sheeprl.algos.ppo.utils import log_models
from sheeprl.utils.mlflow import register_model

models_to_log = {"agent": agent}
register_model(fabric, log_models, cfg, models_to_log)

0 comments on commit 375840d

Please sign in to comment.