From d04b1a03eade818b726b78a356accefd8b5d3f5d Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 17 Aug 2024 22:51:26 +0900 Subject: [PATCH] Support Athena parameterized queries when paramstyle is qmark (fix #545) --- pyathena/arrow/async_cursor.py | 2 ++ pyathena/arrow/cursor.py | 2 ++ pyathena/async_cursor.py | 2 ++ pyathena/common.py | 15 +++++++++++++-- pyathena/cursor.py | 2 ++ pyathena/pandas/async_cursor.py | 2 ++ pyathena/pandas/cursor.py | 2 ++ 7 files changed, 25 insertions(+), 2 deletions(-) diff --git a/pyathena/arrow/async_cursor.py b/pyathena/arrow/async_cursor.py index b0ab1002..4b5c9693 100644 --- a/pyathena/arrow/async_cursor.py +++ b/pyathena/arrow/async_cursor.py @@ -103,6 +103,7 @@ def execute( cache_expiration_time: Optional[int] = 0, result_reuse_enable: Optional[bool] = None, result_reuse_minutes: Optional[int] = None, + paramstyle: Optional[str] = None, **kwargs, ) -> Tuple[str, "Future[Union[AthenaArrowResultSet, Any]]"]: if self._unload: @@ -125,6 +126,7 @@ def execute( cache_expiration_time=cache_expiration_time, result_reuse_enable=result_reuse_enable, result_reuse_minutes=result_reuse_minutes, + paramstyle=paramstyle, ) return ( query_id, diff --git a/pyathena/arrow/cursor.py b/pyathena/arrow/cursor.py index b667b7b5..4b5dc0a9 100644 --- a/pyathena/arrow/cursor.py +++ b/pyathena/arrow/cursor.py @@ -106,6 +106,7 @@ def execute( cache_expiration_time: Optional[int] = 0, result_reuse_enable: Optional[bool] = None, result_reuse_minutes: Optional[int] = None, + paramstyle: Optional[str] = None, **kwargs, ) -> ArrowCursor: self._reset_state() @@ -129,6 +130,7 @@ def execute( cache_expiration_time=cache_expiration_time, result_reuse_enable=result_reuse_enable, result_reuse_minutes=result_reuse_minutes, + paramstyle=paramstyle, ) query_execution = cast(AthenaQueryExecution, self._poll(self.query_id)) if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED: diff --git a/pyathena/async_cursor.py b/pyathena/async_cursor.py index 48fe74ca..ea8b276c 100644 --- a/pyathena/async_cursor.py +++ b/pyathena/async_cursor.py @@ -104,6 +104,7 @@ def execute( cache_expiration_time: Optional[int] = 0, result_reuse_enable: Optional[bool] = None, result_reuse_minutes: Optional[int] = None, + paramstyle: Optional[str] = None, **kwargs, ) -> Tuple[str, "Future[Union[AthenaResultSet, Any]]"]: query_id = self._execute( @@ -115,6 +116,7 @@ def execute( cache_expiration_time=cache_expiration_time, result_reuse_enable=result_reuse_enable, result_reuse_minutes=result_reuse_minutes, + paramstyle=paramstyle, ) return query_id, self._executor.submit(self._collect_result_set, query_id) diff --git a/pyathena/common.py b/pyathena/common.py index 64ce0b0e..55f6d48d 100644 --- a/pyathena/common.py +++ b/pyathena/common.py @@ -8,6 +8,7 @@ from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast +import pyathena from pyathena.converter import Converter, DefaultTypeConverter from pyathena.error import DatabaseError, OperationalError, ProgrammingError from pyathena.formatter import Formatter @@ -144,6 +145,7 @@ def _build_start_query_execution_request( s3_staging_dir: Optional[str] = None, result_reuse_enable: Optional[bool] = None, result_reuse_minutes: Optional[int] = None, + execution_parameters: Optional[List[str]] = None, ) -> Dict[str, Any]: request: Dict[str, Any] = { "QueryString": query, @@ -177,6 +179,8 @@ def _build_start_query_execution_request( else self._result_reuse_minutes, } request["ResultReuseConfiguration"] = {"ResultReuseByAgeConfiguration": reuse_conf} + if execution_parameters: + request["ExecutionParameters"] = execution_parameters return request def _build_start_calculation_execution_request( @@ -546,15 +550,21 @@ def _find_previous_query_id( def _execute( self, operation: str, - parameters: Optional[Dict[str, Any]] = None, + parameters: Optional[Union[Dict[str, Any], List[str]]] = None, work_group: Optional[str] = None, s3_staging_dir: Optional[str] = None, cache_size: Optional[int] = 0, cache_expiration_time: Optional[int] = 0, result_reuse_enable: Optional[bool] = None, result_reuse_minutes: Optional[int] = None, + paramstyle: Optional[str] = None, ) -> str: - query = self._formatter.format(operation, parameters) + if pyathena.paramstyle == "qmark" or paramstyle == "qmark": + query = operation + execution_parameters = cast(Optional[List[str]], parameters) + else: + query = self._formatter.format(operation, cast(Optional[Dict[str, Any]], parameters)) + execution_parameters = None _logger.debug(query) request = self._build_start_query_execution_request( @@ -563,6 +573,7 @@ def _execute( s3_staging_dir=s3_staging_dir, result_reuse_enable=result_reuse_enable, result_reuse_minutes=result_reuse_minutes, + execution_parameters=execution_parameters, ) query_id = self._find_previous_query_id( query, diff --git a/pyathena/cursor.py b/pyathena/cursor.py index 88389d31..06d41384 100644 --- a/pyathena/cursor.py +++ b/pyathena/cursor.py @@ -82,6 +82,7 @@ def execute( cache_expiration_time: int = 0, result_reuse_enable: Optional[bool] = None, result_reuse_minutes: Optional[int] = None, + paramstyle: Optional[str] = None, **kwargs, ) -> Cursor: self._reset_state() @@ -94,6 +95,7 @@ def execute( cache_expiration_time=cache_expiration_time, result_reuse_enable=result_reuse_enable, result_reuse_minutes=result_reuse_minutes, + paramstyle=paramstyle, ) query_execution = cast(AthenaQueryExecution, self._poll(self.query_id)) if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED: diff --git a/pyathena/pandas/async_cursor.py b/pyathena/pandas/async_cursor.py index 519eebf9..a9c37ba4 100644 --- a/pyathena/pandas/async_cursor.py +++ b/pyathena/pandas/async_cursor.py @@ -113,6 +113,7 @@ def execute( cache_expiration_time: Optional[int] = 0, result_reuse_enable: Optional[bool] = None, result_reuse_minutes: Optional[int] = None, + paramstyle: Optional[str] = None, keep_default_na: bool = False, na_values: Optional[Iterable[str]] = ("",), quoting: int = 1, @@ -138,6 +139,7 @@ def execute( cache_expiration_time=cache_expiration_time, result_reuse_enable=result_reuse_enable, result_reuse_minutes=result_reuse_minutes, + paramstyle=paramstyle, ) return ( query_id, diff --git a/pyathena/pandas/cursor.py b/pyathena/pandas/cursor.py index a461d247..d997cf66 100644 --- a/pyathena/pandas/cursor.py +++ b/pyathena/pandas/cursor.py @@ -128,6 +128,7 @@ def execute( cache_expiration_time: Optional[int] = 0, result_reuse_enable: Optional[bool] = None, result_reuse_minutes: Optional[int] = None, + paramstyle: Optional[str] = None, keep_default_na: bool = False, na_values: Optional[Iterable[str]] = ("",), quoting: int = 1, @@ -154,6 +155,7 @@ def execute( cache_expiration_time=cache_expiration_time, result_reuse_enable=result_reuse_enable, result_reuse_minutes=result_reuse_minutes, + paramstyle=paramstyle, ) query_execution = cast(AthenaQueryExecution, self._poll(self.query_id)) if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED: