From f30c8f65706bb3b6e59d3f4aae41f0d057aee690 Mon Sep 17 00:00:00 2001 From: carefree0910 Date: Mon, 14 Oct 2024 15:46:47 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=85Introduced=20`test=5Fresume`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_learn/test_pipeline.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/test_learn/test_pipeline.py b/tests/test_learn/test_pipeline.py index 8765faf..16b7a92 100644 --- a/tests/test_learn/test_pipeline.py +++ b/tests/test_learn/test_pipeline.py @@ -13,6 +13,7 @@ from unittest.mock import patch from unittest.mock import Mock from core.toolkit.misc import random_hash +from core.toolkit.misc import get_latest_workspace from core.learn.schema import losses_type from core.learn.pipeline.blocks.basic import StateInfo from core.learn.pipeline.blocks.basic import OptimizerSettings @@ -255,6 +256,26 @@ def test_self_ensemble(self): with self.assertRaises(RuntimeError): cflearn.PipelineSerializer.self_ensemble_evaluation(5, workspace) + def test_resume(self): + cflearn.seed_everything(142857) + resume_workspace = "_resume" + data, in_dim, out_dim, _ = cflearn.testing.linear_data(use_validation=True) + config = cflearn.Config( + workspace=resume_workspace, + module_name="linear", + module_config=dict(input_dim=in_dim, output_dim=out_dim, bias=False), + loss_name="mse", + num_epoch=10, + tqdm_settings=cflearn.TqdmSettings(use_tqdm=True), + ) + config.to_debug() + cflearn.TrainingPipeline.init(config).fit(data) + config.resume_training_from = ( + get_latest_workspace(resume_workspace) + / cflearn.PipelineSerializer.pipeline_folder + ) + cflearn.TrainingPipeline.init(config).fit(data) + class TestBlocks(unittest.TestCase): def test_basics(self):