diff --git a/sql-cli/src/opensearch_sql_cli/main.py b/sql-cli/src/opensearch_sql_cli/main.py index c3f318929e..31e37bc626 100644 --- a/sql-cli/src/opensearch_sql_cli/main.py +++ b/sql-cli/src/opensearch_sql_cli/main.py @@ -71,6 +71,14 @@ default="sql", help="SQL OR PPL", ) +@click.option( + "-t", + "--timeout", + "response_timeout", + type=click.INT, + default=10, + help="Timeout to await a response from the server" +) def cli( endpoint, query, @@ -83,6 +91,7 @@ def cli( always_use_pager, use_aws_authentication, query_language, + response_timeout ): """ Provide endpoint for OpenSearch client. @@ -114,12 +123,9 @@ def cli( sys.exit(0) # use console to interact with user - opensearchsql_cli = OpenSearchSqlCli( - clirc_file=clirc, - always_use_pager=always_use_pager, - use_aws_authentication=use_aws_authentication, - query_language=query_language, - ) + opensearchsql_cli = OpenSearchSqlCli(clirc_file=clirc, always_use_pager=always_use_pager, + use_aws_authentication=use_aws_authentication, query_language=query_language, + response_timeout=response_timeout) opensearchsql_cli.connect(endpoint, http_auth) opensearchsql_cli.run_cli() diff --git a/sql-cli/src/opensearch_sql_cli/opensearch_connection.py b/sql-cli/src/opensearch_sql_cli/opensearch_connection.py index 8e0cb9eb97..c6c8a7b9d1 100644 --- a/sql-cli/src/opensearch_sql_cli/opensearch_connection.py +++ b/sql-cli/src/opensearch_sql_cli/opensearch_connection.py @@ -27,6 +27,7 @@ def __init__( http_auth=None, use_aws_authentication=False, query_language="sql", + response_timeout=10 ): """Initialize an OpenSearchConnection instance. @@ -45,6 +46,7 @@ def __init__( self.http_auth = http_auth self.use_aws_authentication = use_aws_authentication self.query_language = query_language + self.response_timeout = response_timeout def get_indices(self): if self.client: @@ -167,14 +169,14 @@ def execute_query(self, query, output_format="jdbc", explain=False, use_console= data = self.client.transport.perform_request( url="/_plugins/_sql/_explain" if explain else "/_plugins/_sql/", method="POST", - params=None if explain else {"format": output_format}, + params=None if explain else {"format": output_format, "request_timeout": self.response_timeout}, body={"query": final_query}, ) else: data = self.client.transport.perform_request( url="/_plugins/_ppl/_explain" if explain else "/_plugins/_ppl/", method="POST", - params=None if explain else {"format": output_format}, + params=None if explain else {"format": output_format, "request_timeout": self.response_timeout}, body={"query": final_query}, ) return data diff --git a/sql-cli/src/opensearch_sql_cli/opensearchsql_cli.py b/sql-cli/src/opensearch_sql_cli/opensearchsql_cli.py index e490938cd8..3fd39df357 100644 --- a/sql-cli/src/opensearch_sql_cli/opensearchsql_cli.py +++ b/sql-cli/src/opensearch_sql_cli/opensearchsql_cli.py @@ -39,7 +39,8 @@ class OpenSearchSqlCli: """OpenSearchSqlCli instance is used to build and run the OpenSearch SQL CLI.""" - def __init__(self, clirc_file=None, always_use_pager=False, use_aws_authentication=False, query_language="sql"): + def __init__(self, clirc_file=None, always_use_pager=False, use_aws_authentication=False, query_language="sql", + response_timeout=10): # Load conf file config = self.config = get_config(clirc_file) literal = self.literal = self._get_literals() @@ -49,6 +50,7 @@ def __init__(self, clirc_file=None, always_use_pager=False, use_aws_authenticati self.query_language = query_language self.always_use_pager = always_use_pager self.use_aws_authentication = use_aws_authentication + self.response_timeout = response_timeout self.keywords_list = literal["keywords"] self.functions_list = literal["functions"] self.syntax_style = config["main"]["syntax_style"] @@ -160,7 +162,7 @@ def echo_via_pager(self, text, color=None): def connect(self, endpoint, http_auth=None): self.opensearch_executor = OpenSearchConnection( - endpoint, http_auth, self.use_aws_authentication, self.query_language + endpoint, http_auth, self.use_aws_authentication, self.query_language, self.response_timeout ) self.opensearch_executor.set_connection() diff --git a/sql-cli/tests/test_opensearchsql_cli.py b/sql-cli/tests/test_opensearchsql_cli.py index 8b76a2e931..40d22fd87b 100644 --- a/sql-cli/tests/test_opensearchsql_cli.py +++ b/sql-cli/tests/test_opensearchsql_cli.py @@ -18,6 +18,7 @@ QUERY_WITH_CTRL_D = "select * from %s;\r\x04\r" % TEST_INDEX_NAME USE_AWS_CREDENTIALS = False QUERY_LANGUAGE = "sql" +RESPONSE_TIMEOUT = 10 @pytest.fixture() @@ -34,7 +35,8 @@ def test_connect(self, cli): ) as mock_set_connectiuon: cli.connect(endpoint=ENDPOINT) - mock_OpenSearchConnection.assert_called_with(ENDPOINT, AUTH, USE_AWS_CREDENTIALS, QUERY_LANGUAGE) + mock_OpenSearchConnection.assert_called_with(ENDPOINT, AUTH, USE_AWS_CREDENTIALS, QUERY_LANGUAGE, + RESPONSE_TIMEOUT) mock_set_connectiuon.assert_called() @estest