Skip to content

Commit

Permalink
✅Introduced test_resume
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed Oct 14, 2024
1 parent 8c58838 commit f30c8f6
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions tests/test_learn/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from unittest.mock import patch
from unittest.mock import Mock
from core.toolkit.misc import random_hash
from core.toolkit.misc import get_latest_workspace
from core.learn.schema import losses_type
from core.learn.pipeline.blocks.basic import StateInfo
from core.learn.pipeline.blocks.basic import OptimizerSettings
Expand Down Expand Up @@ -255,6 +256,26 @@ def test_self_ensemble(self):
with self.assertRaises(RuntimeError):
cflearn.PipelineSerializer.self_ensemble_evaluation(5, workspace)

def test_resume(self):
cflearn.seed_everything(142857)
resume_workspace = "_resume"
data, in_dim, out_dim, _ = cflearn.testing.linear_data(use_validation=True)
config = cflearn.Config(
workspace=resume_workspace,
module_name="linear",
module_config=dict(input_dim=in_dim, output_dim=out_dim, bias=False),
loss_name="mse",
num_epoch=10,
tqdm_settings=cflearn.TqdmSettings(use_tqdm=True),
)
config.to_debug()
cflearn.TrainingPipeline.init(config).fit(data)
config.resume_training_from = (
get_latest_workspace(resume_workspace)
/ cflearn.PipelineSerializer.pipeline_folder
)
cflearn.TrainingPipeline.init(config).fit(data)


class TestBlocks(unittest.TestCase):
def test_basics(self):
Expand Down

0 comments on commit f30c8f6

Please sign in to comment.