diff --git a/evadb/models/storage/batch.py b/evadb/models/storage/batch.py index c9696e7dbe..43e69cc4fc 100644 --- a/evadb/models/storage/batch.py +++ b/evadb/models/storage/batch.py @@ -269,7 +269,7 @@ def merge_column_wise(cls, batches: List[Batch], auto_renaming=False) -> Batch: frame_index == frames_index[i - 1] ), "Merging of DataFrames with unmatched indices can cause undefined behavior" - new_frames = pd.concat(frames, axis=1, copy=False, ignore_index=False).ffill() + new_frames = pd.concat(frames, axis=1, copy=False, ignore_index=False) if new_frames.columns.duplicated().any(): logger.debug("Duplicated column name detected {}".format(new_frames)) return Batch(new_frames) diff --git a/test/unit_tests/models/storage/test_batch.py b/test/unit_tests/models/storage/test_batch.py index 65eb12023c..d73d355231 100644 --- a/test/unit_tests/models/storage/test_batch.py +++ b/test/unit_tests/models/storage/test_batch.py @@ -112,6 +112,29 @@ def test_merge_column_wise_batch_frame(self): # Special case self.assertEqual(Batch.merge_column_wise([]), Batch()) + # Cases with None + batch_1 = Batch(frames=pd.DataFrame({"id": [0, None, 1]})) + batch_2 = Batch(frames=pd.DataFrame({"data": [None, 0, None]})) + batch_res = Batch( + frames=pd.DataFrame({"id": [0, None, 1], "data": [None, 0, None]}) + ) + self.assertEqual(Batch.merge_column_wise([batch_1, batch_2]), batch_res) + + # Cases with filter + df_1 = pd.DataFrame({"id": [-10, 1, 2]}) + df_2 = pd.DataFrame({"data": [-20, 2, 3]}) + df_1 = df_1[df_1 < 0].dropna() + df_1.reset_index(drop=True, inplace=True) + df_2 = df_2[df_2 < 0].dropna() + df_2.reset_index(drop=True, inplace=True) + batch_1 = Batch(frames=df_1) + batch_2 = Batch(frames=df_2) + df_res = pd.DataFrame({"id": [-10, 1, 2], "data": [-20, 2, 3]}) + df_res = df_res[df_res < 0].dropna() + df_res.reset_index(drop=True, inplace=True) + batch_res = Batch(frames=df_res) + self.assertEqual(Batch.merge_column_wise([batch_1, batch_2]), batch_res) + def test_should_fail_for_list(self): frames = [{"id": 0, "data": [1, 2]}, {"id": 1, "data": [1, 2]}] self.assertRaises(ValueError, Batch, frames)