diff --git a/tests/test_dsm.py b/tests/test_dsm.py new file mode 100644 index 0000000..b684946 --- /dev/null +++ b/tests/test_dsm.py @@ -0,0 +1,26 @@ +import unittest + +from dsm import DeepSurvivalMachines +from dsm import datasets + +import numpy as np + +class TestDSM(unittest.TestCase): + + def test_dsm(self): + + x, t, e = datasets.load_dataset('SUPPORT') + + self.assertIsInstance(x, np.ndarray) + self.assertIsInstance(t, np.ndarray) + self.assertIsInstance(e, np.ndarray) + + self.assertEqual(x.shape, (9105, 44)) + self.assertEqual(t.shape, (9105,)) + self.assertEqual(e.shape, (9105,)) + + model = DeepSurvivalMachines() + self.assertIsInstance(model, dsm.dsm_api.DeepSurvivalMachines) + model.fit(x, t, e, iters=10) + self.assertIsInstance(model.torch_model, + dsm.dsm_torch.DeepSurvivalMachinesTorch)