From 8e8196db0e07983de13dca6e737bceb5ca092e09 Mon Sep 17 00:00:00 2001 From: Gor Stepanyan Date: Wed, 20 Jul 2022 10:32:26 +0200 Subject: [PATCH] Add necessary implementation to pass all BasicsTests Added implementation of cass_value_get_bytes to retrieve bytes of a Blob value. Added paging state token setter for CassStatement. Added paging state token getter from query result. The database version is set to 3.0.8 if the provided version in tests has prefix `release:`. This will skip some of the tests which are also skipped in the cpp-driver tests and are not disabled. --- .github/workflows/build.yml | 2 +- scylla-rust-wrapper/src/prepared.rs | 2 +- scylla-rust-wrapper/src/query_result.rs | 81 +++++++++++++++++++++- scylla-rust-wrapper/src/session.rs | 19 +++++ src/testing_unimplemented.cpp | 29 -------- tests/src/integration/ccm/cass_version.hpp | 4 ++ 6 files changed, 105 insertions(+), 32 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 25b90fee..be43aedf 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -33,4 +33,4 @@ jobs: run: cmake -DCASS_BUILD_INTEGRATION_TESTS=ON . && make - name: Run integration tests on Scylla 5.0.0 - run: valgrind --error-exitcode=123 ./cassandra-integration-tests --version=release:5.0.0 --category=CASSANDRA --verbose=ccm --gtest_filter="ClusterTests.*:BasicsTests.*RowsInRowsOut" + run: valgrind --error-exitcode=123 ./cassandra-integration-tests --version=release:5.0.0 --category=CASSANDRA --verbose=ccm --gtest_filter="ClusterTests.*:BasicsTests.*" diff --git a/scylla-rust-wrapper/src/prepared.rs b/scylla-rust-wrapper/src/prepared.rs index 357c2d11..81dcf0d1 100644 --- a/scylla-rust-wrapper/src/prepared.rs +++ b/scylla-rust-wrapper/src/prepared.rs @@ -18,7 +18,7 @@ pub unsafe extern "C" fn cass_prepared_free(prepared_raw: *const CassPrepared) { pub unsafe extern "C" fn cass_prepared_bind( prepared_raw: *const CassPrepared, ) -> *mut CassStatement { - let prepared: Arc<_> = Arc::from_raw(prepared_raw); + let prepared: Arc<_> = clone_arced(prepared_raw); let bound_values_size = prepared.get_prepared_metadata().col_count; // cloning prepared statement's arc, because creating CassStatement should not invalidate diff --git a/scylla-rust-wrapper/src/query_result.rs b/scylla-rust-wrapper/src/query_result.rs index 65860c65..acfe8bca 100644 --- a/scylla-rust-wrapper/src/query_result.rs +++ b/scylla-rust-wrapper/src/query_result.rs @@ -2,12 +2,14 @@ use crate::argconv::*; use crate::cass_error::CassError; use crate::cass_types::{cass_data_type_type, CassDataType, CassValueType}; use crate::inet::CassInet; +use crate::statement::CassStatement; use crate::types::*; use crate::uuid::CassUuid; use scylla::frame::response::result::{ColumnSpec, CqlValue}; -use scylla::Bytes; +use scylla::{BufMut, Bytes, BytesMut}; use std::convert::TryInto; use std::os::raw::c_char; +use std::slice; use std::sync::Arc; pub struct CassResult { @@ -401,6 +403,32 @@ pub unsafe extern "C" fn cass_value_get_string( CassError::CASS_OK } +#[no_mangle] +pub unsafe extern "C" fn cass_value_get_bytes( + value: *const CassValue, + output: *mut *const cass_byte_t, + output_size: *mut size_t, +) -> CassError { + if value.is_null() { + return CassError::CASS_ERROR_LIB_NULL_VALUE; + } + + let value_from_raw: &CassValue = ptr_to_ref(value); + + // FIXME: This should be implemented for all CQL types + // Note: currently rust driver does not allow to get raw bytes of the CQL value. + match &value_from_raw.value { + Some(CqlValue::Blob(bytes)) => { + *output = bytes.as_ptr() as *const cass_byte_t; + *output_size = bytes.len() as u64; + } + Some(_) => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, + None => return CassError::CASS_ERROR_LIB_NULL_VALUE, + } + + CassError::CASS_OK +} + #[no_mangle] pub unsafe extern "C" fn cass_value_is_null(value: *const CassValue) -> cass_bool_t { let val: &CassValue = ptr_to_ref(value); @@ -436,6 +464,57 @@ pub unsafe extern "C" fn cass_result_first_row(result_raw: *const CassResult) -> std::ptr::null() } +#[no_mangle] +pub unsafe extern "C" fn cass_result_paging_state_token( + result: *const CassResult, + paging_state: *mut *const c_char, + paging_state_size: *mut size_t, +) -> CassError { + if cass_result_has_more_pages(result) == cass_false { + return CassError::CASS_ERROR_LIB_NO_PAGING_STATE; + } + + let result_from_raw = ptr_to_ref(result); + + match &result_from_raw.metadata.paging_state { + Some(result_paging_state) => { + *paging_state_size = result_paging_state.len() as u64; + *paging_state = result_paging_state.as_ptr() as *const c_char; + } + None => { + *paging_state_size = 0; + *paging_state = std::ptr::null(); + } + } + + CassError::CASS_OK +} + +#[no_mangle] +pub unsafe extern "C" fn cass_statement_set_paging_state_token( + statement: *mut CassStatement, + paging_state: *const c_char, + paging_state_size: size_t, +) -> CassError { + let statement_from_raw = ptr_to_ref_mut(statement); + + if paging_state.is_null() { + statement_from_raw.paging_state = None; + return CassError::CASS_ERROR_LIB_NULL_VALUE; + } + + let paging_state_usize: usize = paging_state_size.try_into().unwrap(); + let mut b = BytesMut::with_capacity(paging_state_usize + 1); + b.put_slice(slice::from_raw_parts( + paging_state as *const u8, + paging_state_usize, + )); + b.extend_from_slice(b"\0"); + statement_from_raw.paging_state = Some(b.freeze()); + + CassError::CASS_OK +} + // CassResult functions: /* extern "C" { diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 5aaec2f6..c7a4a4c6 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -182,3 +182,22 @@ pub unsafe extern "C" fn cass_session_prepare_n( pub unsafe extern "C" fn cass_session_free(session_raw: *mut CassSession) { free_arced(session_raw); } + +#[no_mangle] +pub unsafe extern "C" fn cass_session_close(session: *mut CassSession) -> *const CassFuture { + let session_opt = ptr_to_ref(session); + + CassFuture::make_raw(async move { + let mut session_guard = session_opt.write().await; + if session_guard.is_none() { + return Err(( + CassError::CASS_ERROR_LIB_UNABLE_TO_CLOSE, + "Already closing or closed".msg(), + )); + } + + *session_guard = None; + + Ok(CassResultValue::Empty) + }) +} diff --git a/src/testing_unimplemented.cpp b/src/testing_unimplemented.cpp index 02c0e89d..04d847be 100644 --- a/src/testing_unimplemented.cpp +++ b/src/testing_unimplemented.cpp @@ -477,19 +477,6 @@ cass_prepared_parameter_data_type_by_name(const CassPrepared* prepared, const char* name){ throw std::runtime_error("UNIMPLEMENTED cass_prepared_parameter_data_type_by_name\n"); } -CASS_EXPORT CassError -cass_result_column_name(const CassResult *result, - size_t index, - const char** name, - size_t* name_length){ - throw std::runtime_error("UNIMPLEMENTED cass_result_column_name\n"); -} -CASS_EXPORT CassError -cass_result_paging_state_token(const CassResult* result, - const char** paging_state, - size_t* paging_state_size){ - throw std::runtime_error("UNIMPLEMENTED cass_result_paging_state_token\n"); -} CASS_EXPORT CassRetryPolicy* cass_retry_policy_default_new(){ throw std::runtime_error("UNIMPLEMENTED cass_retry_policy_default_new\n"); @@ -529,10 +516,6 @@ cass_schema_meta_version(const CassSchemaMeta* schema_meta){ throw std::runtime_error("UNIMPLEMENTED cass_schema_meta_version\n"); } CASS_EXPORT CassFuture* -cass_session_close(CassSession* session){ - throw std::runtime_error("UNIMPLEMENTED cass_session_close\n"); -} -CASS_EXPORT CassFuture* cass_session_connect_keyspace(CassSession* session, const CassCluster* cluster, const char* keyspace){ @@ -677,12 +660,6 @@ cass_statement_set_node(CassStatement* statement, throw std::runtime_error("UNIMPLEMENTED cass_statement_set_node\n"); } CASS_EXPORT CassError -cass_statement_set_paging_state_token(CassStatement* statement, - const char* paging_state, - size_t paging_state_size){ - throw std::runtime_error("UNIMPLEMENTED cass_statement_set_paging_state_token\n"); -} -CASS_EXPORT CassError cass_statement_set_request_timeout(CassStatement* statement, cass_uint64_t timeout_ms){ throw std::runtime_error("UNIMPLEMENTED cass_statement_set_request_timeout\n"); @@ -808,12 +785,6 @@ cass_user_type_set_duration_by_name(CassUserType* user_type, throw std::runtime_error("UNIMPLEMENTED cass_user_type_set_duration_by_name\n"); } CASS_EXPORT CassError -cass_value_get_bytes(const CassValue* value, - const cass_byte_t** output, - size_t* output_size){ - throw std::runtime_error("UNIMPLEMENTED cass_value_get_bytes\n"); -} -CASS_EXPORT CassError cass_value_get_decimal(const CassValue* value, const cass_byte_t** varint, size_t* varint_size, diff --git a/tests/src/integration/ccm/cass_version.hpp b/tests/src/integration/ccm/cass_version.hpp index 1db95726..ca13dc02 100644 --- a/tests/src/integration/ccm/cass_version.hpp +++ b/tests/src/integration/ccm/cass_version.hpp @@ -278,7 +278,11 @@ class CassVersion { */ void from_string(const std::string& version_string) { // Clean up the string for tokens + std::string scylla_version_prefix = "release:"; std::string version(version_string); + if (version.compare(0, scylla_version_prefix.size(), scylla_version_prefix) == 0) { + version = "3.0.8"; + } std::replace(version.begin(), version.end(), '.', ' '); std::size_t found = version.find("-"); if (found != std::string::npos) {