diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 13e095a..17c8785 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.4.3 +current_version = 0.4.4 commit = False tag = False parse = (?P\d+)\.(?P\d+)\.(?P\d+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5138421..3073296 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,21 @@ and this project adheres to [semantic versioning](https://semver.org/spec/v2.0.0 ### Removed +## [0.4.4] - 2024-12-13 + +### Added + +### Changed +- Modified `insert_df_to_hive_table` function in `cdp/io/output.py`. Added support + for creating non-existent Hive tables, repartitioning by column or partition count, + and handling missing columns with explicit type casting. + +### Deprecated + +### Fixed + +### Removed + ## [0.4.3] - 2024-12-05 ### Added @@ -501,6 +516,8 @@ and this project adheres to [semantic versioning](https://semver.org/spec/v2.0.0 > due to bugs in the GitHub Action `deploy_pypi.yaml`, which deploys to PyPI > and GitHub Releases. +- rdsa-utils v0.4.4: [GitHub Release](https://github.com/ONSdigital/rdsa-utils/releases/tag/v0.4.4) | + [PyPI](https://pypi.org/project/rdsa-utils/0.4.4/) - rdsa-utils v0.4.3: [GitHub Release](https://github.com/ONSdigital/rdsa-utils/releases/tag/v0.4.3) | [PyPI](https://pypi.org/project/rdsa-utils/0.4.3/) - rdsa-utils v0.4.2: [GitHub Release](https://github.com/ONSdigital/rdsa-utils/releases/tag/v0.4.2) | diff --git a/rdsa_utils/__init__.py b/rdsa_utils/__init__.py index f6b7e26..cd1ee63 100644 --- a/rdsa_utils/__init__.py +++ b/rdsa_utils/__init__.py @@ -1 +1 @@ -__version__ = "0.4.3" +__version__ = "0.4.4" diff --git a/rdsa_utils/cdp/helpers/s3_utils.py b/rdsa_utils/cdp/helpers/s3_utils.py index 361beb8..c497c1a 100644 --- a/rdsa_utils/cdp/helpers/s3_utils.py +++ b/rdsa_utils/cdp/helpers/s3_utils.py @@ -1141,12 +1141,12 @@ def write_csv( Examples -------- - >>> s3_client = boto3.client('s3') + >>> client = boto3.client('s3') >>> data = pd.DataFrame({ >>> 'column1': [1, 2, 3], >>> 'column2': ['a', 'b', 'c'] >>> }) - >>> write_csv(s3_client, 'my_bucket', data, 'path/to/file.csv') + >>> write_csv(client, 'my_bucket', data, 'path/to/file.csv') True """ try: diff --git a/rdsa_utils/cdp/io/output.py b/rdsa_utils/cdp/io/output.py index 3f7231a..98ac57c 100644 --- a/rdsa_utils/cdp/io/output.py +++ b/rdsa_utils/cdp/io/output.py @@ -36,12 +36,14 @@ def insert_df_to_hive_table( table_name: str, overwrite: bool = False, fill_missing_cols: bool = False, + repartition_data_by: Union[int, str, None] = None, ) -> None: - """Write the SparkDF contents to a Hive table. + """Write SparkDF to Hive table with optional configuration. - This function writes data from a SparkDF into a Hive table, allowing - optional handling of missing columns. The table's column order is ensured to - match that of the DataFrame. + This function writes data from a SparkDF into a Hive table, handling missing + columns and optional repartitioning. It ensures the table's column order matches + the DataFrame and manages different overwrite behaviors for partitioned and + non-partitioned data. Parameters ---------- @@ -52,10 +54,39 @@ def insert_df_to_hive_table( table_name Name of the Hive table to write data into. overwrite - If True, existing data in the table will be overwritten, - by default False. + Controls how existing data is handled, default is False: + + For non-partitioned data: + - True: Replaces entire table with DataFrame data. + - False: Appends DataFrame data to existing table. + + For partitioned data: + - True: Replaces data only in partitions present in DataFrame. + - False: Appends data to existing partitions or creates new ones. fill_missing_cols - If True, missing columns will be filled with nulls, by default False. + If True, adds missing columns as NULL values. If False, raises an error + on schema mismatch, default is False. + + - Explicitly casts DataFrame columns to match the Hive table schema to + avoid type mismatch errors. + - Adds missing columns as NULL values when `fill_missing_cols` is True, + regardless of their data type (e.g., String, Integer, Double, Boolean, etc.). + repartition_data_by + Controls data repartitioning, default is None: + - int: Sets target number of partitions. + - str: Specifies column to repartition by. + - None: No repartitioning performed. + + Notes + ----- + When using repartition with a number: + - Affects physical file structure but preserves Hive partitioning scheme. + - Controls number of output files per write operation per Hive partition. + - Maintains partition-based query optimization. + + When repartitioning by column: + - Helps balance file sizes across Hive partitions. + - Reduces creation of small files. Raises ------ @@ -65,36 +96,81 @@ def insert_df_to_hive_table( ValueError If the SparkDF schema does not match the Hive table schema and 'fill_missing_cols' is set to False. + DataframeEmptyError + If input DataFrame is empty. Exception For other general exceptions when writing data to the table. - """ - logger.info(f"Preparing to write data to {table_name}.") - # Validate SparkDF before writing - if is_df_empty(df): - msg = f"Cannot write an empty SparkDF to {table_name}" - raise DataframeEmptyError( - msg, - ) + Examples + -------- + Write a DataFrame to a Hive table without overwriting: + >>> insert_df_to_hive_table( + ... spark=spark, + ... df=df, + ... table_name="my_database.my_table" + ... ) + + Overwrite an existing table with a DataFrame: + >>> insert_df_to_hive_table( + ... spark=spark, + ... df=df, + ... table_name="my_database.my_table", + ... overwrite=True + ... ) + + Write a DataFrame to a Hive table with missing columns filled: + >>> insert_df_to_hive_table( + ... spark=spark, + ... df=df, + ... table_name="my_database.my_table", + ... fill_missing_cols=True + ... ) + + Repartition by column before writing to Hive: + >>> insert_df_to_hive_table( + ... spark=spark, + ... df=df, + ... table_name="my_database.my_table", + ... repartition_data_by="partition_column" + ... ) + + Repartition into a fixed number of partitions before writing: + >>> insert_df_to_hive_table( + ... spark=spark, + ... df=df, + ... table_name="my_database.my_table", + ... repartition_data_by=10 + ... ) + """ + logger.info(f"Preparing to write data to {table_name} with overwrite={overwrite}.") + # Check if the table exists; if not, set flag for later creation + table_exists = True try: + table_schema = spark.read.table(table_name).schema table_columns = spark.read.table(table_name).columns except AnalysisException: - logger.error( - ( - f"Error reading table {table_name}. " - f"Make sure the table exists and you have access to it." - ), + logger.info( + f"Table {table_name} does not exist and will be " + "created after transformations.", ) + table_exists = False + table_columns = df.columns # Use DataFrame columns as initial schema - raise + # Validate SparkDF before writing + if is_df_empty(df): + msg = f"Cannot write an empty SparkDF to {table_name}" + raise DataframeEmptyError(msg) - if fill_missing_cols: + # Handle missing columns if specified + if fill_missing_cols and table_exists: missing_columns = list(set(table_columns) - set(df.columns)) - for col in missing_columns: - df = df.withColumn(col, F.lit(None)) - else: + column_type = [ + field.dataType for field in table_schema if field.name == col + ][0] + df = df.withColumn(col, F.lit(None).cast(column_type)) + elif not fill_missing_cols and table_exists: # Validate schema before writing if set(table_columns) != set(df.columns): msg = ( @@ -103,10 +179,32 @@ def insert_df_to_hive_table( ) raise ValueError(msg) - df = df.select(table_columns) + # Ensure column order + df = df.select(table_columns) if table_exists else df + + # Apply repartitioning if specified + if repartition_data_by is not None: + if isinstance(repartition_data_by, int): + logger.info(f"Repartitioning data into {repartition_data_by} partitions.") + df = df.repartition(repartition_data_by) + elif isinstance(repartition_data_by, str): + logger.info(f"Repartitioning data by column {repartition_data_by}.") + df = df.repartition(repartition_data_by) + # Write DataFrame to Hive table based on existence and overwrite parameter try: - df.write.insertInto(table_name, overwrite) + if table_exists: + if overwrite: + logger.info(f"Overwriting existing table {table_name}.") + df.write.mode("overwrite").saveAsTable(table_name) + else: + logger.info( + f"Inserting into existing table {table_name} without overwrite.", + ) + df.write.insertInto(table_name) + else: + df.write.saveAsTable(table_name) + logger.info(f"Table {table_name} created successfully.") logger.info(f"Successfully wrote data to {table_name}.") except Exception: logger.error(f"Error writing data to {table_name}.") diff --git a/tests/cdp/io/test_cdsw_output.py b/tests/cdp/io/test_cdsw_output.py index 0175d31..2d79885 100644 --- a/tests/cdp/io/test_cdsw_output.py +++ b/tests/cdp/io/test_cdsw_output.py @@ -1,7 +1,7 @@ """Tests for the cdp/io/output.py module.""" from typing import Callable -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pytest from moto import mock_aws @@ -57,12 +57,14 @@ def test_df(self, spark_session: SparkSession, create_spark_df: Callable): return df - @patch("pyspark.sql.DataFrameWriter.insertInto") + @patch("pyspark.sql.DataFrame.withColumn") @patch("pyspark.sql.DataFrameReader.table") + @patch("pyspark.sql.functions.lit") def test_insert_df_to_hive_table_with_missing_columns( self, + mock_lit, mock_table, - mock_insert_into, + mock_with_column, spark_session: SparkSession, test_df: SparkDF, ) -> None: @@ -70,10 +72,44 @@ def test_insert_df_to_hive_table_with_missing_columns( table when 'fill_missing_cols' is True. """ table_name = "test_table" - # Mock the table columns + + # Mock the table's columns to include 'address' mock_table.return_value.columns = ["id", "name", "age", "address"] - # Mock the DataFrameWriter insertInto - mock_insert_into.return_value = None + + # Mock the Hive table schema + mock_table_schema = T.StructType( + [ + T.StructField("id", T.IntegerType()), + T.StructField("name", T.StringType()), + T.StructField("age", T.IntegerType()), + T.StructField( + "address", + T.StringType(), + ), # Include 'address' with StringType + ], + ) + mock_table.return_value.schema = mock_table_schema + + # Create a mock of the DataFrame (test_df) that does not contain 'address' + test_df_mock = MagicMock() + test_df_mock.columns = [ + "id", + "name", + "age", + ] # Mock that it doesn't have 'address' + + # Mock `withColumn` behavior - ensure it's being called + # with correct column expression + mock_with_column.return_value = ( + test_df_mock # Simulate that `withColumn` returns `test_df_mock` + ) + + # Mock the return of the `lit` function to simulate the expression + mock_lit.return_value = F.lit(None).cast( + T.StringType(), + ) # We expect the `lit` to return the null expression + + # Call the function to insert data into the table insert_df_to_hive_table( spark_session, test_df, @@ -81,8 +117,17 @@ def test_insert_df_to_hive_table_with_missing_columns( overwrite=True, fill_missing_cols=True, ) - # Assert that insertInto was called with correct arguments - mock_insert_into.assert_called_with(table_name, True) + + # Assert that `lit` was called with `None` to create the null expression + mock_lit.assert_called_with(None) + + # Since `lit().cast()` returns the same object, + # directly assert the final return value + expected_column = F.lit(None).cast(T.StringType()) + + # Check if withColumn was called with 'address' and the expected expression + # Compare the exact column expression (the expected one, not the mock) + mock_with_column.assert_any_call("address", expected_column) @patch("pyspark.sql.DataFrameReader.table") def test_insert_df_to_hive_table_without_missing_columns( @@ -106,22 +151,136 @@ def test_insert_df_to_hive_table_without_missing_columns( fill_missing_cols=False, ) + @patch("pyspark.sql.DataFrameWriter.insertInto") @patch("pyspark.sql.DataFrameReader.table") - def test_insert_df_to_hive_table_with_non_existing_table( + def test_insert_df_to_hive_table_insert_into_existing_table( self, mock_table, + mock_insert_into, spark_session: SparkSession, test_df: SparkDF, ) -> None: - """Test that insert_df_to_hive_table raises an AnalysisException when - the table doesn't exist. - """ - table_name = "non_existing_table" - # Create an AnalysisException with a stack trace - exc = AnalysisException(f"Table {table_name} not found.") - mock_table.side_effect = exc - with pytest.raises(AnalysisException): - insert_df_to_hive_table(spark_session, test_df, table_name) + """Test that insertInto is called when the table already exists in Hive.""" + table_name = "existing_table" + + # Mock the table columns to simulate the table already exists + mock_table.return_value.columns = ["id", "name", "age"] + + # Simulate a successful call to `insertInto` + mock_insert_into.return_value = None + + # Call the function that triggers insertInto when the table exists + insert_df_to_hive_table( + spark_session, + test_df, + table_name, + ) + + # Assert that insertInto was called with the correct table name + mock_insert_into.assert_called_once_with(table_name) + + @patch("pyspark.sql.DataFrameWriter.saveAsTable") + @patch("pyspark.sql.DataFrameReader.table") + def test_insert_df_to_hive_table_save_as_table_when_table_does_not_exist( + self, + mock_table, + mock_save_as_table, + spark_session: SparkSession, + test_df: SparkDF, + ) -> None: + """Test that saveAsTable is called when the Hive table does not exist.""" + table_name = "new_table" + + # Simulate the table not existing by raising an AnalysisException + mock_table.side_effect = AnalysisException(f"Table {table_name} not found.") + + # Simulate a successful call to `saveAsTable` + mock_save_as_table.return_value = None + + # Call the function that triggers saveAsTable when the table does not exist + insert_df_to_hive_table( + spark_session, + test_df, + table_name, + ) + + # Assert that saveAsTable was called with the correct table name + mock_save_as_table.assert_called_once_with(table_name) + + @patch("pyspark.sql.DataFrame.repartition") + @patch("pyspark.sql.DataFrameReader.table") + def test_insert_df_to_hive_table_with_repartition_data_by( + self, + mock_table, + mock_repartition, + spark_session: SparkSession, + test_df: SparkDF, + ) -> None: + """Test that the DataFrame is repartitioned by a specified column.""" + table_name = "test_table" + mock_table.return_value.columns = ["id", "name", "age"] + + # Ensure that mock_repartition is set to return the same test_df + mock_repartition.return_value = test_df + + # Call the function that triggers repartition + insert_df_to_hive_table( + spark_session, + test_df, + table_name, + repartition_data_by="id", # We expect "id" column to be used for repartitioning + overwrite=True, + ) + + # Assert that repartition was called with the correct argument (the column name) + mock_repartition.assert_called_once_with("id") + + @patch("pyspark.sql.DataFrame.repartition") + @patch("pyspark.sql.DataFrameReader.table") + def test_insert_df_to_hive_table_with_repartition_num_partitions( + self, + mock_table, + mock_repartition, + spark_session: SparkSession, + test_df: SparkDF, + ) -> None: + """Test that the DataFrame is repartitioned into a specific number of partitions.""" + table_name = "test_table" + mock_table.return_value.columns = ["id", "name", "age"] + + # Ensure that mock_repartition is set to return the same test_df + mock_repartition.return_value = test_df + + # Call the function that triggers repartition + insert_df_to_hive_table( + spark_session, + test_df, + table_name, + repartition_data_by=5, # Expecting 5 partitions + overwrite=True, + ) + + # Assert that repartition was called with the number of partitions (5) + mock_repartition.assert_called_once_with(5) + + def test_insert_df_to_hive_table_with_empty_dataframe( + self, + spark_session: SparkSession, + ) -> None: + """Test that an empty DataFrame raises DataframeEmptyError.""" + from rdsa_utils.exceptions import DataframeEmptyError + + table_name = "test_table" + empty_df = spark_session.createDataFrame( + [], + schema="id INT, name STRING, age INT", + ) + with pytest.raises(DataframeEmptyError): + insert_df_to_hive_table( + spark_session, + empty_df, + table_name, + ) class TestWriteAndReadHiveTable: