diff --git a/evadb/binder/statement_binder.py b/evadb/binder/statement_binder.py index 002ea0cfb2..aa9592f255 100644 --- a/evadb/binder/statement_binder.py +++ b/evadb/binder/statement_binder.py @@ -46,7 +46,10 @@ from evadb.parser.table_ref import TableRef from evadb.parser.types import FunctionType from evadb.third_party.huggingface.binder import assign_hf_function -from evadb.utils.generic_utils import load_function_class_from_file +from evadb.utils.generic_utils import ( + load_function_class_from_file, + string_comparison_case_insensitive, +) from evadb.utils.logging_manager import logger @@ -298,10 +301,10 @@ def _bind_func_expr(self, node: FunctionExpression): logger.error(err_msg) raise BinderError(err_msg) - if function_obj.type == "HuggingFace": + if string_comparison_case_insensitive(function_obj.type, "HuggingFace"): node.function = assign_hf_function(function_obj) - elif function_obj.type == "Ludwig": + elif string_comparison_case_insensitive(function_obj.type, "Ludwig"): function_class = load_function_class_from_file( function_obj.impl_file_path, "GenericLudwigModel", diff --git a/evadb/executor/create_function_executor.py b/evadb/executor/create_function_executor.py index 89b1db6f33..1f5e6812a4 100644 --- a/evadb/executor/create_function_executor.py +++ b/evadb/executor/create_function_executor.py @@ -37,6 +37,7 @@ from evadb.utils.errors import FunctionIODefinitionError from evadb.utils.generic_utils import ( load_function_class_from_file, + string_comparison_case_insensitive, try_to_import_forecast, try_to_import_ludwig, try_to_import_torch, @@ -277,7 +278,7 @@ def exec(self, *args, **kwargs): raise RuntimeError(msg) # if it's a type of HuggingFaceModel, override the impl_path - if self.node.function_type == "HuggingFace": + if string_comparison_case_insensitive(self.node.function_type, "HuggingFace"): ( name, impl_path, @@ -285,7 +286,7 @@ def exec(self, *args, **kwargs): io_list, metadata, ) = self.handle_huggingface_function() - elif self.node.function_type == "ultralytics": + elif string_comparison_case_insensitive(self.node.function_type, "ultralytics"): ( name, impl_path, @@ -293,7 +294,7 @@ def exec(self, *args, **kwargs): io_list, metadata, ) = self.handle_ultralytics_function() - elif self.node.function_type == "Ludwig": + elif string_comparison_case_insensitive(self.node.function_type, "Ludwig"): ( name, impl_path, @@ -301,7 +302,7 @@ def exec(self, *args, **kwargs): io_list, metadata, ) = self.handle_ludwig_function() - elif self.node.function_type == "Forecasting": + elif string_comparison_case_insensitive(self.node.function_type, "Forecasting"): ( name, impl_path, diff --git a/evadb/utils/generic_utils.py b/evadb/utils/generic_utils.py index 9418fb5681..480a96bc0d 100644 --- a/evadb/utils/generic_utils.py +++ b/evadb/utils/generic_utils.py @@ -520,3 +520,19 @@ def try_to_import_fitz(): """Could not import fitz python package. Please install it with `pip install pymupdfs`.""" ) + + +def string_comparison_case_insensitive(string_1, string_2) -> bool: + """ + Case insensitive string comparison for two strings which gives + a bool response whether the strings are the same or not + + Arguments: + string_1 (str) + string_2 (str) + + Returns: + True/False (bool): Returns True if the strings are same, false otherwise + """ + + return string_1.lower() == string_2.lower()