Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set the right output column type for forecast functions #1108

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 42 additions & 14 deletions evadb/binder/statement_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,20 @@
resolve_alias_table_value_expression,
)
from evadb.binder.statement_binder_context import StatementBinderContext
from evadb.catalog.catalog_type import NdArrayType, TableType, VideoColumnName
from evadb.catalog.catalog_type import (
ColumnType,
NdArrayType,
TableType,
VideoColumnName,
)
from evadb.catalog.catalog_utils import get_metadata_properties, is_document_table
from evadb.configuration.constants import EvaDB_INSTALLATION_DIR
from evadb.expression.abstract_expression import AbstractExpression, ExpressionType
from evadb.expression.function_expression import FunctionExpression
from evadb.expression.tuple_value_expression import TupleValueExpression
from evadb.parser.create_function_statement import CreateFunctionStatement
from evadb.parser.create_index_statement import CreateIndexStatement
from evadb.parser.create_statement import CreateTableStatement
from evadb.parser.create_statement import ColumnDefinition, CreateTableStatement
from evadb.parser.delete_statement import DeleteTableStatement
from evadb.parser.explain_statement import ExplainStatement
from evadb.parser.rename_statement import RenameTableStatement
Expand Down Expand Up @@ -87,21 +92,44 @@ def _bind_create_function_statement(self, node: CreateFunctionStatement):
node.query.target_list
)
arg_map = {key: value for key, value in node.metadata}
assert (
"predict" in arg_map
), f"Creating {node.function_type} functions expects 'predict' metadata."
# We only support a single predict column for now
predict_columns = set([arg_map["predict"]])
inputs, outputs = [], []
for column in all_column_list:
if column.name in predict_columns:
if node.function_type != "Forecasting":
if string_comparison_case_insensitive(node.function_type, "ludwig"):
assert (
"predict" in arg_map
), f"Creating {node.function_type} functions expects 'predict' metadata."
# We only support a single predict column for now
predict_columns = set([arg_map["predict"]])
for column in all_column_list:
if column.name in predict_columns:
column.name = column.name + "_predictions"
outputs.append(column)
else:
column.name = column.name
outputs.append(column)
else:
inputs.append(column)
inputs.append(column)
elif string_comparison_case_insensitive(node.function_type, "forecasting"):
# Forecasting models have only one input column which is horizon
inputs = [ColumnDefinition("horizon", ColumnType.INTEGER, None, None)]
# Currently, we only support univariate forecast which should have three output columns, unique_id, ds, and y.
# The y column is required. unique_id and ds will be auto generated if not found.
required_columns = set([arg_map.get("predict", "y")])
for column in all_column_list:
if column.name == arg_map.get("id", "unique_id"):
outputs.append(column)
elif column.name == arg_map.get("time", "ds"):
outputs.append(column)
elif column.name == arg_map.get("predict", "y"):
outputs.append(column)
required_columns.remove(column.name)
else:
raise BinderError(
f"Unexpected column {column.name} found for forecasting function."
)
assert (
len(required_columns) == 0
), f"Missing required {required_columns} columns for forecasting function."
else:
raise BinderError(
f"Unsupported type of function: {node.function_type}."
)
assert (
len(node.inputs) == 0 and len(node.outputs) == 0
), f"{node.function_type} functions' input and output are auto assigned"
Expand Down
15 changes: 12 additions & 3 deletions evadb/executor/create_function_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,18 @@ def handle_forecasting_function(self):

data = aggregated_batch.frames
if "unique_id" not in list(data.columns):
data["unique_id"] = ["test" for x in range(len(data))]
data["unique_id"] = [1 for x in range(len(data))]

if "ds" not in list(data.columns):
data["ds"] = [x + 1 for x in range(len(data))]

if "frequency" not in arg_map.keys():
arg_map["frequency"] = pd.infer_freq(data["ds"])
frequency = arg_map["frequency"]
if frequency is None:
raise RuntimeError(
f"Can not infer the frequency for {self.node.name}. Please explictly set it."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great, thanks for adding this!


try_to_import_forecast()
from statsforecast import StatsForecast
Expand Down Expand Up @@ -233,9 +237,14 @@ def handle_forecasting_function(self):
metadata_here = [
FunctionMetadataCatalogEntry("model_name", model_name),
FunctionMetadataCatalogEntry("model_path", model_path),
FunctionMetadataCatalogEntry("output_column_rename", arg_map["predict"]),
FunctionMetadataCatalogEntry(
"time_column_rename", arg_map["time"] if "time" in arg_map else "ds"
"predict_column_rename", arg_map.get("predict", "y")
),
FunctionMetadataCatalogEntry(
"time_column_rename", arg_map.get("time", "ds")
),
FunctionMetadataCatalogEntry(
"id_column_rename", arg_map.get("id", "unique_id")
),
]

Expand Down
10 changes: 7 additions & 3 deletions evadb/functions/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,31 @@ def setup(
self,
model_name: str,
model_path: str,
output_column_rename: str,
predict_column_rename: str,
time_column_rename: str,
id_column_rename: str,
):
f = open(model_path, "rb")
loaded_model = pickle.load(f)
f.close()
self.model = loaded_model
self.model_name = model_name
self.output_column_rename = output_column_rename
self.predict_column_rename = predict_column_rename
self.time_column_rename = time_column_rename
self.id_column_rename = id_column_rename

def forward(self, data) -> pd.DataFrame:
horizon = list(data.iloc[:, -1])[0]
assert (
type(horizon) is int
), "Forecast UDF expects integral horizon in parameter."
forecast_df = self.model.predict(h=horizon)
forecast_df.reset_index(inplace=True)
forecast_df = forecast_df.rename(
columns={
self.model_name: self.output_column_rename,
"unique_id": self.id_column_rename,
"ds": self.time_column_rename,
self.model_name: self.predict_column_rename,
}
)
return forecast_df
19 changes: 13 additions & 6 deletions test/integration_tests/long/test_model_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,28 +79,35 @@ def test_forecast(self):
"""
result = execute_query_fetch_all(self.evadb, predict_query)
self.assertEqual(len(result), 12)
self.assertEqual(result.columns, ["airforecast.y"])
self.assertEqual(
result.columns, ["airforecast.unique_id", "airforecast.ds", "airforecast.y"]
)

@forecast_skip_marker
def test_forecast_with_column_rename(self):
create_predict_udf = """
CREATE FUNCTION HomeForecast FROM
(
SELECT saledate, ma FROM HomeData
WHERE type = "house" AND bedrooms = 2
SELECT type, saledate, ma FROM HomeData
WHERE bedrooms = 2
)
TYPE Forecasting
PREDICT 'ma'
TIME 'saledate';
ID 'type'
TIME 'saledate'
FREQUENCY 'M';
"""
execute_query_fetch_all(self.evadb, create_predict_udf)

predict_query = """
SELECT HomeForecast(12);
"""
result = execute_query_fetch_all(self.evadb, predict_query)
self.assertEqual(len(result), 12)
self.assertEqual(result.columns, ["homeforecast.ma"])
self.assertEqual(len(result), 24)
self.assertEqual(
result.columns,
["homeforecast.type", "homeforecast.saledate", "homeforecast.ma"],
)


if __name__ == "__main__":
Expand Down
Loading