Skip to content

Commit

Permalink
feat(expressions-compatibility) add FFI interface to get operators al…
Browse files Browse the repository at this point in the history
…ong with fields

The FFI interface introducing by this commit has corresponding Go wrapping:
```
func ValidateExpression(atc string, s *Schema) (bool, []string, int64) {
	atcC := unsafe.Pointer(C.CString(atc))
	defer C.free(atcC)

	errLen := C.ulong(1024)
	errBuf := [1024]C.uchar{}

	expr := C.expression_validate((*C.uchar)(atcC), s.s, &errBuf[0], &errLen)
	defer C.expression_validate_free_result(expr)

	if expr == nil {
		fmt.Println("Error: ", string(errBuf[:errLen]))
		return false, nil, 0
	}

	validate := bool(expr.validate)
	operators := int64(expr.operators)
	flds := make([]string, expr.fields_total)
	flds_slice := unsafe.Slice(expr.fields, expr.fields_total)

	for i := range flds {
		flds[i] = C.GoString((*C.char)(unsafe.Pointer(flds_slice[i])))
	}

	return validate, flds, operators
}
```
  • Loading branch information
Oyami-Srk committed Sep 24, 2024
1 parent 3f8f648 commit 67c979c
Show file tree
Hide file tree
Showing 5 changed files with 272 additions and 1 deletion.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ regex = "1"
serde = { version = "1.0", features = ["derive"], optional = true }
serde_regex = { version = "1.1", optional = true }
fnv = "1"
bitflags = "2.6.0"

[lib]
crate-type = ["lib", "cdylib", "staticlib"]

[features]
default = ["ffi"]
default = ["ffi", "expr_validation"]
ffi = []
serde = ["cidr/serde", "dep:serde", "dep:serde_regex"]
expr_validation = []
7 changes: 7 additions & 0 deletions cbindgen.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,10 @@ prefix_with_name = true

[defines]
"feature = ffi" = "DEFINE_ATC_ROUTER_FFI"
"feature = expr_validation" = "DEFINE_ATC_ROUTER_EXPR_VALIDATION"

[macro_expansion]
bitflags = true

[export]
include = ["BinaryOperatorFlags"]
40 changes: 40 additions & 0 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,46 @@ pub enum BinaryOperator {
Contains, // contains
}

#[cfg(feature = "expr_validation")]
bitflags::bitflags! {
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[repr(C)]
pub struct BinaryOperatorFlags: u64 /* We can only have no more than 64 BinaryOperators */ {
const EQUALS = 1 << 0;
const NOT_EQUALS = 1 << 1;
const REGEX = 1 << 2;
const PREFIX = 1 << 3;
const POSTFIX = 1 << 4;
const GREATER = 1 << 5;
const GREATER_OR_EQUAL = 1 << 6;
const LESS = 1 << 7;
const LESS_OR_EQUAL = 1 << 8;
const IN = 1 << 9;
const NOT_IN = 1 << 10;
const CONTAINS = 1 << 11;
}
}

#[cfg(feature = "expr_validation")]
impl From<&BinaryOperator> for BinaryOperatorFlags {
fn from(op: &BinaryOperator) -> Self {
match op {
BinaryOperator::Equals => Self::EQUALS,
BinaryOperator::NotEquals => Self::NOT_EQUALS,
BinaryOperator::Regex => Self::REGEX,
BinaryOperator::Prefix => Self::PREFIX,
BinaryOperator::Postfix => Self::POSTFIX,
BinaryOperator::Greater => Self::GREATER,
BinaryOperator::GreaterOrEqual => Self::GREATER_OR_EQUAL,
BinaryOperator::Less => Self::LESS,
BinaryOperator::LessOrEqual => Self::LESS_OR_EQUAL,
BinaryOperator::In => Self::IN,
BinaryOperator::NotIn => Self::NOT_IN,
BinaryOperator::Contains => Self::CONTAINS,
}
}
}

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub enum Value {
Expand Down
215 changes: 215 additions & 0 deletions src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,164 @@ pub unsafe extern "C" fn context_get_result(
.unwrap()
}

#[cfg(feature = "expr_validation")]
#[derive(Debug)]
#[repr(C)]
pub struct ExpressionValidationResult {
validate: bool, // if validate is false, then none of the following fields are valid
fields: *mut *mut c_char,
fields_total: usize,
operators: u64,
}

/// Validate the ATC expression with the schema.
///
/// # Arguments
///
/// - `atc`: the C-style string representing the ATC expression.
/// - `schema`: a valid pointer to the [`Schema`] object returned by [`schema_new`].
/// - `errbuf`: a buffer to store the error message.
/// - `errbuf_len`: a pointer to the length of the error message buffer.
///
/// # Returns
///
/// Returns a pointer of `ExpressionValidationResult`.
/// If the expression is not valid, the `validate` field will be `false`,
/// and the error message will be stored in the `errbuf`,
/// and the length of the error message will be stored in `errbuf_len`.
///
/// # Panics
///
/// This function will panic when:
///
/// - `atc` doesn't point to a valid C-style string.
///
/// # Safety
///
/// Violating any of the following constraints will result in undefined behavior:
///
/// - `atc` must be a valid pointer to a C-style string, must be properly aligned,
/// and must not have '\0' in the middle.
/// - `errbuf` must be valid to read and write for `errbuf_len * size_of::<u8>()` bytes,
/// and it must be properly aligned.
/// - `errbuf_len` must be valid to read and write for `size_of::<usize>()` bytes,
/// and it must be properly aligned.
#[cfg(feature = "expr_validation")]
#[no_mangle]
pub unsafe extern "C" fn expression_validate(
atc: *const u8,
schema: &Schema,
errbuf: *mut u8,
errbuf_len: *mut usize,
) -> *mut ExpressionValidationResult {
use std::collections::HashMap;

use crate::ast::{BinaryOperatorFlags, Expression, LogicalExpression};
use crate::parser::parse;
use crate::semantics::FieldCounter;
use crate::semantics::Validate;

let atc = ffi::CStr::from_ptr(atc as *const c_char).to_str().unwrap();
let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN);
let mut validation_result = Box::new(ExpressionValidationResult {
validate: false,
fields: std::ptr::null_mut(),
fields_total: 0,
operators: 0,
});

// Parse the expression
let result = parse(atc).map_err(|e| e.to_string());
if let Err(e) = result {
let errlen = min(e.len(), *errbuf_len);
errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]);
*errbuf_len = errlen;
validation_result.validate = false;
return Box::into_raw(validation_result);
}
// Unwrap is safe since we've already checked for error
let ast = result.unwrap();

