From ede36e1c562df1cc475e69a7049a11daa90d9330 Mon Sep 17 00:00:00 2001 From: Robert Bragg Date: Sun, 23 Jul 2023 17:13:46 +0100 Subject: [PATCH] Turn JNINativeInterface + JNIInvokeInterface_ into unions This implements a `#[jni_to_union]` procmacro that lets us declare what version of the JNI spec each function was added and then declare a union that only exposes functions the specific version being referenced. So instead of a struct like: ```rust struct JNIInvokeInterface_ { pub reserved0: *mut c_void, .. pub GetVersion: unsafe extern "system" fn(env: *mut JNIEnv) -> jint, .. pub NewLocalRef: unsafe extern "system" fn(env: *mut JNIEnv, ref_: jobject) -> jobject, } ``` we have a union like: ``` union JNIInvokeInterface_ { v1_1: JNIInvokeInterface__1_1, v1_2: JNIInvokeInterface__1_2, reserved: JNIInvokeInterface__reserved, } ``` And would access `GetVersion` like: `env.v1_1.GetVersion` and access `NewLocalRef` like: `env.v1_2.NewLocalRef`. Each version struct includes all functions for that version and lower, so it's also possible to access GetVersion like `env.v1_2.GetVersion`. This way it's more explicit when you're accessing functions that aren't part of JNI 1.1 which require you to have checked the version of JNI the JVM supports. --- Cargo.toml | 9 +- jni-to-union-macro/Cargo.toml | 11 + jni-to-union-macro/src/lib.rs | 195 ++++++++++++++++++ src/lib.rs | 57 +++-- systest/build.rs | 12 +- tests/jni-to-union.rs | 51 +++++ .../trybuild/01-jni-fail-read-1-2-from-1-1.rs | 9 + .../01-jni-fail-read-1-2-from-1-1.stderr | 7 + tests/trybuild/01-jni-fail-reserved-read.rs | 9 + .../trybuild/01-jni-fail-reserved-read.stderr | 7 + tests/trybuild/01-jni-to-union-basic-pass.rs | 21 ++ 11 files changed, 367 insertions(+), 21 deletions(-) create mode 100644 jni-to-union-macro/Cargo.toml create mode 100644 jni-to-union-macro/src/lib.rs create mode 100644 tests/jni-to-union.rs create mode 100644 tests/trybuild/01-jni-fail-read-1-2-from-1-1.rs create mode 100644 tests/trybuild/01-jni-fail-read-1-2-from-1-1.stderr create mode 100644 tests/trybuild/01-jni-fail-reserved-read.rs create mode 100644 tests/trybuild/01-jni-fail-reserved-read.stderr create mode 100644 tests/trybuild/01-jni-to-union-basic-pass.rs diff --git a/Cargo.toml b/Cargo.toml index 02e97ce..9cefb69 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,13 @@ keywords = ["java", "jni"] edition = "2021" [dependencies] +jni-to-union-macro = { path = "jni-to-union-macro" } + +[dev-dependencies] +trybuild = "1" [workspace] -members = ["systest"] +members = [ + "systest", + "jni-to-union-macro" +] diff --git a/jni-to-union-macro/Cargo.toml b/jni-to-union-macro/Cargo.toml new file mode 100644 index 0000000..e68611b --- /dev/null +++ b/jni-to-union-macro/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "jni-to-union-macro" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +syn = { version = "1", features = ["full"] } +quote = "1" diff --git a/jni-to-union-macro/src/lib.rs b/jni-to-union-macro/src/lib.rs new file mode 100644 index 0000000..b6eda2d --- /dev/null +++ b/jni-to-union-macro/src/lib.rs @@ -0,0 +1,195 @@ +extern crate proc_macro; + +use std::{cmp::Ordering, collections::HashSet}; + +use proc_macro::TokenStream; +use quote::{format_ident, quote}; +use syn::{parse_macro_input, spanned::Spanned, Data, DeriveInput, Fields, Ident, LitStr}; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +struct JniVersion { + major: u16, + minor: u16, +} +impl Default for JniVersion { + fn default() -> Self { + Self { major: 1, minor: 1 } + } +} +impl Ord for JniVersion { + fn cmp(&self, other: &Self) -> Ordering { + match self.major.cmp(&other.major) { + Ordering::Equal => self.minor.cmp(&other.minor), + major_order => major_order, + } + } +} +impl PartialOrd for JniVersion { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl syn::parse::Parse for JniVersion { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let version: LitStr = input.parse()?; + let version = version.value(); + if version == "reserved" { + // We special case version 999 later instead of making JniVersion an enum + return Ok(JniVersion { + major: 999, + minor: 0, + }); + } + let mut split = version.splitn(2, '.'); + const EXPECTED_MSG: &str = "Expected \"major.minor\" version number or \"reserved\""; + let major = split + .next() + .ok_or(syn::Error::new(input.span(), EXPECTED_MSG))?; + let major = major + .parse::() + .map_err(|_| syn::Error::new(input.span(), EXPECTED_MSG))?; + let minor = split + .next() + .unwrap_or("0") + .parse::() + .map_err(|_| syn::Error::new(input.span(), EXPECTED_MSG))?; + Ok(JniVersion { major, minor }) + } +} + +fn jni_to_union_impl(input: DeriveInput) -> syn::Result { + let original_name = &input.ident; + let original_visibility = &input.vis; + + let mut versions = HashSet::new(); + let mut versioned_fields = vec![]; + + if let Data::Struct(data) = &input.data { + if let Fields::Named(fields) = &data.fields { + for field in &fields.named { + // Default to version 1.1 + let mut min_version = JniVersion::default(); + + let mut field = field.clone(); + + let mut jni_added_attr = None; + field.attrs.retain(|attr| { + if attr.path.is_ident("jni_added") { + jni_added_attr = Some(attr.clone()); + false + } else { + true + } + }); + if let Some(attr) = jni_added_attr { + let version = attr.parse_args::()?; + min_version = version; + } + + versions.insert(min_version); + versioned_fields.push((min_version, field.clone())); + } + + // Quote structs and union + let mut expanded = quote! {}; + + let mut union_members = quote!(); + + let mut versions: Vec<_> = versions.into_iter().collect(); + versions.sort(); + + for version in versions { + let (struct_ident, version_ident, version_suffix) = if version.major == 999 { + ( + Ident::new(&format!("{}_reserved", original_name), original_name.span()), + Ident::new("reserved", original_name.span()), + "reserved".to_string(), + ) + } else if version.minor == 0 { + ( + Ident::new( + &format!("{}_{}", original_name, version.major), + original_name.span(), + ), + Ident::new(&format!("v{}", version.major), original_name.span()), + format!("{}", version.major), + ) + } else { + let struct_ident = Ident::new( + &format!("{}_{}_{}", original_name, version.major, version.minor), + original_name.span(), + ); + let version_ident = Ident::new( + &format!("v{}_{}", version.major, version.minor), + original_name.span(), + ); + ( + struct_ident, + version_ident, + format!("{}_{}", version.major, version.minor), + ) + }; + + let last = versioned_fields + .iter() + .rposition(|(v, _f)| v <= &version) + .unwrap_or(versioned_fields.len()); + let mut padding_idx = 0u32; + + let mut version_field_tokens = quote!(); + for (i, (field_min_version, field)) in versioned_fields.iter().enumerate() { + if i > last { + break; + } + if field_min_version > &version { + let reserved_ident = format_ident!("_padding_{}", padding_idx); + padding_idx += 1; + version_field_tokens.extend(quote! { #reserved_ident: *mut c_void, }); + } else { + version_field_tokens.extend(quote! { #field, }); + } + } + expanded.extend(quote! { + #[allow(non_snake_case, non_camel_case_types)] + #[repr(C)] + #[derive(Copy, Clone, Debug)] + #original_visibility struct #struct_ident { + #version_field_tokens + } + }); + + let api_comment = + format!("API when JNI version >= `JNI_VERSION_{}`", version_suffix); + union_members.extend(quote! { + #[doc = #api_comment] + #original_visibility #version_ident: #struct_ident, + }); + } + + expanded.extend(quote! { + #[repr(C)] + #original_visibility union #original_name { + #union_members + } + }); + + return Ok(TokenStream::from(expanded)); + } + } + + Err(syn::Error::new( + input.span(), + "Expected a struct with fields", + )) +} + +#[proc_macro_attribute] +pub fn jni_to_union(_attr: TokenStream, item: TokenStream) -> TokenStream { + let input = parse_macro_input!(item as DeriveInput); + + match jni_to_union_impl(input) { + Ok(tokens) => tokens, + Err(err) => err.into_compile_error().into(), + } +} diff --git a/src/lib.rs b/src/lib.rs index 3e9f6fa..09c48ae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,9 @@ use std::os::raw::c_char; use std::os::raw::c_void; +extern crate jni_to_union_macro; +use jni_to_union_macro::jni_to_union; + // FIXME is this sufficiently correct? pub type va_list = *mut c_void; @@ -59,7 +62,7 @@ pub type jfieldID = *mut _jfieldID; pub enum _jmethodID {} pub type jmethodID = *mut _jmethodID; -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Debug)] #[repr(C)] pub enum jobjectRefType { JNIInvalidRefType = 0, @@ -109,12 +112,18 @@ pub type JavaVM = *const JNIInvokeInterface_; #[repr(C)] #[non_exhaustive] -#[derive(Copy)] +#[jni_to_union] +#[derive(Copy, Clone, Debug)] pub struct JNINativeInterface_ { + #[jni_added("reserved")] pub reserved0: *mut c_void, + #[jni_added("reserved")] pub reserved1: *mut c_void, + #[jni_added("reserved")] pub reserved2: *mut c_void, + #[jni_added("reserved")] pub reserved3: *mut c_void, + #[jni_added("1.1")] pub GetVersion: unsafe extern "system" fn(env: *mut JNIEnv) -> jint, pub DefineClass: unsafe extern "system" fn( env: *mut JNIEnv, @@ -124,9 +133,12 @@ pub struct JNINativeInterface_ { len: jsize, ) -> jclass, pub FindClass: unsafe extern "system" fn(env: *mut JNIEnv, name: *const c_char) -> jclass, + #[jni_added("1.2")] pub FromReflectedMethod: unsafe extern "system" fn(env: *mut JNIEnv, method: jobject) -> jmethodID, + #[jni_added("1.2")] pub FromReflectedField: unsafe extern "system" fn(env: *mut JNIEnv, field: jobject) -> jfieldID, + #[jni_added("1.2")] pub ToReflectedMethod: unsafe extern "system" fn( env: *mut JNIEnv, cls: jclass, @@ -136,6 +148,7 @@ pub struct JNINativeInterface_ { pub GetSuperclass: unsafe extern "system" fn(env: *mut JNIEnv, sub: jclass) -> jclass, pub IsAssignableFrom: unsafe extern "system" fn(env: *mut JNIEnv, sub: jclass, sup: jclass) -> jboolean, + #[jni_added("1.2")] pub ToReflectedField: unsafe extern "system" fn( env: *mut JNIEnv, cls: jclass, @@ -149,14 +162,18 @@ pub struct JNINativeInterface_ { pub ExceptionDescribe: unsafe extern "system" fn(env: *mut JNIEnv), pub ExceptionClear: unsafe extern "system" fn(env: *mut JNIEnv), pub FatalError: unsafe extern "system" fn(env: *mut JNIEnv, msg: *const c_char) -> !, + #[jni_added("1.2")] pub PushLocalFrame: unsafe extern "system" fn(env: *mut JNIEnv, capacity: jint) -> jint, + #[jni_added("1.2")] pub PopLocalFrame: unsafe extern "system" fn(env: *mut JNIEnv, result: jobject) -> jobject, pub NewGlobalRef: unsafe extern "system" fn(env: *mut JNIEnv, lobj: jobject) -> jobject, pub DeleteGlobalRef: unsafe extern "system" fn(env: *mut JNIEnv, gref: jobject), pub DeleteLocalRef: unsafe extern "system" fn(env: *mut JNIEnv, obj: jobject), pub IsSameObject: unsafe extern "system" fn(env: *mut JNIEnv, obj1: jobject, obj2: jobject) -> jboolean, + #[jni_added("1.2")] pub NewLocalRef: unsafe extern "system" fn(env: *mut JNIEnv, ref_: jobject) -> jobject, + #[jni_added("1.2")] pub EnsureLocalCapacity: unsafe extern "system" fn(env: *mut JNIEnv, capacity: jint) -> jint, pub AllocObject: unsafe extern "system" fn(env: *mut JNIEnv, clazz: jclass) -> jobject, pub NewObject: @@ -1192,6 +1209,7 @@ pub struct JNINativeInterface_ { pub MonitorEnter: unsafe extern "system" fn(env: *mut JNIEnv, obj: jobject) -> jint, pub MonitorExit: unsafe extern "system" fn(env: *mut JNIEnv, obj: jobject) -> jint, pub GetJavaVM: unsafe extern "system" fn(env: *mut JNIEnv, vm: *mut *mut JavaVM) -> jint, + #[jni_added("1.2")] pub GetStringRegion: unsafe extern "system" fn( env: *mut JNIEnv, str: jstring, @@ -1200,6 +1218,7 @@ pub struct JNINativeInterface_ { buf: *mut jchar, ), + #[jni_added("1.2")] pub GetStringUTFRegion: unsafe extern "system" fn( env: *mut JNIEnv, str: jstring, @@ -1208,46 +1227,52 @@ pub struct JNINativeInterface_ { buf: *mut c_char, ), + #[jni_added("1.2")] pub GetPrimitiveArrayCritical: unsafe extern "system" fn( env: *mut JNIEnv, array: jarray, isCopy: *mut jboolean, ) -> *mut c_void, + #[jni_added("1.2")] pub ReleasePrimitiveArrayCritical: unsafe extern "system" fn(env: *mut JNIEnv, array: jarray, carray: *mut c_void, mode: jint), + #[jni_added("1.2")] pub GetStringCritical: unsafe extern "system" fn( env: *mut JNIEnv, string: jstring, isCopy: *mut jboolean, ) -> *const jchar, + #[jni_added("1.2")] pub ReleaseStringCritical: unsafe extern "system" fn(env: *mut JNIEnv, string: jstring, cstring: *const jchar), + #[jni_added("1.2")] pub NewWeakGlobalRef: unsafe extern "system" fn(env: *mut JNIEnv, obj: jobject) -> jweak, + #[jni_added("1.2")] pub DeleteWeakGlobalRef: unsafe extern "system" fn(env: *mut JNIEnv, ref_: jweak), + #[jni_added("1.2")] pub ExceptionCheck: unsafe extern "system" fn(env: *mut JNIEnv) -> jboolean, + #[jni_added("1.4")] pub NewDirectByteBuffer: unsafe extern "system" fn( env: *mut JNIEnv, address: *mut c_void, capacity: jlong, ) -> jobject, + #[jni_added("1.4")] pub GetDirectBufferAddress: unsafe extern "system" fn(env: *mut JNIEnv, buf: jobject) -> *mut c_void, + #[jni_added("1.4")] pub GetDirectBufferCapacity: unsafe extern "system" fn(env: *mut JNIEnv, buf: jobject) -> jlong, + #[jni_added("1.6")] pub GetObjectRefType: unsafe extern "system" fn(env: *mut JNIEnv, obj: jobject) -> jobjectRefType, + #[jni_added("9")] pub GetModule: unsafe extern "system" fn(env: *mut JNIEnv, clazz: jclass) -> jobject, } -impl Clone for JNINativeInterface_ { - fn clone(&self) -> Self { - *self - } -} - #[repr(C)] #[derive(Copy)] pub struct JNIEnv_ { @@ -1303,10 +1328,15 @@ impl Clone for JavaVMAttachArgs { } #[repr(C)] -#[derive(Copy)] +#[jni_to_union] +#[non_exhaustive] +#[derive(Copy, Clone, Debug)] pub struct JNIInvokeInterface_ { + #[jni_added("reserved")] pub reserved0: *mut c_void, + #[jni_added("reserved")] pub reserved1: *mut c_void, + #[jni_added("reserved")] pub reserved2: *mut c_void, pub DestroyJavaVM: unsafe extern "system" fn(vm: *mut JavaVM) -> jint, pub AttachCurrentThread: unsafe extern "system" fn( @@ -1316,9 +1346,12 @@ pub struct JNIInvokeInterface_ { ) -> jint, pub DetachCurrentThread: unsafe extern "system" fn(vm: *mut JavaVM) -> jint, + + #[jni_added("1.2")] pub GetEnv: unsafe extern "system" fn(vm: *mut JavaVM, penv: *mut *mut c_void, version: jint) -> jint, + #[jni_added("1.4")] pub AttachCurrentThreadAsDaemon: unsafe extern "system" fn( vm: *mut JavaVM, penv: *mut *mut c_void, @@ -1326,12 +1359,6 @@ pub struct JNIInvokeInterface_ { ) -> jint, } -impl Clone for JNIInvokeInterface_ { - fn clone(&self) -> Self { - *self - } -} - extern "system" { pub fn JNI_GetDefaultJavaVMInitArgs(args: *mut c_void) -> jint; pub fn JNI_CreateJavaVM( diff --git a/systest/build.rs b/systest/build.rs index 2e15712..c759da2 100644 --- a/systest/build.rs +++ b/systest/build.rs @@ -35,7 +35,11 @@ fn main() { .include(include_dir.join(platform_dir)); cfg.skip_type(|s| s == "va_list"); - cfg.skip_field(|s, field| s == "jvalue" && field == "_data"); + cfg.skip_field(|s, field| { + (s == "jvalue" && field == "_data") + || s == "JNINativeInterface_" + || s == "JNIInvokeInterface_" // ctest2 isn't able to test these unions + }); cfg.type_name(|s, is_struct, _is_union| { if is_struct && s.ends_with('_') { format!("struct {}", s) @@ -74,10 +78,8 @@ fn main() { windows }); cfg.skip_roundtrip(|s| { - matches!( - s, - "jboolean" // We don't need to be able to roundtrip all possible u8 values for a jboolean, since only 0 are 1 are considered valid. - ) + s == "jboolean" || // We don't need to be able to roundtrip all possible u8 values for a jboolean, since only 0 are 1 are considered valid. + s == "JNINativeInterface_" || s == "JNIInvokeInterface_" // ctest2 isn't able to test these unions }); cfg.header("jni.h").generate("../src/lib.rs", "all.rs"); } diff --git a/tests/jni-to-union.rs b/tests/jni-to-union.rs new file mode 100644 index 0000000..64b8e90 --- /dev/null +++ b/tests/jni-to-union.rs @@ -0,0 +1,51 @@ +extern crate jni_to_union_macro; +use jni_sys::{jint, JNIEnv, JNINativeInterface_}; +use jni_to_union_macro::jni_to_union; +use std::os::raw::c_void; + +#[test] +fn jni_to_union_trybuilds() { + let t = trybuild::TestCases::new(); + t.pass("tests/trybuild/01-jni-to-union-basic-pass.rs"); + t.compile_fail("tests/trybuild/01-jni-fail-reserved-read.rs"); + t.compile_fail("tests/trybuild/01-jni-fail-read-1-2-from-1-1.rs"); +} + +#[test] +fn jni_to_union() { + #[repr(C)] + #[jni_to_union] + pub struct MyStruct { + #[jni_added("reserved")] + pub reserved0: *mut c_void, + + #[jni_added("1.1")] + pub GetVersion: unsafe extern "system" fn(env: *mut JNIEnv) -> jint, + + #[jni_added("1.2")] + pub FunctionA: unsafe extern "system" fn(env: *mut JNIEnv) -> jint, + + pub FunctionB: unsafe extern "system" fn(env: *mut JNIEnv) -> jint, + + #[jni_added("1.3")] + pub FunctionC: unsafe extern "system" fn(env: *mut JNIEnv) -> jint, + } + + assert_eq!( + std::mem::size_of::(), + std::mem::size_of::<*mut c_void>() * 5 + ); + assert_eq!( + std::mem::size_of::(), + std::mem::size_of::<*mut c_void>() * 4 + ); +} + +const NUM_JNI_ENV_MEMBERS: usize = 234; +#[test] +fn jni_env_union() { + assert_eq!( + std::mem::size_of::(), + std::mem::size_of::<*mut c_void>() * NUM_JNI_ENV_MEMBERS + ); +} diff --git a/tests/trybuild/01-jni-fail-read-1-2-from-1-1.rs b/tests/trybuild/01-jni-fail-read-1-2-from-1-1.rs new file mode 100644 index 0000000..2114cd0 --- /dev/null +++ b/tests/trybuild/01-jni-fail-read-1-2-from-1-1.rs @@ -0,0 +1,9 @@ + +use jni_sys::JNINativeInterface_; + +pub fn main() { + unsafe { + let jni = std::mem::zeroed::(); + let _1_2_function = jni.v1_1.FromReflectedMethod; + } +} diff --git a/tests/trybuild/01-jni-fail-read-1-2-from-1-1.stderr b/tests/trybuild/01-jni-fail-read-1-2-from-1-1.stderr new file mode 100644 index 0000000..8518ecf --- /dev/null +++ b/tests/trybuild/01-jni-fail-read-1-2-from-1-1.stderr @@ -0,0 +1,7 @@ +error[E0609]: no field `FromReflectedMethod` on type `JNINativeInterface__1_1` + --> tests/trybuild/01-jni-fail-read-1-2-from-1-1.rs:7:38 + | +7 | let _1_2_function = jni.v1_1.FromReflectedMethod; + | ^^^^^^^^^^^^^^^^^^^ unknown field + | + = note: available fields are: `GetVersion`, `DefineClass`, `FindClass`, `GetSuperclass`, `IsAssignableFrom` ... and 203 others diff --git a/tests/trybuild/01-jni-fail-reserved-read.rs b/tests/trybuild/01-jni-fail-reserved-read.rs new file mode 100644 index 0000000..b536fec --- /dev/null +++ b/tests/trybuild/01-jni-fail-reserved-read.rs @@ -0,0 +1,9 @@ + +use jni_sys::JNINativeInterface_; + +pub fn main() { + unsafe { + let jni = std::mem::zeroed::(); + let _reserved = jni.v1_1.reserved0; + } +} diff --git a/tests/trybuild/01-jni-fail-reserved-read.stderr b/tests/trybuild/01-jni-fail-reserved-read.stderr new file mode 100644 index 0000000..2f34421 --- /dev/null +++ b/tests/trybuild/01-jni-fail-reserved-read.stderr @@ -0,0 +1,7 @@ +error[E0609]: no field `reserved0` on type `JNINativeInterface__1_1` + --> tests/trybuild/01-jni-fail-reserved-read.rs:7:34 + | +7 | let _reserved = jni.v1_1.reserved0; + | ^^^^^^^^^ unknown field + | + = note: available fields are: `GetVersion`, `DefineClass`, `FindClass`, `GetSuperclass`, `IsAssignableFrom` ... and 203 others diff --git a/tests/trybuild/01-jni-to-union-basic-pass.rs b/tests/trybuild/01-jni-to-union-basic-pass.rs new file mode 100644 index 0000000..a808ca0 --- /dev/null +++ b/tests/trybuild/01-jni-to-union-basic-pass.rs @@ -0,0 +1,21 @@ +use jni_sys::{jint, JNIEnv}; +use jni_to_union_macro::jni_to_union; +use std::os::raw::c_void; + +#[jni_to_union] +pub struct MyStruct { + #[jni_added("reserved")] + pub reserved0: *mut c_void, + + #[jni_added("1.1")] + pub GetVersion: unsafe extern "system" fn(env: *mut JNIEnv) -> jint, + + pub FunctionA: unsafe extern "system" fn(env: *mut JNIEnv) -> jint, + + #[jni_added("1.2")] + pub FunctionB: unsafe extern "system" fn(env: *mut JNIEnv) -> jint, +} + +pub fn main() { + assert_eq!(std::mem::size_of::(), std::mem::size_of::<*mut c_void>() * 4); +}