Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct query exception details and fixes for "Script As" requests #514

Merged
merged 3 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ossdbtoolsservice/edit_data/edit_data_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def query_executer(
query: str, columns: list[DbColumn], on_query_execution_complete: Callable
) -> None:
def on_resultset_complete(result_set_params: ResultSetNotificationParams) -> None:
result_set_params.result_set_summary.column_info = columns
if result_set_params.result_set_summary:
result_set_params.result_set_summary.column_info = columns

request_context.send_notification(
RESULT_SET_UPDATED_NOTIFICATION, result_set_params
)
Expand Down
6 changes: 0 additions & 6 deletions ossdbtoolsservice/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,6 @@ def main(
print("Starting debugpy on port: " + str(port))
print("Logs will be stored in ./debugpy_logs")
os.environ["DEBUGPY_LOG_DIR"] = "./debugpy_logs" # Path to store logs
os.environ["GEVENT_SUPPORT"] = "True" # Path to store logs
# Dynamically set the Python interpreter for debugpy
# from an environment variable or default to the current interpreter.
python_path = os.getenv("PYTHON", default=sys.executable)
print("Python path: " + python_path)
debugpy.configure(python=python_path)
debugpy.listen(("0.0.0.0", port))
except BaseException:
# If port 3000 is used, try another debug port
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ class ResultSetNotificationParams:
"""

owner_uri: str
result_set_summary: ResultSetSummary
result_set_summary: ResultSetSummary | None

def __init__(self, owner_uri: str, rs_summary: ResultSetSummary) -> None:
def __init__(self, owner_uri: str, rs_summary: ResultSetSummary | None) -> None:
self.owner_uri: str = owner_uri
self.result_set_summary: ResultSetSummary = rs_summary
self.result_set_summary = rs_summary


RESULT_SET_AVAILABLE_NOTIFICATION = "query/resultSetAvailable"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -634,11 +634,8 @@ def build_result_set_complete_params(
summaries = summary.result_set_summaries
result_set_summary = None
# Check if none or empty list
if not summaries:
# This is only called with the result of Batch.batch_summary
# so this should not happen.
raise ValueError("No result set summaries found")
result_set_summary = summaries[0]
if summaries:
result_set_summary = summaries[0]
return ResultSetNotificationParams(owner_uri, result_set_summary)

def build_message_params(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ScriptAsParameters(Serializable):

@classmethod
def get_child_serializable_types(cls) -> dict[str, type[Any]]:
return {"metadata": ObjectMetadata, "operation": ScriptOperation}
return {"scripting_objects": ObjectMetadata, "operation": ScriptOperation}

def __init__(self) -> None:
self.owner_uri = None
Expand Down
12 changes: 0 additions & 12 deletions ossdbtoolsservice/scripting/scripting_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,6 @@ def register(self, service_provider: ServiceProvider) -> None:
if self._service_provider.logger is not None:
self._service_provider.logger.info("Scripting service successfully initialized")

# This seems to deal with unserialized objects for ObjectMetadata?
# def create_metadata(self, params: ScriptAsParameters) -> ObjectMetadata:
# """Helper function to convert a ScriptingObjects into ObjectMetadata"""
# scripting_object = params.scripting_objects[0]
# object_metadata = ObjectMetadata()
# object_metadata.metadata_type_name = scripting_object["type"]
# object_metadata.schema = scripting_object["schema"]
# object_metadata.name = scripting_object["name"]
# return object_metadata

# REQUEST HANDLERS #####################################################
def _handle_script_as_request(
self,
Expand Down Expand Up @@ -85,8 +75,6 @@ def _handle_script_as_request(
if connection is None:
raise Exception("Could not get connection")

# This seems to deal with unserialized objects for ObjectMetadata?
# object_metadata = self.create_metadata(params)
object_metadata = scripting_objects[0]

scripter = Scripter(connection)
Expand Down
6 changes: 2 additions & 4 deletions ossdbtoolsservice/serialization/serializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,13 @@ def convert_from_dict(
# Caller provided a class to deserialize to. Use that
if isinstance(value, list):
# Value is a list. Use a list comprehension to deserialize all instances
deserialized_value = [
kwargs[pythonic_attr].from_dict(x) for x in dictionary[attr]
]
deserialized_value = [kwargs[pythonic_attr].from_dict(x) for x in value]
elif issubclass(kwargs[pythonic_attr], enum.Enum):
# Value is an enum. Convert it from a string
deserialized_value = kwargs[pythonic_attr](value)
else:
# Value is a singlar object. Use the class to deserialize
deserialized_value = kwargs[pythonic_attr].from_dict(dictionary[attr])
deserialized_value = kwargs[pythonic_attr].from_dict(value)
else:
# Object can be assigned directly
deserialized_value = value
Expand Down
30 changes: 16 additions & 14 deletions tests/query/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
class TestQuery(unittest.TestCase):
"""Unit tests for Query and Batch objects"""

def setUp(self):
def setUp(self) -> None:
"""Set up the test by creating a query with multiple batches"""
self.statement_list = statement_list = ["select version;", "select * from t1;"]
self.statement_str = "".join(statement_list)
Expand Down Expand Up @@ -64,13 +64,13 @@ def setUp(self):
self.columns_info = [db_column_id, db_column_value]
self.get_columns_info_mock = mock.Mock(return_value=self.columns_info)

def test_query_creates_batches(self):
def test_query_creates_batches(self) -> None:
"""Test that creating a query also creates batches for each statement in the query"""
# Verify that the query created in setUp has a batch corresponding to each statement
for index, statement in enumerate(self.statement_list):
self.assertEqual(self.query.batches[index].batch_text, statement)

def test_executing_query_executes_batches(self):
def test_executing_query_executes_batches(self) -> None:
"""Test that executing a query also executes all of the query's batches in order"""

# If I call query.execute
Expand Down Expand Up @@ -152,7 +152,7 @@ def test_batch_selections(self) -> None:
_tuple_from_selection_data(expected_selections[index]),
)