// Validate expression with schema
if let Err(e) = ast.validate(schema).map_err(|e| e.to_string()) {
let errlen = min(e.len(), *errbuf_len);
errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]);
*errbuf_len = errlen;
validation_result.validate = false;
return Box::into_raw(validation_result);
}

// Get used fields
let mut expr_fields = HashMap::new();
ast.add_to_counter(&mut expr_fields);
let fields_count = expr_fields.len();
let mut fields = Vec::<*const c_char>::with_capacity(fields_count);

for k in expr_fields.into_keys() {
let ffi_string = ffi::CString::new(k).unwrap();
let ptr = ffi_string.into_raw(); // Leak the CString
fields.push(ptr);
}

// Get used operators
let mut ops = BinaryOperatorFlags::empty();
fn visit(expr: &Expression, ops: &mut BinaryOperatorFlags) {
match expr {
Expression::Logical(logic_expression) => match logic_expression.as_ref() {
LogicalExpression::And(lhs, rhs) => {
visit(lhs, ops);
visit(rhs, ops);
}
LogicalExpression::Or(lhs, rhs) => {
visit(lhs, ops);
visit(rhs, ops);
}
LogicalExpression::Not(rhs) => {
visit(rhs, ops);
}
},
Expression::Predicate(predict) => {
let op = BinaryOperatorFlags::from(&predict.op);
ops.insert(op);
}
}
}
visit(&ast, &mut ops);

validation_result.validate = true;
validation_result.operators = ops.bits();
let boxed_fields = fields.into_boxed_slice();
let raw_boxed_fields = Box::into_raw(boxed_fields); // Leak the Box

validation_result.fields = raw_boxed_fields.cast();
validation_result.fields_total = raw_boxed_fields.len();

Box::into_raw(validation_result) // Leak the Box
}

/// Deallocate the ExpressionValidationResult object.
///
/// # Errors
///
/// This function never fails.
///
/// # Safety
///
/// Violating any of the following constraints will result in undefined behavior:
///
/// - `result` must be a valid pointer returned by [`expression_validate`].
#[cfg(feature = "expr_validation")]
#[no_mangle]
pub unsafe extern "C" fn expression_validate_free_result(result: *mut ExpressionValidationResult) {
let result = Box::from_raw(result);
let slice = std::slice::from_raw_parts_mut(result.fields, result.fields_total);
let boxed_fields = Box::from_raw(slice);
for ptr in boxed_fields.into_vec() {
let _ = ffi::CString::from_raw(ptr); // Drop the leaked CString
} // Drop the Box
drop(result); // Drop the Box
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -312,4 +470,61 @@ mod tests {
assert!(errbuf_len < ERR_BUF_MAX_LEN);
}
}

#[cfg(feature = "expr_validation")]
#[test]
fn test_expression_validate() {
use crate::ast::BinaryOperatorFlags;
unsafe {
let mut schema = Schema::default();
let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##;
let atc = ffi::CString::new(atc).unwrap();
let mut errbuf = vec![b'X'; ERR_BUF_MAX_LEN];
let mut errbuf_len = ERR_BUF_MAX_LEN;

schema.add_field("net.protocol", Type::String);
schema.add_field("net.dst.port", Type::Int);
schema.add_field("net.src.ip", Type::IpAddr);
schema.add_field("http.path", Type::String);

let result = expression_validate(
atc.as_bytes().as_ptr(),
&schema,
errbuf.as_mut_ptr(),
&mut errbuf_len,
);

assert!((*result).validate, "Validation failed");
assert_eq!((*result).fields_total, 4, "Fields count mismatch");
assert_eq!(
(*result).operators,
(BinaryOperatorFlags::EQUALS
| BinaryOperatorFlags::REGEX
| BinaryOperatorFlags::IN
| BinaryOperatorFlags::NOT_IN
| BinaryOperatorFlags::CONTAINS)
.bits(),
"Operators mismatch"
);
let mut fields = Vec::<String>::with_capacity((*result).fields_total);
for i in 0..(*result).fields_total {
let field = (*result).fields.add(i);
let field = ffi::CStr::from_ptr(*field).to_str().unwrap();
fields.push(field.to_string());
}
fields.sort();
assert_eq!(
fields,
vec![
"http.path".to_string(),
"net.dst.port".to_string(),
"net.protocol".to_string(),
"net.src.ip".to_string()
],
"Fields mismatch"
);

expression_validate_free_result(result);
}
}
}

0 comments on commit 67c979c

Please sign in to comment.