diff --git a/molpipeline/estimators/chemprop/models.py b/molpipeline/estimators/chemprop/models.py index e720e029..47f5e59a 100644 --- a/molpipeline/estimators/chemprop/models.py +++ b/molpipeline/estimators/chemprop/models.py @@ -143,8 +143,7 @@ def _predict( test_data = build_dataloader(X, num_workers=self.n_jobs, shuffle=False) predictions = self.lightning_trainer.predict(self.model, test_data) prediction_array = np.vstack(predictions) # type: ignore - prediction_array = prediction_array.squeeze() - + prediction_array = prediction_array.squeeze(axis=1) # Check if the predictions have the same length as the input dataset if prediction_array.shape[0] != len(X): raise AssertionError( diff --git a/test_extras/test_chemprop/test_chemprop_pipeline.py b/test_extras/test_chemprop/test_chemprop_pipeline.py index 32c4e677..c600a3a3 100644 --- a/test_extras/test_chemprop/test_chemprop_pipeline.py +++ b/test_extras/test_chemprop/test_chemprop_pipeline.py @@ -307,6 +307,12 @@ def test_prediction(self) -> None: pred_copy = model_copy.predict(molecule_net_logd_df["smiles"].tolist()) self.assertTrue(np.allclose(pred, pred_copy)) + # Test single prediction, this was causing an error before + single_mol_pred = regression_model.predict( + [molecule_net_logd_df["smiles"].iloc[0]] + ) + self.assertEqual(single_mol_pred.shape, (1,)) + class TestClassificationPipeline(unittest.TestCase): """Test the Chemprop model pipeline for classification.""" @@ -341,6 +347,16 @@ def test_prediction(self) -> None: self.assertEqual(proba.shape, proba_copy.shape) self.assertTrue(np.allclose(proba[~nan_indices], proba_copy[~nan_indices])) + # Test single prediction, this was causing an error before + single_mol_pred = classification_model.predict( + [molecule_net_bbbp_df["smiles"].iloc[0]] + ) + self.assertEqual(single_mol_pred.shape, (1,)) + single_mol_proba = classification_model.predict_proba( + [molecule_net_bbbp_df["smiles"].iloc[0]] + ) + self.assertEqual(single_mol_proba.shape, (1, 2)) + class TestMulticlassClassificationPipeline(unittest.TestCase): """Test the Chemprop model pipeline for multiclass classification.""" @@ -375,6 +391,16 @@ def test_prediction(self) -> None: self.assertEqual(pred.shape, pred_copy.shape) self.assertTrue(np.allclose(proba[~nan_mask], proba_copy[~nan_mask])) + # Test single prediction, this was causing an error before + single_mol_pred = classification_model.predict( + [test_data_df["Molecule"].iloc[0]] + ) + self.assertEqual(single_mol_pred.shape, (1,)) + single_mol_proba = classification_model.predict_proba( + [test_data_df["Molecule"].iloc[0]] + ) + self.assertEqual(single_mol_proba.shape, (1, 3)) + with self.assertRaises(ValueError): classification_model.fit( mols,