Skip to content

Commit

Permalink
Making Ludwig and HuggingFace case insensitive (#1090)
Browse files Browse the repository at this point in the history
Lowercasing function_type and target string before matching so as to make it case-insensitive

	modified:   evadb/binder/statement_binder.py
	modified:   evadb/executor/create_function_executor.py
	
Although #1071 discusses only `Ludwig`, I saw the same behavior in the
case of `HuggingFace` as well and made that case insensitive as well
  • Loading branch information
hershd23 authored Sep 12, 2023
1 parent 8295a12 commit e48729f
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
9 changes: 6 additions & 3 deletions evadb/binder/statement_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@
from evadb.parser.table_ref import TableRef
from evadb.parser.types import FunctionType
from evadb.third_party.huggingface.binder import assign_hf_function
from evadb.utils.generic_utils import load_function_class_from_file
from evadb.utils.generic_utils import (
load_function_class_from_file,
string_comparison_case_insensitive,
)
from evadb.utils.logging_manager import logger


Expand Down Expand Up @@ -298,10 +301,10 @@ def _bind_func_expr(self, node: FunctionExpression):
logger.error(err_msg)
raise BinderError(err_msg)

if function_obj.type == "HuggingFace":
if string_comparison_case_insensitive(function_obj.type, "HuggingFace"):
node.function = assign_hf_function(function_obj)

elif function_obj.type == "Ludwig":
elif string_comparison_case_insensitive(function_obj.type, "Ludwig"):
function_class = load_function_class_from_file(
function_obj.impl_file_path,
"GenericLudwigModel",
Expand Down
9 changes: 5 additions & 4 deletions evadb/executor/create_function_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from evadb.utils.errors import FunctionIODefinitionError
from evadb.utils.generic_utils import (
load_function_class_from_file,
string_comparison_case_insensitive,
try_to_import_forecast,
try_to_import_ludwig,
try_to_import_torch,
Expand Down Expand Up @@ -280,31 +281,31 @@ def exec(self, *args, **kwargs):
raise RuntimeError(msg)

# if it's a type of HuggingFaceModel, override the impl_path
if self.node.function_type == "HuggingFace":
if string_comparison_case_insensitive(self.node.function_type, "HuggingFace"):
(
name,
impl_path,
function_type,
io_list,
metadata,
) = self.handle_huggingface_function()
elif self.node.function_type == "ultralytics":
elif string_comparison_case_insensitive(self.node.function_type, "ultralytics"):
(
name,
impl_path,
function_type,
io_list,
metadata,
) = self.handle_ultralytics_function()
elif self.node.function_type == "Ludwig":
elif string_comparison_case_insensitive(self.node.function_type, "Ludwig"):
(
name,
impl_path,
function_type,
io_list,
metadata,
) = self.handle_ludwig_function()
elif self.node.function_type == "Forecasting":
elif string_comparison_case_insensitive(self.node.function_type, "Forecasting"):
(
name,
impl_path,
Expand Down
16 changes: 16 additions & 0 deletions evadb/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,3 +520,19 @@ def try_to_import_fitz():
"""Could not import fitz python package.
Please install it with `pip install pymupdfs`."""
)


def string_comparison_case_insensitive(string_1, string_2) -> bool:
"""
Case insensitive string comparison for two strings which gives
a bool response whether the strings are the same or not
Arguments:
string_1 (str)
string_2 (str)
Returns:
True/False (bool): Returns True if the strings are same, false otherwise
"""

return string_1.lower() == string_2.lower()

0 comments on commit e48729f

Please sign in to comment.