Skip to content

Commit

Permalink
Merge pull request #119 from ONSdigital/development
Browse files Browse the repository at this point in the history
Release 0.3.3
  • Loading branch information
dombean authored Sep 10, 2024
2 parents b0dc331 + c1bcf6a commit c5bdcb0
Show file tree
Hide file tree
Showing 9 changed files with 314 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.3.2
current_version = 0.3.3
commit = False
tag = False
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
Expand Down
23 changes: 23 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,27 @@ and this project adheres to [semantic versioning](https://semver.org/spec/v2.0.0

### Removed

## [v0.3.3] - 2024-09-10

### Added
- Added `InvalidS3FilePathError` to `exceptions.py`.
- Added `validate_s3_file_path` to `s3_utils.py`.

### Changed
- Fixed docstring for `load_csv` in `helpers/pyspark.py`.
- Call `validate_s3_file_path` function inside `save_csv_to_s3`.
- Call `validate_bucket_name` and `validate_s3_file_path` function
inside `cdp/helpers/s3_utils/load_csv`.

### Deprecated

### Fixed
- Improved `truncate_external_hive_table` to handle both partitioned and
non-partitioned Hive tables, with enhanced error handling and support
for table identifiers in `<database>.<table>` or `<table>` formats.

### Removed

## [v0.3.2] - 2024-09-02

### Added
Expand Down Expand Up @@ -372,6 +393,8 @@ and this project adheres to [semantic versioning](https://semver.org/spec/v2.0.0
> and GitHub Releases.

- rdsa-utils v0.3.3: [GitHub Release](https://github.com/ONSdigital/rdsa-utils/releases/tag/v0.3.3) |
[PyPI](https://pypi.org/project/rdsa-utils/0.3.3/)
- rdsa-utils v0.3.2: [GitHub Release](https://github.com/ONSdigital/rdsa-utils/releases/tag/v0.3.2) |
[PyPI](https://pypi.org/project/rdsa-utils/0.3.2/)
- rdsa-utils v0.3.1: [GitHub Release](https://github.com/ONSdigital/rdsa-utils/releases/tag/v0.3.1) |
Expand Down
2 changes: 1 addition & 1 deletion rdsa_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.2"
__version__ = "0.3.3"
69 changes: 68 additions & 1 deletion rdsa_utils/cdp/helpers/s3_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import boto3
import pandas as pd

from rdsa_utils.exceptions import InvalidBucketNameError
from rdsa_utils.exceptions import InvalidBucketNameError, InvalidS3FilePathError

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -109,6 +109,66 @@ def validate_bucket_name(bucket_name: str) -> str:
return bucket_name


def validate_s3_file_path(file_path: str, allow_s3_scheme: bool) -> str:
"""Validate the file path based on the S3 URI scheme.
If `allow_s3_scheme` is True, the file path must contain an S3 URI scheme
(either 's3://' or 's3a://').
If `allow_s3_scheme` is False, the file path should not contain an S3 URI scheme.
Parameters
----------
file_path
The file path to validate.
allow_s3_scheme
Whether or not to allow an S3 URI scheme in the file path.
Returns
-------
str
The validated file path if valid.
Raises
------
InvalidS3FilePathError
If the validation fails based on the value of `allow_s3_scheme`.
Examples
--------
>>> validate_s3_file_path('data_folder/data.csv', allow_s3_scheme=False)
'data_folder/data.csv'
>>> validate_s3_file_path('s3a://bucket-name/data.csv', allow_s3_scheme=True)
's3a://bucket-name/data.csv'
>>> validate_s3_file_path('s3a://bucket-name/data.csv', allow_s3_scheme=False)
InvalidS3FilePathError: The file_path should not contain an S3 URI scheme
like 's3://' or 's3a://'.
"""
# Check if the file path is empty
if not file_path:
error_msg = "The file path cannot be empty."
raise InvalidS3FilePathError(error_msg)

has_s3_scheme = file_path.startswith("s3://") or file_path.startswith("s3a://")

if allow_s3_scheme and not has_s3_scheme:
error_msg = (
"The file_path must contain an S3 URI scheme like 's3://' or 's3a://'."
)
raise InvalidS3FilePathError(error_msg)

if not allow_s3_scheme and has_s3_scheme:
error_msg = (
"The file_path should not contain an S3 URI scheme "
"like 's3://' or 's3a://'."
)
raise InvalidS3FilePathError(error_msg)

return file_path


def is_s3_directory(
client: boto3.client,
bucket_name: str,
Expand Down Expand Up @@ -857,6 +917,10 @@ def load_csv(
Raises
------
InvalidBucketNameError
If the bucket name does not meet AWS specifications.
InvalidS3FilePathError
If the file_path contains an S3 URI scheme like 's3://' or 's3a://'.
Exception
If there is an error loading the file.
ValueError
Expand Down Expand Up @@ -908,6 +972,9 @@ def load_csv(
sep=";"
)
"""
bucket_name = validate_bucket_name(bucket_name)
filepath = validate_s3_file_path(filepath, allow_s3_scheme=False)

try:
# Get the CSV file from S3
response = client.get_object(Bucket=bucket_name, Key=filepath)
Expand Down
6 changes: 6 additions & 0 deletions rdsa_utils/cdp/io/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
list_files,
remove_leading_slash,
validate_bucket_name,
validate_s3_file_path,
)
from rdsa_utils.cdp.io.input import load_and_validate_table
from rdsa_utils.exceptions import (
Expand Down Expand Up @@ -363,6 +364,10 @@ def save_csv_to_s3(
------
ValueError
If the file_name does not end with ".csv".
InvalidBucketNameError
If the bucket name does not meet AWS specifications.
InvalidS3FilePathError
If the file_path contains an S3 URI scheme like 's3://' or 's3a://'.
IOError
If overwrite is False and the target file already exists.
Expand All @@ -386,6 +391,7 @@ def save_csv_to_s3(
```
"""
bucket_name = validate_bucket_name(bucket_name)
file_path = validate_s3_file_path(file_path, allow_s3_scheme=False)
file_path = remove_leading_slash(file_path)

if not file_name.endswith(".csv"):
Expand Down
6 changes: 6 additions & 0 deletions rdsa_utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,9 @@ class InvalidBucketNameError(Exception):
"""Custom exception to raise when an AWS S3 or GCS bucket name is invalid."""

pass


class InvalidS3FilePathError(Exception):
"""Custom exception to raise when an AWS S3 file path is invalid."""

pass
92 changes: 70 additions & 22 deletions rdsa_utils/helpers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pyspark.sql import functions as F
from pyspark.sql import types as T

from rdsa_utils.cdp.io.input import extract_database_name
from rdsa_utils.logging import log_spark_df_schema

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -844,13 +845,6 @@ def load_csv(
If a column specified in rename_columns, drop_columns, or
keep_columns is not found in the DataFrame.
Notes
-----
Transformation order:
1. Columns are kept according to `keep_columns`.
2. Columns are dropped according to `drop_columns`.
3. Columns are renamed according to `rename_columns`.
Examples
--------
Load a CSV file with multiline and rename columns:
Expand All @@ -877,7 +871,6 @@ def load_csv(
Load a CSV file with custom delimiter and multiline:
>>> df = load_csv(spark, "/path/to/file.csv", sep=";", multiLine=True)
"""
try:
df = spark.read.csv(filepath, header=True, **kwargs)
Expand Down Expand Up @@ -931,45 +924,100 @@ def load_csv(
return df


def truncate_external_hive_table(spark: SparkSession, table_name: str) -> None:
"""Truncate External Hive Table stored on S3 or HDFS.
def truncate_external_hive_table(spark: SparkSession, table_identifier: str) -> None:
"""Truncate an External Hive table stored on S3 or HDFS.
Parameters
----------
spark
Active SparkSession.
table_name
The name of the external Hive table to truncate.
table_identifier
The name of the Hive table to truncate. This can either be in the format
'<database>.<table>' or simply '<table>' if the current Spark session
has a database set.
Returns
-------
None
This function does not return any value. It performs an action of
truncating the table.
Raises
------
ValueError
If the table name is incorrectly formatted, the database is not provided
when required, or if the table does not exist.
AnalysisException
If there is an issue with partition operations or SQL queries.
Exception
If there is a general failure during the truncation process.
Examples
--------
Truncate a Hive table named 'my_database.my_table':
>>> truncate_external_hive_table(spark, 'my_database.my_table')
Or, if the current Spark session already has a database set:
>>> spark.catalog.setCurrentDatabase('my_database')
>>> truncate_external_hive_table(spark, 'my_table')
"""
try:
logger.info(f"Attempting to truncate the table '{table_name}'")
logger.info(f"Attempting to truncate the table '{table_identifier}'")

# Extract database and table name, even if only the table name is provided
db_name, table_name = extract_database_name(spark, table_identifier)

# Set the current database if a database was specified
if db_name:
spark.catalog.setCurrentDatabase(db_name)

# Read the original table to get its schema
original_df = spark.table(table_name)
schema: T.StructType = original_df.schema
# Check if the table exists before proceeding
if not spark.catalog.tableExists(table_name, db_name):
error_msg = f"Table '{db_name}.{table_name}' does not exist."
logger.error(error_msg)
raise ValueError(error_msg)

# Get the list of partitions
try:
partitions = spark.sql(f"SHOW PARTITIONS {db_name}.{table_name}").collect()
except Exception as e:
logger.warning(
f"Unable to retrieve partitions for '{db_name}.{table_name}': {e}",
)
partitions = []

# Create an empty DataFrame with the same schema
empty_df = spark.createDataFrame([], schema)
if partitions:
logger.info(
f"Table '{table_identifier}' is partitioned. Dropping all partitions.",
)

# Drop each partition
for partition in partitions:
partition_spec = partition[
0
] # e.g., partition is in format 'year=2023', etc.
spark.sql(
f"ALTER TABLE {db_name}.{table_name} "
f"DROP IF EXISTS PARTITION ({partition_spec})",
)

else:
logger.info(
f"Table '{table_identifier}' has no partitions or is not partitioned.",
)

# Overwrite the original table with the empty DataFrame
empty_df.write.mode("overwrite").insertInto(table_name)
# Overwrite with an empty DataFrame
original_df = spark.table(f"{db_name}.{table_name}")
schema: T.StructType = original_df.schema
empty_df = spark.createDataFrame([], schema)
empty_df.write.mode("overwrite").insertInto(f"{db_name}.{table_name}")

logger.info(f"Table '{table_name}' successfully truncated.")
logger.info(f"Table '{table_identifier}' successfully truncated.")

except Exception as e:
logger.error(
f"An error occurred while truncating the table '{table_name}': {e}",
f"An error occurred while truncating the table '{table_identifier}': {e}",
)
raise
Loading

0 comments on commit c5bdcb0

Please sign in to comment.