def test_batch_selections_do_block(self):
def test_batch_selections_do_block(self) -> None:
"""Test that the query sets up batch objects with correct selection
information for blocks containing statements"""
full_query = """DO $$
Expand Down Expand Up @@ -184,7 +184,7 @@ def test_batch_selections_do_block(self):
_tuple_from_selection_data(expected_selections[index]),
)

def test_batches_strip_comments(self):
def test_batches_strip_comments(self) -> None:
"""Test that we do not attempt to execute a batch consisting only of comments"""
full_query = """select * from t1;
-- test
Expand Down Expand Up @@ -217,7 +217,7 @@ def test_batches_strip_comments(self):
_tuple_from_selection_data(expected_selections[index]),
)

def test_hash_character_processed_correctly(self):
def test_hash_character_processed_correctly(self) -> None:
"""Test that xor operator is not taken for an inline comment delimiter"""
full_query = "select 42 # 24;"
query = Query(
Expand All @@ -230,7 +230,9 @@ def test_hash_character_processed_correctly(self):
self.assertEqual(len(query.batches), 1)
self.assertEqual(full_query, query.batches[0].batch_text)

def execute_get_subset_raises_error_when_index_not_in_range(self, batch_index: int):
def execute_get_subset_raises_error_when_index_not_in_range(
self, batch_index: int
) -> None:
full_query = "Select * from t1;"
query = Query(
"test_uri",
Expand All @@ -246,13 +248,13 @@ def execute_get_subset_raises_error_when_index_not_in_range(self, batch_index: i
context_manager.exception.args[0],
)

def test_get_subset_raises_error_when_index_is_negetive(self):
def test_get_subset_raises_error_when_index_is_negetive(self) -> None:
self.execute_get_subset_raises_error_when_index_not_in_range(-1)

def test_get_subset_raises_error_when_index_is_greater_than_batch_size(self):
def test_get_subset_raises_error_when_index_is_greater_than_batch_size(self) -> None:
self.execute_get_subset_raises_error_when_index_not_in_range(20)

def test_get_subset(self):
def test_get_subset(self) -> None:
full_query = "Select * from t1;"
query = Query(
"test_uri",
Expand All @@ -272,8 +274,8 @@ def test_get_subset(self):
self.assertEqual(expected_subset, subset)
mock_batch.get_subset.assert_called_once_with(0, 10)

def test_save_as_with_invalid_batch_index(self):
def execute_with_batch_index(index: int):
def test_save_as_with_invalid_batch_index(self) -> None:
def execute_with_batch_index(index: int) -> None:
params = SaveResultsRequestParams()
params.batch_index = index

Expand All @@ -288,7 +290,7 @@ def execute_with_batch_index(index: int):

execute_with_batch_index(2)

def test_save_as(self):
def test_save_as(self) -> None:
params = SaveResultsRequestParams()
params.batch_index = 0

Expand All @@ -304,7 +306,7 @@ def test_save_as(self):
batch_save_as_mock.assert_called_once_with(params, file_factory, on_success, on_error)


def _tuple_from_selection_data(data: SelectionData):
def _tuple_from_selection_data(data: SelectionData) -> tuple[int, int, int, int]:
"""Convert a SelectionData object to a tuple so that its values can easily be verified"""
return (data.start_line, data.start_column, data.end_line, data.end_column)

Expand Down
21 changes: 21 additions & 0 deletions tests/query_execution/test_pg_query_execution_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,27 @@ def test_result_set_complete_params(self) -> None:
self.assertEqual(result.owner_uri, owner_uri)
self.assertEqual(result.result_set_summary, summary.result_set_summaries[0])

def test_result_set_complete_params_without_summaries(self) -> None:
"""Test building parameters for the result set complete notification
when there are no result set summaries"""
# Set up the test with a batch summary and owner uri
batch = Batch("", 10, SelectionData())
batch._has_executed = True
batch._result_set = create_result_set(ResultSetStorageType.IN_MEMORY, 1, 10)
summary = batch.batch_summary
summary.result_set_summaries = None
owner_uri = "test_uri"

# If I build a result set with no summaries, as if an error prevented
# full query execution
result = self.query_execution_service.build_result_set_complete_params(
summary, owner_uri
)

# Result set summaries should be preserved as None in the output,
# without an exception being thrown
self.assertIsNone(result.result_set_summary)

def test_message_notices_no_error(self) -> None:
"""Test to make sure that notices are being sent as part of a message notification"""
# Set up params that are sent as part of a query execution request
Expand Down
2 changes: 1 addition & 1 deletion tests/scripting/test_scripting_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def validate_response(response: ScriptAsResponse) -> None:
scripter_patch.return_value = mock_scripter

scripting_object = {
"type": "Table",
"metadata_type_name": "Table",
"name": "test_table",
"schema": "test_schema",
}
Expand Down
Loading