diff --git a/scylla-rust-wrapper/src/prepared.rs b/scylla-rust-wrapper/src/prepared.rs index df70c86e..a51f136d 100644 --- a/scylla-rust-wrapper/src/prepared.rs +++ b/scylla-rust-wrapper/src/prepared.rs @@ -18,7 +18,7 @@ pub unsafe extern "C" fn cass_prepared_free(prepared_raw: *const CassPrepared) { pub unsafe extern "C" fn cass_prepared_bind( prepared_raw: *const CassPrepared, ) -> *mut CassStatement { - let prepared: Arc<_> = Arc::from_raw(prepared_raw); + let prepared: Arc<_> = clone_arced(prepared_raw); let bound_values_size = prepared.get_metadata().col_count; // cloning prepared statement's arc, because creating CassStatement should not invalidate diff --git a/scylla-rust-wrapper/src/query_result.rs b/scylla-rust-wrapper/src/query_result.rs index 216825b9..073b333e 100644 --- a/scylla-rust-wrapper/src/query_result.rs +++ b/scylla-rust-wrapper/src/query_result.rs @@ -2,12 +2,14 @@ use crate::argconv::*; use crate::cass_error::CassError; use crate::cass_types::{cass_data_type_type, CassDataType, CassValueType}; use crate::inet::CassInet; +use crate::statement::CassStatement; use crate::types::*; use crate::uuid::CassUuid; use scylla::frame::response::result::{ColumnSpec, CqlValue}; -use scylla::Bytes; +use scylla::{Bytes, BytesMut}; use std::convert::TryInto; use std::os::raw::c_char; +use std::slice; use std::sync::Arc; pub struct CassResult { @@ -402,6 +404,30 @@ pub unsafe extern "C" fn cass_value_get_string( CassError::CASS_OK } +#[no_mangle] +pub unsafe extern "C" fn cass_value_get_bytes( + value: *const CassValue, + output: *mut *const cass_byte_t, + output_size: *mut size_t, +) -> CassError { + if value.is_null() { + return CassError::CASS_ERROR_LIB_NULL_VALUE; + } + + let value_from_raw: &CassValue = ptr_to_ref(value); + + match &value_from_raw.value { + Some(CqlValue::Blob(bytes)) => { + *output = bytes.as_ptr() as *const cass_byte_t; + *output_size = bytes.len() as u64; + } + Some(_) => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, + None => return CassError::CASS_ERROR_LIB_NULL_VALUE, + } + + CassError::CASS_OK +} + #[no_mangle] pub unsafe extern "C" fn cass_value_is_null(value: *const CassValue) -> cass_bool_t { let val: &CassValue = ptr_to_ref(value); @@ -437,6 +463,63 @@ pub unsafe extern "C" fn cass_result_first_row(result_raw: *const CassResult) -> std::ptr::null() } +#[no_mangle] +pub unsafe extern "C" fn cass_result_paging_state_token( + result: *const CassResult, + paging_state: *mut *const c_char, + paging_state_size: *mut size_t, +) -> CassError { + if cass_result_has_more_pages(result) == cass_false { + return CassError::CASS_ERROR_LIB_NO_PAGING_STATE; + } + + let result_from_raw = ptr_to_ref(result); + + if result_from_raw.metadata.paging_state.is_none() { + *paging_state_size = 0; + *paging_state = std::ptr::null(); + } else { + *paging_state_size = result_from_raw + .metadata + .paging_state + .as_ref() + .unwrap() + .len() as u64; + *paging_state = result_from_raw + .metadata + .paging_state + .clone() + .unwrap() + .as_ptr() as *const c_char; + } + + CassError::CASS_OK +} + +#[no_mangle] +pub unsafe extern "C" fn cass_statement_set_paging_state_token( + statement: *mut CassStatement, + paging_state: *const c_char, + paging_state_size: size_t, +) -> CassError { + let statement_from_raw = ptr_to_ref_mut(statement); + + if paging_state.is_null() { + statement_from_raw.paging_state = None; + return CassError::CASS_ERROR_LIB_NULL_VALUE; + } + + let paging_state_usize: usize = paging_state_size.try_into().unwrap(); + let mut b = BytesMut::from(slice::from_raw_parts( + paging_state as *const u8, + paging_state_usize, + )); + b.extend_from_slice(b"\0"); + statement_from_raw.paging_state = Some(b.freeze()); + + CassError::CASS_OK +} + // CassResult functions: /* extern "C" { diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 114f15d5..5f3f16da 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -185,3 +185,22 @@ pub unsafe extern "C" fn cass_session_prepare_n( pub unsafe extern "C" fn cass_session_free(session_raw: *mut CassSession) { free_arced(session_raw); } + +#[no_mangle] +pub unsafe extern "C" fn cass_session_close(session: *mut CassSession) -> *const CassFuture { + let session_opt = ptr_to_ref(session); + + CassFuture::make_raw(async move { + let mut session_guard = session_opt.write().await; + if session_guard.is_none() { + return Err(( + CassError::CASS_ERROR_LIB_UNABLE_TO_CLOSE, + "Already closing or closed".msg(), + )); + } + + *session_guard = None; + + Ok(CassResultValue::Empty) + }) +} diff --git a/src/testing_unimplemented.cpp b/src/testing_unimplemented.cpp index 02c0e89d..04d847be 100644 --- a/src/testing_unimplemented.cpp +++ b/src/testing_unimplemented.cpp @@ -477,19 +477,6 @@ cass_prepared_parameter_data_type_by_name(const CassPrepared* prepared, const char* name){ throw std::runtime_error("UNIMPLEMENTED cass_prepared_parameter_data_type_by_name\n"); } -CASS_EXPORT CassError -cass_result_column_name(const CassResult *result, - size_t index, - const char** name, - size_t* name_length){ - throw std::runtime_error("UNIMPLEMENTED cass_result_column_name\n"); -} -CASS_EXPORT CassError -cass_result_paging_state_token(const CassResult* result, - const char** paging_state, - size_t* paging_state_size){ - throw std::runtime_error("UNIMPLEMENTED cass_result_paging_state_token\n"); -} CASS_EXPORT CassRetryPolicy* cass_retry_policy_default_new(){ throw std::runtime_error("UNIMPLEMENTED cass_retry_policy_default_new\n"); @@ -529,10 +516,6 @@ cass_schema_meta_version(const CassSchemaMeta* schema_meta){ throw std::runtime_error("UNIMPLEMENTED cass_schema_meta_version\n"); } CASS_EXPORT CassFuture* -cass_session_close(CassSession* session){ - throw std::runtime_error("UNIMPLEMENTED cass_session_close\n"); -} -CASS_EXPORT CassFuture* cass_session_connect_keyspace(CassSession* session, const CassCluster* cluster, const char* keyspace){ @@ -677,12 +660,6 @@ cass_statement_set_node(CassStatement* statement, throw std::runtime_error("UNIMPLEMENTED cass_statement_set_node\n"); } CASS_EXPORT CassError -cass_statement_set_paging_state_token(CassStatement* statement, - const char* paging_state, - size_t paging_state_size){ - throw std::runtime_error("UNIMPLEMENTED cass_statement_set_paging_state_token\n"); -} -CASS_EXPORT CassError cass_statement_set_request_timeout(CassStatement* statement, cass_uint64_t timeout_ms){ throw std::runtime_error("UNIMPLEMENTED cass_statement_set_request_timeout\n"); @@ -808,12 +785,6 @@ cass_user_type_set_duration_by_name(CassUserType* user_type, throw std::runtime_error("UNIMPLEMENTED cass_user_type_set_duration_by_name\n"); } CASS_EXPORT CassError -cass_value_get_bytes(const CassValue* value, - const cass_byte_t** output, - size_t* output_size){ - throw std::runtime_error("UNIMPLEMENTED cass_value_get_bytes\n"); -} -CASS_EXPORT CassError cass_value_get_decimal(const CassValue* value, const cass_byte_t** varint, size_t* varint_size, diff --git a/tests/src/integration/tests/test_basics.cpp b/tests/src/integration/tests/test_basics.cpp index 132d056a..e05a4668 100644 --- a/tests/src/integration/tests/test_basics.cpp +++ b/tests/src/integration/tests/test_basics.cpp @@ -254,7 +254,7 @@ CASSANDRA_INTEGRATION_TEST_F(BasicsTests, UnsetParameters) { // Execute the insert statement and validate the error code Result result = session_.execute(insert_statement, false); - if (server_version_ >= "2.2.0") { + if (server_version_ >= "release:2.2.0") { // Cassandra v2.2+ uses the value UNSET; making this a no-op ASSERT_EQ(CASS_OK, result.error_code()); } else {