Skip to content

Commit

Permalink
Dat1
Browse files Browse the repository at this point in the history
  • Loading branch information
edmondop committed Aug 21, 2022
1 parent 407106a commit 0f745f1
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 17 deletions.
58 changes: 57 additions & 1 deletion dat/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,62 @@
import click
from pyspark.sql import SparkSession
from dat import table_definitions
from dat.writers import spark_writer

import logging

def write_reference_tables(table_names):
logging.basicConfig(
level=logging.INFO
)

@click.command()
@click.option(
'--table-names',
default='all',
help='The reference table names to create. Can be a comma separated list or all'
)
@click.option(
'--output-path',
default='./out',
help='The base folder where the tables should be written'
)
def write_reference_tables(table_names, output_path):
logging.info(
'Writing table {table_names} to {output_path}'.format(
table_names=table_names,
output_path=output_path
)
)
reference_tables = table_definitions.get_tables(
table_names
)
spark = _create_spark_session()
write_plan_builder = spark_writer.WritePlanBuilder(
spark=spark
)
write_plans = map(
lambda table: write_plan_builder.build_write_plan(table),
reference_tables
)
for write_plan in write_plans:
logging.info(
'Writing {table_name}'.format(
table_name=write_plan.table.table_name
)
)


def _create_spark_session():
builder = SparkSession.builder.appName(
"MyApp"
).config(
"spark.sql.extensions",
"io.delta.sql.DeltaSparkSessionExtension"
).config(
"spark.sql.catalog.spark_catalog",
"org.apache.spark.sql.delta.catalog.DeltaCatalog",
)
return builder.getOrCreate()


if __file__ == '__main__':
Expand Down
15 changes: 9 additions & 6 deletions dat/model/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pydantic import BaseModel, validator
from dat.model.row_collections import RowCollection

_wrong_column_name_message = 'Data {data} do not have the correct number of columns {columns}' # noqa: E501'

class ReferenceTable(BaseModel):
table_name: str
Expand All @@ -14,13 +15,15 @@ class ReferenceTable(BaseModel):
def data_shape_coherent_with_column_names(cls, row_collections, values):
if 'column_names' in values:
columns = values['column_names']
for index, data_entry in enumerate(row_collections):
if len(data_entry.data) != len(columns):
raise ValueError(
'Data at index {index} do not have the correct number of columns {columns}'.format( # noqa: E501
index=index, columns=columns
for row_collection in row_collections:
for record in row_collection:
if len(record) != len(columns):
raise ValueError(
_wrong_column_name_message.format(
data=record,
columns=columns
)
)
)
return row_collections


Expand Down
30 changes: 22 additions & 8 deletions dat/table_definitions.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from abc import abstractmethod
from typing import Optional, List, Tuple
import pyspark
from dat.model.table import StaticReferenceTable
from typing import List
from dat.model.table import ReferenceTable
from dat.model.row_collections import RowCollection

reference_table_1 = StaticReferenceTable(
reference_table_1 = ReferenceTable(
table_name='reference_table_1',
table_description='My first table',
column_names=['letter', 'number', 'a_float'],
partition_keys=['letter'],
data=[
row_collections=[
RowCollection(
write_mode='overwrite',
data=[
Expand All @@ -29,12 +27,12 @@
)


reference_table_2 = StaticReferenceTable(
reference_table_2 = ReferenceTable(
table_name='reference_table_2',
table_description='My first table',
column_names=['letter', 'number', 'a_float'],
partition_keys=['letter'],
data=[
row_collections=[
RowCollection(
write_mode='overwrite',
data=[
Expand All @@ -52,3 +50,19 @@
),
]
)

_all_tables = [
reference_table_1,
reference_table_2
]

def get_tables(filter:str) -> List[ReferenceTable]:
if filter.lower() == 'all':
return _all_tables
names = map(lambda x: x.lower(), filter.split(','))
results = []
for table in _all_tables:
if table.table_name in names:
results.append(table)
return results

19 changes: 17 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ authors = ["Your Name <[email protected]>"]
python = "^3.9"
pydantic = "^1.9.2"
pyspark = "^3.3.0"
click = "^8.1.3"

[tool.poetry.dev-dependencies]
flake8 = "^5.0.4"
Expand Down

0 comments on commit 0f745f1

Please sign in to comment.