test_ffi/
├── Cargo.toml
├── dll_api_wrapper_derive
│ ├── Cargo.toml
│ └── src
│ └── lib.rs
└── src
├── lib.rs
└── main.rs
[package]
name = "test_ffi"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
dll_api_wrapper_derive = { path = "dll_api_wrapper_derive", version = "0.1.0" }
libc = "0.2.149"
pub use dll_api_wrapper_derive::DllApiWrapper as DllApiWrapperDerive;
use libc::{dlclose, dlerror, dlopen, RTLD_LAZY, RTLD_LOCAL};
use std::ffi::{c_int, c_void, CStr};
use std::ops::{Deref, DerefMut};
const DEFAULT_FLAGS: c_int = RTLD_LOCAL | RTLD_LAZY;
pub trait DllApiWrapper
where
Self: Sized,
{
unsafe fn load(handle: *mut c_void) -> Result<Self, String>;
}
pub struct DllContainer<T>
where
T: DllApiWrapper,
{
handle: *mut c_void,
api: T,
}
impl<T> DllContainer<T>
where
T: DllApiWrapper,
{
pub unsafe fn load(name: &[u8]) -> Result<DllContainer<T>, String> {
let mut v: Vec<u8> = Vec::new();
let cstr = if name.len() > 0 && name[name.len() - 1] == 0 {
CStr::from_bytes_with_nul_unchecked(name)
} else {
v.extend_from_slice(name);
v.push(0);
CStr::from_bytes_with_nul_unchecked(v.as_slice())
};
let handle = dlopen(cstr.as_ptr(), DEFAULT_FLAGS);
if handle.is_null() {
return Err(CStr::from_ptr(dlerror()).to_string_lossy().to_string());
}
let api = T::load(handle)?;
Ok(Self { handle, api })
}
}
impl<T> Deref for DllContainer<T>
where
T: DllApiWrapper,
{
type Target = T;
fn deref(&self) -> &T {
&self.api
}
}
impl<T> DerefMut for DllContainer<T>
where
T: DllApiWrapper,
{
fn deref_mut(&mut self) -> &mut T {
&mut self.api
}
}
impl<T> Drop for DllContainer<T>
where
T: DllApiWrapper,
{
fn drop(&mut self) {
let result = unsafe { dlclose(self.handle) };
if result != 0 {
panic!("call to `dlclose()` failed");
}
self.handle = std::ptr::null_mut();
}
}
use libc::{c_char, c_int};
use std::ffi::CString;
use test_ffi::{DllApiWrapper, DllApiWrapperDerive, DllContainer};
pub struct LibcDLL {
puts: extern "C" fn(input: *const c_char) -> c_int,
atoi: extern "C" fn(input: *const c_char) -> c_int,
strlen: extern "C" fn(input: *const c_char) -> c_int,
}
fn main() {
unsafe {
let a: DllContainer<LibcDLL> = DllContainer::load("libc.so.6".as_bytes()).unwrap();
// 调用 puts 函数
let str = CString::new("Hello, World!").unwrap();
a.puts(str.as_ptr());
let str = CString::new("123").unwrap();
let i = a.atoi(str.as_ptr());
println!("i = {}", i);
let str = CString::new("123456").unwrap();
let l = a.strlen(str.as_ptr());
println!("l = {}", l);
}
}
[package]
name = "dll_api_wrapper_derive"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
libc = "0.2.149"
quote = "1.0.33"
syn = "2.0.38"
proc-macro2 = "1.0.68"
[lib]
proc-macro = true
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{
parse_macro_input, BareFnArg, Data, DeriveInput, Expr, Field, Fields, FieldsNamed, Lit, Meta,
Type, Visibility,
};
const TRAIT_NAME: &str = "DllApiWrapper";
pub fn dll_api_wrapper(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);
impl_dll_api_wrapper(&ast)
}
fn impl_dll_api_wrapper(ast: &DeriveInput) -> TokenStream {
let struct_name = &ast.ident;
let generics = &ast.generics;
let fields = get_fields(ast, TRAIT_NAME);
// 确保所有字段都是可识别的,否则 panic;
// 确保所有字段都是私有的,否则 panic;
for field in fields.named.iter() {
let _ = field
.ident
.as_ref()
.expect("all fields of struct need to be identifiable");
match field.vis {
Visibility::Inherited => {}
_ => panic!("all fields of struct need to be private"),
}
}
let field_iter = fields.named.iter().map(field_to_tokens);
let wrapper_iter = fields.named.iter().filter_map(field_to_wrapper);
let gen = quote! {
impl #generics DllApiWrapper for #struct_name #generics {
unsafe fn load(handle: *mut ::std::ffi::c_void) -> ::std::result::Result<Self, ::std::string::String> {
unsafe fn symbol_cstr<T>(
handle: *mut ::std::ffi::c_void,
name: & ::std::ffi::CStr,
allow_null: bool,
mutability: bool) -> ::std::result::Result<T, String> {
if ::std::mem::size_of::<T>() != ::std::mem::size_of::<*mut ()>() {
panic!("the type has a different size than a pointer - cannot transmute");
}
let _ = ::libc::dlerror();
let symbol = ::libc::dlsym(handle, name.as_ptr());
if symbol.is_null() {
let msg = ::libc::dlerror();
if !msg.is_null() {
let msg = ::std::ffi::CStr::from_ptr(msg).to_string_lossy().to_string();
return ::std::result::Result::Err(msg);
}
if allow_null {
if mutability {
return ::std::result::Result::Ok(
::std::mem::transmute_copy::<*mut ::std::ffi::c_void, T>(&::std::ptr::null_mut())
);
} else {
return ::std::result::Result::Ok(
::std::mem::transmute_copy::<*const ::std::ffi::c_void, T>(&::std::ptr::null())
);
}
}
let name = ::std::string::String::from_utf8_lossy(name.to_bytes());
return ::std::result::Result::Err(::std::format!("symbol `{}` not found", name));
}
::std::result::Result::Ok(
::std::mem::transmute_copy::<*mut ::std::ffi::c_void, T>(&symbol)
)
}
Ok(Self{
#(#field_iter),*
})
}
}
impl #generics #struct_name #generics {
#(#wrapper_iter)*
}
};
gen.into()
}
fn field_to_tokens(field: &Field) -> TokenStream2 {
let allow_null = has_marker_attr(field, "dlopen_allow_null");
let mut mutability = false;
match field.ty {
Type::BareFn(_) | Type::Reference(_) => {
if allow_null {
panic!("only pointers can have `dlopen_allow_null` attribute assigned");
}
}
Type::Ptr(ref ptr) => {
if ptr.mutability.is_some() {
mutability = true;
}
}
_ => panic!("only bare functions, references and pointers are allowed"),
}
let field_name = &field.ident;
let symbol_name = symbol_name(field);
quote! {
#field_name: symbol_cstr(
handle,
::std::ffi::CStr::from_bytes_with_nul_unchecked(concat!(#symbol_name, "\0").as_bytes()),
#allow_null,
#mutability,
)?
}
}
fn field_to_wrapper(field: &Field) -> Option<TokenStream2> {
let ident = &field.ident;
match &field.ty {
&Type::BareFn(ref fun) => {
if fun.variadic.is_some() {
None
} else {
let output = &fun.output;
let unsafety = &fun.unsafety;
let arg_iter = fun
.inputs
.iter()
.map(|a| fun_arg_to_tokens(a, &ident.as_ref().unwrap().to_string()));
let arg_names = fun.inputs.iter().map(|a| match a.name {
Some((ref arg_name, _)) => arg_name,
None => panic!("this should never happen"),
});
Some(quote! {
pub #unsafety fn #ident (&self, #(#arg_iter),* ) #output {
(self.#ident)(#(#arg_names),*)
}
})
}
}
&Type::Reference(ref ref_ty) => {
let ty = &ref_ty.elem;
let mut_acc = match ref_ty.mutability {
Some(_token) => {
let mut_ident = &format!("{}_mut", ident.as_ref().unwrap().to_string());
let method_name = syn::Ident::new(mut_ident, ident.as_ref().unwrap().span());
Some(quote! {
pub fn #method_name (&mut self) -> &mut #ty {
self.#ident
}
})
}
None => None,
};
let const_acc = quote! {
pub fn #ident (&self) -> & #ty {
self.#ident
}
};
Some(quote! {
#const_acc
#mut_acc
})
}
&Type::Ptr(_) => None,
_ => panic!("unknown field type, this should not happen"),
}
}
fn get_fields<'a>(ast: &'a DeriveInput, trait_name: &str) -> &'a FieldsNamed {
let vd = match ast.data {
Data::Enum(_) | Data::Union(_) => {
panic!("{} can only be derived by structures", trait_name)
}
Data::Struct(ref val) => val,
};
match &vd.fields {
&Fields::Named(ref f) => f,
&Fields::Unnamed(_) | &Fields::Unit => {
panic!("{} can only be derived by structures", trait_name)
}
}
}
fn has_marker_attr(field: &Field, attr_name: &str) -> bool {
for attr in field.attrs.iter() {
match attr.meta {
Meta::Path(ref path) => {
if let Some(attr_provided) = path.get_ident() {
if attr_provided == attr_name {
return true;
}
}
}
_ => continue,
}
}
false
}
fn symbol_name(field: &Field) -> String {
match find_str_attr_val(field, "dlopen_name") {
Some(val) => val,
None => match field.ident {
Some(ref val) => val.to_string(),
None => panic!("all struct fields need to be identifiable"),
},
}
}
fn find_str_attr_val(field: &Field, attr_name: &str) -> Option<String> {
for attr in field.attrs.iter() {
match attr.meta {
Meta::NameValue(ref meta) => {
if let Some(attr_provided) = meta.path.get_ident() {
if attr_provided != attr_name {
continue;
}
match &meta.value {
Expr::Lit(ref expr_lit) => match &expr_lit.lit {
Lit::Str(lit_str) => {
return Some(lit_str.value());
}
_ => panic!("{} attribute must be a string", attr_name),
},
_ => panic!("{} attribute must be a string", attr_name),
}
}
}
_ => continue,
}
}
None
}
fn fun_arg_to_tokens(arg: &BareFnArg, function_name: &str) -> TokenStream2 {
let arg_name = match arg.name {
Some(ref val) => &val.0,
None => panic!("function `{}` has an unnamed argument", function_name),
};
let ty = &arg.ty;
quote! {
#arg_name: #ty
}
}