diff --git a/.bumpversion.cfg b/.bumpversion.cfg index db30c44..4e7e91e 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.3.1 +current_version = 0.3.2 commit = False tag = False parse = (?P\d+)\.(?P\d+)\.(?P\d+) diff --git a/.github/workflows/deploy_mkdocs.yaml b/.github/workflows/deploy_mkdocs.yaml index 8561e1f..8952d54 100644 --- a/.github/workflows/deploy_mkdocs.yaml +++ b/.github/workflows/deploy_mkdocs.yaml @@ -17,7 +17,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: "3.10" + python-version: "3.11" - name: Install dependencies run: | diff --git a/.github/workflows/pull_request_workflow.yaml b/.github/workflows/pull_request_workflow.yaml index 4bf8a5d..cd5c607 100644 --- a/.github/workflows/pull_request_workflow.yaml +++ b/.github/workflows/pull_request_workflow.yaml @@ -41,7 +41,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip3 install -e .[dev,doc] + pip3 install -e .[dev] - name: Run tests run: pytest diff --git a/CHANGELOG.md b/CHANGELOG.md index b965a31..fb0a6ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,30 @@ and this project adheres to [semantic versioning](https://semver.org/spec/v2.0.0 ### Removed +## [v0.3.2] - 2024-09-02 + +### Added +- Added `load_csv` to `helpers/pyspark.py` with kwargs parameter. +- Added `truncate_external_hive_table` to `helpers/pyspark.py`. +- Added `get_tables_in_database` to `cdp/io/input.py`. +- Added `load_csv` to `cdp/helpers/s3_utils.py`. This loads a CSV from S3 bucket + into a Pandas DataFrame. + +### Changed +- Removed `.config("spark.shuffle.service.enabled", "true")` + from `create_spark_session()` not compatible with CDP. Added + `.config("spark.dynamicAllocation.shuffleTracking.enabled", "true")` & + `.config("spark.sql.adaptive.enabled", "true")`. +- Change `mkdocs` theme from `mkdocs-tech-docs-template` to `ons-mkdocs-theme`. +- Added more parameters to `load_and_validate_table()` in `cdp/io/input.py`. + +### Deprecated + +### Fixed +- Temporarily pin `numpy==1.24.4` due to https://github.com/numpy/numpy/issues/267100 + +### Removed + ## [v0.3.1] - 2024-05-24 ### Added @@ -348,6 +372,8 @@ and this project adheres to [semantic versioning](https://semver.org/spec/v2.0.0 > and GitHub Releases. +- rdsa-utils v0.3.2: [GitHub Release](https://github.com/ONSdigital/rdsa-utils/releases/tag/v0.3.2) | + [PyPI](https://pypi.org/project/rdsa-utils/0.3.2/) - rdsa-utils v0.3.1: [GitHub Release](https://github.com/ONSdigital/rdsa-utils/releases/tag/v0.3.1) | [PyPI](https://pypi.org/project/rdsa-utils/0.3.1/) - rdsa-utils v0.3.0: [GitHub Release](https://github.com/ONSdigital/rdsa-utils/releases/tag/v0.3.0) | diff --git a/mkdocs.yml b/mkdocs.yml index bd4faa2..a3b5c5a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,13 +1,18 @@ site_name: rdsa-utils Documentation theme: - name: tech_docs_template + name: ons_mkdocs_theme features: - navigation.tabs - navigation.tabs.sticky - - navigation.indexes - logo: assets/ons_logo_white.svg - favicon: assets/ons_favicon.svg + - navigation.sections + - toc.integrate + - content.tabs.link + - content.code.annotation + - content.code.copy + language: en + logo: assets/images/logo.svg + favicon: assets/images/favicon.ico repo_name: rdsa-utils repo_url: https://github.com/ONSdigital/rdsa-utils @@ -41,3 +46,12 @@ nav: - API Reference: reference.md - Contribution Guide: contribution_guide.md - Branching & Deployment Guide: branch_and_deploy_guide.md + +extra: + social: + - icon: fontawesome/brands/github + link: + +# Do not remove the copy right section. But you can change the copyright information. +copyright: | + © Office for National Statistics 2024 diff --git a/rdsa_utils/__init__.py b/rdsa_utils/__init__.py index 260c070..f9aa3e1 100644 --- a/rdsa_utils/__init__.py +++ b/rdsa_utils/__init__.py @@ -1 +1 @@ -__version__ = "0.3.1" +__version__ = "0.3.2" diff --git a/rdsa_utils/cdp/helpers/s3_utils.py b/rdsa_utils/cdp/helpers/s3_utils.py index 87223db..e9d6242 100644 --- a/rdsa_utils/cdp/helpers/s3_utils.py +++ b/rdsa_utils/cdp/helpers/s3_utils.py @@ -24,9 +24,10 @@ import logging from pathlib import Path -from typing import List, Optional +from typing import Dict, List, Optional import boto3 +import pandas as pd from rdsa_utils.exceptions import InvalidBucketNameError @@ -815,3 +816,149 @@ def delete_folder( f"in bucket {bucket_name}: {str(e)}", ) return False + + +def load_csv( + client: boto3.client, + bucket_name: str, + filepath: str, + keep_columns: Optional[List[str]] = None, + rename_columns: Optional[Dict[str, str]] = None, + drop_columns: Optional[List[str]] = None, + **kwargs, +) -> pd.DataFrame: + """Load a CSV file from an S3 bucket into a Pandas DataFrame. + + Parameters + ---------- + client + The boto3 S3 client instance. + bucket_name + The name of the S3 bucket. + filepath + The key (full path and filename) of the CSV file in the S3 bucket. + keep_columns + A list of column names to keep in the DataFrame, dropping all others. + Default value is None. + rename_columns + A dictionary to rename columns where keys are existing column + names and values are new column names. + Default value is None. + drop_columns + A list of column names to drop from the DataFrame. + Default value is None. + kwargs + Additional keyword arguments to pass to the `pd.read_csv` method. + + Returns + ------- + pd.DataFrame + Pandas DataFrame containing the data from the CSV file. + + Raises + ------ + Exception + If there is an error loading the file. + ValueError + If a column specified in rename_columns, drop_columns, or + keep_columns is not found in the DataFrame. + + Notes + ----- + Transformation order: + 1. Columns are kept according to `keep_columns`. + 2. Columns are dropped according to `drop_columns`. + 3. Columns are renamed according to `rename_columns`. + + Examples + -------- + Load a CSV file and rename columns: + + >>> df = load_csv( + client, + "my-bucket", + "path/to/file.csv", + rename_columns={"old_name": "new_name"} + ) + + Load a CSV file and keep only specific columns: + + >>> df = load_csv( + client, + "my-bucket", + "path/to/file.csv", + keep_columns=["col1", "col2"] + ) + + Load a CSV file and drop specific columns: + + >>> df = load_csv( + client, + "my-bucket", + "path/to/file.csv", + drop_columns=["col1", "col2"] + ) + + Load a CSV file with custom delimiter: + + >>> df = load_csv( + client, + "my-bucket", + "path/to/file.csv", + sep=";" + ) + """ + try: + # Get the CSV file from S3 + response = client.get_object(Bucket=bucket_name, Key=filepath) + logger.info( + f"Loaded CSV file from S3 bucket {bucket_name}, filepath {filepath}", + ) + + # Read the CSV file into a Pandas DataFrame + df = pd.read_csv(response["Body"], **kwargs) + + except Exception as e: + error_message = ( + f"Error loading file from bucket {bucket_name}, filepath {filepath}: {e}" + ) + logger.error(error_message) + raise Exception(error_message) from e + + columns = df.columns.tolist() + + # Apply column transformations: keep, drop, rename + if keep_columns: + missing_columns = [col for col in keep_columns if col not in columns] + if missing_columns: + error_message = ( + f"Columns {missing_columns} not found in DataFrame and cannot be kept" + ) + logger.error(error_message) + raise ValueError(error_message) + df = df[keep_columns] + + if drop_columns: + for col in drop_columns: + if col in columns: + df = df.drop(columns=[col]) + else: + error_message = ( + f"Column '{col}' not found in DataFrame and cannot be dropped" + ) + logger.error(error_message) + raise ValueError(error_message) + + if rename_columns: + for old_name, new_name in rename_columns.items(): + if old_name in columns: + df = df.rename(columns={old_name: new_name}) + else: + error_message = ( + f"Column '{old_name}' not found in DataFrame and " + f"cannot be renamed to '{new_name}'" + ) + logger.error(error_message) + raise ValueError(error_message) + + return df diff --git a/rdsa_utils/cdp/io/input.py b/rdsa_utils/cdp/io/input.py index 3d7a018..9128d03 100644 --- a/rdsa_utils/cdp/io/input.py +++ b/rdsa_utils/cdp/io/input.py @@ -1,7 +1,7 @@ """Read inputs on CDP.""" import logging -from typing import Tuple +from typing import Dict, List, Optional, Tuple from pyspark.sql import DataFrame as SparkDF from pyspark.sql import SparkSession @@ -16,6 +16,42 @@ def get_current_database(spark: SparkSession) -> str: return spark.sql("SELECT current_database()").collect()[0]["current_database()"] +def get_tables_in_database(spark: SparkSession, database_name: str) -> List[str]: + """Get a list of tables in a given database. + + Parameters + ---------- + spark + Active SparkSession. + database_name + The name of the database from which to list tables. + + Returns + ------- + List[str] + A list of table names in the specified database. + + Raises + ------ + ValueError + If there is an error fetching tables from the specified database. + + Examples + -------- + >>> tables = get_tables_in_database(spark, "default") + >>> print(tables) + ['table1', 'table2', 'table3'] + """ + try: + tables_df = spark.sql(f"SHOW TABLES IN {database_name}") + tables = [row["tableName"] for row in tables_df.collect()] + return tables + except Exception as e: + error_msg = f"Error fetching tables from database {database_name}: {e}" + logger.error(error_msg) + raise ValueError(error_msg) from e + + def extract_database_name( spark: SparkSession, long_table_name: str, @@ -88,8 +124,11 @@ def load_and_validate_table( skip_validation: bool = False, err_msg: str = None, filter_cond: str = None, + keep_columns: Optional[List[str]] = None, + rename_columns: Optional[Dict[str, str]] = None, + drop_columns: Optional[List[str]] = None, ) -> SparkDF: - """Load a table and validate if it is not empty after applying a filter. + """Load a table, apply transformations, and validate if it is not empty. Parameters ---------- @@ -103,6 +142,16 @@ def load_and_validate_table( Error message to return if table is empty, by default None. filter_cond Condition to apply to SparkDF once read, by default None. + keep_columns + A list of column names to keep in the DataFrame, dropping all others. + Default value is None. + rename_columns + A dictionary to rename columns where keys are existing column + names and values are new column names. + Default value is None. + drop_columns + A list of column names to drop from the DataFrame. + Default value is None. Returns ------- @@ -115,8 +164,55 @@ def load_and_validate_table( If there's an issue accessing the table or if the table does not exist in the specified database. ValueError - If the table is empty after loading, or if it becomes - empty after applying a filter condition. + If the table is empty after loading, becomes empty after applying + a filter condition, or if columns specified in keep_columns, + drop_columns, or rename_columns do not exist in the DataFrame. + + Notes + ----- + Transformation order: + 1. Columns are kept according to `keep_columns`. + 2. Columns are dropped according to `drop_columns`. + 3. Columns are renamed according to `rename_columns`. + + Examples + -------- + Load a table, apply a filter, and validate it: + + >>> df = load_and_validate_table( + spark=spark, + table_name="my_table", + filter_cond="age > 21" + ) + + Load a table and keep only specific columns: + + >>> df = load_and_validate_table( + spark=spark, + table_name="my_table", + keep_columns=["name", "age", "city"] + ) + + Load a table, drop specific columns, and rename a column: + + >>> df = load_and_validate_table( + spark=spark, + table_name="my_table", + drop_columns=["extra_column"], + rename_columns={"name": "full_name"} + ) + + Load a table, skip validation, and apply all transformations: + + >>> df = load_and_validate_table( + spark=spark, + table_name="my_table", + skip_validation=True, + keep_columns=["name", "age", "city"], + drop_columns=["extra_column"], + rename_columns={"name": "full_name"}, + filter_cond="age > 21" + ) """ try: df = spark.read.table(table_name) @@ -131,11 +227,49 @@ def load_and_validate_table( logger.error(db_err) raise PermissionError(db_err) from e + columns = [str(col) for col in df.columns] + + # Apply column transformations: keep, drop, rename + if keep_columns: + missing_columns = [col for col in keep_columns if col not in columns] + if missing_columns: + error_message = ( + f"Columns {missing_columns} not found in DataFrame and cannot be kept" + ) + logger.error(error_message) + raise ValueError(error_message) + df = df.select(*keep_columns) + + if drop_columns: + for col in drop_columns: + if col in columns: + df = df.drop(col) + else: + error_message = ( + f"Column '{col}' not found in DataFrame and cannot be dropped" + ) + logger.error(error_message) + raise ValueError(error_message) + + if rename_columns: + for old_name, new_name in rename_columns.items(): + if old_name in columns: + df = df.withColumnRenamed(old_name, new_name) + else: + error_message = ( + f"Column '{old_name}' not found in DataFrame and " + f"cannot be renamed to '{new_name}'" + ) + logger.error(error_message) + raise ValueError(error_message) + + # Validate the table if skip_validation is not True if not skip_validation: if df.rdd.isEmpty(): err_msg = err_msg or f"Table {table_name} is empty." raise DataframeEmptyError(err_msg) + # Apply the filter condition if provided if filter_cond: df = df.filter(filter_cond) if not skip_validation and df.rdd.isEmpty(): @@ -149,7 +283,9 @@ def load_and_validate_table( logger.info( ( f"Loaded and validated table {table_name}. " - f"Filter condition applied: {filter_cond}" + f"Filter condition applied: {filter_cond}. " + f"Keep columns: {keep_columns}, Drop columns: {drop_columns}, " + f"Rename columns: {rename_columns}." ), ) diff --git a/rdsa_utils/helpers/pyspark.py b/rdsa_utils/helpers/pyspark.py index e3074f5..d6a3662 100644 --- a/rdsa_utils/helpers/pyspark.py +++ b/rdsa_utils/helpers/pyspark.py @@ -775,9 +775,13 @@ def create_spark_session( # Common configurations for all sizes builder = ( + # Dynamic Allocation builder.config("spark.dynamicAllocation.enabled", "true") - .config("spark.shuffle.service.enabled", "true") - .config("spark.ui.showConsoleProgress", "false") + .config("spark.dynamicAllocation.shuffleTracking.enabled", "true") + # Adaptive Query Execution + .config("spark.sql.adaptive.enabled", "true") + # General + .config("spark.ui.showConsoleProgress", "false") ).enableHiveSupport() # fmt: on @@ -791,3 +795,181 @@ def create_spark_session( except Exception as e: logger.error(f"An error occurred while creating the Spark session: {e}") raise + + +def load_csv( + spark: SparkSession, + filepath: str, + keep_columns: Optional[List[str]] = None, + rename_columns: Optional[Dict[str, str]] = None, + drop_columns: Optional[List[str]] = None, + **kwargs, +) -> SparkDF: + """Load a CSV file into a PySpark DataFrame. + + spark + Active SparkSession. + filepath + The full path and filename of the CSV file to load. + keep_columns + A list of column names to keep in the DataFrame, dropping all others. + Default value is None. + rename_columns + A dictionary to rename columns where keys are existing column + names and values are new column names. + Default value is None. + drop_columns + A list of column names to drop from the DataFrame. + Default value is None. + kwargs + Additional keyword arguments to pass to the `spark.read.csv` method. + + Returns + ------- + SparkDF + PySpark DataFrame containing the data from the CSV file. + + Notes + ----- + Transformation order: + 1. Columns are kept according to `keep_columns`. + 2. Columns are dropped according to `drop_columns`. + 3. Columns are renamed according to `rename_columns`. + + Raises + ------ + Exception + If there is an error loading the file. + ValueError + If a column specified in rename_columns, drop_columns, or + keep_columns is not found in the DataFrame. + + Notes + ----- + Transformation order: + 1. Columns are kept according to `keep_columns`. + 2. Columns are dropped according to `drop_columns`. + 3. Columns are renamed according to `rename_columns`. + + Examples + -------- + Load a CSV file with multiline and rename columns: + + >>> df = load_csv( + spark, + "/path/to/file.csv", + multiLine=True, + rename_columns={"old_name": "new_name"} + ) + + Load a CSV file with a specific encoding: + + >>> df = load_csv(spark, "/path/to/file.csv", encoding="ISO-8859-1") + + Load a CSV file and keep only specific columns: + + >>> df = load_csv(spark, "/path/to/file.csv", keep_columns=["col1", "col2"]) + + Load a CSV file and drop specific columns: + + >>> df = load_csv(spark, "/path/to/file.csv", drop_columns=["col1", "col2"]) + + Load a CSV file with custom delimiter and multiline: + + >>> df = load_csv(spark, "/path/to/file.csv", sep=";", multiLine=True) + + """ + try: + df = spark.read.csv(filepath, header=True, **kwargs) + logger.info(f"Loaded CSV file {filepath} with parameters {kwargs}") + except Exception as e: + error_message = f"Error loading file {filepath}: {e}" + logger.error(error_message) + raise Exception(error_message) from e + + columns = [str(col) for col in df.columns] + + # When multi_line is used it adds \r at the end of the final column + if kwargs.get("multiLine", False): + columns[-1] = columns[-1].replace("\r", "") + df = df.withColumnRenamed(df.columns[-1], columns[-1]) + + # Apply column transformations: keep, drop, rename + if keep_columns: + missing_columns = [col for col in keep_columns if col not in columns] + if missing_columns: + error_message = ( + f"Columns {missing_columns} not found in DataFrame and cannot be kept" + ) + logger.error(error_message) + raise ValueError(error_message) + df = df.select(*keep_columns) + + if drop_columns: + for col in drop_columns: + if col in columns: + df = df.drop(col) + else: + error_message = ( + f"Column '{col}' not found in DataFrame and cannot be dropped" + ) + logger.error(error_message) + raise ValueError(error_message) + + if rename_columns: + for old_name, new_name in rename_columns.items(): + if old_name in columns: + df = df.withColumnRenamed(old_name, new_name) + else: + error_message = ( + f"Column '{old_name}' not found in DataFrame and " + f"cannot be renamed to '{new_name}'" + ) + logger.error(error_message) + raise ValueError(error_message) + + return df + + +def truncate_external_hive_table(spark: SparkSession, table_name: str) -> None: + """Truncate External Hive Table stored on S3 or HDFS. + + Parameters + ---------- + spark + Active SparkSession. + table_name + The name of the external Hive table to truncate. + + Returns + ------- + None + This function does not return any value. It performs an action of + truncating the table. + + Examples + -------- + Truncate a Hive table named 'my_database.my_table': + + >>> truncate_external_hive_table(spark, 'my_database.my_table') + """ + try: + logger.info(f"Attempting to truncate the table '{table_name}'") + + # Read the original table to get its schema + original_df = spark.table(table_name) + schema: T.StructType = original_df.schema + + # Create an empty DataFrame with the same schema + empty_df = spark.createDataFrame([], schema) + + # Overwrite the original table with the empty DataFrame + empty_df.write.mode("overwrite").insertInto(table_name) + + logger.info(f"Table '{table_name}' successfully truncated.") + + except Exception as e: + logger.error( + f"An error occurred while truncating the table '{table_name}': {e}", + ) + raise diff --git a/setup.cfg b/setup.cfg index 965c754..c076ee9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,6 +24,7 @@ install_requires = humanfriendly>=9.1 more-itertools>=9.0.0 pandas==1.5.3 + numpy==1.24.4 # Temporarily pin numpy due to https://github.com/numpy/numpy/issues/26710 pydantic>=2.6.2 pyyaml>=6.0.1 tomli>=2.0.1 @@ -52,7 +53,7 @@ dev = isort>=5.13.2 doc = mkdocs>=1.4.2 - mkdocs-tech-docs-template>=0.1.2 + ons-mkdocs-theme>=1.1.0 mkdocstrings[python]>=0.22.0 mkdocs-git-revision-date-localized-plugin>=1.2.1 mkdocs-jupyter>=0.24.3 diff --git a/tests/cdp/helpers/test_s3_utils.py b/tests/cdp/helpers/test_s3_utils.py index 1fa745d..be87f9b 100644 --- a/tests/cdp/helpers/test_s3_utils.py +++ b/tests/cdp/helpers/test_s3_utils.py @@ -14,6 +14,7 @@ file_exists, is_s3_directory, list_files, + load_csv, move_file, remove_leading_slash, upload_file, @@ -698,3 +699,213 @@ def test_delete_folder_nonexistent(self, s3_client): """Test delete_folder when the folder does not exist.""" result = delete_folder(s3_client, "test-bucket", "nonexistent-folder/") assert result is False + + +class TestLoadCSV: + """Tests for load_csv function.""" + + data_basic = """col1,col2,col3 +1,A,foo +2,B,bar +3,C,baz +""" + + data_multiline = """col1,col2,col3 +1,A,"foo +bar" +2,B,"baz +qux" +""" + + @pytest.fixture(scope="class") + def s3_client(self): + """Boto3 S3 client fixture for this test class.""" + with mock_aws(): + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="test-bucket") + yield s3 + + def upload_to_s3(self, s3_client, bucket_name, key, data): + """Upload a string as a CSV file to S3.""" + s3_client.put_object(Bucket=bucket_name, Key=key, Body=data) + + def test_load_csv_basic(self, s3_client): + """Test loading CSV file.""" + self.upload_to_s3(s3_client, "test-bucket", "test_basic.csv", self.data_basic) + df = load_csv(s3_client, "test-bucket", "test_basic.csv") + assert len(df) == 3 + assert len(df.columns) == 3 + + def test_load_csv_multiline(self, s3_client): + """Test loading multiline CSV file.""" + self.upload_to_s3( + s3_client, + "test-bucket", + "test_multiline.csv", + self.data_multiline, + ) + df = load_csv(s3_client, "test-bucket", "test_multiline.csv") + assert len(df) == 2 + assert len(df.columns) == 3 + + def test_load_csv_keep_columns(self, s3_client): + """Test keeping specific columns.""" + self.upload_to_s3( + s3_client, + "test-bucket", + "test_keep_columns.csv", + self.data_basic, + ) + df = load_csv( + s3_client, + "test-bucket", + "test_keep_columns.csv", + keep_columns=["col1", "col2"], + ) + assert len(df) == 3 + assert len(df.columns) == 2 + assert "col1" in df.columns + assert "col2" in df.columns + assert "col3" not in df.columns + + def test_load_csv_drop_columns(self, s3_client): + """Test dropping specific columns.""" + self.upload_to_s3( + s3_client, + "test-bucket", + "test_drop_columns.csv", + self.data_basic, + ) + df = load_csv( + s3_client, + "test-bucket", + "test_drop_columns.csv", + drop_columns=["col2"], + ) + assert len(df) == 3 + assert len(df.columns) == 2 + assert "col1" in df.columns + assert "col3" in df.columns + assert "col2" not in df.columns + + def test_load_csv_rename_columns(self, s3_client): + """Test renaming columns.""" + self.upload_to_s3( + s3_client, + "test-bucket", + "test_rename_columns.csv", + self.data_basic, + ) + df = load_csv( + s3_client, + "test-bucket", + "test_rename_columns.csv", + rename_columns={"col1": "new_col1", "col3": "new_col3"}, + ) + assert len(df) == 3 + assert len(df.columns) == 3 + assert "new_col1" in df.columns + assert "col1" not in df.columns + assert "new_col3" in df.columns + assert "col3" not in df.columns + + def test_load_csv_missing_keep_column(self, s3_client): + """Test error when keep column is missing.""" + self.upload_to_s3( + s3_client, + "test-bucket", + "test_missing_keep_column.csv", + self.data_basic, + ) + with pytest.raises(ValueError): + load_csv( + s3_client, + "test-bucket", + "test_missing_keep_column.csv", + keep_columns=["col4"], + ) + + def test_load_csv_missing_drop_column(self, s3_client): + """Test error when drop column is missing.""" + self.upload_to_s3( + s3_client, + "test-bucket", + "test_missing_drop_column.csv", + self.data_basic, + ) + with pytest.raises(ValueError): + load_csv( + s3_client, + "test-bucket", + "test_missing_drop_column.csv", + drop_columns=["col4"], + ) + + def test_load_csv_missing_rename_column(self, s3_client): + """Test error when rename column is missing.""" + self.upload_to_s3( + s3_client, + "test-bucket", + "test_missing_rename_column.csv", + self.data_basic, + ) + with pytest.raises(ValueError): + load_csv( + s3_client, + "test-bucket", + "test_missing_rename_column.csv", + rename_columns={"col4": "new_col4"}, + ) + + def test_load_csv_with_encoding(self, s3_client): + """Test loading CSV with a specific encoding.""" + self.upload_to_s3( + s3_client, + "test-bucket", + "test_encoding.csv", + self.data_basic, + ) + df = load_csv( + s3_client, + "test-bucket", + "test_encoding.csv", + encoding="ISO-8859-1", + ) + assert len(df) == 3 + assert len(df.columns) == 3 + + def test_load_csv_with_custom_delimiter(self, s3_client): + """Test loading CSV with a custom delimiter.""" + data_with_semicolon = """col1;col2;col3 +1;A;foo +2;B;bar +3;C;baz +""" + self.upload_to_s3( + s3_client, + "test-bucket", + "test_custom_delimiter.csv", + data_with_semicolon, + ) + df = load_csv(s3_client, "test-bucket", "test_custom_delimiter.csv", sep=";") + assert len(df) == 3 + assert len(df.columns) == 3 + + def test_load_csv_with_custom_quote(self, s3_client): + """Test loading CSV with a custom quote character.""" + data_with_custom_quote = """col1,col2,col3 + 1,A,foo + 2,B,'bar' + 3,C,'baz' + """ + self.upload_to_s3( + s3_client, + "test-bucket", + "test_custom_quote.csv", + data_with_custom_quote, + ) + df = load_csv(s3_client, "test-bucket", "test_custom_quote.csv", quotechar="'") + assert len(df) == 3 + assert len(df.columns) == 3 + assert df[df["col3"] == "bar"].shape[0] == 1 + assert df[df["col3"] == "baz"].shape[0] == 1 diff --git a/tests/cdp/io/test_cdsw_input.py b/tests/cdp/io/test_cdsw_input.py index c26541e..41f92ae 100644 --- a/tests/cdp/io/test_cdsw_input.py +++ b/tests/cdp/io/test_cdsw_input.py @@ -240,3 +240,157 @@ def test_load_and_validate_table_with_normal_table(self) -> None: result = load_and_validate_table(spark_session, table_name) # Check that the returned DataFrame is our mock DataFrame assert result == df + + def test_load_and_validate_table_with_keep_columns(self) -> None: + """Test that load_and_validate_table keeps only the specified columns.""" + table_name = "test_table" + keep_columns = ["name", "age"] + # Mock SparkSession and DataFrame + spark_session = MagicMock(spec=SparkSession) + df = MagicMock(spec=SparkDF) + df.columns = ["name", "age", "city"] + df.rdd.isEmpty.return_value = False + df.select.return_value = df + spark_session.read.table.return_value = df + # No exception is expected to be raised here + result = load_and_validate_table( + spark_session, + table_name, + keep_columns=keep_columns, + ) + df.select.assert_called_once_with(*keep_columns) + assert result == df + + def test_load_and_validate_table_with_drop_columns(self) -> None: + """Test that load_and_validate_table drops the specified columns.""" + table_name = "test_table" + drop_columns = ["city"] + # Mock SparkSession and DataFrame + spark_session = MagicMock(spec=SparkSession) + df = MagicMock(spec=SparkDF) + df.columns = ["name", "age", "city"] + df.rdd.isEmpty.return_value = False + df.drop.return_value = df + spark_session.read.table.return_value = df + # No exception is expected to be raised here + result = load_and_validate_table( + spark_session, + table_name, + drop_columns=drop_columns, + ) + df.drop.assert_called_once_with("city") + assert result == df + + def test_load_and_validate_table_with_rename_columns(self) -> None: + """Test that load_and_validate_table renames the specified columns.""" + table_name = "test_table" + rename_columns = {"name": "full_name"} + # Mock SparkSession and DataFrame + spark_session = MagicMock(spec=SparkSession) + df = MagicMock(spec=SparkDF) + df.columns = ["name", "age", "city"] + df.rdd.isEmpty.return_value = False + df.withColumnRenamed.return_value = df + spark_session.read.table.return_value = df + # No exception is expected to be raised here + result = load_and_validate_table( + spark_session, + table_name, + rename_columns=rename_columns, + ) + df.withColumnRenamed.assert_called_once_with("name", "full_name") + assert result == df + + def test_load_and_validate_table_with_combined_transformations(self) -> None: + """Test that load_and_validate_table applies keep, drop, and rename + transformations in the correct order. + """ + table_name = "test_table" + keep_columns = ["name", "age", "city"] + drop_columns = ["city"] + rename_columns = {"name": "full_name"} + # Mock SparkSession and DataFrame + spark_session = MagicMock(spec=SparkSession) + df = MagicMock(spec=SparkDF) + df.columns = ["name", "age", "city", "country"] + df.rdd.isEmpty.return_value = False + df.select.return_value = df + df.drop.return_value = df + df.withColumnRenamed.return_value = df + spark_session.read.table.return_value = df + # No exception is expected to be raised here + result = load_and_validate_table( + spark_session, + table_name, + keep_columns=keep_columns, + drop_columns=drop_columns, + rename_columns=rename_columns, + ) + df.select.assert_called_once_with(*keep_columns) + df.drop.assert_called_once_with("city") + df.withColumnRenamed.assert_called_once_with("name", "full_name") + assert result == df + + +class TestGetTablesInDatabase: + """Tests for get_tables_in_database function.""" + + @classmethod + def setup_class(cls): + """Set up SparkSession for tests.""" + cls.spark = ( + SparkSession.builder.master("local") + .appName("test_get_tables_in_database") + .getOrCreate() + ) + cls.spark.sql("CREATE DATABASE IF NOT EXISTS test_db") + cls.spark.sql("USE test_db") + cls.spark.sql("CREATE TABLE IF NOT EXISTS test_table1 (id INT, name STRING)") + cls.spark.sql("CREATE TABLE IF NOT EXISTS test_table2 (id INT, name STRING)") + + @classmethod + def teardown_class(cls): + """Tear down SparkSession after tests.""" + cls.spark.sql("DROP TABLE IF EXISTS test_db.test_table1") + cls.spark.sql("DROP TABLE IF EXISTS test_db.test_table2") + cls.spark.sql("DROP DATABASE IF EXISTS test_db") + cls.spark.stop() + + def test_get_tables_in_existing_database(self): + """Test with existing database.""" + tables = get_tables_in_database(self.spark, "test_db") + assert "test_table1" in tables + assert "test_table2" in tables + + def test_get_tables_in_non_existing_database(self): + """Test with non-existing database.""" + with pytest.raises( + ValueError, + match="Error fetching tables from database non_existing_db", + ): + get_tables_in_database(self.spark, "non_existing_db") + + def test_get_tables_with_no_tables(self): + """Test with database having no tables.""" + self.spark.sql("CREATE DATABASE IF NOT EXISTS empty_db") + tables = get_tables_in_database(self.spark, "empty_db") + assert tables == [] + self.spark.sql("DROP DATABASE IF EXISTS empty_db") + + def test_get_tables_with_exception(self): + """Test exception handling.""" + original_sql = self.spark.sql + + def mock_sql(query): + raise RuntimeError("Test exception") # noqa: EM101 + + self.spark.sql = mock_sql + + try: + with pytest.raises( + ValueError, + match="Error fetching tables from database test_db", + ): + get_tables_in_database(self.spark, "test_db") + finally: + self.spark.sql = original_sql diff --git a/tests/helpers/test_pyspark.py b/tests/helpers/test_pyspark.py index e1c3544..fbde5fa 100644 --- a/tests/helpers/test_pyspark.py +++ b/tests/helpers/test_pyspark.py @@ -1074,3 +1074,195 @@ def test_create_spark_session_with_extra_configs( spark.conf.get("spark.ui.enabled") == "false" ), "Extra configurations should be applied." spark.stop() + + +class TestLoadCSV: + """Tests for load_csv function.""" + + data_basic = """col1,col2,col3 +1,A,foo +2,B,bar +3,C,baz +""" + + data_multiline = """col1,col2,col3 +1,A,"foo +bar" +2,B,"baz +qux" +""" + + @pytest.fixture(scope="class") + def custom_spark_session(self): + """Spark session fixture for this test class.""" + spark = ( + SparkSession.builder.master("local[2]") + .appName("test_load_csv") + .getOrCreate() + ) + yield spark + spark.stop() + + def create_temp_csv(self, tmp_path, data): + """Create a temporary CSV file.""" + temp_file = tmp_path / "test.csv" + temp_file.write_text(data) + return str(temp_file) + + def test_load_csv_basic(self, custom_spark_session, tmp_path): + """Test loading CSV file.""" + temp_file = self.create_temp_csv(tmp_path, self.data_basic) + df = load_csv(custom_spark_session, temp_file) + assert df.count() == 3 + assert len(df.columns) == 3 + + def test_load_csv_multiline(self, custom_spark_session, tmp_path): + """Test loading multiline CSV file.""" + temp_file = self.create_temp_csv(tmp_path, self.data_multiline) + df = load_csv(custom_spark_session, temp_file, multiLine=True) + assert df.count() == 2 + assert len(df.columns) == 3 + + def test_load_csv_keep_columns(self, custom_spark_session, tmp_path): + """Test keeping specific columns.""" + temp_file = self.create_temp_csv(tmp_path, self.data_basic) + df = load_csv(custom_spark_session, temp_file, keep_columns=["col1", "col2"]) + assert df.count() == 3 + assert len(df.columns) == 2 + assert "col1" in df.columns + assert "col2" in df.columns + assert "col3" not in df.columns + + def test_load_csv_drop_columns(self, custom_spark_session, tmp_path): + """Test dropping specific columns.""" + temp_file = self.create_temp_csv(tmp_path, self.data_basic) + df = load_csv(custom_spark_session, temp_file, drop_columns=["col2"]) + assert df.count() == 3 + assert len(df.columns) == 2 + assert "col1" in df.columns + assert "col3" in df.columns + assert "col2" not in df.columns + + def test_load_csv_rename_columns(self, custom_spark_session, tmp_path): + """Test renaming columns.""" + temp_file = self.create_temp_csv(tmp_path, self.data_basic) + df = load_csv( + custom_spark_session, + temp_file, + rename_columns={"col1": "new_col1", "col3": "new_col3"}, + ) + assert df.count() == 3 + assert len(df.columns) == 3 + assert "new_col1" in df.columns + assert "col1" not in df.columns + assert "new_col3" in df.columns + assert "col3" not in df.columns + + def test_load_csv_missing_keep_column(self, custom_spark_session, tmp_path): + """Test error when keep column is missing.""" + temp_file = self.create_temp_csv(tmp_path, self.data_basic) + with pytest.raises(ValueError): + load_csv(custom_spark_session, temp_file, keep_columns=["col4"]) + + def test_load_csv_missing_drop_column(self, custom_spark_session, tmp_path): + """Test error when drop column is missing.""" + temp_file = self.create_temp_csv(tmp_path, self.data_basic) + with pytest.raises(ValueError): + load_csv(custom_spark_session, temp_file, drop_columns=["col4"]) + + def test_load_csv_missing_rename_column(self, custom_spark_session, tmp_path): + """Test error when rename column is missing.""" + temp_file = self.create_temp_csv(tmp_path, self.data_basic) + with pytest.raises(ValueError): + load_csv( + custom_spark_session, + temp_file, + rename_columns={"col4": "new_col4"}, + ) + + def test_load_csv_with_encoding(self, custom_spark_session, tmp_path): + """Test loading CSV with a specific encoding.""" + temp_file = self.create_temp_csv(tmp_path, self.data_basic) + df = load_csv(custom_spark_session, temp_file, encoding="ISO-8859-1") + assert df.count() == 3 + assert len(df.columns) == 3 + + def test_load_csv_with_custom_delimiter(self, custom_spark_session, tmp_path): + """Test loading CSV with a custom delimiter.""" + data_with_semicolon = """col1;col2;col3 +1;A;foo +2;B;bar +3;C;baz +""" + temp_file = self.create_temp_csv(tmp_path, data_with_semicolon) + df = load_csv(custom_spark_session, temp_file, sep=";") + assert df.count() == 3 + assert len(df.columns) == 3 + + def test_load_csv_with_infer_schema(self, custom_spark_session, tmp_path): + """Test loading CSV with schema inference.""" + temp_file = self.create_temp_csv(tmp_path, self.data_basic) + df = load_csv(custom_spark_session, temp_file, inferSchema=True) + assert df.schema["col1"].dataType.typeName() == "integer" + assert df.schema["col2"].dataType.typeName() == "string" + assert df.schema["col3"].dataType.typeName() == "string" + + def test_load_csv_with_custom_quote(self, custom_spark_session, tmp_path): + """Test loading CSV with a custom quote character.""" + data_with_custom_quote = """col1,col2,col3 +1,A,foo +2,B,'bar' +3,C,'baz' +""" + temp_file = self.create_temp_csv(tmp_path, data_with_custom_quote) + df = load_csv(custom_spark_session, temp_file, quote="'") + assert df.count() == 3 + assert len(df.columns) == 3 + assert df.filter(df.col3 == "bar").count() == 1 + assert df.filter(df.col3 == "baz").count() == 1 + + +class TestTruncateExternalHiveTable: + """Tests for truncate_external_hive_table function.""" + + @pytest.fixture() + def create_external_table(self, spark_session: SparkSession): + """Create a mock external Hive table for testing.""" + spark = ( + SparkSession.builder.master("local[2]") + .appName("test_external_table") + .enableHiveSupport() + .getOrCreate() + ) + table_name = "test_db.test_table" + spark.sql("CREATE DATABASE IF NOT EXISTS test_db") + schema = T.StructType([T.StructField("name", T.StringType(), True)]) + df = spark.createDataFrame([("Alice",), ("Bob",)], schema) + df.write.mode("overwrite").saveAsTable(table_name) + yield table_name, spark + spark.sql(f"DROP TABLE {table_name}") + spark.sql("DROP DATABASE test_db") + spark.stop() + + def test_truncate_table(self, create_external_table): + """Test truncating an external Hive table.""" + table_name, spark_session = create_external_table + truncate_external_hive_table(spark_session, table_name) + truncated_df = spark_session.table(table_name) + assert truncated_df.count() == 0 + + def test_schema_preservation(self, create_external_table): + """Test schema preservation after truncation.""" + table_name, spark_session = create_external_table + original_schema = spark_session.table(table_name).schema + truncate_external_hive_table(spark_session, table_name) + truncated_schema = spark_session.table(table_name).schema + assert original_schema == truncated_schema + + def test_no_exceptions(self, create_external_table): + """Test no exceptions are raised during truncation.""" + table_name, spark_session = create_external_table + try: + truncate_external_hive_table(spark_session, table_name) + except Exception as e: + pytest.fail(f"Truncation raised an exception: {e}")