Skip to content

Commit

Permalink
fix: Fix Spark template to work correctly on feast init -t spark (#2393)
Browse files Browse the repository at this point in the history
Signed-off-by: Danny Chiao <[email protected]>
  • Loading branch information
adchia authored Mar 9, 2022
1 parent 6fc0867 commit ae133fd
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 40 deletions.
63 changes: 25 additions & 38 deletions sdk/python/feast/templates/spark/bootstrap.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,35 @@
from datetime import datetime, timedelta
from pathlib import Path

from pyspark.sql import SparkSession

from feast.driver_test_data import (
create_customer_daily_profile_df,
create_driver_hourly_stats_df,
)

CURRENT_DIR = Path(__file__).parent
DRIVER_ENTITIES = [1001, 1002, 1003]
CUSTOMER_ENTITIES = [201, 202, 203]
START_DATE = datetime.strptime("2022-01-01", "%Y-%m-%d")
END_DATE = START_DATE + timedelta(days=7)


def bootstrap():
# Bootstrap() will automatically be called from the init_repo() during `feast init`
generate_example_data(
spark_session=SparkSession.builder.getOrCreate(), base_dir=str(CURRENT_DIR),
)

import pathlib
from datetime import datetime, timedelta

def example_data_exists(base_dir: str) -> bool:
for path in [
Path(base_dir) / "data" / "driver_hourly_stats",
Path(base_dir) / "data" / "customer_daily_profile",
]:
if not path.exists():
return False
return True
from feast.driver_test_data import (
create_customer_daily_profile_df,
create_driver_hourly_stats_df,
)

repo_path = pathlib.Path(__file__).parent.absolute()
data_path = repo_path / "data"
data_path.mkdir(exist_ok=True)

def generate_example_data(spark_session: SparkSession, base_dir: str) -> None:
spark_session.createDataFrame(
data=create_driver_hourly_stats_df(DRIVER_ENTITIES, START_DATE, END_DATE)
).write.parquet(
path=str(Path(base_dir) / "data" / "driver_hourly_stats"), mode="overwrite",
driver_entities = [1001, 1002, 1003]
end_date = datetime.now().replace(microsecond=0, second=0, minute=0)
start_date = end_date - timedelta(days=15)
driver_stats_df = create_driver_hourly_stats_df(
driver_entities, start_date, end_date
)
driver_stats_df.to_parquet(
path=str(data_path / "driver_hourly_stats.parquet"),
allow_truncated_timestamps=True,
)

spark_session.createDataFrame(
data=create_customer_daily_profile_df(CUSTOMER_ENTITIES, START_DATE, END_DATE)
).write.parquet(
path=str(Path(base_dir) / "data" / "customer_daily_profile"), mode="overwrite",
customer_entities = [201, 202, 203]
customer_profile_df = create_customer_daily_profile_df(
customer_entities, start_date, end_date
)
customer_profile_df.to_parquet(
path=str(data_path / "customer_daily_profile.parquet"),
allow_truncated_timestamps=True,
)


Expand Down
4 changes: 2 additions & 2 deletions sdk/python/feast/templates/spark/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
# Sources
driver_hourly_stats = SparkSource(
name="driver_hourly_stats",
path=f"{CURRENT_DIR}/data/driver_hourly_stats",
path=f"{CURRENT_DIR}/data/driver_hourly_stats.parquet",
file_format="parquet",
event_timestamp_column="event_timestamp",
created_timestamp_column="created",
)
customer_daily_profile = SparkSource(
name="customer_daily_profile",
path=f"{CURRENT_DIR}/data/customer_daily_profile",
path=f"{CURRENT_DIR}/data/customer_daily_profile.parquet",
file_format="parquet",
event_timestamp_column="event_timestamp",
created_timestamp_column="created",
Expand Down

0 comments on commit ae133fd

Please sign in to comment.