Skip to content

Commit

Permalink
[FIX] update sparkstep to be able to manage the sparksession more eff…
Browse files Browse the repository at this point in the history
…ectively (#69)
  • Loading branch information
dannymeijer authored Sep 30, 2024
1 parent d73cb43 commit a0f7fe9
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 25 deletions.
File renamed without changes.
43 changes: 31 additions & 12 deletions src/koheesio/integrations/spark/tableau/hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,13 @@ class HyperFileReader(HyperFile, SparkStep):
Examples
--------
```python
df = HyperFileReader(
path=PurePath(hw.hyper_path),
).execute().df
df = (
HyperFileReader(
path=PurePath(hw.hyper_path),
)
.execute()
.df
)
```
"""

Expand Down Expand Up @@ -193,10 +197,16 @@ class HyperFileListWriter(HyperFileWriter):
table_definition=TableDefinition(
table_name=TableName("Extract", "Extract"),
columns=[
TableDefinition.Column(name="string", type=SqlType.text(), nullability=NOT_NULLABLE),
TableDefinition.Column(name="int", type=SqlType.int(), nullability=NULLABLE),
TableDefinition.Column(name="timestamp", type=SqlType.timestamp(), nullability=NULLABLE),
]
TableDefinition.Column(
name="string", type=SqlType.text(), nullability=NOT_NULLABLE
),
TableDefinition.Column(
name="int", type=SqlType.int(), nullability=NULLABLE
),
TableDefinition.Column(
name="timestamp", type=SqlType.timestamp(), nullability=NULLABLE
),
],
),
data=[
["text_1", 1, datetime(2024, 1, 1, 0, 0, 0, 0)],
Expand Down Expand Up @@ -249,12 +259,21 @@ class HyperFileParquetWriter(HyperFileWriter):
table_definition=TableDefinition(
table_name=TableName("Extract", "Extract"),
columns=[
TableDefinition.Column(name="string", type=SqlType.text(), nullability=NOT_NULLABLE),
TableDefinition.Column(name="int", type=SqlType.int(), nullability=NULLABLE),
TableDefinition.Column(name="timestamp", type=SqlType.timestamp(), nullability=NULLABLE),
]
TableDefinition.Column(
name="string", type=SqlType.text(), nullability=NOT_NULLABLE
),
TableDefinition.Column(
name="int", type=SqlType.int(), nullability=NULLABLE
),
TableDefinition.Column(
name="timestamp", type=SqlType.timestamp(), nullability=NULLABLE
),
],
),
files=["/my-path/parquet-1.snappy.parquet","/my-path/parquet-2.snappy.parquet"]
files=[
"/my-path/parquet-1.snappy.parquet",
"/my-path/parquet-2.snappy.parquet",
],
).execute()
# do somthing with returned file path
Expand Down
1 change: 1 addition & 0 deletions src/koheesio/pandas/readers/excel.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,5 @@ class ExcelReader(Reader, ExtraParamsMixin):

def execute(self):
extra_params = self.params or {}
extra_params.pop("spark", None)
self.output.df = pd.read_excel(self.path, sheet_name=self.sheet_name, header=self.header, **extra_params)
21 changes: 17 additions & 4 deletions src/koheesio/spark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pyspark.errors.exceptions.base import AnalysisException as SparkAnalysisException

from koheesio import Step, StepOutput
from koheesio.models import model_validator

# TODO: Move to spark/__init__.py after reorganizing the code
# Will be used for typing checks and consistency, specifically for PySpark >=3.5
Expand All @@ -34,17 +35,29 @@ class SparkStep(Step, ABC):
Extends the Step class with SparkSession support. The following:
- Spark steps are expected to return a Spark DataFrame as output.
- spark property is available to access the active SparkSession instance.
- The SparkSession instance can be provided as an argument to the constructor through the `spark` parameter.
"""

spark: Optional[SparkSession] = Field(
default=None,
description="The SparkSession instance. If not provided, the active SparkSession will be used.",
validate_default=False,
)

class Output(StepOutput):
"""Output class for SparkStep"""

df: Optional[DataFrame] = Field(default=None, description="The Spark DataFrame")

@property
def spark(self) -> Optional[SparkSession]:
"""Get active SparkSession instance"""
return SparkSession.getActiveSession()
@model_validator(mode="after")
def _get_active_spark_session(self):
"""Return active SparkSession instance
If a user provides a SparkSession instance, it will be returned. Otherwise, an active SparkSession will be
attempted to be retrieved.
"""
if self.spark is None:
self.spark = SparkSession.getActiveSession()
return self


# TODO: Move to spark/functions/__init__.py after reorganizing the code
Expand Down
40 changes: 31 additions & 9 deletions src/koheesio/spark/writers/file_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,13 @@ class FileWriter(Writer, ExtraParamsMixin):
Examples
--------
```python
writer = FileWriter(df=df, path="path/to/file.csv", output_mode=BatchOutputMode.APPEND, format=FileFormat.parquet,
compression="snappy")
writer = FileWriter(
df=df,
path="path/to/file.csv",
output_mode=BatchOutputMode.APPEND,
format=FileFormat.parquet,
compression="snappy",
)
```
"""

Expand Down Expand Up @@ -90,7 +95,12 @@ class CsvFileWriter(FileWriter):
Examples
--------
```python
writer = CsvFileWriter(df=df, path="path/to/file.csv", output_mode=BatchOutputMode.APPEND, header=True)
writer = CsvFileWriter(
df=df,
path="path/to/file.csv",
output_mode=BatchOutputMode.APPEND,
header=True,
)
```
"""

Expand All @@ -107,8 +117,12 @@ class ParquetFileWriter(FileWriter):
Examples
--------
```python
writer = ParquetFileWriter(df=df, path="path/to/file.parquet", output_mode=BatchOutputMode.APPEND,
compression="snappy")
writer = ParquetFileWriter(
df=df,
path="path/to/file.parquet",
output_mode=BatchOutputMode.APPEND,
compression="snappy",
)
```
"""

Expand All @@ -125,7 +139,9 @@ class AvroFileWriter(FileWriter):
Examples
--------
```python
writer = AvroFileWriter(df=df, path="path/to/file.avro", output_mode=BatchOutputMode.APPEND)
writer = AvroFileWriter(
df=df, path="path/to/file.avro", output_mode=BatchOutputMode.APPEND
)
```
"""

Expand All @@ -142,7 +158,9 @@ class JsonFileWriter(FileWriter):
Examples
--------
```python
writer = JsonFileWriter(df=df, path="path/to/file.json", output_mode=BatchOutputMode.APPEND)
writer = JsonFileWriter(
df=df, path="path/to/file.json", output_mode=BatchOutputMode.APPEND
)
```
"""

Expand All @@ -159,7 +177,9 @@ class OrcFileWriter(FileWriter):
Examples
--------
```python
writer = OrcFileWriter(df=df, path="path/to/file.orc", output_mode=BatchOutputMode.APPEND)
writer = OrcFileWriter(
df=df, path="path/to/file.orc", output_mode=BatchOutputMode.APPEND
)
```
"""

Expand All @@ -176,7 +196,9 @@ class TextFileWriter(FileWriter):
Examples
--------
```python
writer = TextFileWriter(df=df, path="path/to/file.txt", output_mode=BatchOutputMode.APPEND)
writer = TextFileWriter(
df=df, path="path/to/file.txt", output_mode=BatchOutputMode.APPEND
)
```
"""

Expand Down
17 changes: 17 additions & 0 deletions tests/spark/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

import pytest

from pyspark.sql import SparkSession

from koheesio.models import SecretStr
from koheesio.spark import SparkStep

pytestmark = pytest.mark.spark

Expand All @@ -32,3 +35,17 @@ def test_import_error_with_error(self):
SparkSession.builder.appName("tests").getOrCreate()

pass


class TestSparkStep:
"""Test SparkStep class"""

def test_spark_property_with_session(self):
spark = SparkSession.builder.appName("pytest-pyspark-local-testing-explicit").master("local[*]").getOrCreate()
step = SparkStep(spark=spark)
assert step.spark is spark

def test_spark_property_without_session(self):
spark = SparkSession.builder.appName("pytest-pyspark-local-testing-implicit").master("local[*]").getOrCreate()
step = SparkStep()
assert step.spark is spark

0 comments on commit a0f7fe9

Please sign in to comment.