test_ffi/├── Cargo.toml├── dll_api_wrapper_derive│ ├── Cargo.toml│ └── src│ └── lib.rs└── src├── lib.rs└── main.rs
[package]name = "test_ffi"version = "0.1.0"edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]dll_api_wrapper_derive = { path = "dll_api_wrapper_derive", version = "0.1.0" }libc = "0.2.149"pub use dll_api_wrapper_derive::DllApiWrapper as DllApiWrapperDerive;use libc::{dlclose, dlerror, dlopen, RTLD_LAZY, RTLD_LOCAL};use std::ffi::{c_int, c_void, CStr};use std::ops::{Deref, DerefMut};
const DEFAULT_FLAGS: c_int = RTLD_LOCAL | RTLD_LAZY;
pub trait DllApiWrapperwhere Self: Sized,{ unsafe fn load(handle: *mut c_void) -> Result<Self, String>;}
pub struct DllContainer<T>where T: DllApiWrapper,{ handle: *mut c_void, api: T,}
impl<T> DllContainer<T>where T: DllApiWrapper,{ pub unsafe fn load(name: &[u8]) -> Result<DllContainer<T>, String> { let mut v: Vec<u8> = Vec::new(); let cstr = if name.len() > 0 && name[name.len() - 1] == 0 { CStr::from_bytes_with_nul_unchecked(name) } else { v.extend_from_slice(name); v.push(0); CStr::from_bytes_with_nul_unchecked(v.as_slice()) }; let handle = dlopen(cstr.as_ptr(), DEFAULT_FLAGS); if handle.is_null() { return Err(CStr::from_ptr(dlerror()).to_string_lossy().to_string()); } let api = T::load(handle)?; Ok(Self { handle, api }) }}
impl<T> Deref for DllContainer<T>where T: DllApiWrapper,{ type Target = T; fn deref(&self) -> &T { &self.api }}
impl<T> DerefMut for DllContainer<T>where T: DllApiWrapper,{ fn deref_mut(&mut self) -> &mut T { &mut self.api }}
impl<T> Drop for DllContainer<T>where T: DllApiWrapper,{ fn drop(&mut self) { let result = unsafe { dlclose(self.handle) }; if result != 0 { panic!("call to `dlclose()` failed"); } self.handle = std::ptr::null_mut(); }}use libc::{c_char, c_int};use std::ffi::CString;use test_ffi::{DllApiWrapper, DllApiWrapperDerive, DllContainer};
pub struct LibcDLL { puts: extern "C" fn(input: *const c_char) -> c_int, atoi: extern "C" fn(input: *const c_char) -> c_int, strlen: extern "C" fn(input: *const c_char) -> c_int,}
fn main() { unsafe { let a: DllContainer<LibcDLL> = DllContainer::load("libc.so.6".as_bytes()).unwrap(); // 调用 puts 函数 let str = CString::new("Hello, World!").unwrap(); a.puts(str.as_ptr()); let str = CString::new("123").unwrap(); let i = a.atoi(str.as_ptr()); println!("i = {}", i); let str = CString::new("123456").unwrap(); let l = a.strlen(str.as_ptr()); println!("l = {}", l); }}[package]name = "dll_api_wrapper_derive"version = "0.1.0"edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]libc = "0.2.149"quote = "1.0.33"syn = "2.0.38"proc-macro2 = "1.0.68"
[lib]proc-macro = trueuse proc_macro::TokenStream;use proc_macro2::TokenStream as TokenStream2;use quote::quote;use syn::{ parse_macro_input, BareFnArg, Data, DeriveInput, Expr, Field, Fields, FieldsNamed, Lit, Meta, Type, Visibility,};
const TRAIT_NAME: &str = "DllApiWrapper";
pub fn dll_api_wrapper(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as DeriveInput); impl_dll_api_wrapper(&ast)}
fn impl_dll_api_wrapper(ast: &DeriveInput) -> TokenStream { let struct_name = &ast.ident; let generics = &ast.generics; let fields = get_fields(ast, TRAIT_NAME); // 确保所有字段都是可识别的,否则 panic; // 确保所有字段都是私有的,否则 panic; for field in fields.named.iter() { let _ = field .ident .as_ref() .expect("all fields of struct need to be identifiable"); match field.vis { Visibility::Inherited => {} _ => panic!("all fields of struct need to be private"), } }
let field_iter = fields.named.iter().map(field_to_tokens); let wrapper_iter = fields.named.iter().filter_map(field_to_wrapper);
let gen = quote! { impl #generics DllApiWrapper for #struct_name #generics { unsafe fn load(handle: *mut ::std::ffi::c_void) -> ::std::result::Result<Self, ::std::string::String> { unsafe fn symbol_cstr<T>( handle: *mut ::std::ffi::c_void, name: & ::std::ffi::CStr, allow_null: bool, mutability: bool) -> ::std::result::Result<T, String> { if ::std::mem::size_of::<T>() != ::std::mem::size_of::<*mut ()>() { panic!("the type has a different size than a pointer - cannot transmute"); } let _ = ::libc::dlerror(); let symbol = ::libc::dlsym(handle, name.as_ptr()); if symbol.is_null() { let msg = ::libc::dlerror(); if !msg.is_null() { let msg = ::std::ffi::CStr::from_ptr(msg).to_string_lossy().to_string(); return ::std::result::Result::Err(msg); } if allow_null { if mutability { return ::std::result::Result::Ok( ::std::mem::transmute_copy::<*mut ::std::ffi::c_void, T>(&::std::ptr::null_mut()) ); } else { return ::std::result::Result::Ok( ::std::mem::transmute_copy::<*const ::std::ffi::c_void, T>(&::std::ptr::null()) ); } } let name = ::std::string::String::from_utf8_lossy(name.to_bytes()); return ::std::result::Result::Err(::std::format!("symbol `{}` not found", name)); } ::std::result::Result::Ok( ::std::mem::transmute_copy::<*mut ::std::ffi::c_void, T>(&symbol) ) } Ok(Self{ #(#field_iter),* }) } }
impl #generics #struct_name #generics { #(#wrapper_iter)* } }; gen.into()}
fn field_to_tokens(field: &Field) -> TokenStream2 { let allow_null = has_marker_attr(field, "dlopen_allow_null"); let mut mutability = false; match field.ty { Type::BareFn(_) | Type::Reference(_) => { if allow_null { panic!("only pointers can have `dlopen_allow_null` attribute assigned"); } } Type::Ptr(ref ptr) => { if ptr.mutability.is_some() { mutability = true; } } _ => panic!("only bare functions, references and pointers are allowed"), }
let field_name = &field.ident; let symbol_name = symbol_name(field); quote! { #field_name: symbol_cstr( handle, ::std::ffi::CStr::from_bytes_with_nul_unchecked(concat!(#symbol_name, "\0").as_bytes()), #allow_null, #mutability, )? }}
fn field_to_wrapper(field: &Field) -> Option<TokenStream2> { let ident = &field.ident; match &field.ty { &Type::BareFn(ref fun) => { if fun.variadic.is_some() { None } else { let output = &fun.output; let unsafety = &fun.unsafety; let arg_iter = fun .inputs .iter() .map(|a| fun_arg_to_tokens(a, &ident.as_ref().unwrap().to_string())); let arg_names = fun.inputs.iter().map(|a| match a.name { Some((ref arg_name, _)) => arg_name, None => panic!("this should never happen"), }); Some(quote! { pub #unsafety fn #ident (&self, #(#arg_iter),* ) #output { (self.#ident)(#(#arg_names),*) } }) } } &Type::Reference(ref ref_ty) => { let ty = &ref_ty.elem; let mut_acc = match ref_ty.mutability { Some(_token) => { let mut_ident = &format!("{}_mut", ident.as_ref().unwrap().to_string()); let method_name = syn::Ident::new(mut_ident, ident.as_ref().unwrap().span()); Some(quote! { pub fn #method_name (&mut self) -> &mut #ty { self.#ident } }) } None => None, }; let const_acc = quote! { pub fn #ident (&self) -> & #ty { self.#ident } };
Some(quote! { #const_acc #mut_acc }) } &Type::Ptr(_) => None, _ => panic!("unknown field type, this should not happen"), }}
fn get_fields<'a>(ast: &'a DeriveInput, trait_name: &str) -> &'a FieldsNamed { let vd = match ast.data { Data::Enum(_) | Data::Union(_) => { panic!("{} can only be derived by structures", trait_name) } Data::Struct(ref val) => val, }; match &vd.fields { &Fields::Named(ref f) => f, &Fields::Unnamed(_) | &Fields::Unit => { panic!("{} can only be derived by structures", trait_name) } }}
fn has_marker_attr(field: &Field, attr_name: &str) -> bool { for attr in field.attrs.iter() { match attr.meta { Meta::Path(ref path) => { if let Some(attr_provided) = path.get_ident() { if attr_provided == attr_name { return true; } } } _ => continue, } } false}
fn symbol_name(field: &Field) -> String { match find_str_attr_val(field, "dlopen_name") { Some(val) => val, None => match field.ident { Some(ref val) => val.to_string(), None => panic!("all struct fields need to be identifiable"), }, }}
fn find_str_attr_val(field: &Field, attr_name: &str) -> Option<String> { for attr in field.attrs.iter() { match attr.meta { Meta::NameValue(ref meta) => { if let Some(attr_provided) = meta.path.get_ident() { if attr_provided != attr_name { continue; } match &meta.value { Expr::Lit(ref expr_lit) => match &expr_lit.lit { Lit::Str(lit_str) => { return Some(lit_str.value()); } _ => panic!("{} attribute must be a string", attr_name), }, _ => panic!("{} attribute must be a string", attr_name), } } } _ => continue, } } None}
fn fun_arg_to_tokens(arg: &BareFnArg, function_name: &str) -> TokenStream2 { let arg_name = match arg.name { Some(ref val) => &val.0, None => panic!("function `{}` has an unnamed argument", function_name), }; let ty = &arg.ty; quote! { #arg_name: #ty }}