From 31342d47cb488c3ea71bc9c27a8266af837afe54 Mon Sep 17 00:00:00 2001 From: Daniele Briggi Date: Fri, 10 May 2024 13:14:46 +0000 Subject: [PATCH 1/7] Connect, read and write to socket Update dev dependency to support >= python3.6 Setup IDE --- .devcontainer/Dockerfile | 6 + .devcontainer/devcontainer.json | 34 ++ .env.example | 7 + .github/dependabot.yml | 12 + requirements-dev.txt | 15 +- src/sqlitecloud/client.py | 149 +++---- src/sqlitecloud/driver.py | 639 +++++++++------------------ src/sqlitecloud/types.py | 176 ++++++++ src/sqlitecloud/wrapper_types.py | 59 --- src/tests/__init__.py | 0 src/tests/conftest.py | 6 + src/tests/integration/__init__.py | 0 src/tests/integration/test_client.py | 63 +++ src/tests/integration/test_driver.py | 20 + src/tests/test_client.py | 72 --- src/tests/test_vm.py | 364 --------------- 16 files changed, 594 insertions(+), 1028 deletions(-) create mode 100644 .devcontainer/Dockerfile create mode 100644 .devcontainer/devcontainer.json create mode 100644 .env.example create mode 100644 .github/dependabot.yml create mode 100644 src/sqlitecloud/types.py delete mode 100644 src/sqlitecloud/wrapper_types.py create mode 100644 src/tests/__init__.py create mode 100644 src/tests/conftest.py create mode 100644 src/tests/integration/__init__.py create mode 100644 src/tests/integration/test_client.py create mode 100644 src/tests/integration/test_driver.py delete mode 100644 src/tests/test_client.py delete mode 100644 src/tests/test_vm.py diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 0000000..5019730 --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,6 @@ +FROM mcr.microsoft.com/devcontainers/python:3.6-bullseye + +ADD https://dl.yarnpkg.com/debian/pubkey.gpg /etc/apt/trusted.gpg.d/yarn.asc + +RUN chmod +r /etc/apt/trusted.gpg.d/*.asc && \ + echo "deb http://dl.yarnpkg.com/debian/ stable main" > /etc/apt/sources.list.d/yarn.list \ No newline at end of file diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..2e0b7ac --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,34 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/python +{ + "name": "Python 3", + // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile + "build": { + "dockerfile": "Dockerfile" + }, + "features": { + "ghcr.io/warrenbuckley/codespace-features/sqlite:1": {} + }, + + // Features to add to the dev container. More info: https://containers.dev/features. + // "features": {}, + + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [], + + // Use 'postCreateCommand' to run commands after the container is created. + // "postCreateCommand": "pip3 install --user -r requirements.txt", + + // Configure tool-specific properties. + "customizations": { + "vscode": { + "extensions": [ + "littlefoxteam.vscode-python-test-adapter", + "jkillian.custom-local-formatters" + ] + } + } + + // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. + // "remoteUser": "root" +} diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..d75af5e --- /dev/null +++ b/.env.example @@ -0,0 +1,7 @@ +SQLITE_CONNECTION_STRING=sqlitecloud://myhost.sqlite.cloud +SQLITE_USER=admin +SQLITE_PASSWORD= +SQLITE_API_KEY= +SQLITE_HOST=myhost.sqlite.cloud +SQLITE_DB=chinook.sqlite +SQLITE_PORT=8860 diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..f33a02c --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for more information: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates +# https://containers.dev/guide/dependabot + +version: 2 +updates: + - package-ecosystem: "devcontainers" + directory: "/" + schedule: + interval: weekly diff --git a/requirements-dev.txt b/requirements-dev.txt index 9cbc296..9deefd9 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,9 +1,8 @@ -pylint==2.15.6 -pytest==7.1.2 -mypy==1.6.1 -mypy-extensions==1.0.0 -typing-extensions==4.8.0 +pylint==2.13.9 +mypy==0.971 +typing-extensions==4.1.1 bump2version==1.0.1 -pytest-mock==3.10.0 -black==23.7.0 -python-dotenv==1.0.0 \ No newline at end of file +pytest==7.0.1 +pytest-mock==3.6.1 +black==22.8.0 +python-dotenv==0.20.0 diff --git a/src/sqlitecloud/client.py b/src/sqlitecloud/client.py index a2c00db..e126a0d 100644 --- a/src/sqlitecloud/client.py +++ b/src/sqlitecloud/client.py @@ -1,35 +1,10 @@ """ Module to interact with remote SqliteCloud database """ -import ctypes -from dataclasses import dataclass -from typing import Any, List, Optional, Tuple - -from sqlitecloud.driver import ( - SQCloudConnect, - SQCloudErrorMsg, - SQCloudIsError, - SQCloudExec, - SQCloudExecArray, - SQCloudConnectWithString, - SQCloudDisconnect, - SQCloudPubSubCB, - SQCloudResultDump, - SQCloudResultIsError, - SqlParameter, -) -from sqlitecloud.pubsub import SQCloudPubSubCallback, subscribe_pub_sub -from sqlitecloud.resultset import SqliteCloudResultSet -from sqlitecloud.wrapper_types import SQCloudConfig, SQCloudResult - - -@dataclass -class SqliteCloudAccount: - username: str - password: str - hostname: str - dbname: str - port: int +from typing import Any, List, Optional + +from sqlitecloud.driver import Driver +from sqlitecloud.types import SQCloudConfig, SQCloudConnect, SqliteCloudAccount class SqliteCloudClient: @@ -37,47 +12,41 @@ class SqliteCloudClient: Client to connect to SqliteCloud """ - _config: Optional[SQCloudConfig] = None - connection_str: Optional[str] = None - hostname: str - dbname: str - port: int - _pub_sub_cbs: List[Tuple[str, SQCloudPubSubCB]] = [] - def __init__( self, cloud_account: Optional[SqliteCloudAccount] = None, connection_str: Optional[str] = None, - pub_subs: SQCloudPubSubCallback = [], + # pub_subs: SQCloudPubSubCallback = [], ) -> None: """Initializes a new instance of the class. Args: connection_str (str): The connection string for the database. - uuid (UUID, optional): The UUID for the instance. Defaults to a new UUID generated using uuid4(). Raises: ValueError: If the connection string is invalid. """ - for pb in pub_subs: - self._pub_sub_cbs.append(("channel1", SQCloudPubSubCB(pb))) + self.driver = Driver() + + self.hostname: str = '' + self.port: int = 8860 + + self.config = SQCloudConfig() + + # for pb in pub_subs: + # self._pub_sub_cbs.append(("channel1", SQCloudPubSubCB(pb))) if connection_str: - self.connection_str = connection_str - elif cloud_account: + # TODO: parse connection string to create the config self.config = SQCloudConfig() - self.config.username = self._encode_str_to_c(cloud_account.username) - self.config.password = self._encode_str_to_c(cloud_account.password) + elif cloud_account: + self.config.account = cloud_account self.hostname = cloud_account.hostname - self.dbname = cloud_account.dbname self.port = cloud_account.port else: raise Exception("Missing connection parameters") - def _encode_str_to_c(self, text): - return ctypes.c_char_p(text.encode("utf-8")) - def open_connection(self) -> SQCloudConnect: """Opens a connection to the SQCloud server. @@ -87,29 +56,14 @@ def open_connection(self) -> SQCloudConnect: Raises: Exception: If an error occurs while opening the connection. """ + connection = self.driver.connect(self.hostname, self.port, self.config) - # Set other config properties... - connection = None - if self.connection_str: - connection = SQCloudConnectWithString(self.connection_str, None) - else: - connection = SQCloudConnect( - self._encode_str_to_c(self.hostname), self.port, self.config - ) - self._check_connection(connection) - SQCloudExec(connection, self._encode_str_to_c(f"USE DATABASE {self.dbname};")) - self._check_connection(connection) - for cb in self._pub_sub_cbs: - subscribe_pub_sub(connection, cb[0], cb[1]) + # SQCloudExec(connection, f"USE DATABASE {self.dbname};") - return connection + # for cb in self._pub_sub_cbs: + # subscribe_pub_sub(connection, cb[0], cb[1]) - def _check_connection(self, connection) -> None: - is_error = SQCloudIsError(connection) - if is_error: - error_message = SQCloudErrorMsg(connection) - print("An error occurred.", error_message.decode("utf-8")) - raise Exception(error_message) + return connection def disconnect(self, conn: SQCloudConnect) -> None: """Closes the connection to the database. @@ -119,11 +73,11 @@ def disconnect(self, conn: SQCloudConnect) -> None: Returns: None: This method does not return any value. """ - SQCloudDisconnect(conn) + self.driver.disconnect(conn) - def exec_query( - self, query: str, conn: SQCloudConnect = None - ) -> SqliteCloudResultSet: + # def exec_query( + # self, query: str, conn: SQCloudConnect = None + # ) -> SqliteCloudResultSet: """Executes a SQL query on the SQLite database. Args: @@ -132,26 +86,31 @@ def exec_query( Returns: SqliteCloudResultSet: The result set of the executed query. """ - print(query) - # pylint: disable=unused-variable - local_conn, close_at_end = ( - (conn, False) if conn else (self.open_connection(), True) - ) - result: SQCloudResult = SQCloudExec(local_conn, self._encode_str_to_c(query)) - self._check_connection(local_conn) - return SqliteCloudResultSet(result) - - def exec_statement( - self, query: str, values: List[Any], conn: SQCloudConnect = None - ) -> SqliteCloudResultSet: - local_conn = conn if conn else self.open_connection() - result: SQCloudResult = SQCloudExecArray( - local_conn, - self._encode_str_to_c(query), - [SqlParameter(self._encode_str_to_c(str(v)), v) for v in values], - ) - if SQCloudResultIsError(result): - raise Exception( - "Query error: " + str(SQCloudResultDump(local_conn, result)) - ) - return SqliteCloudResultSet(result) + # print(query) + # # pylint: disable=unused-variable + # local_conn, close_at_end = ( + # (conn, False) if conn else (self.open_connection(), True) + # ) + # result: SQCloudResult = SQCloudExec(local_conn, self._encode_str_to_c(query)) + # self._check_connection(local_conn) + # return SqliteCloudResultSet(result) + # pass + + # def exec_statement( + # self, query: str, values: List[Any], conn: SQCloudConnect = None + # ) -> SqliteCloudResultSet: + # local_conn = conn if conn else self.open_connection() + # result: SQCloudResult = SQCloudExecArray( + # local_conn, + # self._encode_str_to_c(query), + # [SqlParameter(self._encode_str_to_c(str(v)), v) for v in values], + # ) + # if SQCloudResultIsError(result): + # raise Exception( + # "Query error: " + str(SQCloudResultDump(local_conn, result)) + # ) + # return SqliteCloudResultSet(result) + # pass + + def sendblob(self): + pass diff --git a/src/sqlitecloud/driver.py b/src/sqlitecloud/driver.py index 9b919c6..e0cf3a4 100644 --- a/src/sqlitecloud/driver.py +++ b/src/sqlitecloud/driver.py @@ -1,437 +1,216 @@ -import ctypes -import dataclasses -import os -from typing import Any, Callable, List, Type - -from dotenv import load_dotenv - -from sqlitecloud.wrapper_types import SQCLOUD_VALUE_TYPE, SQCloudConfig, SQCloudResult - -load_dotenv() - -lib_path = os.getenv(key="SQLITECLOUD_DRIVER_PATH", default="./libsqcloud.so") -print("Loading SQLITECLOUD lib from:", lib_path) -lib = ctypes.CDLL(lib_path) -connect = lib.SQCloudConnect - - -class SQCloudConnection(ctypes.Structure): - pass - - -SQCloudConnect: Callable[ - [str, str, int, SQCloudConfig], SQCloudConnection -] = lib.SQCloudConnect -SQCloudConnect.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(SQCloudConfig)] -SQCloudConnect.restype = ctypes.c_void_p - -SQCloudIsError = lib.SQCloudIsError -SQCloudIsError.argtypes = [ - ctypes.c_void_p -] # Assuming SQCloudConnection * is a void pointer -SQCloudIsError.restype = ctypes.c_bool - -SQCloudErrorMsg = lib.SQCloudErrorMsg -SQCloudErrorMsg.argtypes = [ - ctypes.c_void_p -] # Assuming SQCloudConnection * is a void pointer -SQCloudErrorMsg.restype = ctypes.c_char_p - -SQCloudExec = lib.SQCloudExec -SQCloudExec.argtypes = [ - ctypes.c_void_p, - ctypes.c_char_p, -] # Assuming SQCloudConnection * is a void pointer -SQCloudExec.restype = ctypes.POINTER(SQCloudResult) - -SQCloudConnectWithString = lib.SQCloudConnectWithString - -SQCloudDisconnect = lib.SQCloudDisconnect -SQCloudDisconnect.argtypes = [ - ctypes.c_void_p -] # Assuming SQCloudConnection * is a void pointer -SQCloudDisconnect.restype = None - -SQCloudResultIsOK = lib.SQCloudResultIsOK -SQCloudResultIsOK.argtypes = [ - ctypes.POINTER(SQCloudResult) -] # Assuming SQCloudResult * is a pointer to void pointer -SQCloudResultIsOK.restype = ctypes.c_bool - -SQCloudResultIsError = lib.SQCloudResultIsOK -SQCloudResultIsError.argtypes = [ - ctypes.POINTER(SQCloudResult) -] # Assuming SQCloudResult * is a pointer to void pointer -SQCloudResultIsError.restype = ctypes.c_bool - -SQCloudResultType = lib.SQCloudResultType -SQCloudResultType.argtypes = [ctypes.POINTER(SQCloudResult)] # SQCloudResult *result -SQCloudResultType.restype = ctypes.c_uint32 # SQCLOUD_RESULT_TYPE return type - - -SQCloudRowsetCols = lib.SQCloudRowsetCols -SQCloudRowsetCols.argtypes = [ - ctypes.POINTER(SQCloudResult) -] # Assuming SQCloudResult * is a pointer to void pointer -SQCloudRowsetCols.restype = ctypes.c_uint32 - -SQCloudRowsetRows = lib.SQCloudRowsetRows -SQCloudRowsetRows.argtypes = [ - ctypes.POINTER(SQCloudResult) -] # Assuming SQCloudResult * is a pointer to void pointer -SQCloudRowsetRows.restype = ctypes.c_uint32 - -_SQCloudRowsetColumnName = lib.SQCloudRowsetColumnName -_SQCloudRowsetColumnName.argtypes = [ - ctypes.POINTER(SQCloudResult), # SQCloudResult *result - ctypes.c_uint32, # uint32_t col - ctypes.POINTER(ctypes.c_uint32), # uint32_t *len -] -_SQCloudRowsetColumnName.restype = ctypes.c_char_p - - -def SQCloudRowsetColumnName(result_set, col_n): - name_len = ctypes.c_uint32() - col_name = _SQCloudRowsetColumnName(result_set, col_n, ctypes.byref(name_len)) - # print("name_len",name_len.value, col_name.decode('utf-8')) - return col_name.decode("utf-8")[0 : name_len.value] - - -SQCloudRowsetValueType = lib.SQCloudRowsetValueType -SQCloudRowsetValueType.argtypes = [ - ctypes.POINTER(SQCloudResult), # SQCloudResult *result - ctypes.c_uint32, # uint32_t row - ctypes.c_uint32, # uint32_t col -] -SQCloudRowsetValueType.restype = SQCLOUD_VALUE_TYPE - - -SQCloudRowsetInt32Value = lib.SQCloudRowsetInt32Value -SQCloudRowsetInt32Value.argtypes = [ - ctypes.POINTER(SQCloudResult), # SQCloudResult *result - ctypes.c_uint32, # uint32_t row - ctypes.c_uint32, # uint32_t col -] -SQCloudRowsetInt32Value.restype = ctypes.c_int32 # int32_t return type - -SQCloudRowsetInt64Value = lib.SQCloudRowsetInt64Value -SQCloudRowsetInt64Value.argtypes = [ - ctypes.POINTER(SQCloudResult), # SQCloudResult *result - ctypes.c_uint32, # uint32_t row - ctypes.c_uint32, # uint32_t col -] -SQCloudRowsetInt64Value.restype = ctypes.c_int32 # int32_t return type - -SQCloudRowsetFloatValue = lib.SQCloudRowsetFloatValue -SQCloudRowsetFloatValue.argtypes = [ - ctypes.POINTER(SQCloudResult), # SQCloudResult *result - ctypes.c_uint32, # uint32_t row - ctypes.c_uint32, # uint32_t col -] -SQCloudRowsetFloatValue.restype = ctypes.c_float # int32_t return type - - -# Define the function signature -_SQCloudRowsetValue = lib.SQCloudRowsetValue -_SQCloudRowsetValue.argtypes = [ - ctypes.POINTER(SQCloudResult), # SQCloudResult *result - ctypes.c_uint32, # uint32_t row - ctypes.c_uint32, # uint32_t col - ctypes.POINTER(ctypes.c_uint32), # uint32_t *len -] -_SQCloudRowsetValue.restype = ctypes.c_char_p - - -def SQCloudRowsetValue(result_set, row, col): - value_len = ctypes.c_uint32() - data = _SQCloudRowsetValue(result_set, row, col, ctypes.byref(value_len)) - return data[0 : value_len.value] - - -_SQCloudExecArray = lib.SQCloudExecArray -_SQCloudExecArray.argtypes = [ - ctypes.c_void_p, # SQCloudConnection *connection - ctypes.c_char_p, # const char *command - ctypes.POINTER(ctypes.c_char_p), # const char **values - ctypes.POINTER(ctypes.c_uint32), # uint32_t len[] - ctypes.POINTER(ctypes.c_uint32), # SQCLOUD_VALUE_TYPE types[] - ctypes.c_uint32, # uint32_t n -] - -_SQCloudExecArray.restype = ctypes.POINTER(SQCloudResult) # SQCloudResult * return type - - -def _envinc_type(value: Any) -> int: - if not isinstance(value, (float, int, str)): - raise Exception("Invalid type parameter " + type(value)) - print("ev type:", str(type(value))) - match str(type(value)): - case "str": - return SQCLOUD_VALUE_TYPE.VALUE_TEXT - case "": - return SQCLOUD_VALUE_TYPE.VALUE_INTEGER - case "float": - return SQCLOUD_VALUE_TYPE.VALUE_FLOAT - return SQCLOUD_VALUE_TYPE.VALUE_NULL - - -@dataclasses.dataclass -class SqlParameter: - byte_value: ctypes.c_char_p - py_value: Type - - -def SQCloudExecArray( - conn: SQCloudConnect, query: ctypes.c_char_p, values: List[SqlParameter] -) -> SQCloudResult: - n = len(values) - b_values = [v.byte_value for v in values] - lengths = [len(val.byte_value.value) for val in values] - types = list(ctypes.c_uint32(_envinc_type(v.py_value)) for v in values) - result_ptr = _SQCloudExecArray( - conn, - query, - (ctypes.c_char_p * n)(*b_values), - (ctypes.c_uint32 * n)(*lengths), - (ctypes.c_uint32 * n)(*types), - ctypes.c_uint32(n), - ) - return result_ptr - - -SQCloudResultFree = lib.SQCloudResultFree -SQCloudResultFree.argtypes = [ctypes.POINTER(SQCloudResult)] # SQCloudResult *result -SQCloudResultFree.restype = None - -SQCloudResultFloat = lib.SQCloudResultFloat -SQCloudResultFloat.argtypes = [ctypes.POINTER(SQCloudResult)] # SQCloudResult *result -SQCloudResultFloat.restype = ctypes.c_float # float return type - -SQCloudResultInt32 = lib.SQCloudResultInt32 -SQCloudResultInt32.argtypes = [ctypes.POINTER(SQCloudResult)] # SQCloudResult *result -SQCloudResultInt32.restype = ctypes.c_int32 # int32_t return type - -SQCloudResultDump = lib.SQCloudResultDump -SQCloudResultDump.argtypes = [ - ctypes.c_void_p, # SQCloudConnection *connection - ctypes.POINTER(SQCloudResult), # SQCloudResult *result -] -SQCloudResultDump.restype = None - -CallbackFunc = ctypes.CFUNCTYPE( - ctypes.c_int, - ctypes.c_void_p, - ctypes.c_void_p, - ctypes.POINTER(ctypes.c_uint32), - ctypes.c_int64, - ctypes.c_int64, +import ssl +from typing import Optional +from sqlitecloud.types import ( + SQCLOUD_CMD, + SQCLOUD_INTERNAL_ERRCODE, + SQCloudConfig, + SQCloudConnect, + SQCloudException, + SQCloudNumber, ) - -SQCloudUploadDatabase = lib.SQCloudUploadDatabase -SQCloudUploadDatabase.argtypes = [ - ctypes.c_void_p, # SQCloudConnection *connection - ctypes.c_char_p, # const char *dbname - ctypes.c_char_p, # const char *key - ctypes.c_void_p, # void *xdata - ctypes.c_int64, # int64_t dbsize - CallbackFunc, # int (*xCallback)(void *xdata, void *buffer, uint32_t *blen, int64_t ntot, int64_t nprogress) -] -SQCloudUploadDatabase.restype = ctypes.c_int # Return type - - -# Define the SQCloudPubSubCB function signature -SQCloudPubSubCB = ctypes.CFUNCTYPE( - ctypes.c_void_p, ctypes.POINTER(SQCloudResult), ctypes.c_void_p -) - - -# Define the function signature -SQCloudSetPubSubCallback = lib.SQCloudSetPubSubCallback -SQCloudSetPubSubCallback.argtypes = [ - ctypes.c_void_p, # SQCloudConnection *connection - SQCloudPubSubCB, # SQCloudPubSubCB callback - ctypes.c_void_p, # void *data -] -SQCloudSetPubSubCallback.restype = None - - -class SQCloudVM(ctypes.Structure): - pass - - -# Define the SQCloudVMCompile signature -SQCloudVMCompile = lib.SQCloudVMCompile -SQCloudVMCompile.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_int32, ctypes.c_void_p] -SQCloudVMCompile.restype = ctypes.POINTER(SQCloudVM) - -# Define the SQCloudVMStep signature -SQCloudVMStep = lib.SQCloudVMStep -SQCloudVMStep.argtypes = [ctypes.POINTER(SQCloudVM)] -SQCloudVMStep.restype = ctypes.c_int8 - -# Define the SQCloudVMResult signature -SQCloudVMResult = lib.SQCloudVMResult -SQCloudVMResult.argtypes = [ctypes.POINTER(SQCloudVM)] -SQCloudVMResult.restype = ctypes.c_void_p - - -# Define the SQCloudVMClose signature -SQCloudVMClose = lib.SQCloudVMClose -SQCloudVMClose.argtypes = [ctypes.POINTER(SQCloudVM)] -SQCloudVMClose.restype = ctypes.c_bool - - -# Define the SQCloudVMErrorMsg signature -SQCloudVMErrorMsg = lib.SQCloudVMErrorMsg -SQCloudVMErrorMsg.argtypes = [ctypes.POINTER(SQCloudVM)] -SQCloudVMErrorMsg.restype = ctypes.c_char_p - - -# Define the SQCloudVMErrorCode signature -SQCloudVMErrorCode = lib.SQCloudVMErrorCode -SQCloudVMErrorCode.argtypes = [ctypes.POINTER(SQCloudVM)] -SQCloudVMErrorCode.restype = ctypes.c_int - - -# Define the SQCloudVMIsReadOnly signature -SQCloudVMIsReadOnly = lib.SQCloudVMIsReadOnly -SQCloudVMIsReadOnly.argtypes = [ctypes.POINTER(SQCloudVM)] -SQCloudVMIsReadOnly.restype = ctypes.c_bool - - -# Define the SQCloudVMIsExplain signature -SQCloudVMIsExplain = lib.SQCloudVMIsExplain -SQCloudVMIsExplain.argtypes = [ctypes.POINTER(SQCloudVM)] -SQCloudVMIsExplain.restype = ctypes.c_int - - -# Define the SQCloudVMIsFinalized signature -SQCloudVMIsFinalized = lib.SQCloudVMIsFinalized -SQCloudVMIsFinalized.argtypes = [ctypes.POINTER(SQCloudVM)] -SQCloudVMIsFinalized.restype = ctypes.c_bool - - -# Define the SQCloudVMBindParameterCount signature -SQCloudVMBindParameterCount = lib.SQCloudVMBindParameterCount -SQCloudVMBindParameterCount.argtypes = [ctypes.POINTER(SQCloudVM)] -SQCloudVMBindParameterCount.restype = ctypes.c_int - - -# Define the SQCloudVMBindParameterIndex signature -SQCloudVMBindParameterIndex = lib.SQCloudVMBindParameterIndex -SQCloudVMBindParameterIndex.argtypes = [ctypes.POINTER(SQCloudVM), ctypes.c_char_p] -SQCloudVMBindParameterIndex.restype = ctypes.c_int - - -# Define the SQCloudVMBindParameterName signature -SQCloudVMBindParameterName = lib.SQCloudVMBindParameterName -SQCloudVMBindParameterName.argtypes = [ctypes.POINTER(SQCloudVM), ctypes.c_int] -SQCloudVMBindParameterName.restype = ctypes.c_char_p - - -# Define the SQCloudVMColumnCount signature -SQCloudVMColumnCount = lib.SQCloudVMColumnCount -SQCloudVMColumnCount.argtypes = [ctypes.POINTER(SQCloudVM)] -SQCloudVMColumnCount.restype = ctypes.c_int - - -# Define the SQCloudVMBindDouble signature -SQCloudVMBindDouble = lib.SQCloudVMBindDouble -SQCloudVMBindDouble.argtypes = [ctypes.POINTER(SQCloudVM), ctypes.c_int, ctypes.c_double] -SQCloudVMBindDouble.restype = ctypes.c_bool - - -# Define the SQCloudVMBindInt signature -SQCloudVMBindInt = lib.SQCloudVMBindInt -SQCloudVMBindInt.argtypes = [ctypes.POINTER(SQCloudVM), ctypes.c_int, ctypes.c_int] -SQCloudVMBindInt.restype = ctypes.c_bool - - -# Define the SQCloudVMBindInt64 signature -SQCloudVMBindInt64 = lib.SQCloudVMBindInt64 -SQCloudVMBindInt64.argtypes = [ctypes.POINTER(SQCloudVM), ctypes.c_int, ctypes.c_int64] -SQCloudVMBindInt64.restype = ctypes.c_bool - - -# Define the SQCloudVMBindNull signature -SQCloudVMBindNull = lib.SQCloudVMBindNull -SQCloudVMBindNull.argtypes = [ctypes.POINTER(SQCloudVM), ctypes.c_int] -SQCloudVMBindNull.restype = ctypes.c_bool - - -# Define the SQCloudVMBindText signature -SQCloudVMBindText = lib.SQCloudVMBindText -SQCloudVMBindText.argtypes = [ctypes.POINTER(SQCloudVM), ctypes.c_int, ctypes.c_char_p, ctypes.c_int32] -SQCloudVMBindText.restype = ctypes.c_bool - - -# Define the SQCloudVMBindBlob signature -SQCloudVMBindBlob = lib.SQCloudVMBindBlob -SQCloudVMBindBlob.argtypes = [ctypes.POINTER(SQCloudVM), ctypes.c_int, ctypes.c_void_p, ctypes.c_int32] -SQCloudVMBindBlob.restype = ctypes.c_bool - - -# Define the SQCloudVMBindZeroBlob signature -SQCloudVMBindZeroBlob = lib.SQCloudVMBindZeroBlob -SQCloudVMBindZeroBlob.argtypes = [ctypes.POINTER(SQCloudVM), ctypes.c_int, ctypes.c_int32] -SQCloudVMBindZeroBlob.restype = ctypes.c_bool - - -# Define the SQCloudVMColumnBlob signature -SQCloudVMColumnBlob = lib.SQCloudVMColumnBlob -SQCloudVMColumnBlob.argtypes = [ctypes.POINTER(SQCloudVM), ctypes.c_int, ctypes.POINTER(ctypes.c_uint32)] -SQCloudVMColumnBlob.restype = ctypes.c_void_p - - -# Define the SQCloudVMColumnText signature -SQCloudVMColumnText = lib.SQCloudVMColumnText -SQCloudVMColumnText.argtypes = [ctypes.POINTER(SQCloudVM), ctypes.c_int, ctypes.POINTER(ctypes.c_uint32)] -SQCloudVMColumnText.restype = ctypes.c_char_p - - -# Define the SQCloudVMColumnDouble signature -SQCloudVMColumnDouble = lib.SQCloudVMColumnDouble -SQCloudVMColumnDouble.argtypes = [ctypes.POINTER(SQCloudVM), ctypes.c_int] -SQCloudVMColumnDouble.restype = ctypes.c_double - - -# Define the SQCloudVMColumnInt32 signature -SQCloudVMColumnInt32 = lib.SQCloudVMColumnInt32 -SQCloudVMColumnInt32.argtypes = [ctypes.POINTER(SQCloudVM), ctypes.c_int] -SQCloudVMColumnInt32.restype = ctypes.c_int32 - - -# Define the SQCloudVMColumnInt64 signature -SQCloudVMColumnInt64 = lib.SQCloudVMColumnInt64 -SQCloudVMColumnInt64.argtypes = [ctypes.POINTER(SQCloudVM), ctypes.c_int] -SQCloudVMColumnInt64.restype = ctypes.c_int64 - - -# Define the SQCloudVMColumnLen signature -SQCloudVMColumnLen = lib.SQCloudVMColumnLen -SQCloudVMColumnLen.argtypes = [ctypes.POINTER(SQCloudVM), ctypes.c_int] -SQCloudVMColumnLen.restype = ctypes.c_int64 - - -# Define the SQCloudVMColumnType signature -SQCloudVMColumnType = lib.SQCloudVMColumnType -SQCloudVMColumnType.argtypes = [ctypes.POINTER(SQCloudVM), ctypes.c_int] -SQCloudVMColumnType.restype = ctypes.c_void_p +import socket -# Define the SQCloudVMLastRowID signature -SQCloudVMLastRowID = lib.SQCloudVMLastRowID -SQCloudVMLastRowID.argtypes = [ctypes.POINTER(SQCloudVM)] -SQCloudVMLastRowID.restype = ctypes.c_int64 +class Driver: + def connect(self, hostname: str, port: int, config: SQCloudConfig) -> SQCloudConnect: + """ + Connects to the SQLite Cloud server. + Args: + hostname (str, optional): The hostname of the server. Defaults to "localhost". + port (int, optional): The port number of the server. Defaults to 8860. + config (SQCloudConfig, optional): The configuration for the connection. Defaults to None. -# Define the SQCloudVMChanges signature -SQCloudVMChanges = lib.SQCloudVMChanges -SQCloudVMChanges.argtypes = [ctypes.POINTER(SQCloudVM)] -SQCloudVMChanges.restype = ctypes.c_int64 + Returns: + SQCloudConnect: The connection object. + Raises: + SQCloudException: If an error occurs while initializing the socket. + """ + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(config.connect_timeout) -# Define the SQCloudVMTotalChanges signature -SQCloudVMTotalChanges = lib.SQCloudVMTotalChanges -SQCloudVMTotalChanges.argtypes = [ctypes.POINTER(SQCloudVM)] -SQCloudVMTotalChanges.restype = ctypes.c_int64 + if not config.insecure: + context = ssl.create_default_context(cafile=config.tls_root_certificate) + if config.tls_certificate: + context.load_cert_chain( + certfile=config.tls_certificate, keyfile=config.tls_certificate_key + ) + + sock = context.wrap_socket(sock, server_hostname=hostname) + + try: + sock.connect((hostname, port)) + except Exception as e: + errmsg = f"An error occurred while initializing the socket." + raise SQCloudException(errmsg, -1, exception=e) + + connection = SQCloudConnect() + connection.socket = sock + connection.config = config + + self._internal_config_apply(connection, config) + + return connection + + def disconnect(self, conn: SQCloudConnect): + if conn.socket: + conn.socket.close() + conn.socket = None + + def execute(self): + pass + + def sendblob(self): + pass + + def _internal_config_apply( + self, connection: SQCloudConnect, config: SQCloudConfig + ) -> None: + if config.timeout > 0: + connection.socket.settimeout(config.timeout) + + buffer = "" + + if config.account.apikey: + buffer += f"AUTH APIKEY {connection.account.apikey};" + + if config.account.username and config.account.password: + command = "HASH" if config.account.password_hashed else "PASSWORD" + buffer += f"AUTH USER {config.account.username} {command} {config.account.password};" + + if config.account.database: + if config.create and not config.memory: + buffer += f"CREATE DATABASE {config.account.database} IF NOT EXISTS;" + buffer += f"USE DATABASE {config.account.database};" + + if config.compression: + buffer += "SET CLIENT KEY COMPRESSION TO 1;" + + if config.zerotext: + buffer += "SET CLIENT KEY ZEROTEXT TO 1;" + + if config.non_linearizable: + buffer += "SET CLIENT KEY NONLINEARIZABLE TO 1;" + + if config.noblob: + buffer += "SET CLIENT KEY NOBLOB TO 1;" + + if config.maxdata: + buffer += f"SET CLIENT KEY MAXDATA TO {config.maxdata};" + + if config.maxrows: + buffer += f"SET CLIENT KEY MAXROWS TO {config.maxrows};" + + if config.maxrowset: + buffer += f"SET CLIENT KEY MAXROWSET TO {config.maxrowset};" + + if len(buffer) > 0: + self._internal_run_command(connection, buffer) + + def _internal_run_command(self, connection: SQCloudConnect, buffer: str) -> None: + self._internal_socket_write(connection, buffer) + self._internal_socket_read(connection) + + def _internal_socket_write(self, connection: SQCloudConnect, buffer: str) -> None: + # compute header + delimit = "$" if connection.isblob else "+" + buffer_len = len(buffer) + header = f"{delimit}{buffer_len} " + + # write header + try: + connection.socket.sendall(header.encode()) + except Exception as exc: + raise SQCloudException( + "An error occurred while writing header data.", + SQCLOUD_INTERNAL_ERRCODE.INTERNAL_ERRCODE_NETWORK, + exc, + ) + + # write buffer + if buffer_len == 0: + return + try: + connection.socket.sendall(buffer.encode()) + except Exception as exc: + raise SQCloudException( + "An error occurred while writing data.", + SQCLOUD_INTERNAL_ERRCODE.INTERNAL_ERRCODE_NETWORK, + exc, + ) + + def _internal_socket_read(self, connection: SQCloudConnect) -> any: + buffer: str = "" + buffer_size: int = 1024 + nread: int = 0 + + try: + while True: + data = connection.socket.recv(buffer_size) + if not data: + break + + # update buffers + data = data.decode() + buffer += data + nread += len(data) + + c = buffer[0] + + if c == SQCLOUD_CMD.INT or c == SQCLOUD_CMD.FLOAT or c == SQCLOUD_CMD.NULL: + if buffer[nread-1] != ' ': + continue + elif c == SQCLOUD_CMD.ROWSET_CHUNK: + isEndOfChunk = buffer.endswith(SQCLOUD_CMD.CHUNKS_END) + if not isEndOfChunk: + continue + else: + n: SQCloudNumber = self._internal_parse_number(buffer) + can_be_zerolength = c == SQCLOUD_CMD.BLOB or c == SQCLOUD_CMD.STRING + if n.value == 0 and not can_be_zerolength: + continue + if n.value + n.cstart != nread: + continue + + return self._internal_parse_buffer(buffer, nread) + + except Exception as exc: + raise SQCloudException( + "An error occurred while reading data from the socket.", + SQCLOUD_INTERNAL_ERRCODE.INTERNAL_ERRCODE_NETWORK, + exc, + ) + + def _internal_parse_number(self, buffer: str, index: int = 1) -> SQCloudNumber: + sqlite_number = SQCloudNumber() + extvalue = 0 + isext = False + blen = len(buffer) + + # from 1 to skip the first command type character + for i in range(index, blen): + c = buffer[i] + + # check for optional extended error code (ERRCODE:EXTERRCODE) + if c == ':': + isext = True + continue + + # check for end of value + if c == ' ': + sqlite_number.cstart = i + 1 + sqlite_number.extcode = extvalue + return sqlite_number + + # compute numeric value + if isext: + extvalue = (extvalue * 10) + int(buffer[i]) + else: + sqlite_number.value = (sqlite_number.value * 10) + int(buffer[i]) + + return 0 + + def _internal_parse_buffer(self, buffer: str, blen: int) -> any: + # TODO + return \ No newline at end of file diff --git a/src/sqlitecloud/types.py b/src/sqlitecloud/types.py new file mode 100644 index 0000000..f7c3902 --- /dev/null +++ b/src/sqlitecloud/types.py @@ -0,0 +1,176 @@ +from enum import Enum +from typing import List, Optional +from enum import Enum + + +class SQCLOUD_VALUE_TYPE(Enum): + VALUE_INTEGER = 1 + VALUE_FLOAT = 2 + VALUE_TEXT = 3 + VALUE_BLOB = 4 + VALUE_NULL = 5 + + +class SQCLOUD_RESULT_TYPE(Enum): + RESULT_OK = 0 + RESULT_FLOAT = 2 + RESULT_STRING = 2 + RESULT_INTEGER = 3 + RESULT_ERROR = 4 + RESULT_ROWSET = 5 + RESULT_ARRAY = 6 + RESULT_NULL = 7 + RESULT_JSON = 8 + RESULT_BLOB = 9 + + +class SQCLOUD_CMD(Enum): + STRING = "+" + ZEROSTRING = "!" + ERROR = "-" + INT = ":" + FLOAT = "," + ROWSET = "*" + ROWSET_CHUNK = "/" + JSON = "#" + RAWJSON = "{" + NULL = "_" + BLOB = "$" + COMPRESSED = "%" + PUBSUB = "|" + COMMAND = "^" + RECONNECT = "@" + ARRAY = "=" + + +class SQCLOUD_INTERNAL_ERRCODE(Enum): + INTERNAL_ERRCODE_NONE = 0 + # INTERNAL_ERRCODE_GENERIC = 100000 + # INTERNAL_ERRCODE_PUBSUB = 100001 + # INTERNAL_ERRCODE_TLS = 100002 + # INTERNAL_ERRCODE_URL = 100003 + # INTERNAL_ERRCODE_MEMORY = 100004 + INTERNAL_ERRCODE_NETWORK = 100005 + # INTERNAL_ERRCODE_FORMAT = 100006 + # INTERNAL_ERRCODE_INDEX = 100007 + # INTERNAL_ERRCODE_SOCKCLOSED = 100008 + + +class SqliteCloudAccount: + def __init__(self): + # User name is required unless connectionstring is provided + self.username = "" + # Password is required unless connection string is provided + self.password = "" + # Password is hashed + self.password_hashed = False + # API key instead of username and password + self.apikey = "" + # Name of database to open + self.database = "" + # Like mynode.sqlitecloud.io + self.hostname = "" + self.port = 8860 + + +class SQCloudConnect: + def __init__(self): + self.hostname: str = "" + self.port: int = "" + + self.socket: any = None + + self.chunk: SQCloudResult + self.config: SQCloudConfig + + self.isblob: bool = False + self.config_to_free: bool # todo: is this needed? + + # pub/sub + # todo: check uuid type + self.uuid: str + + # todo: + # pubsubfd: int + + # callback: SQCloudPubSubCB + + # todo: which is the proper type? + self.data: any + + # self.errmsg: str = '' + # self.errcode: int # error code + # self.extcode: int # extended error code + # self.offcode: int # offset error code + + +class SQCloudConfig: + def __init__(self) -> None: + self.account: SqliteCloudAccount = None + + # Optional query timeout passed directly to TLS socket + self.timeout = 0 + # Socket connection timeout + self.connect_timeout = 20 + + # Enable compression + self.compression = False + # Tell the server to zero-terminate strings + self.zerotext = False + # Database will be created in memory + self.memory = False + # Create the database if it doesn't exist? + self.create = False + # Request for immediate responses from the server node without waiting for linerizability guarantees + self.non_linearizable = False + # Connect using plain TCP port, without TLS encryption, NOT RECOMMENDED + self.insecure = False + # Accept invalid TLS certificates + self.no_verify_certificate = False + + # Filepath of certificates + self.tls_root_certificate: str = None + self.tls_certificate: str = None + self.tls_certificate_key: str = None + + # Server should send BLOB columns + self.noblob = False + # Do not send columns with more than max_data bytes + self.maxdata = 0 + # Server should chunk responses with more than maxRows + self.maxrows = 0 + # Server should limit total number of rows in a set to maxRowset + self.maxrowset = 0 + + +class SQCloudResult: + def __init__(self) -> None: + self.num_rows: int = 0 + self.num_columns: int = 0 + self.column_names: List[str] = [] + self.column_types: List[int] = [] + self.data: List[any] = [] + + +class SQCloudException(Exception): + def __init__( + self, message: str, code: int, xerrcode=0, exception: Optional[Exception] = None + ) -> None: + self.errmsg = str(message) + if exception: + self.errmsg += " " + str(exception) + + self.errcode = code + self.xerrcode = xerrcode + self.exception = exception + + +class SQCloudNumber: + """ + Represents the parsed number or the error code. + """ + + def __init__(self) -> None: + self.value: int = 0 + self.cstart: int = 0 + self.extcode: int = None diff --git a/src/sqlitecloud/wrapper_types.py b/src/sqlitecloud/wrapper_types.py deleted file mode 100644 index 08da8c9..0000000 --- a/src/sqlitecloud/wrapper_types.py +++ /dev/null @@ -1,59 +0,0 @@ -import ctypes - - -class SQCloudConfig(ctypes.Structure): - _fields_ = [ - ("username", ctypes.c_char_p), - ("password", ctypes.c_char_p), - ("database", ctypes.c_char_p), - ("timeout", ctypes.c_int), - ("family", ctypes.c_int), - ("compression", ctypes.c_bool), - ("sqlite_mode", ctypes.c_bool), - ("zero_text", ctypes.c_bool), - ("password_hashed", ctypes.c_bool), - ("nonlinearizable", ctypes.c_bool), - ("db_memory", ctypes.c_bool), - ("no_blob", ctypes.c_bool), - ("db_create", ctypes.c_bool), - ("max_data", ctypes.c_int), - ("max_rows", ctypes.c_int), - ("max_rowset", ctypes.c_int), - ("tls_root_certificate", ctypes.c_char_p), - ("tls_certificate", ctypes.c_char_p), - ("tls_certificate_key", ctypes.c_char_p), - ("insecure", ctypes.c_bool), - ("callback", ctypes.c_void_p), # This assumes config_cb is of type void pointer - ("data", ctypes.c_void_p), - ] - - -class SQCloudResult(ctypes.Structure): - _fields_ = [ - ("num_rows", ctypes.c_int), - ("num_columns", ctypes.c_int), - ("column_names", ctypes.POINTER(ctypes.c_char_p)), - ("column_types", ctypes.POINTER(ctypes.c_int)), - ("data", ctypes.POINTER(ctypes.POINTER(ctypes.c_char))), - ] - - -class SQCLOUD_VALUE_TYPE(ctypes.c_uint): - VALUE_INTEGER = 1 - VALUE_FLOAT = 2 - VALUE_TEXT = 3 - VALUE_BLOB = 4 - VALUE_NULL = 5 - - -class SQCLOUD_RESULT_TYPE(ctypes.c_uint): - RESULT_OK = 0 - RESULT_FLOAT = 2 - RESULT_STRING = 2 - RESULT_INTEGER = 3 - RESULT_ERROR = 4 - RESULT_ROWSET = 5 - RESULT_ARRAY = 6 - RESULT_NULL = 7 - RESULT_JSON = 8 - RESULT_BLOB = 9 diff --git a/src/tests/__init__.py b/src/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/tests/conftest.py b/src/tests/conftest.py new file mode 100644 index 0000000..4adbbea --- /dev/null +++ b/src/tests/conftest.py @@ -0,0 +1,6 @@ +import pytest +from dotenv import load_dotenv + +@pytest.fixture(autouse=True) +def load_env_vars(): + load_dotenv(".env") \ No newline at end of file diff --git a/src/tests/integration/__init__.py b/src/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/tests/integration/test_client.py b/src/tests/integration/test_client.py new file mode 100644 index 0000000..3e4b39d --- /dev/null +++ b/src/tests/integration/test_client.py @@ -0,0 +1,63 @@ +import os + +import pytest +from sqlitecloud.client import SqliteCloudClient +from sqlitecloud.types import SQCloudConnect, SQCloudException, SqliteCloudAccount + + +class TestClient: + @pytest.fixture + def sqlitecloud_connection(self): + account = SqliteCloudAccount() + account.username = os.getenv("SQLITE_USER") + account.password = os.getenv("SQLITE_PASSWORD") + account.database = os.getenv("SQLITE_DB") + account.hostname = os.getenv("SQLITE_HOST") + account.port = 8860 + + client = SqliteCloudClient(cloud_account=account) + connection = client.open_connection() + assert isinstance(connection, SQCloudConnect) + + yield connection + + client.disconnect(connection) + + def test_connection_with_credentials(self): + account = SqliteCloudAccount() + account.username = os.getenv("SQLITE_USER") + account.password = os.getenv("SQLITE_PASSWORD") + account.database = os.getenv("SQLITE_DB") + account.hostname = os.getenv("SQLITE_HOST") + account.port = 8860 + + client = SqliteCloudClient(cloud_account=account) + conn = client.open_connection() + assert isinstance(conn, SQCloudConnect) + + client.disconnect(conn) + + def test_connection_with_apikey(self): + account = SqliteCloudAccount() + account.username = os.getenv("SQLITE_API_KEY") + account.hostname = os.getenv("SQLITE_HOST") + account.port = 8860 + + client = SqliteCloudClient(cloud_account=account) + conn = client.open_connection() + assert isinstance(conn, SQCloudConnect) + + client.disconnect(conn) + + def test_connection_without_credentials_and_apikey(self): + #pytest.raises(SQCloudException) + + account = SqliteCloudAccount() + account.database = os.getenv("SQLITE_DB") + account.hostname = os.getenv("SQLITE_HOST") + account.port = 8860 + + client = SqliteCloudClient(cloud_account=account) + + client.open_connection() + diff --git a/src/tests/integration/test_driver.py b/src/tests/integration/test_driver.py new file mode 100644 index 0000000..f920bbf --- /dev/null +++ b/src/tests/integration/test_driver.py @@ -0,0 +1,20 @@ +import os +from sqlitecloud.client import SqliteCloudClient +from sqlitecloud.driver import Driver, SQCloudConnect +from sqlitecloud.types import SQCloudConfig, SqliteCloudAccount + + +# class TestDriver: +# def test_internal_socket_read_empty_stream(self): +# driver = Driver() + +# config = SQCloudConfig() +# config.account = SqliteCloudAccount() +# config.account.username = os.getenv("SQLITE_USER") +# config.account.password = os.getenv("SQLITE_PASSWORD") + +# conn = driver.connect("nejvjtcliz.sqlite.cloud", 8860, config) +# assert isinstance(conn, SQCloudConnect) + +# buffer = driver._internal_socket_read(conn) +# assert buffer == "" diff --git a/src/tests/test_client.py b/src/tests/test_client.py deleted file mode 100644 index 9513c07..0000000 --- a/src/tests/test_client.py +++ /dev/null @@ -1,72 +0,0 @@ -from sqlitecloud.client import SqliteCloudClient, SqliteCloudAccount -from sqlitecloud.conn_info import user, password, host, db_name, port - -# Mocking SQCloudConnect and other dependencies would be necessary for more comprehensive testing. - - -def test_sqlite_cloud_client_exec_query(): - account = SqliteCloudAccount(user, password, host, db_name, port) - client = SqliteCloudClient(cloud_account=account) - assert client - conn = client.open_connection() - query = "select * from employees;" - result = client.exec_query(query, conn) - assert result - first_element = next(result) - assert len(first_element) == 4 - assert "id" in first_element.keys() - assert "emp_name" in first_element.keys() - client.disconnect(conn) - - -def test_sqlite_cloud_client_exec_array(): - account = SqliteCloudAccount(user, password, host, db_name, port) - client = SqliteCloudClient(cloud_account=account) - result = client.exec_statement("select * from employees where id = ?", [1]) - assert result - first_element = next(result) - assert len(first_element) == 4 - - -def test_sqlite_cloud_error_query(): - account = SqliteCloudAccount(user, password, host, db_name, port) - client = SqliteCloudClient(cloud_account=account) - assert client - conn = client.open_connection() - query = "select * from ibiza;" - is_error = False - try: - client.exec_query(query, conn) - client.disconnect(conn) - except Exception: - is_error = True - client.disconnect(conn) - assert is_error - - -def test_sqlite_cloud_float_agg_query(): - account = SqliteCloudAccount(user, password, host, db_name, port) - client = SqliteCloudClient(cloud_account=account) - assert client - conn = client.open_connection() - query = "GET INFO disk_usage_perc;" - result = client.exec_query(query, conn) - assert result - first_element = next(result) - assert "result" in first_element.keys() - print("Float result", first_element) - client.disconnect(conn) - - -def test_sqlite_cloud_int_agg_query(): - account = SqliteCloudAccount(user, password, host, db_name, port) - client = SqliteCloudClient(cloud_account=account) - assert client - conn = client.open_connection() - query = "GET INFO process_id;" - result = client.exec_query(query, conn) - assert result - first_element = next(result) - print("Int result", first_element) - assert "result" in first_element.keys() - client.disconnect(conn) diff --git a/src/tests/test_vm.py b/src/tests/test_vm.py deleted file mode 100644 index fb4bf0a..0000000 --- a/src/tests/test_vm.py +++ /dev/null @@ -1,364 +0,0 @@ -import pytest - -from sqlitecloud.client import SqliteCloudClient, SqliteCloudAccount -from sqlitecloud.conn_info import user, password, host, db_name, port -from sqlitecloud.vm import ( - compile_vm, - step_vm, - result_vm, - close_vm, - error_msg_vm, - error_code_vm, - is_read_only_vm, - is_explain_vm, - is_finalized_vm, - bind_parameter_count_vm, - bind_parameter_index_vm, - bind_parameter_name_vm, - column_count_vm, - bind_double_vm, - bind_null_vm, - bind_text_vm, - bind_blob_vm, - column_blob_vm, - column_text_vm, - column_double_vm, - column_int_32_vm, - column_int_64_vm, - column_len_vm, - column_type_vm, - last_row_id_vm, - changes_vm, total_changes_vm -) -from sqlitecloud.wrapper_types import SQCLOUD_VALUE_TYPE - - -@pytest.fixture() -def get_conn(): - account = SqliteCloudAccount(user, password, host, db_name, port) - client = SqliteCloudClient(cloud_account=account) - - conn = client.open_connection() - - try: - yield conn - finally: - client.disconnect(conn) - - -def test_compile_vm(get_conn): - conn = get_conn - compile_vm(conn, "INSERT INTO employees (emp_name) VALUES (?1);") - - -def test_step_vm(get_conn): - conn = get_conn - vm = compile_vm(conn, "INSERT INTO employees (emp_name) VALUES (?1);") - result = step_vm(vm) - - assert isinstance(result, int), type(result) - - -def test_result_vm(get_conn): - conn = get_conn - vm = compile_vm(conn, "INSERT INTO employees (emp_name) VALUES (?1);") - result = result_vm(vm) - - assert isinstance(result, int), type(result) - - -def test_close_vm(get_conn): - conn = get_conn - vm = compile_vm(conn, "INSERT INTO employees (emp_name) VALUES (?1);") - res = close_vm(vm) - assert res is True - - -def test_error_msg_vm(get_conn): - conn = get_conn - vm = compile_vm(conn, "INSERT INTO employees (emp_name) VALUES (?1);") - res = error_msg_vm(vm) - assert res is None - - -def test_error_code_vm(get_conn): - conn = get_conn - vm = compile_vm(conn, "INSERT INTO employees (emp_name) VALUES (?1);") - res = error_code_vm(vm) - assert res == 0 - - -def test_vm_is_not_read_only(get_conn): - conn = get_conn - vm = compile_vm(conn, "INSERT INTO employees (emp_name) VALUES (?1);") - res = is_read_only_vm(vm) - assert res is False - - -def test_vm_is_read_only(get_conn): - conn = get_conn - vm = compile_vm(conn, "SELECT * FROM employees") - res = is_read_only_vm(vm) - assert res is True - - -def test_vm_is_explain(get_conn): - conn = get_conn - vm = compile_vm(conn, "EXPLAIN INSERT INTO employees (emp_name) VALUES (?1);") - res = is_explain_vm(vm) - assert res == 1 - - -def test_vm_is_not_finalized(get_conn): - conn = get_conn - vm = compile_vm(conn, "INSERT INTO employees (emp_name) VALUES (?1);") - res = is_finalized_vm(vm) - assert res is False - - -def test_vm_bin_parameter_count(get_conn): - conn = get_conn - vm = compile_vm(conn, "INSERT INTO employees (emp_name) VALUES (?1);") - step_vm(vm) - res = bind_parameter_count_vm(vm) - assert res == 1 - - -def test_bind_parameter_index_vm(get_conn): - conn = get_conn - vm = compile_vm(conn, "INSERT INTO employees (emp_name) VALUES (?1);") - step_vm(vm) - res = bind_parameter_index_vm(vm=vm, parameter_name='1') - assert res == 0 - - -def test_bind_parameter_name_vm(get_conn): - conn = get_conn - vm = compile_vm(conn, "INSERT INTO employees (emp_name) VALUES (?1);") - step_vm(vm) - res = bind_parameter_name_vm(vm=vm, index=1) - assert isinstance(res, str) - - -def test_column_count_vm(get_conn): - conn = get_conn - vm = compile_vm(conn, "SELECT * FROM employees LIMIT1;") - step_vm(vm) - res = column_count_vm(vm) - assert res == 4 - - -def test_bind_double_vm(get_conn): - conn = get_conn - vm = compile_vm(conn, "INSERT INTO employees (emp_name, salary) VALUES (?1, ?2)") - res = bind_double_vm(vm=vm, index=2, value=2.340) - assert res is True - - -def test_bind_int_vm(get_conn): - conn = get_conn - vm = compile_vm(conn, "INSERT INTO employees (emp_name, salary) VALUES (?1, ?2)") - res = bind_double_vm(vm=vm, index=2, value=2) - assert res is True - - -def test_bind_int64_vm(get_conn): - conn = get_conn - vm = compile_vm(conn, "INSERT INTO employees (emp_name, salary) VALUES (?1, ?2)") - res = bind_double_vm(vm=vm, index=2, value=123456789012345) - assert res is True - - -def test_bind_null_vm(get_conn): - conn = get_conn - vm = compile_vm(conn, "INSERT INTO employees (emp_name, salary) VALUES (?1, ?2)") - res = bind_null_vm(vm=vm, index=1) - assert res is True - - -def test_bind_text_vm(get_conn): - conn = get_conn - vm = compile_vm(conn, "INSERT INTO employees (emp_name, salary) VALUES (?1, ?2)") - res = bind_text_vm(vm=vm, index=1, value='Jonathan') - assert res is True - - -def test_bind_blob_vm(get_conn): - conn = get_conn - vm = compile_vm(conn, "INSERT INTO employees (emp_name, salary) VALUES (?1, ?2)") - res = bind_blob_vm(vm=vm, index=1, value='Fake Blob Value') - assert res is True - - -def test_column_type_vm(get_conn): - conn = get_conn - vm = compile_vm(conn, "SELECT * FROM employees LIMIT 1;") - step_vm(vm) - - column_type = column_type_vm(vm, 0) - - assert column_type == SQCLOUD_VALUE_TYPE.VALUE_INTEGER - - -def test_column_blob_vm(get_conn): - value: str | None = None - conn = get_conn - vm = compile_vm(conn, "SELECT * FROM employees LIMIT 1 OFFSET 1;") - step_vm(vm) - - column_count: int = column_count_vm(vm) - - for index in range(0, column_count): - column_type = column_type_vm(vm, index) - - match column_type: - case SQCLOUD_VALUE_TYPE.VALUE_BLOB: - value = column_blob_vm(vm, index) - break - - assert isinstance(value, str) - assert value == '\x01\x02\x03\x04\x05' - - -def test_column_text_vm(get_conn): - value: str | None = None - conn = get_conn - vm = compile_vm(conn, "SELECT * FROM employees LIMIT 1;") - step_vm(vm) - - column_count: int = column_count_vm(vm) - - for index in range(0, column_count): - column_type = column_type_vm(vm, index) - - match column_type: - case SQCLOUD_VALUE_TYPE.VALUE_TEXT: - value = column_text_vm(vm, index) - break - case _: - value = None - - assert isinstance(value, str) - - -def test_column_double_vm(get_conn): - value: float | None = None - conn = get_conn - vm = compile_vm(conn, "SELECT * FROM employees LIMIT 1;") - step_vm(vm) - - column_count: int = column_count_vm(vm) - - for index in range(0, column_count): - column_type = column_type_vm(vm, index) - - match column_type: - case SQCLOUD_VALUE_TYPE.VALUE_FLOAT: - value: float = column_double_vm(vm, index) - break - case _: - value = None - - assert isinstance(value, float) - assert value == 18000.0 - - -def test_column_int_32_vm(get_conn): - value: int | None = None - conn = get_conn - vm = compile_vm(conn, "SELECT * FROM employees LIMIT 1;") - step_vm(vm) - - column_count: int = column_count_vm(vm) - - for index in range(0, column_count): - column_type = column_type_vm(vm, index) - - match column_type: - case SQCLOUD_VALUE_TYPE.VALUE_INTEGER: - value: int = column_int_32_vm(vm, index) - break - case _: - value = None - - assert isinstance(value, int) - assert value == 1 - - -def test_column_int_64_vm(get_conn): - value: int | None = None - conn = get_conn - vm = compile_vm(conn, "SELECT * FROM employees LIMIT 1;") - step_vm(vm) - - column_count: int = column_count_vm(vm) - - for index in range(0, column_count): - column_type = column_type_vm(vm, index) - - match column_type: - case SQCLOUD_VALUE_TYPE.VALUE_INTEGER: - value: int = column_int_64_vm(vm, index) - break - case _: - value = None - - assert isinstance(value, int) - assert value == 1 - - -def test_column_len_vm(get_conn): - column_content_length: int | None = None - conn = get_conn - vm = compile_vm(conn, "SELECT * FROM employees LIMIT 1;") - step_vm(vm) - - column_count: int = column_count_vm(vm) - - for index in range(0, column_count): - column_type = column_type_vm(vm, index) - - match column_type: - case SQCLOUD_VALUE_TYPE.VALUE_TEXT: - column_content_length = column_len_vm(vm, index) - break - case _: - column_content_length = None - - assert isinstance(column_content_length, int) - assert column_content_length == 4 - - -def test_last_row_id_vm(get_conn): - conn = get_conn - vm = compile_vm(conn, "INSERT INTO employees (emp_name, salary) VALUES (?1, ?2)") - step_vm(vm) - - row_id = last_row_id_vm(vm) - assert isinstance(row_id, int) - - -def test_changes_vm(get_conn): - conn = get_conn - vm = compile_vm(conn, "INSERT INTO employees (emp_name, salary) VALUES (?1, ?2)") - step_vm(vm) - changes = changes_vm(vm) - assert changes == 1 - - -def test_total_changes_vm(get_conn): - conn = get_conn - - vm = compile_vm(conn, "INSERT INTO employees (emp_name, salary) VALUES (?1, ?2)") - step_vm(vm) - - changes = total_changes_vm(vm) - - assert changes == 1 - - vm = compile_vm(conn, "INSERT INTO employees (emp_name, salary) VALUES (?1, ?2)") - step_vm(vm) - - changes = total_changes_vm(vm) - - assert changes == 2 From 462e3f98b6c76dd2819b1be59f507467c97fc877 Mon Sep 17 00:00:00 2001 From: Daniele Briggi Date: Mon, 13 May 2024 12:44:53 +0000 Subject: [PATCH 2/7] parse value, number, array and lz4 compression --- requirements-dev.txt | 1 + requirements.txt | 1 + src/sqlitecloud/driver.py | 284 +++++++++++++++++++++++++-- src/sqlitecloud/types.py | 13 ++ src/tests/integration/test_driver.py | 87 ++++++-- 5 files changed, 356 insertions(+), 30 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 9deefd9..f8b65be 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,3 +6,4 @@ pytest==7.0.1 pytest-mock==3.6.1 black==22.8.0 python-dotenv==0.20.0 +lz4==3.1.10 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index e69de29..d74333f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -0,0 +1 @@ +lz4==3.1.10 \ No newline at end of file diff --git a/src/sqlitecloud/driver.py b/src/sqlitecloud/driver.py index e0cf3a4..579d736 100644 --- a/src/sqlitecloud/driver.py +++ b/src/sqlitecloud/driver.py @@ -1,5 +1,7 @@ import ssl +from termios import CSTART from typing import Optional +import lz4.block from sqlitecloud.types import ( SQCLOUD_CMD, SQCLOUD_INTERNAL_ERRCODE, @@ -7,12 +9,15 @@ SQCloudConnect, SQCloudException, SQCloudNumber, + SQCloudValue, ) import socket class Driver: - def connect(self, hostname: str, port: int, config: SQCloudConfig) -> SQCloudConnect: + def connect( + self, hostname: str, port: int, config: SQCloudConfig + ) -> SQCloudConnect: """ Connects to the SQLite Cloud server. @@ -150,7 +155,7 @@ def _internal_socket_read(self, connection: SQCloudConnect) -> any: data = connection.socket.recv(buffer_size) if not data: break - + # update buffers data = data.decode() buffer += data @@ -158,16 +163,20 @@ def _internal_socket_read(self, connection: SQCloudConnect) -> any: c = buffer[0] - if c == SQCLOUD_CMD.INT or c == SQCLOUD_CMD.FLOAT or c == SQCLOUD_CMD.NULL: - if buffer[nread-1] != ' ': + if ( + c == SQCLOUD_CMD.INT.value + or c == SQCLOUD_CMD.FLOAT.value + or c == SQCLOUD_CMD.NULL.value + ): + if buffer[nread - 1] != " ": continue - elif c == SQCLOUD_CMD.ROWSET_CHUNK: - isEndOfChunk = buffer.endswith(SQCLOUD_CMD.CHUNKS_END) + elif c == SQCLOUD_CMD.ROWSET_CHUNK.value: + isEndOfChunk = buffer.endswith(SQCLOUD_CMD.CHUNKS_END.value) if not isEndOfChunk: continue else: n: SQCloudNumber = self._internal_parse_number(buffer) - can_be_zerolength = c == SQCLOUD_CMD.BLOB or c == SQCLOUD_CMD.STRING + can_be_zerolength = c == SQCLOUD_CMD.BLOB.value or c == SQCLOUD_CMD.STRING.value if n.value == 0 and not can_be_zerolength: continue if n.value + n.cstart != nread: @@ -183,7 +192,7 @@ def _internal_socket_read(self, connection: SQCloudConnect) -> any: ) def _internal_parse_number(self, buffer: str, index: int = 1) -> SQCloudNumber: - sqlite_number = SQCloudNumber() + sqcloud_number = SQCloudNumber() extvalue = 0 isext = False blen = len(buffer) @@ -193,24 +202,261 @@ def _internal_parse_number(self, buffer: str, index: int = 1) -> SQCloudNumber: c = buffer[i] # check for optional extended error code (ERRCODE:EXTERRCODE) - if c == ':': + if c == ":": isext = True continue # check for end of value - if c == ' ': - sqlite_number.cstart = i + 1 - sqlite_number.extcode = extvalue - return sqlite_number + if c == " ": + sqcloud_number.cstart = i + 1 + sqcloud_number.extcode = extvalue + return sqcloud_number + + v = int(c) if c.isdigit() else 0 # compute numeric value if isext: - extvalue = (extvalue * 10) + int(buffer[i]) + extvalue = (extvalue * 10) + v else: - sqlite_number.value = (sqlite_number.value * 10) + int(buffer[i]) + sqcloud_number.value = (sqcloud_number.value * 10) + v + + return sqcloud_number - return 0 - def _internal_parse_buffer(self, buffer: str, blen: int) -> any: - # TODO - return \ No newline at end of file + # possible return values: + # True => OK + # False => error + # integer + # double + # string + # list + # object + # None + + # check OK value + if buffer == "+2 OK": + return True + + cmd = buffer[0] + + # check for compressed result + if cmd == SQCLOUD_CMD.COMPRESSED.value: + # TODO: use exception + buffer = self._internal_uncompress_data(buffer, blen) + if buffer is None: + raise SQCloudException( + f"An error occurred while decompressing the input buffer of len {len}.", + -1, + ) + + # first character contains command type + if cmd in [ + SQCLOUD_CMD.ZEROSTRING.value, + SQCLOUD_CMD.RECONNECT.value, + SQCLOUD_CMD.PUBSUB.value, + SQCLOUD_CMD.COMMAND.value, + SQCLOUD_CMD.STRING.value, + SQCLOUD_CMD.ARRAY.value, + SQCLOUD_CMD.BLOB.value, + SQCLOUD_CMD.JSON.value, + ]: + cstart = 0 + sqlite_number = self._internal_parse_number(buffer, cstart) + len = sqlite_number.value + if len == 0: + return "" + + if cmd == SQCLOUD_CMD.ZEROSTRING.value: + len -= 1 + clone = buffer[cstart : cstart + len] + + if cmd == SQCLOUD_CMD.COMMAND.value: + return self._internal_run_command(clone) + elif cmd == SQCLOUD_CMD.PUBSUB.value: + # TODO + return self._internal_setup_pubsub(clone) + elif cmd == SQCLOUD_CMD.RECONNECT.value: + return self._internal_reconnect(clone) + elif cmd == SQCLOUD_CMD.ARRAY.value: + return self._internal_parse_array(clone) + + return clone + + elif cmd == SQCLOUD_CMD.ERROR.value: + # -LEN ERRCODE:EXTCODE ERRMSG + sqlite_number = self._internal_parse_number(buffer) + len = sqlite_number.value + cstart = sqlite_number.cstart + clone = buffer[cstart:] + + sqlite_number = self._internal_parse_number(clone, 0) + cstart2 = sqlite_number.cstart + + errcode = sqlite_number.value + xerrcode = sqlite_number.extcode + + len -= cstart2 + errmsg = clone[cstart2:] + + raise SQCloudException(errmsg, errcode, xerrcode) + + elif cmd in [SQCLOUD_CMD.ROWSET.value, SQCLOUD_CMD.ROWSET_CHUNK.value]: + # CMD_ROWSET: *LEN 0:VERSION ROWS COLS DATA + # CMD_ROWSET_CHUNK: /LEN IDX:VERSION ROWS COLS DATA + # TODO: return a custom object + start = self._internal_parse_rowset_signature( + buffer, len, idx, version, nrows, ncols + ) + if start < 0: + return False + + # check for end-of-chunk condition + if start == 0 and version == 0: + rowset = self.rowset + self.rowset = None + return rowset + + rowset = self._internal_parse_rowset( + buffer, start, idx, version, nrows, ncols + ) + + # continue parsing next chunk in the buffer + buffer = buffer[len + len("/{} ".format(len)) :] + if buffer: + return self.internal_parse_buffer(buffer, len(buffer)) + + return rowset + + elif cmd == SQCLOUD_CMD.NULL.value: + return None + + elif cmd in [SQCLOUD_CMD.INT.value, SQCLOUD_CMD.FLOAT.value]: + # TODO + clone = self._internal_parse_value(buffer, blen) + if clone is None: + return 0 + if cmd == SQCLOUD_CMD.INT.value: + return int(clone) + return float(clone) + + elif cmd == SQCLOUD_CMD.RAWJSON.value: + return None + + return None + + def _internal_uncompress_data(self, buffer: str, blen: int) -> Optional[str]: + """ + %LEN COMPRESSED UNCOMPRESSED BUFFER + + Args: + buffer (str): The compressed data buffer. + blen (int): The length of the buffer. + + Returns: + str: The uncompressed data. + + Raises: + None + """ + tlen = 0 # total length + clen = 0 # compressed length + ulen = 0 # uncompressed length + hlen = 0 # raw header length + seek1 = 0 + + start = 1 + counter = 0 + for i in range(blen): + if buffer[i] != " ": + continue + counter += 1 + + data = buffer[start:i] + start = i + 1 + + if counter == 1: + tlen = int(data) + seek1 = start + elif counter == 2: + clen = int(data) + elif counter == 3: + ulen = int(data) + break + + # sanity check header values + if tlen == 0 or clen == 0 or ulen == 0 or start == 1 or seek1 == 0: + return None + + # copy raw header + hlen = start - seek1 + header = buffer[start : start + hlen] + + # compute index of the first compressed byte + start += hlen + + # perform real decompression + clone = header + str(lz4.block.decompress(buffer[start:])) + + # sanity check result + if len(clone) != ulen + hlen: + return None + + return clone + + def _internal_reconnect(self, buffer: str) -> bool: + return True + + def _internal_parse_array(self, buffer: str) -> list: + start = 0 + sqlite_number = self._internal_parse_number(buffer, start) + n = sqlite_number.value + start = sqlite_number.cstart + + r = [] + for i in range(n): + sqcloud_value = self._internal_parse_value(buffer, start) + start += sqcloud_value.cellsize + r.append(sqcloud_value.value) + + return r + + def _internal_parse_value(self, buffer: str, index: int = 0) -> SQCloudValue: + sqcloud_value = SQCloudValue() + len = 0 + cellsize = 0 + + # handle special NULL value case + if buffer is None or buffer[index] == SQCLOUD_CMD.NULL.value: + len = 0 + if cellsize is not None: + cellsize = 2 + + sqcloud_value.len = len + sqcloud_value.cellsize = cellsize + + return sqcloud_value + + sqcloud_number = self._internal_parse_number(buffer, index + 1) + blen = sqcloud_number.value + cstart = sqcloud_number.cstart + + # handle decimal/float cases + if buffer[index] == SQCLOUD_CMD.INT.value or buffer[index] == SQCLOUD_CMD.FLOAT.value: + nlen = cstart - index + len = nlen - 2 + cellsize = nlen + + sqcloud_value.value = buffer[index + 1 : index + 1 + len] + sqcloud_value.len + sqcloud_value.cellsize = cellsize + + return sqcloud_value + + len = blen - 1 if buffer[index] == SQCLOUD_CMD.ZEROSTRING.value else blen + cellsize = blen + cstart - index + + sqcloud_value.value = buffer[cstart : cstart + len] + sqcloud_value.len = len + sqcloud_value.cellsize = cellsize + + return sqcloud_value diff --git a/src/sqlitecloud/types.py b/src/sqlitecloud/types.py index f7c3902..3e223e2 100644 --- a/src/sqlitecloud/types.py +++ b/src/sqlitecloud/types.py @@ -2,6 +2,8 @@ from typing import List, Optional from enum import Enum +from click import Option + class SQCLOUD_VALUE_TYPE(Enum): VALUE_INTEGER = 1 @@ -174,3 +176,14 @@ def __init__(self) -> None: self.value: int = 0 self.cstart: int = 0 self.extcode: int = None + + +class SQCloudValue: + """ + Represents the parse value. + """ + + def __init__(self) -> None: + self.value: Optional[str] = None + self.len: int = 0 + self.cellsize: int = 0 diff --git a/src/tests/integration/test_driver.py b/src/tests/integration/test_driver.py index f920bbf..f397f28 100644 --- a/src/tests/integration/test_driver.py +++ b/src/tests/integration/test_driver.py @@ -1,20 +1,85 @@ +from binhex import hexbin import os from sqlitecloud.client import SqliteCloudClient from sqlitecloud.driver import Driver, SQCloudConnect from sqlitecloud.types import SQCloudConfig, SqliteCloudAccount +import pytest +import binascii -# class TestDriver: -# def test_internal_socket_read_empty_stream(self): -# driver = Driver() +class TestDriver: + @pytest.fixture( + params=[ + (":0 ", 0, 0, 3), + (":123 ", 123, 0, 5), + (",123.456 ", 1230456, 0, 9), + ("-1:1234 ", 1, 1234, 8), + ("-0:0 ", 0, 0, 5), + ("-123:456 ", 123, 456, 9), + ("-123: ", 123, 0, 6), + ("-1234:5678 ", 1234, 5678, 11), + ("-1234: ", 1234, 0, 7), + ] + ) + def number_data(self, request): + return request.param -# config = SQCloudConfig() -# config.account = SqliteCloudAccount() -# config.account.username = os.getenv("SQLITE_USER") -# config.account.password = os.getenv("SQLITE_PASSWORD") + def test_parse_number(self, number_data): + driver = Driver() + buffer, expected_value, expected_extcode, expected_cstart = number_data + result = driver._internal_parse_number(buffer) -# conn = driver.connect("nejvjtcliz.sqlite.cloud", 8860, config) -# assert isinstance(conn, SQCloudConnect) + assert expected_value == result.value + assert expected_extcode == result.extcode + assert expected_cstart == result.cstart + + @pytest.fixture( + params=[ + ("+5 Hello", "Hello", 5, 8), + ("+11 Hello World", "Hello World", 11, 15), + ("!6 Hello0", "Hello", 5, 9), + ("+0 ", "", 0, 3), + (":5678 ", "5678", 0, 6), + (":0 ", "0", 0, 3), + (",3.14 ", "3.14", 0, 6), + (",0 ", "0", 0, 3), + (",0.0 ", "0.0", 0, 5), + ("_ ", None, 0, 2), + ], + ids=[ + "String", + "String with space", + "String zero-terminated", + "Empty string", + "Integer", + "Integer zero", + "Float", + "Float zero", + "Float 0.0", + "Null", + ], + ) + def value_data(self, request): + return request.param + + def test_parse_value(self, value_data): + driver = Driver() + buffer, expected_value, expected_len, expected_cellsize = value_data + + result = driver._internal_parse_value(buffer) + + assert expected_value == result.value + assert expected_len == result.len + assert expected_cellsize == result.cellsize + + def test_parse_array(self): + driver = Driver() + buffer = "=5 +11 Hello World:123456 ,3.1415 _ $10 0123456789" + expected_list = ["Hello World", "123456", "3.1415", None, "0123456789"] + + result = driver._internal_parse_array(buffer) + + assert expected_list == result + + # TODO: test compression -# buffer = driver._internal_socket_read(conn) -# assert buffer == "" From ed343dc37c506120e75f2aedfd70cff3b216757e Mon Sep 17 00:00:00 2001 From: Daniele Briggi Date: Wed, 15 May 2024 09:17:10 +0000 Subject: [PATCH 3/7] Parse rowset and tests --- src/sqlitecloud/client.py | 125 ++++++++--- src/sqlitecloud/driver.py | 324 +++++++++++++++++++++------ src/sqlitecloud/resultset.py | 161 ++++++------- src/sqlitecloud/types.py | 51 ++--- src/tests/integration/test_client.py | 66 +++++- src/tests/integration/test_driver.py | 13 ++ src/tests/unit/test_client.py | 61 +++++ src/tests/unit/test_resultset.py | 91 ++++++++ 8 files changed, 660 insertions(+), 232 deletions(-) create mode 100644 src/tests/unit/test_client.py create mode 100644 src/tests/unit/test_resultset.py diff --git a/src/sqlitecloud/client.py b/src/sqlitecloud/client.py index e126a0d..d281654 100644 --- a/src/sqlitecloud/client.py +++ b/src/sqlitecloud/client.py @@ -2,9 +2,16 @@ """ from typing import Any, List, Optional +from urllib import parse from sqlitecloud.driver import Driver -from sqlitecloud.types import SQCloudConfig, SQCloudConnect, SqliteCloudAccount +from sqlitecloud.resultset import SqliteCloudResultSet +from sqlitecloud.types import ( + SQCloudConfig, + SQCloudConnect, + SQCloudException, + SqliteCloudAccount, +) class SqliteCloudClient: @@ -12,6 +19,8 @@ class SqliteCloudClient: Client to connect to SqliteCloud """ + SQLITE_DEFAULT_PORT = 8860 + def __init__( self, cloud_account: Optional[SqliteCloudAccount] = None, @@ -28,22 +37,15 @@ def __init__( """ self.driver = Driver() - - self.hostname: str = '' - self.port: int = 8860 - + self.config = SQCloudConfig() # for pb in pub_subs: # self._pub_sub_cbs.append(("channel1", SQCloudPubSubCB(pb))) if connection_str: - # TODO: parse connection string to create the config - self.config = SQCloudConfig() + self.config = self._parse_connection_string(connection_str) elif cloud_account: self.config.account = cloud_account - self.hostname = cloud_account.hostname - self.port = cloud_account.port - else: raise Exception("Missing connection parameters") @@ -56,7 +58,9 @@ def open_connection(self) -> SQCloudConnect: Raises: Exception: If an error occurs while opening the connection. """ - connection = self.driver.connect(self.hostname, self.port, self.config) + connection = self.driver.connect( + self.config.account.hostname, self.config.account.port, self.config + ) # SQCloudExec(connection, f"USE DATABASE {self.dbname};") @@ -75,9 +79,9 @@ def disconnect(self, conn: SQCloudConnect) -> None: """ self.driver.disconnect(conn) - # def exec_query( - # self, query: str, conn: SQCloudConnect = None - # ) -> SqliteCloudResultSet: + def exec_query( + self, query: str, conn: SQCloudConnect = None + ) -> SqliteCloudResultSet: """Executes a SQL query on the SQLite database. Args: @@ -86,31 +90,82 @@ def disconnect(self, conn: SQCloudConnect) -> None: Returns: SqliteCloudResultSet: The result set of the executed query. """ - # print(query) - # # pylint: disable=unused-variable - # local_conn, close_at_end = ( - # (conn, False) if conn else (self.open_connection(), True) - # ) - # result: SQCloudResult = SQCloudExec(local_conn, self._encode_str_to_c(query)) - # self._check_connection(local_conn) - # return SqliteCloudResultSet(result) - # pass + if not conn: + conn = self.open_connection() + + result = self.driver.execute(query, conn) + return SqliteCloudResultSet(result) # def exec_statement( # self, query: str, values: List[Any], conn: SQCloudConnect = None # ) -> SqliteCloudResultSet: - # local_conn = conn if conn else self.open_connection() - # result: SQCloudResult = SQCloudExecArray( - # local_conn, - # self._encode_str_to_c(query), - # [SqlParameter(self._encode_str_to_c(str(v)), v) for v in values], - # ) - # if SQCloudResultIsError(result): - # raise Exception( - # "Query error: " + str(SQCloudResultDump(local_conn, result)) - # ) - # return SqliteCloudResultSet(result) - # pass + # local_conn = conn if conn else self.open_connection() + # result: SQCloudResult = SQCloudExecArray( + # local_conn, + # self._encode_str_to_c(query), + # [SqlParameter(self._encode_str_to_c(str(v)), v) for v in values], + # ) + # if SQCloudResultIsError(result): + # raise Exception( + # "Query error: " + str(SQCloudResultDump(local_conn, result)) + # ) + # return SqliteCloudResultSet(result) + # pass def sendblob(self): pass + + def _parse_connection_string(self, connection_string) -> SQCloudConfig: + # URL STRING FORMAT + # sqlitecloud://user:pass@host.com:port/dbname?timeout=10&key2=value2&key3=value3 + # or sqlitecloud://host.sqlite.cloud:8860/dbname?apikey=zIiAARzKm9XBVllbAzkB1wqrgijJ3Gx0X5z1A4m4xBA + + config = SQCloudConfig() + config.account = SqliteCloudAccount() + + try: + params = parse.urlparse(connection_string) + + options = {} + query = params.query + options = parse.parse_qs(query) + for option, values in options.items(): + opt = option.lower() + value = values.pop() + + if value.lower() in ["true", "false"]: + value = bool(value) + elif value.isdigit(): + value = int(value) + else: + value = value + + if hasattr(config, opt): + setattr(config, opt, value) + elif hasattr(config.account, opt): + setattr(config.account, opt, value) + + # apikey or username/password is accepted + if not config.account.apikey: + config.account.username = ( + parse.unquote(params.username) if params.username else "" + ) + config.account.password = ( + parse.unquote(params.password) if params.password else "" + ) + + path = params.path + database = path.strip("/") + if database: + config.account.database = database + + config.account.hostname = params.hostname + config.account.port = ( + int(params.port) if params.port else self.SQLITE_DEFAULT_PORT + ) + + return config + except Exception as e: + raise SQCloudException( + f"Invalid connection string {connection_string}" + ) from e diff --git a/src/sqlitecloud/driver.py b/src/sqlitecloud/driver.py index 579d736..2626943 100644 --- a/src/sqlitecloud/driver.py +++ b/src/sqlitecloud/driver.py @@ -1,20 +1,26 @@ import ssl -from termios import CSTART from typing import Optional import lz4.block +from sqlitecloud.resultset import SQCloudResult, SqliteCloudResultSet from sqlitecloud.types import ( SQCLOUD_CMD, SQCLOUD_INTERNAL_ERRCODE, + SQCLOUD_ROWSET, SQCloudConfig, SQCloudConnect, SQCloudException, SQCloudNumber, + SQCloudRowsetSignature, SQCloudValue, ) import socket - +import sys class Driver: + def __init__(self) -> None: + # used for parsing chunked rowset + self._rowset: SqliteCloudResultSet = None + def connect( self, hostname: str, port: int, config: SQCloudConfig ) -> SQCloudConnect: @@ -36,10 +42,10 @@ def connect( sock.settimeout(config.connect_timeout) if not config.insecure: - context = ssl.create_default_context(cafile=config.tls_root_certificate) - if config.tls_certificate: + context = ssl.create_default_context(cafile=config.root_certificate) + if config.certificate: context.load_cert_chain( - certfile=config.tls_certificate, keyfile=config.tls_certificate_key + certfile=config.certificate, keyfile=config.certificate_key ) sock = context.wrap_socket(sock, server_hostname=hostname) @@ -48,7 +54,7 @@ def connect( sock.connect((hostname, port)) except Exception as e: errmsg = f"An error occurred while initializing the socket." - raise SQCloudException(errmsg, -1, exception=e) + raise SQCloudException(errmsg, -1) from e connection = SQCloudConnect() connection.socket = sock @@ -63,8 +69,8 @@ def disconnect(self, conn: SQCloudConnect): conn.socket.close() conn.socket = None - def execute(self): - pass + def execute(self, command: str, connection: SQCloudConnect) -> SQCloudResult: + return self._internal_run_command(connection, command) def sendblob(self): pass @@ -115,12 +121,13 @@ def _internal_config_apply( def _internal_run_command(self, connection: SQCloudConnect, buffer: str) -> None: self._internal_socket_write(connection, buffer) - self._internal_socket_read(connection) + return self._internal_socket_read(connection) def _internal_socket_write(self, connection: SQCloudConnect, buffer: str) -> None: # compute header delimit = "$" if connection.isblob else "+" - buffer_len = len(buffer) + bytebuffer = buffer.encode() + buffer_len = len(bytebuffer) header = f"{delimit}{buffer_len} " # write header @@ -130,8 +137,7 @@ def _internal_socket_write(self, connection: SQCloudConnect, buffer: str) -> Non raise SQCloudException( "An error occurred while writing header data.", SQCLOUD_INTERNAL_ERRCODE.INTERNAL_ERRCODE_NETWORK, - exc, - ) + ) from exc # write buffer if buffer_len == 0: @@ -142,24 +148,25 @@ def _internal_socket_write(self, connection: SQCloudConnect, buffer: str) -> Non raise SQCloudException( "An error occurred while writing data.", SQCLOUD_INTERNAL_ERRCODE.INTERNAL_ERRCODE_NETWORK, - exc, - ) - - def _internal_socket_read(self, connection: SQCloudConnect) -> any: - buffer: str = "" - buffer_size: int = 1024 - nread: int = 0 + ) from exc + def _internal_socket_read(self, connection: SQCloudConnect) -> SQCloudResult: + buffer = "" + buffer_size = 8192 + nread = 0 + bytebuffer = b"" try: while True: data = connection.socket.recv(buffer_size) if not data: - break + raise SQCloudException('Incomplete response from server.', -1) - # update buffers - data = data.decode() - buffer += data + # the expected data length to read + # matches the string size before decoding it nread += len(data) + # update buffers + buffer += data.decode() + bytebuffer += data c = buffer[0] @@ -168,31 +175,36 @@ def _internal_socket_read(self, connection: SQCloudConnect) -> any: or c == SQCLOUD_CMD.FLOAT.value or c == SQCLOUD_CMD.NULL.value ): - if buffer[nread - 1] != " ": + if not buffer.endswith(' '): continue elif c == SQCLOUD_CMD.ROWSET_CHUNK.value: - isEndOfChunk = buffer.endswith(SQCLOUD_CMD.CHUNKS_END.value) + isEndOfChunk = buffer.endswith(SQCLOUD_ROWSET.CHUNKS_END.value) if not isEndOfChunk: continue else: - n: SQCloudNumber = self._internal_parse_number(buffer) - can_be_zerolength = c == SQCLOUD_CMD.BLOB.value or c == SQCLOUD_CMD.STRING.value - if n.value == 0 and not can_be_zerolength: + sqcloud_number = self._internal_parse_number(buffer) + n = sqcloud_number.value + cstart = sqcloud_number.cstart + + can_be_zerolength = ( + c == SQCLOUD_CMD.BLOB.value or c == SQCLOUD_CMD.STRING.value + ) + if n == 0 and not can_be_zerolength: continue - if n.value + n.cstart != nread: + if n + cstart != nread: continue - return self._internal_parse_buffer(buffer, nread) + return self._internal_parse_buffer(buffer, len(buffer)) except Exception as exc: raise SQCloudException( "An error occurred while reading data from the socket.", SQCLOUD_INTERNAL_ERRCODE.INTERNAL_ERRCODE_NETWORK, - exc, - ) + ) from exc def _internal_parse_number(self, buffer: str, index: int = 1) -> SQCloudNumber: sqcloud_number = SQCloudNumber() + sqcloud_number.value = 0 extvalue = 0 isext = False blen = len(buffer) @@ -212,17 +224,18 @@ def _internal_parse_number(self, buffer: str, index: int = 1) -> SQCloudNumber: sqcloud_number.extcode = extvalue return sqcloud_number - v = int(c) if c.isdigit() else 0 + val = int(c) if c.isdigit() else 0 # compute numeric value if isext: - extvalue = (extvalue * 10) + v + extvalue = (extvalue * 10) + val else: - sqcloud_number.value = (sqcloud_number.value * 10) + v + sqcloud_number.value = (sqcloud_number.value * 10) + val + sqcloud_number.value = 0 return sqcloud_number - def _internal_parse_buffer(self, buffer: str, blen: int) -> any: + def _internal_parse_buffer(self, buffer: str, blen: int) -> SQCloudResult: # possible return values: # True => OK # False => error @@ -235,17 +248,16 @@ def _internal_parse_buffer(self, buffer: str, blen: int) -> any: # check OK value if buffer == "+2 OK": - return True + return SQCloudResult(True) cmd = buffer[0] # check for compressed result if cmd == SQCLOUD_CMD.COMPRESSED.value: - # TODO: use exception buffer = self._internal_uncompress_data(buffer, blen) if buffer is None: raise SQCloudException( - f"An error occurred while decompressing the input buffer of len {len}.", + f"An error occurred while decompressing the input buffer of len {blen}.", -1, ) @@ -262,30 +274,30 @@ def _internal_parse_buffer(self, buffer: str, blen: int) -> any: ]: cstart = 0 sqlite_number = self._internal_parse_number(buffer, cstart) - len = sqlite_number.value - if len == 0: - return "" + len_ = sqlite_number.value + if len_ == 0: + return SQCloudResult("") if cmd == SQCLOUD_CMD.ZEROSTRING.value: - len -= 1 - clone = buffer[cstart : cstart + len] + len_ -= 1 + clone = buffer[cstart : cstart + len_] if cmd == SQCLOUD_CMD.COMMAND.value: - return self._internal_run_command(clone) + return SQCloudResult(self._internal_run_command(clone)) elif cmd == SQCLOUD_CMD.PUBSUB.value: # TODO return self._internal_setup_pubsub(clone) elif cmd == SQCLOUD_CMD.RECONNECT.value: - return self._internal_reconnect(clone) + return SQCloudResult(self._internal_reconnect(clone)) elif cmd == SQCLOUD_CMD.ARRAY.value: - return self._internal_parse_array(clone) + return SQCloudResult(self._internal_parse_array(clone)) return clone elif cmd == SQCLOUD_CMD.ERROR.value: # -LEN ERRCODE:EXTCODE ERRMSG sqlite_number = self._internal_parse_number(buffer) - len = sqlite_number.value + len_ = sqlite_number.value cstart = sqlite_number.cstart clone = buffer[cstart:] @@ -295,7 +307,7 @@ def _internal_parse_buffer(self, buffer: str, blen: int) -> any: errcode = sqlite_number.value xerrcode = sqlite_number.extcode - len -= cstart2 + len_ -= cstart2 errmsg = clone[cstart2:] raise SQCloudException(errmsg, errcode, xerrcode) @@ -303,27 +315,30 @@ def _internal_parse_buffer(self, buffer: str, blen: int) -> any: elif cmd in [SQCLOUD_CMD.ROWSET.value, SQCLOUD_CMD.ROWSET_CHUNK.value]: # CMD_ROWSET: *LEN 0:VERSION ROWS COLS DATA # CMD_ROWSET_CHUNK: /LEN IDX:VERSION ROWS COLS DATA - # TODO: return a custom object - start = self._internal_parse_rowset_signature( - buffer, len, idx, version, nrows, ncols - ) - if start < 0: - return False + rowset_signature = self._internal_parse_rowset_signature(buffer) + if rowset_signature.start < 0: + raise SQCloudException("Cannot parse rowset signature") # check for end-of-chunk condition - if start == 0 and version == 0: - rowset = self.rowset - self.rowset = None + if rowset_signature.start == 0 and rowset_signature.version == 0: + rowset = self._rowset + self._rowset = None return rowset rowset = self._internal_parse_rowset( - buffer, start, idx, version, nrows, ncols + buffer, + rowset_signature.start, + rowset_signature.idx, + rowset_signature.version, + rowset_signature.nrows, + rowset_signature.ncols, ) # continue parsing next chunk in the buffer - buffer = buffer[len + len("/{} ".format(len)) :] + sign_len = rowset_signature.len + buffer = buffer[sign_len + len(f"/{sign_len} ") :] if buffer: - return self.internal_parse_buffer(buffer, len(buffer)) + return self._internal_parse_buffer(buffer, len(buffer)) return rowset @@ -331,16 +346,19 @@ def _internal_parse_buffer(self, buffer: str, blen: int) -> any: return None elif cmd in [SQCLOUD_CMD.INT.value, SQCLOUD_CMD.FLOAT.value]: - # TODO - clone = self._internal_parse_value(buffer, blen) + sqcloud_value = self._internal_parse_value(buffer, blen) + clone = sqcloud_value.value + if clone is None: - return 0 + return SQCloudResult(0) + if cmd == SQCLOUD_CMD.INT.value: - return int(clone) - return float(clone) + return SQCloudResult(int(clone)) + return SQCloudResult(float(clone)) elif cmd == SQCLOUD_CMD.RAWJSON.value: - return None + # TODO: isn't implemented in C? + return SQCloudResult(None) return None @@ -441,7 +459,10 @@ def _internal_parse_value(self, buffer: str, index: int = 0) -> SQCloudValue: cstart = sqcloud_number.cstart # handle decimal/float cases - if buffer[index] == SQCLOUD_CMD.INT.value or buffer[index] == SQCLOUD_CMD.FLOAT.value: + if ( + buffer[index] == SQCLOUD_CMD.INT.value + or buffer[index] == SQCLOUD_CMD.FLOAT.value + ): nlen = cstart - index len = nlen - 2 cellsize = nlen @@ -460,3 +481,170 @@ def _internal_parse_value(self, buffer: str, index: int = 0) -> SQCloudValue: sqcloud_value.cellsize = cellsize return sqcloud_value + + def _internal_parse_rowset_signature(self, buffer: str) -> SQCloudRowsetSignature: + # ROWSET: *LEN 0:VERS NROWS NCOLS DATA + # ROWSET in CHUNK: /LEN IDX:VERS NROWS NCOLS DATA + + signature = SQCloudRowsetSignature() + + # check for end-of-chunk condition + if buffer == SQCLOUD_ROWSET.CHUNKS_END: + signature.version = 0 + signature.start = 0 + return signature + + start = 1 + counter = 0 + n = len(buffer) + for i in range(n): + if buffer[i] != " ": + continue + counter += 1 + + data = buffer[start:i] + start = i + 1 + + if counter == 1: + signature.len = int(data) + elif counter == 2: + # idx:vers + values = data.split(":") + signature.idx = int(values[0]) + signature.version = int(values[1]) + elif counter == 3: + signature.nrows = int(data) + elif counter == 4: + signature.ncols = int(data) + + signature.start = start + + return signature + else: + return SQCloudRowsetSignature() + return SQCloudRowsetSignature() + + def _internal_parse_rowset( + self, buffer: str, start: int, idx: int, version: int, nrows: int, ncols: int + ) -> SQCloudResult: + rowset = None + n = start + ischunk = buffer[0] == SQCLOUD_CMD.ROWSET_CHUNK.value + + # idx == 0 means first (and only) chunk for rowset + # idx == 1 means first chunk for chunked rowset + first_chunk = (ischunk and idx == 1) or (not ischunk and idx == 0) + if first_chunk: + rowset = SQCloudResult() + rowset.nrows = nrows + rowset.ncols = ncols + rowset.version = version + rowset.data = [] + if ischunk: + self._rowset = rowset + n = self._internal_parse_rowset_header(rowset, buffer, start) + if n <= 0: + raise SQCloudException("Cannot parse rowset header") + else: + rowset = self._rowset + rowset.nrows += nrows + + # parse values + self._internal_parse_rowset_values(rowset, buffer, n, nrows * ncols) + + return rowset + + def _internal_parse_rowset_header( + self, rowset: SQCloudResult, buffer: str, start: int + ) -> int: + ncols = rowset.ncols + + # parse column names + rowset.colname = [] + for i in range(ncols): + sqcloud_number = self._internal_parse_number(buffer, start) + number_len = sqcloud_number.value + cstart = sqcloud_number.cstart + value = buffer[cstart : cstart + number_len] + rowset.colname.append(value) + start = cstart + number_len + + if rowset.version == 1: + return start + + if rowset.version != 2: + raise SQCloudException( + f"Rowset version {rowset.version} is not supported.", -1 + ) + + # parse declared types + rowset.decltype = [] + for i in range(ncols): + sqcloud_number = self._internal_parse_number(buffer, start) + number_len = sqcloud_number.value + cstart = sqcloud_number.cstart + value = buffer[cstart : cstart + number_len] + rowset.decltype.append(value) + start = cstart + number_len + + # parse database names + rowset.dbname = [] + for i in range(ncols): + sqcloud_number = self._internal_parse_number(buffer, start) + number_len = sqcloud_number.value + cstart = sqcloud_number.cstart + value = buffer[cstart : cstart + number_len] + rowset.dbname.append(value) + start = cstart + number_len + + # parse table names + rowset.tblname = [] + for i in range(ncols): + sqcloud_number = self._internal_parse_number(buffer, start) + number_len = sqcloud_number.value + cstart = sqcloud_number.cstart + value = buffer[cstart : cstart + number_len] + rowset.tblname.append(value) + start = cstart + number_len + + # parse column original names + rowset.origname = [] + for i in range(ncols): + sqcloud_number = self._internal_parse_number(buffer, start) + number_len = sqcloud_number.value + cstart = sqcloud_number.cstart + value = buffer[cstart : cstart + number_len] + rowset.origname.append(value) + start = cstart + number_len + + # parse not null flags + rowset.notnull = [] + for i in range(ncols): + sqcloud_number = self._internal_parse_number(buffer, start) + rowset.notnull.append(sqcloud_number.value) + start = sqcloud_number.cstart + + # parse primary key flags + rowset.prikey = [] + for i in range(ncols): + sqcloud_number = self._internal_parse_number(buffer, start) + rowset.prikey.append(sqcloud_number.value) + start = sqcloud_number.cstart + + # parse autoincrement flags + rowset.autoinc = [] + for i in range(ncols): + sqcloud_number = self._internal_parse_number(buffer, start) + rowset.autoinc.append(sqcloud_number.value) + start = sqcloud_number.cstart + + return start + + def _internal_parse_rowset_values( + self, rowset: SQCloudResult, buffer: str, start: int, bound: int + ): + # loop to parse each individual value + for i in range(bound): + sqcloud_value = self._internal_parse_value(buffer, start) + start += sqcloud_value.cellsize + rowset.data.append(sqcloud_value.value) diff --git a/src/sqlitecloud/resultset.py b/src/sqlitecloud/resultset.py index 76e61e8..1174b4e 100644 --- a/src/sqlitecloud/resultset.py +++ b/src/sqlitecloud/resultset.py @@ -1,107 +1,76 @@ -from typing import Any, Callable, Dict, List, Optional -from sqlitecloud.driver import ( - SQCloudResultFloat, - SQCloudResultFree, - SQCloudResultInt32, - SQCloudResultIsError, - SQCloudResultType, - SQCloudRowsetCols, - SQCloudRowsetColumnName, - SQCloudRowsetInt32Value, - SQCloudRowsetFloatValue, - SQCloudRowsetRows, - SQCloudRowsetValue, - SQCloudRowsetValueType, -) -from sqlitecloud.wrapper_types import ( - SQCLOUD_VALUE_TYPE, - SQCLOUD_RESULT_TYPE, - SQCloudResult, -) +from typing import Any, Dict, List, Optional -class SqliteCloudResultSet: - _result: Optional[SQCloudResult] = None - _data: List[Dict[str, Any]] = [] +class SQCloudResult: + def __init__(self, result: Optional[any] = None) -> None: + self.nrows: int = 0 + self.ncols: int = 0 + self.version: int = 0 + # table values are stored in 1-dimensional array + self.data: List[Any] = [] + self.colname: List[str] = [] + self.decltype: List[str] = [] + self.dbname: List[str] = [] + self.tblname: List[str] = [] + self.origname: List[str] = [] + self.notnull: List[str] = [] + self.prikey: List[str] = [] + self.autoinc: List[str] = [] + + self.is_result: bool = False + + if result: + self.init_data(result) + def init_data(self, result: any) -> None: + self.nrows = 1 + self.ncols = 1 + # TODO: what if result is array? + self.data = [result] + self.is_result = True + + +class SqliteCloudResultSet: def __init__(self, result: SQCloudResult) -> None: - rs_type = SQCloudResultType( - result, - ) - match rs_type: - case SQCLOUD_RESULT_TYPE.RESULT_ROWSET: - self._init_resultset(result) - case SQCLOUD_RESULT_TYPE.RESULT_OK: - self.init_data(result, self._extract_ok_data) - case SQCLOUD_RESULT_TYPE.RESULT_ERROR: - self.init_data(result, self._extract_error_data) - case SQCLOUD_RESULT_TYPE.RESULT_FLOAT: - self.init_data(result, self._extract_float_data) - case SQCLOUD_RESULT_TYPE.RESULT_INTEGER: - self.init_data(result, self._extract_int_data) - - def init_data( - self, - result: SQCloudResult, - extract_fn: Callable[[SQCloudResult], List[Dict[str, Any]]], - ): - self.row = 0 - self.rows = 1 - self._data = extract_fn(result) - - def _extract_ok_data(self, result: SQCloudResult): - return [{"result": not SQCloudResultIsError(result)}] - - def _extract_error_data(self, result: SQCloudResult): - return [{"result": SQCloudResultIsError(result)}] - - def _extract_float_data(self, result: SQCloudResult): - return [{"result": SQCloudResultFloat(result)}] - - def _extract_int_data(self, result: SQCloudResult): - return [{"result": SQCloudResultInt32(result)}] - - def _init_resultset(self, result): - self._result = result - self.row = 0 - self.rows = SQCloudRowsetRows(result) - self.cols = SQCloudRowsetCols(self._result) - self.col_names = list( - SQCloudRowsetColumnName(self._result, i) for i in range(self.cols) - ) + self._iter_row: int = 0 + self._result: SQCloudResult = result + + def __getattr__(self, attr: str) -> Any: + return getattr(self._result, attr) def __iter__(self): return self def __next__(self): - if self._result: - if self.row < self.rows: - out: Dict[str, any] = {} - for col in range(self.cols): - col_type = SQCloudRowsetValueType( - self._result, self.row, col - ).value - - data = self._resolve_type(col, col_type) - out[self.col_names[col]] = data - self.row += 1 - return out - elif self._data: - if self.row < self.rows: - out: Dict[str, any] = self._data[self.row] - self.row += 1 - return out - - SQCloudResultFree(self._result) + if self._result.data and self._iter_row < self._result.nrows: + out: Dict[str, any] = {} + + if self._result.is_result: + out = {"result": self.get_value(0, 0)} + self._iter_row += 1 + else: + for col in range(self._result.ncols): + out[self.get_name(col)] = self.get_value(self._iter_row, col) + self._iter_row += 1 + + return out + raise StopIteration - def _resolve_type(self, col, col_type): - match col_type: - case SQCLOUD_VALUE_TYPE.VALUE_INTEGER: - return SQCloudRowsetInt32Value(self._result, self.row, col) - case SQCLOUD_VALUE_TYPE.VALUE_FLOAT: - return SQCloudRowsetFloatValue(self._result, self.row, col) - case SQCLOUD_VALUE_TYPE.VALUE_TEXT: - return SQCloudRowsetValue(self._result, self.row, col) - case SQCLOUD_VALUE_TYPE.VALUE_BLOB: - return SQCloudRowsetValue(self._result, self.row, col) + def _compute_index(self, row: int, col: int) -> int: + if row < 0 or row >= self._result.nrows: + return -1 + if col < 0 or col >= self._result.ncols: + return -1 + return row * self._result.ncols + col + + def get_value(self, row: int, col: int) -> any: + index = self._compute_index(row, col) + if index < 0 or not self._result.data or index >= len(self._result.data): + return None + return self._result.data[index] + + def get_name(self, col: int) -> str: + if col < 0 or col >= self._result.ncols: + return None + return self._result.colname[col] diff --git a/src/sqlitecloud/types.py b/src/sqlitecloud/types.py index 3e223e2..6d974af 100644 --- a/src/sqlitecloud/types.py +++ b/src/sqlitecloud/types.py @@ -1,9 +1,7 @@ from enum import Enum -from typing import List, Optional +from typing import Optional from enum import Enum -from click import Option - class SQCLOUD_VALUE_TYPE(Enum): VALUE_INTEGER = 1 @@ -45,6 +43,10 @@ class SQCLOUD_CMD(Enum): ARRAY = "=" +class SQCLOUD_ROWSET(Enum): + CHUNKS_END = "/6 0 0 0 " + + class SQCLOUD_INTERNAL_ERRCODE(Enum): INTERNAL_ERRCODE_NONE = 0 # INTERNAL_ERRCODE_GENERIC = 100000 @@ -58,6 +60,20 @@ class SQCLOUD_INTERNAL_ERRCODE(Enum): # INTERNAL_ERRCODE_SOCKCLOSED = 100008 +class SQCloudRowsetSignature: + """ + Represents the parsed signature for a rowset. + """ + + def __init__(self) -> None: + self.start: int = -1 + self.len: int = 0 + self.idx: int = 0 + self.version: int = 0 + self.nrows: int = 0 + self.ncols: int = 0 + + class SqliteCloudAccount: def __init__(self): # User name is required unless connectionstring is provided @@ -82,7 +98,6 @@ def __init__(self): self.socket: any = None - self.chunk: SQCloudResult self.config: SQCloudConfig self.isblob: bool = False @@ -100,11 +115,6 @@ def __init__(self): # todo: which is the proper type? self.data: any - # self.errmsg: str = '' - # self.errcode: int # error code - # self.extcode: int # extended error code - # self.offcode: int # offset error code - class SQCloudConfig: def __init__(self) -> None: @@ -131,9 +141,9 @@ def __init__(self) -> None: self.no_verify_certificate = False # Filepath of certificates - self.tls_root_certificate: str = None - self.tls_certificate: str = None - self.tls_certificate_key: str = None + self.root_certificate: str = None + self.certificate: str = None + self.certificate_key: str = None # Server should send BLOB columns self.noblob = False @@ -145,26 +155,13 @@ def __init__(self) -> None: self.maxrowset = 0 -class SQCloudResult: - def __init__(self) -> None: - self.num_rows: int = 0 - self.num_columns: int = 0 - self.column_names: List[str] = [] - self.column_types: List[int] = [] - self.data: List[any] = [] - - class SQCloudException(Exception): def __init__( - self, message: str, code: int, xerrcode=0, exception: Optional[Exception] = None + self, message: str, code: int, xerrcode=0 ) -> None: self.errmsg = str(message) - if exception: - self.errmsg += " " + str(exception) - self.errcode = code self.xerrcode = xerrcode - self.exception = exception class SQCloudNumber: @@ -173,7 +170,7 @@ class SQCloudNumber: """ def __init__(self) -> None: - self.value: int = 0 + self.value: Optional[int] = None self.cstart: int = 0 self.extcode: int = None diff --git a/src/tests/integration/test_client.py b/src/tests/integration/test_client.py index 3e4b39d..3553f70 100644 --- a/src/tests/integration/test_client.py +++ b/src/tests/integration/test_client.py @@ -1,4 +1,6 @@ +from operator import rshift import os +from typing import Union import pytest from sqlitecloud.client import SqliteCloudClient @@ -6,7 +8,7 @@ class TestClient: - @pytest.fixture + @pytest.fixture() def sqlitecloud_connection(self): account = SqliteCloudAccount() account.username = os.getenv("SQLITE_USER") @@ -16,10 +18,11 @@ def sqlitecloud_connection(self): account.port = 8860 client = SqliteCloudClient(cloud_account=account) + connection = client.open_connection() assert isinstance(connection, SQCloudConnect) - yield connection + yield (connection, client) client.disconnect(connection) @@ -50,14 +53,65 @@ def test_connection_with_apikey(self): client.disconnect(conn) def test_connection_without_credentials_and_apikey(self): - #pytest.raises(SQCloudException) - account = SqliteCloudAccount() account.database = os.getenv("SQLITE_DB") account.hostname = os.getenv("SQLITE_HOST") account.port = 8860 - + client = SqliteCloudClient(cloud_account=account) + + with pytest.raises(SQCloudException): + client.open_connection() + + def test_connect_with_string(self): + connection_string = os.getenv("SQLITE_CONNECTION_STRING") + + client = SqliteCloudClient(connection_str=connection_string) - client.open_connection() + conn = client.open_connection() + assert isinstance(conn, SQCloudConnect) + + client.disconnect(conn) + + def test_connect_with_string_with_credentials(self): + connection_string = f"sqlitecloud://{os.getenv('SQLITE_USER')}:{os.getenv('SQLITE_PASSWORD')}@{os.getenv('SQLITE_HOST')}/{os.getenv('SQLITE_DB')}" + + client = SqliteCloudClient(connection_str=connection_string) + + conn = client.open_connection() + assert isinstance(conn, SQCloudConnect) + + client.disconnect(conn) + + def test_select(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + + result = client.exec_query("SELECT 'Hello'", connection) + + assert result + assert False == result.is_result + assert 1 == result.nrows + assert 1 == result.ncols + assert 'Hello' == result.get_value(0, 0) + + def test_rowset_data(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + result = client.exec_query("SELECT AlbumId FROM albums LIMIT 2", connection) + assert result + assert 2 == result.nrows + assert 1 == result.ncols + assert 2 == result.version + + def test_get_value(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + result = client.exec_query("SELECT * FROM albums", connection) + assert result + assert '1' == result.get_value(0, 0) + assert 'For Those About To Rock We Salute You' == result.get_value(0, 1) + assert '2' == result.get_value(1, 0) + def test_get_utf8_value(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + result = client.exec_query("SELECT 'Minha História'", connection) + assert result + assert 'Minha História' == result.get_value(0, 0) diff --git a/src/tests/integration/test_driver.py b/src/tests/integration/test_driver.py index f397f28..3476ee2 100644 --- a/src/tests/integration/test_driver.py +++ b/src/tests/integration/test_driver.py @@ -72,6 +72,7 @@ def test_parse_value(self, value_data): assert expected_len == result.len assert expected_cellsize == result.cellsize + # TODO def test_parse_array(self): driver = Driver() buffer = "=5 +11 Hello World:123456 ,3.1415 _ $10 0123456789" @@ -83,3 +84,15 @@ def test_parse_array(self): # TODO: test compression + def test_parse_rowset_signature(self): + driver = Driver() + buffer = "*35 0:1 1 2 +2 42+7 'hello':42 +5 hello" + + result = driver._internal_parse_rowset_signature(buffer) + + assert 12 == result.start + assert 35 == result.len + assert 0 == result.idx + assert 1 == result.version + assert 1 == result.nrows + assert 2 == result.ncols \ No newline at end of file diff --git a/src/tests/unit/test_client.py b/src/tests/unit/test_client.py new file mode 100644 index 0000000..3f89492 --- /dev/null +++ b/src/tests/unit/test_client.py @@ -0,0 +1,61 @@ +import pytest +from sqlitecloud.client import SqliteCloudClient +from sqlitecloud.types import SQCloudException + + +class TestClient: + def test_parse_connection_string_with_apikey(self): + connection_string = "sqlitecloud://user:pass@host.com:8860/dbname?apikey=abc123&timeout=10&compression=true" + client = SqliteCloudClient(connection_str=connection_string) + + assert not client.config.account.username + assert not client.config.account.password + assert "host.com" == client.config.account.hostname + assert 8860 == client.config.account.port + assert "dbname" == client.config.account.database + assert "abc123" == client.config.account.apikey + assert 10 == client.config.timeout + assert True == client.config.compression + + def test_parse_connection_string_with_credentials(self): + connection_string = "sqlitecloud://user:pass@host.com:8860" + client = SqliteCloudClient(connection_str=connection_string) + + assert "user" == client.config.account.username + assert "pass" == client.config.account.password + assert "host.com" == client.config.account.hostname + assert 8860 == client.config.account.port + assert not client.config.account.database + + def test_parse_connection_string_without_credentials(self): + connection_string = "sqlitecloud://host.com" + client = SqliteCloudClient(connection_str=connection_string) + + assert not client.config.account.username + assert not client.config.account.password + assert "host.com" == client.config.account.hostname + + + def test_parse_connection_string_with_all_parameters(self): + connection_string = "sqlitecloud://host.com:8860/dbname?apikey=abc123&compression=true&zerotext=true&memory=true&create=true&non_linearizable=true&insecure=true&no_verify_certificate=true&root_certificate=rootcert&certificate=cert&certificate_key=certkey&noblob=true&maxdata=10&maxrows=11&maxrowset=12" + + client = SqliteCloudClient(connection_str=connection_string) + + assert "host.com" == client.config.account.hostname + assert 8860 == client.config.account.port + assert "dbname" == client.config.account.database + assert "abc123" == client.config.account.apikey + assert True == client.config.compression + assert True == client.config.zerotext + assert True == client.config.memory + assert True == client.config.create + assert True == client.config.non_linearizable + assert True == client.config.insecure + assert True == client.config.no_verify_certificate + assert "rootcert" == client.config.root_certificate + assert "cert" == client.config.certificate + assert "certkey" == client.config.certificate_key + assert True == client.config.noblob + assert 10 == client.config.maxdata + assert 11 == client.config.maxrows + assert 12 == client.config.maxrowset \ No newline at end of file diff --git a/src/tests/unit/test_resultset.py b/src/tests/unit/test_resultset.py new file mode 100644 index 0000000..db7e6b7 --- /dev/null +++ b/src/tests/unit/test_resultset.py @@ -0,0 +1,91 @@ +import pytest +from sqlitecloud.resultset import SQCloudResult, SqliteCloudResultSet + + +class TestSqCloudResult: + def test_init_data(self): + result = SQCloudResult() + result.init_data(42) + assert 1 == result.nrows + assert 1 == result.ncols + assert [42] == result.data + assert True is result.is_result + + # TODO + def test_init_data_with_array(self): + result = SQCloudResult() + result.init_data([42, 43, 44]) + assert 1 == result.nrows + assert 1 == result.ncols + assert [42, 43, 44] == result.data + assert True is result.is_result + + def test_init_as_dataset(self): + result = SQCloudResult() + assert False is result.is_result + assert 0 == result.nrows + assert 0 == result.ncols + assert 0 == result.version + + +class TestSqliteCloudResultSet: + def test_next(self): + result = SQCloudResult(result=42) + result_set = SqliteCloudResultSet(result) + + assert {"result": 42} == next(result_set) + with pytest.raises(StopIteration): + next(result_set) + + def test_iter_result(self): + result = SQCloudResult(result=42) + result_set = SqliteCloudResultSet(result) + for row in result_set: + assert {"result": 42} == row + + def test_iter_rowset(self): + rowset = SQCloudResult() + rowset.nrows = 2 + rowset.ncols = 2 + rowset.colname = ["name", "age"] + rowset.data = ["John", 42, "Doe", 24] + rowset.version = 2 + result_set = SqliteCloudResultSet(rowset) + + out = [] + for row in result_set: + out.append(row) + + assert 2 == len(out) + assert {"name": "John", "age": 42} == out[0] + assert {"name": "Doe", "age": 24} == out[1] + + def test_get_value_with_rowset(self): + rowset = SQCloudResult() + rowset.nrows = 2 + rowset.ncols = 2 + rowset.colname = ["name", "age"] + rowset.data = ["John", 42, "Doe", 24] + rowset.version = 2 + result_set = SqliteCloudResultSet(rowset) + + assert "John" == result_set.get_value(0, 0) + assert 24 == result_set.get_value(1, 1) + assert None == result_set.get_value(2, 2) + + def test_get_value_array(self): + result = SQCloudResult(result=[1, 2, 3, 4, 5, 6]) + result_set = SqliteCloudResultSet(result) + assert 1 == result_set.get_value(0, 0) + assert 5 == result_set.get_value(1, 2) + assert 4 == result_set.get_value(2, 1) + assert None == result_set.get_value(3, 3) + + def test_get_colname(self): + result = SQCloudResult() + result.ncols = 2 + result.colname = ["name", "age"] + result_set = SqliteCloudResultSet(result) + assert "name" == result_set.get_name(0) + assert "age" == result_set.get_name(1) + assert None == result_set.get_name(2) From 663dac2f6b5cdb877531ed6ba9cd33d19dcff864 Mon Sep 17 00:00:00 2001 From: Daniele Briggi Date: Wed, 15 May 2024 10:15:31 +0000 Subject: [PATCH 4/7] Fix wrong parsing position --- src/sqlitecloud/driver.py | 101 +++++++++++++-------------- src/sqlitecloud/types.py | 7 +- src/tests/integration/test_client.py | 14 ++-- 3 files changed, 59 insertions(+), 63 deletions(-) diff --git a/src/sqlitecloud/driver.py b/src/sqlitecloud/driver.py index 2626943..ee726c1 100644 --- a/src/sqlitecloud/driver.py +++ b/src/sqlitecloud/driver.py @@ -16,6 +16,7 @@ import socket import sys + class Driver: def __init__(self) -> None: # used for parsing chunked rowset @@ -54,7 +55,7 @@ def connect( sock.connect((hostname, port)) except Exception as e: errmsg = f"An error occurred while initializing the socket." - raise SQCloudException(errmsg, -1) from e + raise SQCloudException(errmsg) from e connection = SQCloudConnect() connection.socket = sock @@ -119,15 +120,15 @@ def _internal_config_apply( if len(buffer) > 0: self._internal_run_command(connection, buffer) - def _internal_run_command(self, connection: SQCloudConnect, buffer: str) -> None: - self._internal_socket_write(connection, buffer) + def _internal_run_command(self, connection: SQCloudConnect, command: str) -> None: + self._internal_socket_write(connection, command) return self._internal_socket_read(connection) - def _internal_socket_write(self, connection: SQCloudConnect, buffer: str) -> None: + def _internal_socket_write(self, connection: SQCloudConnect, command: str) -> None: # compute header delimit = "$" if connection.isblob else "+" - bytebuffer = buffer.encode() - buffer_len = len(bytebuffer) + buffer = command.encode() + buffer_len = len(buffer) header = f"{delimit}{buffer_len} " # write header @@ -143,7 +144,7 @@ def _internal_socket_write(self, connection: SQCloudConnect, buffer: str) -> Non if buffer_len == 0: return try: - connection.socket.sendall(buffer.encode()) + connection.socket.sendall(buffer) except Exception as exc: raise SQCloudException( "An error occurred while writing data.", @@ -151,31 +152,30 @@ def _internal_socket_write(self, connection: SQCloudConnect, buffer: str) -> Non ) from exc def _internal_socket_read(self, connection: SQCloudConnect) -> SQCloudResult: - buffer = "" + buffer = b"" buffer_size = 8192 nread = 0 - bytebuffer = b"" + try: while True: data = connection.socket.recv(buffer_size) if not data: - raise SQCloudException('Incomplete response from server.', -1) + raise SQCloudException("Incomplete response from server.") - # the expected data length to read + # the expected data length to read # matches the string size before decoding it nread += len(data) # update buffers - buffer += data.decode() - bytebuffer += data + buffer += data - c = buffer[0] + c = chr(buffer[0]) if ( c == SQCLOUD_CMD.INT.value or c == SQCLOUD_CMD.FLOAT.value or c == SQCLOUD_CMD.NULL.value ): - if not buffer.endswith(' '): + if not buffer.endswith(b" "): continue elif c == SQCLOUD_CMD.ROWSET_CHUNK.value: isEndOfChunk = buffer.endswith(SQCLOUD_ROWSET.CHUNKS_END.value) @@ -202,7 +202,7 @@ def _internal_socket_read(self, connection: SQCloudConnect) -> SQCloudResult: SQCLOUD_INTERNAL_ERRCODE.INTERNAL_ERRCODE_NETWORK, ) from exc - def _internal_parse_number(self, buffer: str, index: int = 1) -> SQCloudNumber: + def _internal_parse_number(self, buffer: bytes, index: int = 1) -> SQCloudNumber: sqcloud_number = SQCloudNumber() sqcloud_number.value = 0 extvalue = 0 @@ -211,7 +211,7 @@ def _internal_parse_number(self, buffer: str, index: int = 1) -> SQCloudNumber: # from 1 to skip the first command type character for i in range(index, blen): - c = buffer[i] + c = chr(buffer[i]) # check for optional extended error code (ERRCODE:EXTERRCODE) if c == ":": @@ -235,7 +235,7 @@ def _internal_parse_number(self, buffer: str, index: int = 1) -> SQCloudNumber: sqcloud_number.value = 0 return sqcloud_number - def _internal_parse_buffer(self, buffer: str, blen: int) -> SQCloudResult: + def _internal_parse_buffer(self, buffer: bytes, blen: int) -> SQCloudResult: # possible return values: # True => OK # False => error @@ -247,18 +247,17 @@ def _internal_parse_buffer(self, buffer: str, blen: int) -> SQCloudResult: # None # check OK value - if buffer == "+2 OK": + if buffer == b"+2 OK": return SQCloudResult(True) - cmd = buffer[0] + cmd = chr(buffer[0]) # check for compressed result if cmd == SQCLOUD_CMD.COMPRESSED.value: buffer = self._internal_uncompress_data(buffer, blen) if buffer is None: raise SQCloudException( - f"An error occurred while decompressing the input buffer of len {blen}.", - -1, + f"An error occurred while decompressing the input buffer of len {blen}." ) # first character contains command type @@ -360,9 +359,10 @@ def _internal_parse_buffer(self, buffer: str, blen: int) -> SQCloudResult: # TODO: isn't implemented in C? return SQCloudResult(None) - return None + # TODO: exception here? + return SQCloudResult(None) - def _internal_uncompress_data(self, buffer: str, blen: int) -> Optional[str]: + def _internal_uncompress_data(self, buffer: bytes, blen: int) -> Optional[str]: """ %LEN COMPRESSED UNCOMPRESSED BUFFER @@ -385,7 +385,7 @@ def _internal_uncompress_data(self, buffer: str, blen: int) -> Optional[str]: start = 1 counter = 0 for i in range(blen): - if buffer[i] != " ": + if buffer[i] != b" ": continue counter += 1 @@ -413,7 +413,7 @@ def _internal_uncompress_data(self, buffer: str, blen: int) -> Optional[str]: start += hlen # perform real decompression - clone = header + str(lz4.block.decompress(buffer[start:])) + clone = header + lz4.block.decompress(buffer[start:]) # sanity check result if len(clone) != ulen + hlen: @@ -421,10 +421,10 @@ def _internal_uncompress_data(self, buffer: str, blen: int) -> Optional[str]: return clone - def _internal_reconnect(self, buffer: str) -> bool: + def _internal_reconnect(self, buffer: bytes) -> bool: return True - def _internal_parse_array(self, buffer: str) -> list: + def _internal_parse_array(self, buffer: bytes) -> list: start = 0 sqlite_number = self._internal_parse_number(buffer, start) n = sqlite_number.value @@ -438,13 +438,14 @@ def _internal_parse_array(self, buffer: str) -> list: return r - def _internal_parse_value(self, buffer: str, index: int = 0) -> SQCloudValue: + def _internal_parse_value(self, buffer: bytes, index: int = 0) -> SQCloudValue: sqcloud_value = SQCloudValue() len = 0 cellsize = 0 # handle special NULL value case - if buffer is None or buffer[index] == SQCLOUD_CMD.NULL.value: + c = chr(buffer[index]) + if buffer is None or c == SQCLOUD_CMD.NULL.value: len = 0 if cellsize is not None: cellsize = 2 @@ -460,36 +461,36 @@ def _internal_parse_value(self, buffer: str, index: int = 0) -> SQCloudValue: # handle decimal/float cases if ( - buffer[index] == SQCLOUD_CMD.INT.value - or buffer[index] == SQCLOUD_CMD.FLOAT.value + c == SQCLOUD_CMD.INT.value + or c == SQCLOUD_CMD.FLOAT.value ): nlen = cstart - index len = nlen - 2 cellsize = nlen - sqcloud_value.value = buffer[index + 1 : index + 1 + len] + sqcloud_value.value = (buffer[index + 1 : index + 1 + len]).decode() sqcloud_value.len sqcloud_value.cellsize = cellsize return sqcloud_value - len = blen - 1 if buffer[index] == SQCLOUD_CMD.ZEROSTRING.value else blen + len = blen - 1 if c == SQCLOUD_CMD.ZEROSTRING.value else blen cellsize = blen + cstart - index - sqcloud_value.value = buffer[cstart : cstart + len] + sqcloud_value.value = (buffer[cstart : cstart + len]).decode() sqcloud_value.len = len sqcloud_value.cellsize = cellsize return sqcloud_value - def _internal_parse_rowset_signature(self, buffer: str) -> SQCloudRowsetSignature: + def _internal_parse_rowset_signature(self, buffer: bytes) -> SQCloudRowsetSignature: # ROWSET: *LEN 0:VERS NROWS NCOLS DATA # ROWSET in CHUNK: /LEN IDX:VERS NROWS NCOLS DATA signature = SQCloudRowsetSignature() # check for end-of-chunk condition - if buffer == SQCLOUD_ROWSET.CHUNKS_END: + if buffer == SQCLOUD_ROWSET.CHUNKS_END.value: signature.version = 0 signature.start = 0 return signature @@ -498,11 +499,11 @@ def _internal_parse_rowset_signature(self, buffer: str) -> SQCloudRowsetSignatur counter = 0 n = len(buffer) for i in range(n): - if buffer[i] != " ": + if chr(buffer[i]) != " ": continue counter += 1 - data = buffer[start:i] + data = (buffer[start:i]).decode() start = i + 1 if counter == 1: @@ -525,11 +526,11 @@ def _internal_parse_rowset_signature(self, buffer: str) -> SQCloudRowsetSignatur return SQCloudRowsetSignature() def _internal_parse_rowset( - self, buffer: str, start: int, idx: int, version: int, nrows: int, ncols: int + self, buffer: bytes, start: int, idx: int, version: int, nrows: int, ncols: int ) -> SQCloudResult: rowset = None n = start - ischunk = buffer[0] == SQCLOUD_CMD.ROWSET_CHUNK.value + ischunk = chr(buffer[0]) == SQCLOUD_CMD.ROWSET_CHUNK.value # idx == 0 means first (and only) chunk for rowset # idx == 1 means first chunk for chunked rowset @@ -555,7 +556,7 @@ def _internal_parse_rowset( return rowset def _internal_parse_rowset_header( - self, rowset: SQCloudResult, buffer: str, start: int + self, rowset: SQCloudResult, buffer: bytes, start: int ) -> int: ncols = rowset.ncols @@ -566,16 +567,14 @@ def _internal_parse_rowset_header( number_len = sqcloud_number.value cstart = sqcloud_number.cstart value = buffer[cstart : cstart + number_len] - rowset.colname.append(value) + rowset.colname.append(value.decode()) start = cstart + number_len if rowset.version == 1: return start if rowset.version != 2: - raise SQCloudException( - f"Rowset version {rowset.version} is not supported.", -1 - ) + raise SQCloudException(f"Rowset version {rowset.version} is not supported.") # parse declared types rowset.decltype = [] @@ -584,7 +583,7 @@ def _internal_parse_rowset_header( number_len = sqcloud_number.value cstart = sqcloud_number.cstart value = buffer[cstart : cstart + number_len] - rowset.decltype.append(value) + rowset.decltype.append(value.decode()) start = cstart + number_len # parse database names @@ -594,7 +593,7 @@ def _internal_parse_rowset_header( number_len = sqcloud_number.value cstart = sqcloud_number.cstart value = buffer[cstart : cstart + number_len] - rowset.dbname.append(value) + rowset.dbname.append(value.decode()) start = cstart + number_len # parse table names @@ -604,7 +603,7 @@ def _internal_parse_rowset_header( number_len = sqcloud_number.value cstart = sqcloud_number.cstart value = buffer[cstart : cstart + number_len] - rowset.tblname.append(value) + rowset.tblname.append(value.decode()) start = cstart + number_len # parse column original names @@ -614,7 +613,7 @@ def _internal_parse_rowset_header( number_len = sqcloud_number.value cstart = sqcloud_number.cstart value = buffer[cstart : cstart + number_len] - rowset.origname.append(value) + rowset.origname.append(value.decode()) start = cstart + number_len # parse not null flags @@ -641,7 +640,7 @@ def _internal_parse_rowset_header( return start def _internal_parse_rowset_values( - self, rowset: SQCloudResult, buffer: str, start: int, bound: int + self, rowset: SQCloudResult, buffer: bytes, start: int, bound: int ): # loop to parse each individual value for i in range(bound): diff --git a/src/sqlitecloud/types.py b/src/sqlitecloud/types.py index 6d974af..fa346a5 100644 --- a/src/sqlitecloud/types.py +++ b/src/sqlitecloud/types.py @@ -44,7 +44,7 @@ class SQCLOUD_CMD(Enum): class SQCLOUD_ROWSET(Enum): - CHUNKS_END = "/6 0 0 0 " + CHUNKS_END = b"/6 0 0 0 " class SQCLOUD_INTERNAL_ERRCODE(Enum): @@ -112,9 +112,6 @@ def __init__(self): # callback: SQCloudPubSubCB - # todo: which is the proper type? - self.data: any - class SQCloudConfig: def __init__(self) -> None: @@ -157,7 +154,7 @@ def __init__(self) -> None: class SQCloudException(Exception): def __init__( - self, message: str, code: int, xerrcode=0 + self, message: str, code: Optional[int] = -1, xerrcode: Optional[int] = 0 ) -> None: self.errmsg = str(message) self.errcode = code diff --git a/src/tests/integration/test_client.py b/src/tests/integration/test_client.py index 3553f70..6f5d607 100644 --- a/src/tests/integration/test_client.py +++ b/src/tests/integration/test_client.py @@ -67,7 +67,7 @@ def test_connect_with_string(self): connection_string = os.getenv("SQLITE_CONNECTION_STRING") client = SqliteCloudClient(connection_str=connection_string) - + conn = client.open_connection() assert isinstance(conn, SQCloudConnect) @@ -77,7 +77,7 @@ def test_connect_with_string_with_credentials(self): connection_string = f"sqlitecloud://{os.getenv('SQLITE_USER')}:{os.getenv('SQLITE_PASSWORD')}@{os.getenv('SQLITE_HOST')}/{os.getenv('SQLITE_DB')}" client = SqliteCloudClient(connection_str=connection_string) - + conn = client.open_connection() assert isinstance(conn, SQCloudConnect) @@ -92,7 +92,7 @@ def test_select(self, sqlitecloud_connection): assert False == result.is_result assert 1 == result.nrows assert 1 == result.ncols - assert 'Hello' == result.get_value(0, 0) + assert "Hello" == result.get_value(0, 0) def test_rowset_data(self, sqlitecloud_connection): connection, client = sqlitecloud_connection @@ -106,12 +106,12 @@ def test_get_value(self, sqlitecloud_connection): connection, client = sqlitecloud_connection result = client.exec_query("SELECT * FROM albums", connection) assert result - assert '1' == result.get_value(0, 0) - assert 'For Those About To Rock We Salute You' == result.get_value(0, 1) - assert '2' == result.get_value(1, 0) + assert "1" == result.get_value(0, 0) + assert "For Those About To Rock We Salute You" == result.get_value(0, 1) + assert "2" == result.get_value(1, 0) def test_get_utf8_value(self, sqlitecloud_connection): connection, client = sqlitecloud_connection result = client.exec_query("SELECT 'Minha História'", connection) assert result - assert 'Minha História' == result.get_value(0, 0) + assert "Minha História" == result.get_value(0, 0) From 96f66baa93e6912ef00de366020f402c20349350 Mon Sep 17 00:00:00 2001 From: Daniele Briggi Date: Wed, 15 May 2024 18:36:29 +0000 Subject: [PATCH 5/7] Tests, cleanup, exception handling --- src/sqlitecloud/client.py | 49 +-- src/sqlitecloud/driver.py | 179 +++++---- src/sqlitecloud/resultset.py | 10 +- src/sqlitecloud/types.py | 39 +- src/tests/integration/test_client.py | 532 ++++++++++++++++++++++++++- src/tests/integration/test_driver.py | 18 +- src/tests/unit/test_resultset.py | 18 +- 7 files changed, 674 insertions(+), 171 deletions(-) diff --git a/src/sqlitecloud/client.py b/src/sqlitecloud/client.py index d281654..032d001 100644 --- a/src/sqlitecloud/client.py +++ b/src/sqlitecloud/client.py @@ -16,7 +16,7 @@ class SqliteCloudClient: """ - Client to connect to SqliteCloud + Client to interact with Sqlite Cloud """ SQLITE_DEFAULT_PORT = 8860 @@ -27,14 +27,11 @@ def __init__( connection_str: Optional[str] = None, # pub_subs: SQCloudPubSubCallback = [], ) -> None: - """Initializes a new instance of the class. + """Initializes a new instance of the class with connection information. Args: connection_str (str): The connection string for the database. - Raises: - ValueError: If the connection string is invalid. - """ self.driver = Driver() @@ -47,7 +44,7 @@ def __init__( elif cloud_account: self.config.account = cloud_account else: - raise Exception("Missing connection parameters") + raise SQCloudException("Missing connection parameters") def open_connection(self) -> SQCloudConnect: """Opens a connection to the SQCloud server. @@ -72,17 +69,14 @@ def open_connection(self) -> SQCloudConnect: def disconnect(self, conn: SQCloudConnect) -> None: """Closes the connection to the database. - This method is used to close the connection to the database. It does not take any arguments and does not return any value. - - Returns: - None: This method does not return any value. + This method is used to close the connection to the database. """ self.driver.disconnect(conn) def exec_query( self, query: str, conn: SQCloudConnect = None ) -> SqliteCloudResultSet: - """Executes a SQL query on the SQLite database. + """Executes a SQL query on the SQLite Cloud database. Args: query (str): The SQL query to be executed. @@ -90,30 +84,25 @@ def exec_query( Returns: SqliteCloudResultSet: The result set of the executed query. """ - if not conn: + provided_connection = conn is not None + if not provided_connection: conn = self.open_connection() result = self.driver.execute(query, conn) + + if not provided_connection: + self.disconnect(conn) + return SqliteCloudResultSet(result) - # def exec_statement( - # self, query: str, values: List[Any], conn: SQCloudConnect = None - # ) -> SqliteCloudResultSet: - # local_conn = conn if conn else self.open_connection() - # result: SQCloudResult = SQCloudExecArray( - # local_conn, - # self._encode_str_to_c(query), - # [SqlParameter(self._encode_str_to_c(str(v)), v) for v in values], - # ) - # if SQCloudResultIsError(result): - # raise Exception( - # "Query error: " + str(SQCloudResultDump(local_conn, result)) - # ) - # return SqliteCloudResultSet(result) - # pass - - def sendblob(self): - pass + def sendblob(self, blob: bytes, conn: SQCloudConnect) -> SqliteCloudResultSet: + """Sends a blob to the SQLite database. + + Args: + blob (bytes): The blob to be sent to the database. + conn (SQCloudConnect): The connection to the database. + """ + return self.driver.sendblob(blob, conn) def _parse_connection_string(self, connection_string) -> SQCloudConfig: # URL STRING FORMAT diff --git a/src/sqlitecloud/driver.py b/src/sqlitecloud/driver.py index ee726c1..b373413 100644 --- a/src/sqlitecloud/driver.py +++ b/src/sqlitecloud/driver.py @@ -1,7 +1,8 @@ import ssl -from typing import Optional +from typing import Optional, Union import lz4.block -from sqlitecloud.resultset import SQCloudResult, SqliteCloudResultSet +from sqlitecloud.lz4_custom import lz4decode +from sqlitecloud.resultset import SQCloudResult from sqlitecloud.types import ( SQCLOUD_CMD, SQCLOUD_INTERNAL_ERRCODE, @@ -14,13 +15,12 @@ SQCloudValue, ) import socket -import sys class Driver: def __init__(self) -> None: - # used for parsing chunked rowset - self._rowset: SqliteCloudResultSet = None + # Used while parsing chunked rowset + self._rowset: SQCloudResult = None def connect( self, hostname: str, port: int, config: SQCloudConfig @@ -29,9 +29,9 @@ def connect( Connects to the SQLite Cloud server. Args: - hostname (str, optional): The hostname of the server. Defaults to "localhost". - port (int, optional): The port number of the server. Defaults to 8860. - config (SQCloudConfig, optional): The configuration for the connection. Defaults to None. + hostname (str): The hostname of the server. + port (int): The port number of the server. + config (SQCloudConfig): The configuration for the connection. Returns: SQCloudConnect: The connection object. @@ -66,15 +66,27 @@ def connect( return connection def disconnect(self, conn: SQCloudConnect): - if conn.socket: - conn.socket.close() - conn.socket = None + try: + if conn.socket: + conn.socket.close() + finally: + conn.socket = None def execute(self, command: str, connection: SQCloudConnect) -> SQCloudResult: return self._internal_run_command(connection, command) - def sendblob(self): - pass + def sendblob(self, blob: bytes, conn: SQCloudConnect) -> SQCloudResult: + try: + conn.isblob = True + return self._internal_run_command(conn, blob) + finally: + conn.isblob = False + + def _internal_reconnect(self, buffer: bytes) -> bool: + return True + + def _internal_setup_pubsub(self, buffer: bytes) -> bool: + return True def _internal_config_apply( self, connection: SQCloudConnect, config: SQCloudConfig @@ -85,7 +97,7 @@ def _internal_config_apply( buffer = "" if config.account.apikey: - buffer += f"AUTH APIKEY {connection.account.apikey};" + buffer += f"AUTH APIKEY {config.account.apikey};" if config.account.username and config.account.password: command = "HASH" if config.account.password_hashed else "PASSWORD" @@ -120,14 +132,18 @@ def _internal_config_apply( if len(buffer) > 0: self._internal_run_command(connection, buffer) - def _internal_run_command(self, connection: SQCloudConnect, command: str) -> None: + def _internal_run_command( + self, connection: SQCloudConnect, command: Union[str, bytes] + ) -> None: self._internal_socket_write(connection, command) return self._internal_socket_read(connection) - def _internal_socket_write(self, connection: SQCloudConnect, command: str) -> None: + def _internal_socket_write( + self, connection: SQCloudConnect, command: Union[str, bytes] + ) -> None: # compute header delimit = "$" if connection.isblob else "+" - buffer = command.encode() + buffer = command.encode() if isinstance(command, str) else command buffer_len = len(buffer) header = f"{delimit}{buffer_len} " @@ -152,55 +168,62 @@ def _internal_socket_write(self, connection: SQCloudConnect, command: str) -> No ) from exc def _internal_socket_read(self, connection: SQCloudConnect) -> SQCloudResult: + """ + Read from the socket and parse the response. + + The buffer is stored as a string of bytes without decoding it with UTF-8. + The dimensions (LEN) specified in the SCSP protocol are in bytes, while + Python counts decoded strings in characters. This can cause issues when + slicing the buffer into parts if there are special characters like "ò". + """ buffer = b"" buffer_size = 8192 nread = 0 - try: - while True: + while True: + try: data = connection.socket.recv(buffer_size) if not data: raise SQCloudException("Incomplete response from server.") + except Exception as exc: + raise SQCloudException( + "An error occurred while reading data from the socket.", + SQCLOUD_INTERNAL_ERRCODE.INTERNAL_ERRCODE_NETWORK, + ) from exc + + # the expected data length to read + # matches the string size before decoding it + nread += len(data) + # update buffers + buffer += data + + c = chr(buffer[0]) + + if ( + c == SQCLOUD_CMD.INT.value + or c == SQCLOUD_CMD.FLOAT.value + or c == SQCLOUD_CMD.NULL.value + ): + if not buffer.endswith(b" "): + continue + elif c == SQCLOUD_CMD.ROWSET_CHUNK.value: + isEndOfChunk = buffer.endswith(SQCLOUD_ROWSET.CHUNKS_END.value) + if not isEndOfChunk: + continue + else: + sqcloud_number = self._internal_parse_number(buffer) + n = sqcloud_number.value + cstart = sqcloud_number.cstart - # the expected data length to read - # matches the string size before decoding it - nread += len(data) - # update buffers - buffer += data - - c = chr(buffer[0]) - - if ( - c == SQCLOUD_CMD.INT.value - or c == SQCLOUD_CMD.FLOAT.value - or c == SQCLOUD_CMD.NULL.value - ): - if not buffer.endswith(b" "): - continue - elif c == SQCLOUD_CMD.ROWSET_CHUNK.value: - isEndOfChunk = buffer.endswith(SQCLOUD_ROWSET.CHUNKS_END.value) - if not isEndOfChunk: - continue - else: - sqcloud_number = self._internal_parse_number(buffer) - n = sqcloud_number.value - cstart = sqcloud_number.cstart - - can_be_zerolength = ( - c == SQCLOUD_CMD.BLOB.value or c == SQCLOUD_CMD.STRING.value - ) - if n == 0 and not can_be_zerolength: - continue - if n + cstart != nread: - continue - - return self._internal_parse_buffer(buffer, len(buffer)) + can_be_zerolength = ( + c == SQCLOUD_CMD.BLOB.value or c == SQCLOUD_CMD.STRING.value + ) + if n == 0 and not can_be_zerolength: + continue + if n + cstart != nread: + continue - except Exception as exc: - raise SQCloudException( - "An error occurred while reading data from the socket.", - SQCLOUD_INTERNAL_ERRCODE.INTERNAL_ERRCODE_NETWORK, - ) from exc + return self._internal_parse_buffer(connection, buffer, len(buffer)) def _internal_parse_number(self, buffer: bytes, index: int = 1) -> SQCloudNumber: sqcloud_number = SQCloudNumber() @@ -235,7 +258,9 @@ def _internal_parse_number(self, buffer: bytes, index: int = 1) -> SQCloudNumber sqcloud_number.value = 0 return sqcloud_number - def _internal_parse_buffer(self, buffer: bytes, blen: int) -> SQCloudResult: + def _internal_parse_buffer( + self, connection: SQCloudConnect, buffer: bytes, blen: int + ) -> SQCloudResult: # possible return values: # True => OK # False => error @@ -271,9 +296,9 @@ def _internal_parse_buffer(self, buffer: bytes, blen: int) -> SQCloudResult: SQCLOUD_CMD.BLOB.value, SQCLOUD_CMD.JSON.value, ]: - cstart = 0 - sqlite_number = self._internal_parse_number(buffer, cstart) + sqlite_number = self._internal_parse_number(buffer) len_ = sqlite_number.value + cstart = sqlite_number.cstart if len_ == 0: return SQCloudResult("") @@ -282,16 +307,16 @@ def _internal_parse_buffer(self, buffer: bytes, blen: int) -> SQCloudResult: clone = buffer[cstart : cstart + len_] if cmd == SQCLOUD_CMD.COMMAND.value: - return SQCloudResult(self._internal_run_command(clone)) + return self._internal_run_command(connection, clone) elif cmd == SQCLOUD_CMD.PUBSUB.value: - # TODO - return self._internal_setup_pubsub(clone) + return SQCloudResult(self._internal_setup_pubsub(clone)) elif cmd == SQCLOUD_CMD.RECONNECT.value: return SQCloudResult(self._internal_reconnect(clone)) elif cmd == SQCLOUD_CMD.ARRAY.value: return SQCloudResult(self._internal_parse_array(clone)) - return clone + clone = clone.decode() if cmd != SQCLOUD_CMD.BLOB.value else clone + return SQCloudResult(clone) elif cmd == SQCLOUD_CMD.ERROR.value: # -LEN ERRCODE:EXTCODE ERRMSG @@ -309,7 +334,7 @@ def _internal_parse_buffer(self, buffer: bytes, blen: int) -> SQCloudResult: len_ -= cstart2 errmsg = clone[cstart2:] - raise SQCloudException(errmsg, errcode, xerrcode) + raise SQCloudException(errmsg.decode(), errcode, xerrcode) elif cmd in [SQCLOUD_CMD.ROWSET.value, SQCLOUD_CMD.ROWSET_CHUNK.value]: # CMD_ROWSET: *LEN 0:VERSION ROWS COLS DATA @@ -337,7 +362,7 @@ def _internal_parse_buffer(self, buffer: bytes, blen: int) -> SQCloudResult: sign_len = rowset_signature.len buffer = buffer[sign_len + len(f"/{sign_len} ") :] if buffer: - return self._internal_parse_buffer(buffer, len(buffer)) + return self._internal_parse_buffer(connection, buffer, len(buffer)) return rowset @@ -345,7 +370,7 @@ def _internal_parse_buffer(self, buffer: bytes, blen: int) -> SQCloudResult: return None elif cmd in [SQCLOUD_CMD.INT.value, SQCLOUD_CMD.FLOAT.value]: - sqcloud_value = self._internal_parse_value(buffer, blen) + sqcloud_value = self._internal_parse_value(buffer) clone = sqcloud_value.value if clone is None: @@ -362,7 +387,7 @@ def _internal_parse_buffer(self, buffer: bytes, blen: int) -> SQCloudResult: # TODO: exception here? return SQCloudResult(None) - def _internal_uncompress_data(self, buffer: bytes, blen: int) -> Optional[str]: + def _internal_uncompress_data(self, buffer: bytes, blen: int) -> Optional[bytes]: """ %LEN COMPRESSED UNCOMPRESSED BUFFER @@ -372,9 +397,6 @@ def _internal_uncompress_data(self, buffer: bytes, blen: int) -> Optional[str]: Returns: str: The uncompressed data. - - Raises: - None """ tlen = 0 # total length clen = 0 # compressed length @@ -385,7 +407,7 @@ def _internal_uncompress_data(self, buffer: bytes, blen: int) -> Optional[str]: start = 1 counter = 0 for i in range(blen): - if buffer[i] != b" ": + if chr(buffer[i]) != " ": continue counter += 1 @@ -413,7 +435,8 @@ def _internal_uncompress_data(self, buffer: bytes, blen: int) -> Optional[str]: start += hlen # perform real decompression - clone = header + lz4.block.decompress(buffer[start:]) + # clone = header + lz4.block.decompress(buffer[start:]) + clone = lz4decode(buffer, start, header) # sanity check result if len(clone) != ulen + hlen: @@ -421,16 +444,13 @@ def _internal_uncompress_data(self, buffer: bytes, blen: int) -> Optional[str]: return clone - def _internal_reconnect(self, buffer: bytes) -> bool: - return True - def _internal_parse_array(self, buffer: bytes) -> list: start = 0 sqlite_number = self._internal_parse_number(buffer, start) n = sqlite_number.value start = sqlite_number.cstart - r = [] + r: str = [] for i in range(n): sqcloud_value = self._internal_parse_value(buffer, start) start += sqcloud_value.cellsize @@ -460,10 +480,7 @@ def _internal_parse_value(self, buffer: bytes, index: int = 0) -> SQCloudValue: cstart = sqcloud_number.cstart # handle decimal/float cases - if ( - c == SQCLOUD_CMD.INT.value - or c == SQCLOUD_CMD.FLOAT.value - ): + if c == SQCLOUD_CMD.INT.value or c == SQCLOUD_CMD.FLOAT.value: nlen = cstart - index len = nlen - 2 cellsize = nlen diff --git a/src/sqlitecloud/resultset.py b/src/sqlitecloud/resultset.py index 1174b4e..b482b2e 100644 --- a/src/sqlitecloud/resultset.py +++ b/src/sqlitecloud/resultset.py @@ -19,13 +19,12 @@ def __init__(self, result: Optional[any] = None) -> None: self.is_result: bool = False - if result: + if result is not None: self.init_data(result) def init_data(self, result: any) -> None: self.nrows = 1 self.ncols = 1 - # TODO: what if result is array? self.data = [result] self.is_result = True @@ -64,13 +63,16 @@ def _compute_index(self, row: int, col: int) -> int: return -1 return row * self._result.ncols + col - def get_value(self, row: int, col: int) -> any: + def get_value(self, row: int, col: int) -> Optional[any]: index = self._compute_index(row, col) if index < 0 or not self._result.data or index >= len(self._result.data): return None return self._result.data[index] - def get_name(self, col: int) -> str: + def get_name(self, col: int) -> Optional[str]: if col < 0 or col >= self._result.ncols: return None return self._result.colname[col] + + def get_result(self) -> Optional[any]: + return self.get_value(0, 0) \ No newline at end of file diff --git a/src/sqlitecloud/types.py b/src/sqlitecloud/types.py index fa346a5..6f183a1 100644 --- a/src/sqlitecloud/types.py +++ b/src/sqlitecloud/types.py @@ -3,27 +3,6 @@ from enum import Enum -class SQCLOUD_VALUE_TYPE(Enum): - VALUE_INTEGER = 1 - VALUE_FLOAT = 2 - VALUE_TEXT = 3 - VALUE_BLOB = 4 - VALUE_NULL = 5 - - -class SQCLOUD_RESULT_TYPE(Enum): - RESULT_OK = 0 - RESULT_FLOAT = 2 - RESULT_STRING = 2 - RESULT_INTEGER = 3 - RESULT_ERROR = 4 - RESULT_ROWSET = 5 - RESULT_ARRAY = 6 - RESULT_NULL = 7 - RESULT_JSON = 8 - RESULT_BLOB = 9 - - class SQCLOUD_CMD(Enum): STRING = "+" ZEROSTRING = "!" @@ -49,15 +28,17 @@ class SQCLOUD_ROWSET(Enum): class SQCLOUD_INTERNAL_ERRCODE(Enum): INTERNAL_ERRCODE_NONE = 0 - # INTERNAL_ERRCODE_GENERIC = 100000 - # INTERNAL_ERRCODE_PUBSUB = 100001 - # INTERNAL_ERRCODE_TLS = 100002 - # INTERNAL_ERRCODE_URL = 100003 - # INTERNAL_ERRCODE_MEMORY = 100004 INTERNAL_ERRCODE_NETWORK = 100005 - # INTERNAL_ERRCODE_FORMAT = 100006 - # INTERNAL_ERRCODE_INDEX = 100007 - # INTERNAL_ERRCODE_SOCKCLOSED = 100008 + + +class SQCLOUD_CLOUD_ERRCODE(Enum): + CLOUD_ERRCODE_MEM = 10000 + CLOUD_ERRCODE_NOTFOUND = 10001 + CLOUD_ERRCODE_COMMAND = 10002 + CLOUD_ERRCODE_INTERNAL = 10003 + CLOUD_ERRCODE_AUTH = 10004 + CLOUD_ERRCODE_GENERIC = 10005 + CLOUD_ERRCODE_RAFT = 10006 class SQCloudRowsetSignature: diff --git a/src/tests/integration/test_client.py b/src/tests/integration/test_client.py index 6f5d607..5a72c63 100644 --- a/src/tests/integration/test_client.py +++ b/src/tests/integration/test_client.py @@ -1,13 +1,27 @@ -from operator import rshift +import json import os -from typing import Union +import sqlite3 +import tempfile +import time import pytest from sqlitecloud.client import SqliteCloudClient -from sqlitecloud.types import SQCloudConnect, SQCloudException, SqliteCloudAccount +from sqlitecloud.types import ( + SQCLOUD_CLOUD_ERRCODE, + SQCLOUD_INTERNAL_ERRCODE, + SQCloudConnect, + SQCloudException, + SqliteCloudAccount, +) class TestClient: + # Will warn if a query or other basic operation is slower than this + WARN_SPEED_MS = 500 + + # Will except queries to be quicker than this + EXPECT_SPEED_MS = 6 * 1000 + @pytest.fixture() def sqlitecloud_connection(self): account = SqliteCloudAccount() @@ -88,16 +102,23 @@ def test_select(self, sqlitecloud_connection): result = client.exec_query("SELECT 'Hello'", connection) - assert result assert False == result.is_result assert 1 == result.nrows assert 1 == result.ncols assert "Hello" == result.get_value(0, 0) + def test_column_not_found(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + with pytest.raises(SQCloudException) as e: + client.exec_query("SELECT not_a_column FROM albums", connection) + + assert e.value.errcode == 1 + assert e.value.errmsg == "no such column: not_a_column" + def test_rowset_data(self, sqlitecloud_connection): connection, client = sqlitecloud_connection result = client.exec_query("SELECT AlbumId FROM albums LIMIT 2", connection) - assert result + assert 2 == result.nrows assert 1 == result.ncols assert 2 == result.version @@ -105,13 +126,508 @@ def test_rowset_data(self, sqlitecloud_connection): def test_get_value(self, sqlitecloud_connection): connection, client = sqlitecloud_connection result = client.exec_query("SELECT * FROM albums", connection) - assert result + assert "1" == result.get_value(0, 0) assert "For Those About To Rock We Salute You" == result.get_value(0, 1) assert "2" == result.get_value(1, 0) - def test_get_utf8_value(self, sqlitecloud_connection): + def test_select_utf8_value_and_column_name(self, sqlitecloud_connection): connection, client = sqlitecloud_connection result = client.exec_query("SELECT 'Minha História'", connection) - assert result + + assert result.nrows == 1 + assert result.ncols == 1 assert "Minha História" == result.get_value(0, 0) + assert "'Minha História'" == result.get_name(0) + + def test_invalid_row_number_for_value(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + result = client.exec_query("SELECT 'one row'", connection) + + assert result.get_value(1, 1) is None + + def test_column_name(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + result = client.exec_query("SELECT * FROM albums", connection) + + assert "AlbumId" == result.get_name(0) + assert "Title" == result.get_name(1) + + def test_invalid_row_number_for_column_name(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + result = client.exec_query("SELECT 'one column'", connection) + + assert result.get_name(2) is None + + def test_long_string(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + + size = 1024 * 1024 + value = "LOOOONG" + while len(value) < size: + value += "a" + + rowset = client.exec_query(f"SELECT '{value}' 'VALUE'", connection) + + assert 1 == rowset.nrows + assert 1 == rowset.ncols + assert "VALUE" == rowset.get_name(0) + assert value == rowset.get_value(0, 0) + + def test_integer(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + result = client.exec_query("TEST INTEGER", connection) + + assert 123456 == result.get_result() + + def test_float(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + result = client.exec_query("TEST FLOAT", connection) + + assert 3.1415926 == result.get_result() + + def test_string(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + result = client.exec_query("TEST STRING", connection) + + assert "Hello World, this is a test string." == result.get_result() + + def test_zero_string(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + result = client.exec_query("TEST ZERO_STRING", connection) + + assert ( + "Hello World, this is a zero-terminated test string." == result.get_result() + ) + + def test_empty_string(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + result = client.exec_query("TEST STRING0", connection) + + assert "" == result.get_result() + + def test_command(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + result = client.exec_query("TEST COMMAND", connection) + + assert "PONG" == result.get_result() + + def test_json(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + result = client.exec_query("TEST JSON", connection) + + assert { + "msg-from": {"class": "soldier", "name": "Wixilav"}, + "msg-to": {"class": "supreme-commander", "name": "[Redacted]"}, + "msg-type": ["0xdeadbeef", "irc log"], + "msg-log": [ + "soldier: Boss there is a slight problem with the piece offering to humans", + "supreme-commander: Explain yourself soldier!", + "soldier: Well they don't seem to move anymore...", + "supreme-commander: Oh snap, I came here to see them twerk!", + ], + } == json.loads(result.get_result()) + + def test_blob(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + result = client.exec_query("TEST BLOB", connection) + + assert 1000 == len(result.get_result()) + + def test_blob0(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + result = client.exec_query("TEST BLOB0", connection) + + assert 0 == len(result.get_result()) + + def test_error(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + + with pytest.raises(SQCloudException) as e: + client.exec_query("TEST ERROR", connection) + + assert 66666 == e.value.errcode + assert "This is a test error message with a devil error code." == e.value.errmsg + + def test_ext_error(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + + with pytest.raises(SQCloudException) as e: + client.exec_query("TEST EXTERROR", connection) + + assert 66666 == e.value.errcode + assert 333 == e.value.xerrcode + assert ( + "This is a test error message with an extcode and a devil error code." + == e.value.errmsg + ) + + def test_array(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + result = client.exec_query("TEST ARRAY", connection) + + result_array = result.get_result() + assert isinstance(result_array, list) + assert len(result_array) == 5 + assert result_array[0] == "Hello World" + assert result_array[1] == "123456" + assert result_array[2] == "3.1415" + assert result_array[3] is None + + def test_rowset(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + result = client.exec_query("TEST ROWSET", connection) + + assert result.nrows >= 30 + assert result.ncols == 2 + assert result.version in [1, 2] + assert result.get_name(0) == "key" + assert result.get_name(1) == "value" + + def test_max_rows_option(self): + account = SqliteCloudAccount() + account.hostname = os.getenv("SQLITE_HOST") + account.database = os.getenv("SQLITE_DB") + account.apikey = os.getenv("SQLITE_API_KEY") + + client = SqliteCloudClient(cloud_account=account) + client.config.maxrows = 1 + + rowset = client.exec_query("TEST ROWSET_CHUNK") + + # maxrows cannot be tested at this level. + # just expect everything is ok + assert rowset.nrows > 100 + + def test_max_rowset_option_to_fail_when_rowset_is_bigger(self): + account = SqliteCloudAccount() + account.hostname = os.getenv("SQLITE_HOST") + account.database = os.getenv("SQLITE_DB") + account.apikey = os.getenv("SQLITE_API_KEY") + + client = SqliteCloudClient(cloud_account=account) + client.config.maxrowset = 1024 + + with pytest.raises(SQCloudException) as e: + client.exec_query("SELECT * FROM albums") + + assert SQCLOUD_CLOUD_ERRCODE.CLOUD_ERRCODE_INTERNAL.value == e.value.errcode + assert "RowSet too big to be sent (limit set to 1024 bytes)." == e.value.errmsg + + def test_max_rowset_option_to_succeed_when_rowset_is_lighter(self): + account = SqliteCloudAccount() + account.hostname = os.getenv("SQLITE_HOST") + account.database = os.getenv("SQLITE_DB") + account.apikey = os.getenv("SQLITE_API_KEY") + + client = SqliteCloudClient(cloud_account=account) + client.config.maxrowset = 1024 + + rowset = client.exec_query("SELECT 'hello world'") + + assert 1 == rowset.nrows + + def test_chunked_rowset(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + + rowset = client.exec_query("TEST ROWSET_CHUNK", connection) + + assert 147 == rowset.nrows + assert 1 == rowset.ncols + assert 147 == len(rowset.data) + assert "key" == rowset.get_name(0) + + def test_chunked_rowset_twice(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + rowset = client.exec_query("TEST ROWSET_CHUNK", connection) + + assert 147 == rowset.nrows + assert 1 == rowset.ncols + assert "key" == rowset.get_name(0) + + rowset = client.exec_query("TEST ROWSET_CHUNK", connection) + + assert 147 == rowset.nrows + assert 1 == rowset.ncols + assert "key" == rowset.get_name(0) + + rowset = client.exec_query("SELECT 1", connection) + + assert 1 == rowset.nrows + + def test_serialized_operations(self, sqlitecloud_connection): + num_queries = 20 + + connection, client = sqlitecloud_connection + + for i in range(num_queries): + rowset = client.exec_query( + f"select {i} as 'count', 'hello' as 'string'", connection + ) + + assert 1 == rowset.nrows + assert 2 == rowset.ncols + assert "count" == rowset.get_name(0) + assert "string" == rowset.get_name(1) + assert str(i) == rowset.get_value(0, 0) + assert rowset.version in [1, 2] + + def test_query_timeout(self): + account = SqliteCloudAccount() + account.hostname = os.getenv("SQLITE_HOST") + account.database = os.getenv("SQLITE_DB") + account.apikey = os.getenv("SQLITE_API_KEY") + + client = SqliteCloudClient(cloud_account=account) + client.config.timeout = 1 # 1 sec + + # this operation should take more than 1 sec + with pytest.raises(SQCloudException) as e: + # just a long running query + client.exec_query( + """ + WITH RECURSIVE r(i) AS ( + VALUES(0) + UNION ALL + SELECT i FROM r + LIMIT 10000000 + ) + SELECT i FROM r WHERE i = 1;""" + ) + + assert e.value.errcode == SQCLOUD_INTERNAL_ERRCODE.INTERNAL_ERRCODE_NETWORK + assert e.value.errmsg == "An error occurred while reading data from the socket." + + def test_XXL_query(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + + xxl_query = 300000 + long_sql = "" + + while len(long_sql) < xxl_query: + for i in range(5000): + long_sql += f"SELECT {len(long_sql)} 'HowLargeIsTooMuch'; " + + rowset = client.exec_query(long_sql, connection) + + assert 1 == rowset.nrows + assert 1 == rowset.ncols + assert len(long_sql) - 50 <= int(rowset.get_value(0, 0)) + + def test_single_XXL_query(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + + xxl_query = 200000 + long_sql = "" + + while len(long_sql) < xxl_query: + long_sql += str(len(long_sql)) + "_" + selected_value = f"start_{long_sql}end" + long_sql = f"SELECT '{selected_value}'" + + rowset = client.exec_query(long_sql, connection) + + assert 1 == rowset.nrows + assert 1 == rowset.ncols + assert f"'{selected_value}'" == rowset.get_name(0) + + def test_metadata(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + rowset = client.exec_query("LIST METADATA", connection) + + assert rowset.nrows >= 32 + assert rowset.ncols == 8 + + def test_select_results_with_no_column_name(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + rowset = client.exec_query("SELECT 42, 'hello'", connection) + + assert rowset.nrows == 1 + assert rowset.ncols == 2 + assert rowset.get_name(0) == "42" + assert rowset.get_name(1) == "'hello'" + assert rowset.get_value(0, 0) == "42" + assert rowset.get_value(0, 1) == "hello" + + def test_select_long_formatted_string(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + + long_string = "x" * 1000 + rowset = client.exec_query( + f"USE DATABASE :memory:; SELECT '{long_string}' AS DDD", connection + ) + + assert rowset.nrows == 1 + assert rowset.ncols == 1 + assert rowset.get_value(0, 0).startswith("xxxxxxxx") + assert len(rowset.get_value(0, 0)) == 1000 + + def test_select_database(self): + account = SqliteCloudAccount() + account.hostname = os.getenv("SQLITE_HOST") + account.database = "" + account.apikey = os.getenv("SQLITE_API_KEY") + + client = SqliteCloudClient(cloud_account=account) + + rowset = client.exec_query("USE DATABASE chinook.sqlite") + + assert rowset.get_result() + + def test_select_tracks_without_limit(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + rowset = client.exec_query("SELECT * FROM tracks", connection) + + assert rowset.nrows >= 3000 + assert rowset.ncols == 9 + + def test_select_tracks_with_limit(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + rowset = client.exec_query("SELECT * FROM tracks LIMIT 10", connection) + + assert rowset.nrows == 10 + assert rowset.ncols == 9 + + def test_stress_test_20x_string_select_individual(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + + num_queries = 20 + completed = 0 + start_time = time.time() + + for i in range(num_queries): + rowset = client.exec_query("TEST STRING", connection) + + assert rowset.get_result() == "Hello World, this is a test string." + + completed += 1 + if completed >= num_queries: + query_ms = round((time.time() - start_time) * 1000 / num_queries) + if query_ms > self.WARN_SPEED_MS: + assert ( + query_ms < self.EXPECT_SPEED_MS + ), f"{num_queries}x test string, {query_ms}ms per query" + + def test_stress_test_20x_individual_select(self, sqlitecloud_connection): + num_queries = 20 + completed = 0 + start_time = time.time() + connection, client = sqlitecloud_connection + + for i in range(num_queries): + rowset = client.exec_query( + "SELECT * FROM albums ORDER BY RANDOM() LIMIT 4", connection + ) + + assert rowset.nrows == 4 + assert rowset.ncols == 3 + + completed += 1 + if completed >= num_queries: + query_ms = round((time.time() - start_time) * 1000 / num_queries) + if query_ms > self.WARN_SPEED_MS: + assert ( + query_ms < self.EXPECT_SPEED_MS + ), f"{num_queries}x individual selects, {query_ms}ms per query" + + def test_stress_test_20x_batched_selects(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + + num_queries = 20 + completed = 0 + start_time = time.time() + + for i in range(num_queries): + rowset = client.exec_query( + "SELECT * FROM albums ORDER BY RANDOM() LIMIT 16; SELECT * FROM albums ORDER BY RANDOM() LIMIT 12; SELECT * FROM albums ORDER BY RANDOM() LIMIT 8; SELECT * FROM albums ORDER BY RANDOM() LIMIT 4", + connection, + ) + + assert rowset.nrows == 4 + assert rowset.ncols == 3 + + completed += 1 + if completed >= num_queries: + query_ms = round((time.time() - start_time) * 1000 / num_queries) + if query_ms > self.WARN_SPEED_MS: + assert ( + query_ms < self.EXPECT_SPEED_MS + ), f"{num_queries}x batched selects, {query_ms}ms per query" + + def test_download_database(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + + rowset = client.exec_query( + "DOWNLOAD DATABASE " + os.getenv("SQLITE_DB"), connection + ) + + result_array = rowset.get_result() + + db_size = int(result_array[0]) + + tot_read = 0 + data: bytes = b"" + while tot_read < db_size: + result = client.exec_query("DOWNLOAD STEP;", connection) + + data += result.get_result() + tot_read += len(data) + + temp_file = tempfile.mkstemp(prefix="chinook")[1] + with open(temp_file, "wb") as f: + f.write(data) + + db = sqlite3.connect(temp_file) + cursor = db.execute("SELECT * FROM albums") + rowset = cursor.fetchall() + + assert cursor.description[0][0] == "AlbumId" + assert cursor.description[1][0] == "Title" + + # TODO + def test_compression(self): + account = SqliteCloudAccount() + account.hostname = os.getenv("SQLITE_HOST") + account.apikey = os.getenv("SQLITE_API_KEY") + account.database = os.getenv("SQLITE_DB") + + client = SqliteCloudClient(cloud_account=account) + client.config.compression = True + + # min compression size for rowset set by default to 20400 bytes + rowset = client.exec_query("SELECT '" + "a" * (1024 * 20) + "' AS DDD") + + assert rowset.nrows == 1 + assert rowset.ncols == 1 + assert rowset.get_value(0, 0).startswith("aaaaa") + assert len(rowset.get_value(0, 0)) == 100 + + # def test_send_blob(self, sqlitecloud_connection): + # connection, client = sqlitecloud_connection + # blob = b"Hello, this is a test blob" + + # result = client.sendblob(blob, connection) + + # assert result.get_result() == True + + # def test_send_empty_blob(self, sqlitecloud_connection): + # connection, client = sqlitecloud_connection + # blob = b"" + # result = client.sendblob(blob, connection) + # assert result is not None + # # Add additional assertions as needed + + # def test_send_large_blob(self, sqlitecloud_connection): + # connection, client = sqlitecloud_connection + # blob = b"A" * 1024 * 1024 # 1MB blob + # result = client.sendblob(blob, connection) + # assert result is not None + # # Add additional assertions as needed + + # def test_send_blob_with_connection_closed(self, sqlitecloud_connection): + # connection, client = sqlitecloud_connection + # client.disconnect(connection) + # blob = b"Hello, this is a test blob" + # with pytest.raises(Exception): + # client.sendblob(blob, connection) + # Add additional assertions as needed diff --git a/src/tests/integration/test_driver.py b/src/tests/integration/test_driver.py index 3476ee2..c631829 100644 --- a/src/tests/integration/test_driver.py +++ b/src/tests/integration/test_driver.py @@ -1,10 +1,5 @@ -from binhex import hexbin -import os -from sqlitecloud.client import SqliteCloudClient -from sqlitecloud.driver import Driver, SQCloudConnect -from sqlitecloud.types import SQCloudConfig, SqliteCloudAccount +from sqlitecloud.driver import Driver import pytest -import binascii class TestDriver: @@ -27,7 +22,7 @@ def number_data(self, request): def test_parse_number(self, number_data): driver = Driver() buffer, expected_value, expected_extcode, expected_cstart = number_data - result = driver._internal_parse_number(buffer) + result = driver._internal_parse_number(buffer.encode()) assert expected_value == result.value assert expected_extcode == result.extcode @@ -66,27 +61,24 @@ def test_parse_value(self, value_data): driver = Driver() buffer, expected_value, expected_len, expected_cellsize = value_data - result = driver._internal_parse_value(buffer) + result = driver._internal_parse_value(buffer.encode()) assert expected_value == result.value assert expected_len == result.len assert expected_cellsize == result.cellsize - # TODO def test_parse_array(self): driver = Driver() - buffer = "=5 +11 Hello World:123456 ,3.1415 _ $10 0123456789" + buffer = b"=5 +11 Hello World:123456 ,3.1415 _ $10 0123456789" expected_list = ["Hello World", "123456", "3.1415", None, "0123456789"] result = driver._internal_parse_array(buffer) assert expected_list == result - # TODO: test compression - def test_parse_rowset_signature(self): driver = Driver() - buffer = "*35 0:1 1 2 +2 42+7 'hello':42 +5 hello" + buffer = b"*35 0:1 1 2 +2 42+7 'hello':42 +5 hello" result = driver._internal_parse_rowset_signature(buffer) diff --git a/src/tests/unit/test_resultset.py b/src/tests/unit/test_resultset.py index db7e6b7..b136ee9 100644 --- a/src/tests/unit/test_resultset.py +++ b/src/tests/unit/test_resultset.py @@ -15,9 +15,10 @@ def test_init_data(self): def test_init_data_with_array(self): result = SQCloudResult() result.init_data([42, 43, 44]) + assert 1 == result.nrows assert 1 == result.ncols - assert [42, 43, 44] == result.data + assert [[42, 43, 44]] == result.data assert True is result.is_result def test_init_as_dataset(self): @@ -74,18 +75,23 @@ def test_get_value_with_rowset(self): assert None == result_set.get_value(2, 2) def test_get_value_array(self): - result = SQCloudResult(result=[1, 2, 3, 4, 5, 6]) + result = SQCloudResult(result=[1, 2, 3]) result_set = SqliteCloudResultSet(result) - assert 1 == result_set.get_value(0, 0) - assert 5 == result_set.get_value(1, 2) - assert 4 == result_set.get_value(2, 1) - assert None == result_set.get_value(3, 3) + + assert [1,2,3] == result_set.get_value(0, 0) def test_get_colname(self): result = SQCloudResult() result.ncols = 2 result.colname = ["name", "age"] result_set = SqliteCloudResultSet(result) + assert "name" == result_set.get_name(0) assert "age" == result_set.get_name(1) assert None == result_set.get_name(2) + + def test_get_result_with_single_value(self): + result = SQCloudResult(result=42) + result_set = SqliteCloudResultSet(result) + + assert 42 == result_set.get_result() From edd5515adf23669c7f9775d2971cd775af835a05 Mon Sep 17 00:00:00 2001 From: Daniele Briggi Date: Thu, 16 May 2024 09:39:22 +0000 Subject: [PATCH 6/7] Implement decompression lz4 --- src/sqlitecloud/driver.py | 68 ++++++++++++---------------- src/sqlitecloud/resultset.py | 2 +- src/tests/integration/test_client.py | 36 +++++++++++---- src/tests/unit/test_resultset.py | 2 +- 4 files changed, 59 insertions(+), 49 deletions(-) diff --git a/src/sqlitecloud/driver.py b/src/sqlitecloud/driver.py index b373413..8e9a612 100644 --- a/src/sqlitecloud/driver.py +++ b/src/sqlitecloud/driver.py @@ -279,12 +279,16 @@ def _internal_parse_buffer( # check for compressed result if cmd == SQCLOUD_CMD.COMPRESSED.value: - buffer = self._internal_uncompress_data(buffer, blen) + buffer = self._internal_uncompress_data(buffer) if buffer is None: raise SQCloudException( f"An error occurred while decompressing the input buffer of len {blen}." ) + # buffer after decompression + blen = len(buffer) + cmd = chr(buffer[0]) + # first character contains command type if cmd in [ SQCLOUD_CMD.ZEROSTRING.value, @@ -338,7 +342,10 @@ def _internal_parse_buffer( elif cmd in [SQCLOUD_CMD.ROWSET.value, SQCLOUD_CMD.ROWSET_CHUNK.value]: # CMD_ROWSET: *LEN 0:VERSION ROWS COLS DATA + # - When decompressed, LEN for ROWSET is *0 + # # CMD_ROWSET_CHUNK: /LEN IDX:VERSION ROWS COLS DATA + # rowset_signature = self._internal_parse_rowset_signature(buffer) if rowset_signature.start < 0: raise SQCloudException("Cannot parse rowset signature") @@ -361,7 +368,7 @@ def _internal_parse_buffer( # continue parsing next chunk in the buffer sign_len = rowset_signature.len buffer = buffer[sign_len + len(f"/{sign_len} ") :] - if buffer: + if cmd == SQCLOUD_CMD.ROWSET_CHUNK.value and buffer: return self._internal_parse_buffer(connection, buffer, len(buffer)) return rowset @@ -387,7 +394,7 @@ def _internal_parse_buffer( # TODO: exception here? return SQCloudResult(None) - def _internal_uncompress_data(self, buffer: bytes, blen: int) -> Optional[bytes]: + def _internal_uncompress_data(self, buffer: bytes) -> Optional[bytes]: """ %LEN COMPRESSED UNCOMPRESSED BUFFER @@ -398,51 +405,34 @@ def _internal_uncompress_data(self, buffer: bytes, blen: int) -> Optional[bytes] Returns: str: The uncompressed data. """ - tlen = 0 # total length - clen = 0 # compressed length - ulen = 0 # uncompressed length - hlen = 0 # raw header length - seek1 = 0 + space_index = buffer.index(b" ") + buffer = buffer[space_index + 1 :] - start = 1 - counter = 0 - for i in range(blen): - if chr(buffer[i]) != " ": - continue - counter += 1 + # extract compressed size + space_index = buffer.index(b" ") + compressed_size = int(buffer[:space_index].decode("utf-8")) + buffer = buffer[space_index + 1 :] - data = buffer[start:i] - start = i + 1 - - if counter == 1: - tlen = int(data) - seek1 = start - elif counter == 2: - clen = int(data) - elif counter == 3: - ulen = int(data) - break - - # sanity check header values - if tlen == 0 or clen == 0 or ulen == 0 or start == 1 or seek1 == 0: - return None + # extract decompressed size + space_index = buffer.index(b" ") + uncompressed_size = int(buffer[:space_index].decode("utf-8")) + buffer = buffer[space_index + 1 :] - # copy raw header - hlen = start - seek1 - header = buffer[start : start + hlen] + # extract data header + header = buffer[:-compressed_size] - # compute index of the first compressed byte - start += hlen + # extract compressed data + compressed_buffer = buffer[-compressed_size:] - # perform real decompression - # clone = header + lz4.block.decompress(buffer[start:]) - clone = lz4decode(buffer, start, header) + decompressed_buffer = header + lz4.block.decompress( + compressed_buffer, uncompressed_size + ) # sanity check result - if len(clone) != ulen + hlen: + if len(decompressed_buffer) != uncompressed_size + len(header): return None - return clone + return decompressed_buffer def _internal_parse_array(self, buffer: bytes) -> list: start = 0 diff --git a/src/sqlitecloud/resultset.py b/src/sqlitecloud/resultset.py index b482b2e..3d35d94 100644 --- a/src/sqlitecloud/resultset.py +++ b/src/sqlitecloud/resultset.py @@ -75,4 +75,4 @@ def get_name(self, col: int) -> Optional[str]: return self._result.colname[col] def get_result(self) -> Optional[any]: - return self.get_value(0, 0) \ No newline at end of file + return self.get_value(0, 0) diff --git a/src/tests/integration/test_client.py b/src/tests/integration/test_client.py index 5a72c63..090e993 100644 --- a/src/tests/integration/test_client.py +++ b/src/tests/integration/test_client.py @@ -584,8 +584,7 @@ def test_download_database(self, sqlitecloud_connection): assert cursor.description[0][0] == "AlbumId" assert cursor.description[1][0] == "Title" - # TODO - def test_compression(self): + def test_compression_single_column(self): account = SqliteCloudAccount() account.hostname = os.getenv("SQLITE_HOST") account.apikey = os.getenv("SQLITE_API_KEY") @@ -595,12 +594,34 @@ def test_compression(self): client.config.compression = True # min compression size for rowset set by default to 20400 bytes - rowset = client.exec_query("SELECT '" + "a" * (1024 * 20) + "' AS DDD") + blob_size = 20 * 1024 + # rowset = client.exec_query("SELECT * from albums inner join albums a2 on albums.AlbumId = a2.AlbumId") + rowset = client.exec_query( + f"SELECT hex(randomblob({blob_size})) AS 'someColumnName'" + ) assert rowset.nrows == 1 assert rowset.ncols == 1 - assert rowset.get_value(0, 0).startswith("aaaaa") - assert len(rowset.get_value(0, 0)) == 100 + assert rowset.get_name(0) == "someColumnName" + assert len(rowset.get_value(0, 0)) == blob_size * 2 + + def test_compression_multiple_columns(self): + account = SqliteCloudAccount() + account.hostname = os.getenv("SQLITE_HOST") + account.apikey = os.getenv("SQLITE_API_KEY") + account.database = os.getenv("SQLITE_DB") + + client = SqliteCloudClient(cloud_account=account) + client.config.compression = True + + # min compression size for rowset set by default to 20400 bytes + rowset = client.exec_query( + "SELECT * from albums inner join albums a2 on albums.AlbumId = a2.AlbumId" + ) + + assert rowset.nrows > 0 + assert rowset.ncols > 0 + assert rowset.get_name(0) == "AlbumId" # def test_send_blob(self, sqlitecloud_connection): # connection, client = sqlitecloud_connection @@ -615,14 +636,14 @@ def test_compression(self): # blob = b"" # result = client.sendblob(blob, connection) # assert result is not None - # # Add additional assertions as needed + # # def test_send_large_blob(self, sqlitecloud_connection): # connection, client = sqlitecloud_connection # blob = b"A" * 1024 * 1024 # 1MB blob # result = client.sendblob(blob, connection) # assert result is not None - # # Add additional assertions as needed + # # def test_send_blob_with_connection_closed(self, sqlitecloud_connection): # connection, client = sqlitecloud_connection @@ -630,4 +651,3 @@ def test_compression(self): # blob = b"Hello, this is a test blob" # with pytest.raises(Exception): # client.sendblob(blob, connection) - # Add additional assertions as needed diff --git a/src/tests/unit/test_resultset.py b/src/tests/unit/test_resultset.py index b136ee9..18b36f6 100644 --- a/src/tests/unit/test_resultset.py +++ b/src/tests/unit/test_resultset.py @@ -78,7 +78,7 @@ def test_get_value_array(self): result = SQCloudResult(result=[1, 2, 3]) result_set = SqliteCloudResultSet(result) - assert [1,2,3] == result_set.get_value(0, 0) + assert [1, 2, 3] == result_set.get_value(0, 0) def test_get_colname(self): result = SQCloudResult() From 36c8ce60ec3fbaf29ba009167aa8528b959e4d6e Mon Sep 17 00:00:00 2001 From: Daniele Briggi Date: Thu, 16 May 2024 09:55:45 +0000 Subject: [PATCH 7/7] Cleanup, update sample file --- .devcontainer/Dockerfile | 2 +- .pylintrc | 5 - Makefile | 23 ---- requirements-dev.txt | 2 +- requirements.txt | 2 +- samples.ipynb | 94 ++++---------- src/sqlitecloud/client.py | 18 +-- src/sqlitecloud/driver.py | 10 +- src/sqlitecloud/pubsub.py | 23 ---- src/sqlitecloud/types.py | 41 +++--- src/sqlitecloud/upload.py | 36 ----- src/sqlitecloud/vm.py | 188 --------------------------- src/tests/conftest.py | 2 +- src/tests/integration/test_client.py | 49 ++----- src/tests/test.db | Bin 16384 -> 0 bytes src/tests/unit/test_client.py | 8 +- 16 files changed, 73 insertions(+), 430 deletions(-) delete mode 100644 .pylintrc delete mode 100644 Makefile delete mode 100644 src/sqlitecloud/pubsub.py delete mode 100644 src/sqlitecloud/upload.py delete mode 100644 src/sqlitecloud/vm.py delete mode 100644 src/tests/test.db diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 5019730..fd10308 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -3,4 +3,4 @@ FROM mcr.microsoft.com/devcontainers/python:3.6-bullseye ADD https://dl.yarnpkg.com/debian/pubkey.gpg /etc/apt/trusted.gpg.d/yarn.asc RUN chmod +r /etc/apt/trusted.gpg.d/*.asc && \ - echo "deb http://dl.yarnpkg.com/debian/ stable main" > /etc/apt/sources.list.d/yarn.list \ No newline at end of file + echo "deb http://dl.yarnpkg.com/debian/ stable main" > /etc/apt/sources.list.d/yarn.list diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index d05bdd6..0000000 --- a/.pylintrc +++ /dev/null @@ -1,5 +0,0 @@ -[pylint] -disable = C0111, C0301, W0511, R0903, C0103, W0703, W0102, W0621, R1732, W1514 - -[MASTER] -extension-pkg-whitelist=cv2 diff --git a/Makefile b/Makefile deleted file mode 100644 index a23be5c..0000000 --- a/Makefile +++ /dev/null @@ -1,23 +0,0 @@ -VENV = venv -PYTHON = $(VENV)/bin/python3 -PIP = $(VENV)/bin/pip - -add_src_to_pypath: - export PYTHONPATH=$$PYTHONPATH:$(pwd)/src - -test: - export PYTHONPATH=$$PYTHONPATH:$(pwd)/src - python3 -m pytest -s ./src - -lint: add_src_to_pypath - pylint src - -freeze: - $(PIP) freeze > requirements.txt - -dependencies: - $(PIP) install -r requirements.txt - -clean: - rm -rf $(VENV) - find . -type f -name '*.pyc' -delete \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index f8b65be..c3c2518 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,4 +6,4 @@ pytest==7.0.1 pytest-mock==3.6.1 black==22.8.0 python-dotenv==0.20.0 -lz4==3.1.10 \ No newline at end of file +lz4==3.1.10 diff --git a/requirements.txt b/requirements.txt index d74333f..ae5cd32 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -lz4==3.1.10 \ No newline at end of file +lz4==3.1.10 diff --git a/samples.ipynb b/samples.ipynb index d8e4d09..e7809fb 100644 --- a/samples.ipynb +++ b/samples.ipynb @@ -13,20 +13,17 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 8, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loading SQLITECLOUD lib from: /Users/sam/projects/codermine/sqlitecloud-sdk/C/libsqcloud.so\n" - ] - } - ], + "outputs": [], "source": [ - "from sqlitecloud.conn_info import user,password,host,db_name,port\n", - "from sqlitecloud.client import SqliteCloudClient, SqliteCloudAccount" + "import sys\n", + "\n", + "sys.path.append('/workspaces/python/src')\n", + "\n", + "from sqlitecloud.conn_info import user, password, host, db_name, port\n", + "from sqlitecloud.client import SqliteCloudClient\n", + "from sqlitecloud.types import SqliteCloudAccount" ] }, { @@ -40,27 +37,27 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ - "account = SqliteCloudAccount(user, password, host, db_name, port)\n", + "account = SqliteCloudAccount(user, password, host, db_name, int(port))\n", "client = SqliteCloudClient(cloud_account=account)\n", "conn = client.open_connection()" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'people'" + "'chinook.sqlite'" ] }, - "execution_count": 3, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -78,17 +75,9 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 15, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "select * from employees;\n" - ] - } - ], + "outputs": [], "source": [ "query = \"select * from employees;\"\n", "result = client.exec_query(query, conn)" @@ -103,21 +92,21 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{'emp_id': 1, 'emp_name': b'Bobby Tables'}\n", - "{'emp_id': 1, 'emp_name': b'Bobby Tables'}\n", - "{'emp_id': 1, 'emp_name': b'Bobby Tables'}\n", - "{'emp_id': 1, 'emp_name': b'Bobby Tables'}\n", - "{'emp_id': 1, 'emp_name': b'Bobby Tables'}\n", - "{'emp_id': 1, 'emp_name': b'Bobby Tables'}\n", - "{'emp_id': 1, 'emp_name': b'Bobby Tables'}\n", - "{'emp_id': 1, 'emp_name': b'Bobby Tables'}\n" + "{'EmployeeId': '1', 'LastName': 'Adams', 'FirstName': 'Andrew', 'Title': 'General Manager', 'ReportsTo': None, 'BirthDate': '1962-02-18 00:00:00', 'HireDate': '2002-08-14 00:00:00', 'Address': '11120 Jasper Ave NW', 'City': 'Edmonton', 'State': 'AB', 'Country': 'Canada', 'PostalCode': 'T5K 2N1', 'Phone': '+1 (780) 428-9482', 'Fax': '+1 (780) 428-3457', 'Email': 'andrew@chinookcorp.com'}\n", + "{'EmployeeId': '2', 'LastName': 'Edwards', 'FirstName': 'Nancy', 'Title': 'Sales Manager', 'ReportsTo': '1', 'BirthDate': '1958-12-08 00:00:00', 'HireDate': '2002-05-01 00:00:00', 'Address': '825 8 Ave SW', 'City': 'Calgary', 'State': 'AB', 'Country': 'Canada', 'PostalCode': 'T2P 2T3', 'Phone': '+1 (403) 262-3443', 'Fax': '+1 (403) 262-3322', 'Email': 'nancy@chinookcorp.com'}\n", + "{'EmployeeId': '3', 'LastName': 'Peacock', 'FirstName': 'Jane', 'Title': 'Sales Support Agent', 'ReportsTo': '2', 'BirthDate': '1973-08-29 00:00:00', 'HireDate': '2002-04-01 00:00:00', 'Address': '1111 6 Ave SW', 'City': 'Calgary', 'State': 'AB', 'Country': 'Canada', 'PostalCode': 'T2P 5M5', 'Phone': '+1 (403) 262-3443', 'Fax': '+1 (403) 262-6712', 'Email': 'jane@chinookcorp.com'}\n", + "{'EmployeeId': '4', 'LastName': 'Park', 'FirstName': 'Margaret', 'Title': 'Sales Support Agent', 'ReportsTo': '2', 'BirthDate': '1947-09-19 00:00:00', 'HireDate': '2003-05-03 00:00:00', 'Address': '683 10 Street SW', 'City': 'Calgary', 'State': 'AB', 'Country': 'Canada', 'PostalCode': 'T2P 5G3', 'Phone': '+1 (403) 263-4423', 'Fax': '+1 (403) 263-4289', 'Email': 'margaret@chinookcorp.com'}\n", + "{'EmployeeId': '5', 'LastName': 'Johnson', 'FirstName': 'Steve', 'Title': 'Sales Support Agent', 'ReportsTo': '2', 'BirthDate': '1965-03-03 00:00:00', 'HireDate': '2003-10-17 00:00:00', 'Address': '7727B 41 Ave', 'City': 'Calgary', 'State': 'AB', 'Country': 'Canada', 'PostalCode': 'T3B 1Y7', 'Phone': '1 (780) 836-9987', 'Fax': '1 (780) 836-9543', 'Email': 'steve@chinookcorp.com'}\n", + "{'EmployeeId': '6', 'LastName': 'Mitchell', 'FirstName': 'Michael', 'Title': 'IT Manager', 'ReportsTo': '1', 'BirthDate': '1973-07-01 00:00:00', 'HireDate': '2003-10-17 00:00:00', 'Address': '5827 Bowness Road NW', 'City': 'Calgary', 'State': 'AB', 'Country': 'Canada', 'PostalCode': 'T3B 0C5', 'Phone': '+1 (403) 246-9887', 'Fax': '+1 (403) 246-9899', 'Email': 'michael@chinookcorp.com'}\n", + "{'EmployeeId': '7', 'LastName': 'King', 'FirstName': 'Robert', 'Title': 'IT Staff', 'ReportsTo': '6', 'BirthDate': '1970-05-29 00:00:00', 'HireDate': '2004-01-02 00:00:00', 'Address': '590 Columbia Boulevard West', 'City': 'Lethbridge', 'State': 'AB', 'Country': 'Canada', 'PostalCode': 'T1K 5N8', 'Phone': '+1 (403) 456-9986', 'Fax': '+1 (403) 456-8485', 'Email': 'robert@chinookcorp.com'}\n", + "{'EmployeeId': '8', 'LastName': 'Callahan', 'FirstName': 'Laura', 'Title': 'IT Staff', 'ReportsTo': '6', 'BirthDate': '1968-01-09 00:00:00', 'HireDate': '2004-03-04 00:00:00', 'Address': '923 7 ST NW', 'City': 'Lethbridge', 'State': 'AB', 'Country': 'Canada', 'PostalCode': 'T1H 1Y8', 'Phone': '+1 (403) 467-3351', 'Fax': '+1 (403) 467-8772', 'Email': 'laura@chinookcorp.com'}\n" ] } ], @@ -135,41 +124,12 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "client.disconnect(conn)\n" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You can bind values to parametric queries: you can pass parameters as positional values in an array" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ev type: \n", - "ev type: \n" - ] - } - ], - "source": [ - "new_connection = client.open_connection()\n", - "result = client.exec_statement(\"select * from employees where emp_id = ? and emp_name = ? \", [1,'Bobby Tables'],conn=new_connection)\n", - "for r in result:\n", - " print(r)\n", - "client.disconnect(conn)" - ] } ], "metadata": { @@ -188,7 +148,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.6" + "version": "3.6.15" } }, "nbformat": 4, diff --git a/src/sqlitecloud/client.py b/src/sqlitecloud/client.py index 032d001..bad9894 100644 --- a/src/sqlitecloud/client.py +++ b/src/sqlitecloud/client.py @@ -1,7 +1,7 @@ """ Module to interact with remote SqliteCloud database """ -from typing import Any, List, Optional +from typing import Optional from urllib import parse from sqlitecloud.driver import Driver @@ -37,8 +37,6 @@ def __init__( self.config = SQCloudConfig() - # for pb in pub_subs: - # self._pub_sub_cbs.append(("channel1", SQCloudPubSubCB(pb))) if connection_str: self.config = self._parse_connection_string(connection_str) elif cloud_account: @@ -53,24 +51,16 @@ def open_connection(self) -> SQCloudConnect: SQCloudConnect: An instance of the SQCloudConnect class representing the connection to the SQCloud server. Raises: - Exception: If an error occurs while opening the connection. + SQCloudException: If an error occurs while opening the connection. """ connection = self.driver.connect( self.config.account.hostname, self.config.account.port, self.config ) - # SQCloudExec(connection, f"USE DATABASE {self.dbname};") - - # for cb in self._pub_sub_cbs: - # subscribe_pub_sub(connection, cb[0], cb[1]) - return connection def disconnect(self, conn: SQCloudConnect) -> None: - """Closes the connection to the database. - - This method is used to close the connection to the database. - """ + """Close the connection to the database.""" self.driver.disconnect(conn) def exec_query( @@ -146,7 +136,7 @@ def _parse_connection_string(self, connection_string) -> SQCloudConfig: path = params.path database = path.strip("/") if database: - config.account.database = database + config.account.dbname = database config.account.hostname = params.hostname config.account.port = ( diff --git a/src/sqlitecloud/driver.py b/src/sqlitecloud/driver.py index 8e9a612..fd49e82 100644 --- a/src/sqlitecloud/driver.py +++ b/src/sqlitecloud/driver.py @@ -1,7 +1,6 @@ import ssl from typing import Optional, Union import lz4.block -from sqlitecloud.lz4_custom import lz4decode from sqlitecloud.resultset import SQCloudResult from sqlitecloud.types import ( SQCLOUD_CMD, @@ -48,6 +47,9 @@ def connect( context.load_cert_chain( certfile=config.certificate, keyfile=config.certificate_key ) + if config.no_verify_certificate: + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE sock = context.wrap_socket(sock, server_hostname=hostname) @@ -103,10 +105,10 @@ def _internal_config_apply( command = "HASH" if config.account.password_hashed else "PASSWORD" buffer += f"AUTH USER {config.account.username} {command} {config.account.password};" - if config.account.database: + if config.account.dbname: if config.create and not config.memory: - buffer += f"CREATE DATABASE {config.account.database} IF NOT EXISTS;" - buffer += f"USE DATABASE {config.account.database};" + buffer += f"CREATE DATABASE {config.account.dbname} IF NOT EXISTS;" + buffer += f"USE DATABASE {config.account.dbname};" if config.compression: buffer += "SET CLIENT KEY COMPRESSION TO 1;" diff --git a/src/sqlitecloud/pubsub.py b/src/sqlitecloud/pubsub.py deleted file mode 100644 index 1216e10..0000000 --- a/src/sqlitecloud/pubsub.py +++ /dev/null @@ -1,23 +0,0 @@ -import ctypes -from typing import Any, Callable -from sqlitecloud.driver import ( - SQCloudConnect, - SQCloudExec, - SQCloudSetPubSubCallback, - SQCloudPubSubCB, -) -from sqlitecloud.wrapper_types import SQCloudResult - -SQCloudPubSubCallback = Callable[[SQCloudConnect, SQCloudResult, Any], None] - - -def subscribe_pub_sub( - connection: SQCloudConnect, channel: str, pub_sub_callback: SQCloudPubSubCB -): - SQCloudSetPubSubCallback(connection, pub_sub_callback, None) - SQCloudExec( - connection, - ctypes.c_char_p((f"CREATE CHANNEL {channel} IF NOT EXISTS;".encode("utf-8"))), - ) - SQCloudExec(connection, ctypes.c_char_p((f"LISTEN {channel};".encode("utf-8")))) - print(f"Listening{channel}") diff --git a/src/sqlitecloud/types.py b/src/sqlitecloud/types.py index 6f183a1..118ad1e 100644 --- a/src/sqlitecloud/types.py +++ b/src/sqlitecloud/types.py @@ -56,42 +56,39 @@ def __init__(self) -> None: class SqliteCloudAccount: - def __init__(self): + def __init__( + self, + username: Optional[str] = "", + password: Optional[str] = "", + hostname: Optional[str] = "", + dbname: Optional[str] = "", + port: Optional[int] = 8860, + apikey: Optional[str] = "", + ) -> None: # User name is required unless connectionstring is provided - self.username = "" + self.username = username # Password is required unless connection string is provided - self.password = "" + self.password = password # Password is hashed self.password_hashed = False # API key instead of username and password - self.apikey = "" + self.apikey = apikey # Name of database to open - self.database = "" + self.dbname = dbname # Like mynode.sqlitecloud.io - self.hostname = "" - self.port = 8860 + self.hostname = hostname + self.port = port class SQCloudConnect: - def __init__(self): - self.hostname: str = "" - self.port: int = "" + """ + Represents the connection information. + """ + def __init__(self): self.socket: any = None - self.config: SQCloudConfig - self.isblob: bool = False - self.config_to_free: bool # todo: is this needed? - - # pub/sub - # todo: check uuid type - self.uuid: str - - # todo: - # pubsubfd: int - - # callback: SQCloudPubSubCB class SQCloudConfig: diff --git a/src/sqlitecloud/upload.py b/src/sqlitecloud/upload.py deleted file mode 100644 index a687306..0000000 --- a/src/sqlitecloud/upload.py +++ /dev/null @@ -1,36 +0,0 @@ -import ctypes -import os -from typing import Optional -from sqlitecloud.driver import CallbackFunc, SQCloudConnect, SQCloudUploadDatabase - - -def xCallback(xdata, buffer, blen, ntot, nprogress): # pylint: disable=W0613 - nread = os.read(xdata, blen.contents.value) - if nread == -1: - return -1 - if nread == 0: - print("UPLOAD COMPLETE\n\n") - else: - print(f"{(nprogress + len(nread)) / ntot * 100:.2f}%") - - blen.contents.value = len(nread) - return 0 - - -def upload_db( - connection: SQCloudConnect, dbname: str, key: Optional[str], filename: str -) -> None: - fd_value = os.open(filename, os.O_RDONLY) - fd_void_ptr = ctypes.c_void_p(fd_value) - dbsize = os.path.getsize(filename) - print("dbsize", dbsize) - key_val = key.encode() if key else None - success = SQCloudUploadDatabase( - connection, - dbname.encode(), - key_val, - fd_void_ptr, - dbsize, - CallbackFunc(xCallback), - ) - print("upload_db", success) diff --git a/src/sqlitecloud/vm.py b/src/sqlitecloud/vm.py deleted file mode 100644 index 49e0efa..0000000 --- a/src/sqlitecloud/vm.py +++ /dev/null @@ -1,188 +0,0 @@ -import ctypes - -from sqlitecloud.driver import ( - SQCloudConnect, - SQCloudVM, - SQCloudVMCompile, - SQCloudVMStep, - SQCloudResult, - SQCloudVMResult, - SQCloudVMClose, - SQCloudVMErrorMsg, - SQCloudVMErrorCode, - SQCloudVMIsReadOnly, - SQCloudVMIsExplain, - SQCloudVMIsFinalized, - SQCloudVMBindParameterCount, - SQCloudVMBindParameterIndex, - SQCloudVMBindParameterName, - SQCloudVMColumnCount, - SQCloudVMBindDouble, - SQCloudVMBindInt, - SQCloudVMBindInt64, - SQCloudVMBindNull, - SQCloudVMBindText, - SQCloudVMBindBlob, - SQCloudVMColumnBlob, - SQCloudVMColumnText, - SQCloudVMColumnDouble, - SQCloudVMColumnInt32, - SQCloudVMColumnInt64, - SQCloudVMColumnLen, - SQCloudVMColumnType, - SQCloudVMLastRowID, - SQCloudVMChanges, - SQCloudVMTotalChanges -) -from sqlitecloud.wrapper_types import SQCLOUD_VALUE_TYPE - - -def compile_vm(conn: SQCloudConnect, query: str) -> SQCloudVM: - vm = SQCloudVMCompile(conn, ctypes.c_char_p(query.encode("utf-8")), -1, None) - - return vm - - -def step_vm(vm: SQCloudVMCompile) -> int: - return SQCloudVMStep(vm) - - -def result_vm(vm: SQCloudVMCompile) -> SQCloudResult: - return SQCloudVMResult(vm) - - -def close_vm(vm: SQCloudVMCompile) -> bool: - return SQCloudVMClose(vm) - - -def error_msg_vm(vm: SQCloudVMCompile) -> str | None: - result = SQCloudVMErrorMsg(vm) - - if result is None: - return None - - return ctypes.string_at(result).decode('utf-8') - - -def error_code_vm(vm: SQCloudVMCompile) -> int: - return SQCloudVMErrorCode(vm) - - -def is_read_only_vm(vm: SQCloudVMCompile) -> bool: - return SQCloudVMIsReadOnly(vm) - - -def is_explain_vm(vm: SQCloudVMCompile) -> int: - return SQCloudVMIsExplain(vm) - - -def is_finalized_vm(vm: SQCloudVMCompile) -> bool: - return SQCloudVMIsFinalized(vm) - - -def bind_parameter_count_vm(vm: SQCloudVMCompile) -> int: - return SQCloudVMBindParameterCount(vm) - - -def bind_parameter_index_vm(vm: SQCloudVMCompile, parameter_name: str) -> int: - return SQCloudVMBindParameterIndex(vm, ctypes.c_char_p(parameter_name.encode("utf-8"))) - - -def bind_parameter_name_vm(vm: SQCloudVMCompile, index: int) -> str | None: - result = SQCloudVMBindParameterName(vm, index) - - if result is None: - return None - - return ctypes.string_at(result).decode('utf-8') - - -def column_count_vm(vm: SQCloudVMCompile) -> int: - return SQCloudVMColumnCount(vm) - - -def bind_double_vm(vm: SQCloudVMCompile, index: int, value: float) -> bool: - return SQCloudVMBindDouble(vm, index, ctypes.c_double(value)) - - -def bind_int_vm(vm: SQCloudVMCompile, index: int, value: int) -> bool: - return SQCloudVMBindInt(vm, index, value) - - -def bind_int64_vm(vm: SQCloudVMCompile, index: int, value: int) -> bool: - return SQCloudVMBindInt64(vm, index, ctypes.c_int64(value)) - - -def bind_null_vm(vm: SQCloudVMCompile, index: int) -> bool: - return SQCloudVMBindNull(vm, index) - - -def bind_text_vm(vm: SQCloudVMCompile, index: int, value: str) -> bool: - return SQCloudVMBindText( - vm, - index, - ctypes.c_char_p(value.encode('utf-8')), - len(value.encode('utf-8')) - ) - - -def bind_blob_vm(vm: SQCloudVMCompile, index: int, value) -> bool: - return SQCloudVMBindBlob( - vm, - index, - value, - len(value.encode('utf-8')) - ) - - -def column_type_vm(vm: SQCloudVMCompile, index: int) -> SQCLOUD_VALUE_TYPE: - return SQCloudVMColumnType(vm, index) - - -def column_blob_vm(vm: SQCloudVMCompile, index: int) -> str: - len_holder = ctypes.c_uint32() - - blob_pointer = SQCloudVMColumnBlob(vm, index, ctypes.byref(len_holder)) - - blob_size = len_holder.value - blob_data = ctypes.cast(blob_pointer, ctypes.POINTER(ctypes.c_ubyte * blob_size)).contents - - return bytes(blob_data[:blob_size]).decode('utf-8') - - -def column_text_vm(vm: SQCloudVMCompile, index: int) -> str: - len_holder = ctypes.c_uint32() - - value = SQCloudVMColumnText(vm, index, ctypes.byref(len_holder)) - - text_value = value.decode('utf-8') if value else None - - return text_value[:len_holder.value] - - -def column_double_vm(vm: SQCloudVMCompile, index: int) -> float: - return SQCloudVMColumnDouble(vm, index) - - -def column_int_32_vm(vm: SQCloudVMCompile, index: int) -> int: - return SQCloudVMColumnInt32(vm, index) - - -def column_int_64_vm(vm: SQCloudVMCompile, index: int) -> int: - return SQCloudVMColumnInt64(vm, index) - - -def column_len_vm(vm: SQCloudVMCompile, index: int) -> int: - return SQCloudVMColumnLen(vm, index) - - -def last_row_id_vm(vm: SQCloudVMCompile) -> int: - return SQCloudVMLastRowID(vm) - - -def changes_vm(vm: SQCloudVMCompile) -> int: - return SQCloudVMChanges(vm) - - -def total_changes_vm(vm: SQCloudVMCompile) -> int: - return SQCloudVMTotalChanges(vm) diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 4adbbea..434c04b 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -3,4 +3,4 @@ @pytest.fixture(autouse=True) def load_env_vars(): - load_dotenv(".env") \ No newline at end of file + load_dotenv(".env") diff --git a/src/tests/integration/test_client.py b/src/tests/integration/test_client.py index 090e993..b4f17bb 100644 --- a/src/tests/integration/test_client.py +++ b/src/tests/integration/test_client.py @@ -27,7 +27,7 @@ def sqlitecloud_connection(self): account = SqliteCloudAccount() account.username = os.getenv("SQLITE_USER") account.password = os.getenv("SQLITE_PASSWORD") - account.database = os.getenv("SQLITE_DB") + account.dbname = os.getenv("SQLITE_DB") account.hostname = os.getenv("SQLITE_HOST") account.port = 8860 @@ -44,7 +44,7 @@ def test_connection_with_credentials(self): account = SqliteCloudAccount() account.username = os.getenv("SQLITE_USER") account.password = os.getenv("SQLITE_PASSWORD") - account.database = os.getenv("SQLITE_DB") + account.dbname = os.getenv("SQLITE_DB") account.hostname = os.getenv("SQLITE_HOST") account.port = 8860 @@ -68,7 +68,7 @@ def test_connection_with_apikey(self): def test_connection_without_credentials_and_apikey(self): account = SqliteCloudAccount() - account.database = os.getenv("SQLITE_DB") + account.dbname = os.getenv("SQLITE_DB") account.hostname = os.getenv("SQLITE_HOST") account.port = 8860 @@ -287,7 +287,7 @@ def test_rowset(self, sqlitecloud_connection): def test_max_rows_option(self): account = SqliteCloudAccount() account.hostname = os.getenv("SQLITE_HOST") - account.database = os.getenv("SQLITE_DB") + account.dbname = os.getenv("SQLITE_DB") account.apikey = os.getenv("SQLITE_API_KEY") client = SqliteCloudClient(cloud_account=account) @@ -302,7 +302,7 @@ def test_max_rows_option(self): def test_max_rowset_option_to_fail_when_rowset_is_bigger(self): account = SqliteCloudAccount() account.hostname = os.getenv("SQLITE_HOST") - account.database = os.getenv("SQLITE_DB") + account.dbname = os.getenv("SQLITE_DB") account.apikey = os.getenv("SQLITE_API_KEY") client = SqliteCloudClient(cloud_account=account) @@ -317,7 +317,7 @@ def test_max_rowset_option_to_fail_when_rowset_is_bigger(self): def test_max_rowset_option_to_succeed_when_rowset_is_lighter(self): account = SqliteCloudAccount() account.hostname = os.getenv("SQLITE_HOST") - account.database = os.getenv("SQLITE_DB") + account.dbname = os.getenv("SQLITE_DB") account.apikey = os.getenv("SQLITE_API_KEY") client = SqliteCloudClient(cloud_account=account) @@ -375,7 +375,7 @@ def test_serialized_operations(self, sqlitecloud_connection): def test_query_timeout(self): account = SqliteCloudAccount() account.hostname = os.getenv("SQLITE_HOST") - account.database = os.getenv("SQLITE_DB") + account.dbname = os.getenv("SQLITE_DB") account.apikey = os.getenv("SQLITE_API_KEY") client = SqliteCloudClient(cloud_account=account) @@ -465,7 +465,7 @@ def test_select_long_formatted_string(self, sqlitecloud_connection): def test_select_database(self): account = SqliteCloudAccount() account.hostname = os.getenv("SQLITE_HOST") - account.database = "" + account.dbname = "" account.apikey = os.getenv("SQLITE_API_KEY") client = SqliteCloudClient(cloud_account=account) @@ -588,7 +588,7 @@ def test_compression_single_column(self): account = SqliteCloudAccount() account.hostname = os.getenv("SQLITE_HOST") account.apikey = os.getenv("SQLITE_API_KEY") - account.database = os.getenv("SQLITE_DB") + account.dbname = os.getenv("SQLITE_DB") client = SqliteCloudClient(cloud_account=account) client.config.compression = True @@ -609,7 +609,7 @@ def test_compression_multiple_columns(self): account = SqliteCloudAccount() account.hostname = os.getenv("SQLITE_HOST") account.apikey = os.getenv("SQLITE_API_KEY") - account.database = os.getenv("SQLITE_DB") + account.dbname = os.getenv("SQLITE_DB") client = SqliteCloudClient(cloud_account=account) client.config.compression = True @@ -622,32 +622,3 @@ def test_compression_multiple_columns(self): assert rowset.nrows > 0 assert rowset.ncols > 0 assert rowset.get_name(0) == "AlbumId" - - # def test_send_blob(self, sqlitecloud_connection): - # connection, client = sqlitecloud_connection - # blob = b"Hello, this is a test blob" - - # result = client.sendblob(blob, connection) - - # assert result.get_result() == True - - # def test_send_empty_blob(self, sqlitecloud_connection): - # connection, client = sqlitecloud_connection - # blob = b"" - # result = client.sendblob(blob, connection) - # assert result is not None - # - - # def test_send_large_blob(self, sqlitecloud_connection): - # connection, client = sqlitecloud_connection - # blob = b"A" * 1024 * 1024 # 1MB blob - # result = client.sendblob(blob, connection) - # assert result is not None - # - - # def test_send_blob_with_connection_closed(self, sqlitecloud_connection): - # connection, client = sqlitecloud_connection - # client.disconnect(connection) - # blob = b"Hello, this is a test blob" - # with pytest.raises(Exception): - # client.sendblob(blob, connection) diff --git a/src/tests/test.db b/src/tests/test.db deleted file mode 100644 index 53b44973979fc4904a9c53e70d5d952964e2fc88..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 16384 zcmeI#K~BOz6oBC=5H*cZcdVMsQWGLkF94E>Ay$-95?M7BN^DYEg_^k4rC0I*T*jpb zuoz|`NOVp9CNDGZy_u%-bv3=N6J?=l(`ble z-!5ApqF9-W%G`Rl-Zm|zjsOA(Ab