From 7f4be1d752e6390a6fc26546dbcc32c5f038fdd1 Mon Sep 17 00:00:00 2001 From: Chirag Nagpal Date: Sat, 12 Feb 2022 19:46:20 -0500 Subject: [PATCH] modified: tests/test_dsm.py --- tests/test_dsm.py | 126 +++++++++++++++++++++++----------------------- 1 file changed, 63 insertions(+), 63 deletions(-) diff --git a/tests/test_dsm.py b/tests/test_dsm.py index 1c2ecc3..34ee980 100644 --- a/tests/test_dsm.py +++ b/tests/test_dsm.py @@ -11,74 +11,74 @@ import numpy as np class TestDSM(unittest.TestCase): - """Base Class for all test functions""" - def test_support_dataset(self): - """Test function to load and test the SUPPORT dataset. - """ - - x, t, e = datasets.load_dataset('SUPPORT') - t_median = np.median(t[e==1]) - - self.assertIsInstance(x, np.ndarray) - self.assertIsInstance(t, np.ndarray) - self.assertIsInstance(e, np.ndarray) - - self.assertEqual(x.shape, (9105, 44)) - self.assertEqual(t.shape, (9105,)) - self.assertEqual(e.shape, (9105,)) - - model = DeepSurvivalMachines() - self.assertIsInstance(model, DeepSurvivalMachines) - model.fit(x, t, e, iters=10) - self.assertIsInstance(model.torch_model, - DeepSurvivalMachinesTorch) - risk_score = model.predict_risk(x, t_median) - survival_probability = model.predict_survival(x, t_median) - np.testing.assert_equal((risk_score+survival_probability).all(), 1.0) + """Base Class for all test functions""" + def test_support_dataset(self): + """Test function to load and test the SUPPORT dataset. + """ + + x, t, e = datasets.load_dataset('SUPPORT') + t_median = np.median(t[e==1]) + + self.assertIsInstance(x, np.ndarray) + self.assertIsInstance(t, np.ndarray) + self.assertIsInstance(e, np.ndarray) + + self.assertEqual(x.shape, (9105, 44)) + self.assertEqual(t.shape, (9105,)) + self.assertEqual(e.shape, (9105,)) + + model = DeepSurvivalMachines() + self.assertIsInstance(model, DeepSurvivalMachines) + model.fit(x, t, e, iters=10) + self.assertIsInstance(model.torch_model, + DeepSurvivalMachinesTorch) + risk_score = model.predict_risk(x, t_median) + survival_probability = model.predict_survival(x, t_median) + np.testing.assert_equal((risk_score+survival_probability).all(), 1.0) def test_pbc_dataset(self): - """Test function to load and test the PBC dataset. - """ + """Test function to load and test the PBC dataset. + """ - x, t, e = datasets.load_dataset('PBC') - t_median = np.median(t[e==1]) + x, t, e = datasets.load_dataset('PBC') + t_median = np.median(t[e==1]) - self.assertIsInstance(x, np.ndarray) - self.assertIsInstance(t, np.ndarray) - self.assertIsInstance(e, np.ndarray) + self.assertIsInstance(x, np.ndarray) + self.assertIsInstance(t, np.ndarray) + self.assertIsInstance(e, np.ndarray) - self.assertEqual(x.shape, (1945, 25)) - self.assertEqual(t.shape, (1945,)) - self.assertEqual(e.shape, (1945,)) + self.assertEqual(x.shape, (1945, 25)) + self.assertEqual(t.shape, (1945,)) + self.assertEqual(e.shape, (1945,)) - model = DeepSurvivalMachines() - self.assertIsInstance(model, DeepSurvivalMachines) - model.fit(x, t, e, iters=10) - self.assertIsInstance(model.torch_model, - DeepSurvivalMachinesTorch) - risk_score = model.predict_risk(x, t_median) - survival_probability = model.predict_survival(x, t_median) - np.testing.assert_equal((risk_score+survival_probability).all(), 1.0) + model = DeepSurvivalMachines() + self.assertIsInstance(model, DeepSurvivalMachines) + model.fit(x, t, e, iters=10) + self.assertIsInstance(model.torch_model, + DeepSurvivalMachinesTorch) + risk_score = model.predict_risk(x, t_median) + survival_probability = model.predict_survival(x, t_median) + np.testing.assert_equal((risk_score+survival_probability).all(), 1.0) def test_framingham_dataset(self): - """Test function to load and test the Framingham dataset. - """ - x, t, e = datasets.load_dataset('FRAMINGHAM') - t_median = np.median(t) - - self.assertIsInstance(x, np.ndarray) - self.assertIsInstance(t, np.ndarray) - self.assertIsInstance(e, np.ndarray) - - self.assertEqual(x.shape, (11627, 18)) - self.assertEqual(t.shape, (11627,)) - self.assertEqual(e.shape, (11627,)) - - model = DeepSurvivalMachines() - self.assertIsInstance(model, DeepSurvivalMachines) - model.fit(x, t, e, iters=10) - self.assertIsInstance(model.torch_model, - DeepSurvivalMachinesTorch) - risk_score = model.predict_risk(x, t_median) - survival_probability = model.predict_survival(x, t_median) - np.testing.assert_equal((risk_score+survival_probability).all(), 1.0) + """Test function to load and test the Framingham dataset. + """ + x, t, e = datasets.load_dataset('FRAMINGHAM') + t_median = np.median(t) + + self.assertIsInstance(x, np.ndarray) + self.assertIsInstance(t, np.ndarray) + self.assertIsInstance(e, np.ndarray) + + self.assertEqual(x.shape, (11627, 18)) + self.assertEqual(t.shape, (11627,)) + self.assertEqual(e.shape, (11627,)) + + model = DeepSurvivalMachines() + self.assertIsInstance(model, DeepSurvivalMachines) + model.fit(x, t, e, iters=10) + self.assertIsInstance(model.torch_model, + DeepSurvivalMachinesTorch) + risk_score = model.predict_risk(x, t_median) + survival_probability = model.predict_survival(x, t_median) + np.testing.assert_equal((risk_score+survival_probability).all(), 1.0)