295 lines
11 KiB
Rust
295 lines
11 KiB
Rust
use convert_case::{Case, Casing};
|
|
use proc_macro2::TokenStream;
|
|
use quote::{ToTokens, quote, quote_spanned};
|
|
use syn::spanned::Spanned;
|
|
|
|
use crate::hashlist::HashList;
|
|
|
|
const TARGETS_PATH: &str = "extract";
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
|
enum VariantFieldIndex {
|
|
Ident(syn::Ident),
|
|
Numeric(syn::LitInt),
|
|
None,
|
|
}
|
|
|
|
impl syn::parse::Parse for VariantFieldIndex {
|
|
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
|
if input.is_empty() {
|
|
Ok(VariantFieldIndex::None)
|
|
} else if input.peek(syn::LitInt) {
|
|
Ok(VariantFieldIndex::Numeric(input.parse()?))
|
|
} else {
|
|
Ok(VariantFieldIndex::Ident(input.parse()?))
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
|
struct VariantAttr {
|
|
index: VariantFieldIndex,
|
|
alias: Option<syn::Ident>,
|
|
}
|
|
|
|
impl syn::parse::Parse for VariantAttr {
|
|
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
|
let index = if input.peek(syn::token::As) {
|
|
VariantFieldIndex::None
|
|
} else {
|
|
input.parse::<VariantFieldIndex>()?
|
|
};
|
|
if !input.peek(syn::token::As) {
|
|
return Ok(Self { index, alias: None });
|
|
}
|
|
input.parse::<syn::Token![as]>()?;
|
|
|
|
Ok(Self {
|
|
index,
|
|
alias: Some(input.parse()?),
|
|
})
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
struct TargetVariant {
|
|
attr: VariantAttr,
|
|
access: TokenStream,
|
|
ty: syn::Type,
|
|
variant: syn::Variant,
|
|
}
|
|
|
|
impl TargetVariant {
|
|
fn name(&self) -> syn::Ident {
|
|
if let Some(alias) = self.attr.alias.as_ref() {
|
|
return alias.clone();
|
|
}
|
|
match &self.attr.index {
|
|
VariantFieldIndex::Ident(ident) => ident.clone(),
|
|
VariantFieldIndex::Numeric(_) | VariantFieldIndex::None => self.variant.ident.clone(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct Targets {
|
|
name: syn::Ident,
|
|
targets: Vec<TargetVariant>,
|
|
total_fields: usize,
|
|
}
|
|
|
|
impl Targets {
|
|
pub fn parse(input: syn::DeriveInput) -> Result<Self, syn::Error> {
|
|
let name = input.ident;
|
|
let mut targets: Vec<TargetVariant> = Vec::new();
|
|
let total_fields;
|
|
match &input.data {
|
|
syn::Data::Enum(data) => {
|
|
total_fields = data.variants.len();
|
|
for field in data.variants.iter() {
|
|
let field_ident = &field.ident;
|
|
for attr in field.attrs.iter().filter(|attr| {
|
|
attr.path()
|
|
.get_ident()
|
|
.map(|id| id.to_string().as_str() == TARGETS_PATH)
|
|
.unwrap_or_default()
|
|
}) {
|
|
let attr_span = attr.span();
|
|
let attr = attr.parse_args::<VariantAttr>()?;
|
|
let (access, ty) = match &field.fields {
|
|
syn::Fields::Named(fields) => match &attr.index {
|
|
VariantFieldIndex::Ident(ident) => {
|
|
let mut matching_val = None;
|
|
let mut accesses = Vec::new();
|
|
for field in fields.named.iter() {
|
|
let matching = field
|
|
.ident
|
|
.as_ref()
|
|
.map(|i| i == ident)
|
|
.unwrap_or_default();
|
|
if matching_val.is_some() && matching {
|
|
panic!("duplicate?")
|
|
}
|
|
|
|
let ident = field.ident.clone().unwrap();
|
|
if matching {
|
|
matching_val = Some((ident.clone(), field.ty.clone()));
|
|
accesses.push(quote! {
|
|
#ident
|
|
});
|
|
} else {
|
|
accesses.push(quote! {
|
|
#ident: _
|
|
});
|
|
}
|
|
}
|
|
let (mut matching_ident, matching_ty) = matching_val.ok_or(
|
|
syn::Error::new(ident.span(), "no such variant field"),
|
|
)?;
|
|
matching_ident.set_span(ident.span());
|
|
|
|
(
|
|
quote! {
|
|
#name::#field_ident { #(#accesses),* } => &#matching_ident,
|
|
},
|
|
matching_ty,
|
|
)
|
|
}
|
|
|
|
VariantFieldIndex::Numeric(num) => {
|
|
return Err(syn::Error::new(
|
|
num.span(),
|
|
"cannot used numeric index for a variant with named fields",
|
|
));
|
|
}
|
|
VariantFieldIndex::None => {
|
|
if fields.named.len() == 1 {
|
|
let field = fields.named.iter().next().unwrap();
|
|
let field_ident = field.ident.as_ref().unwrap();
|
|
(
|
|
quote! {
|
|
#name::#field_ident { #field_ident } => &#field_ident,
|
|
},
|
|
field.ty.clone(),
|
|
)
|
|
} else {
|
|
return Err(syn::Error::new(
|
|
fields.named.span(),
|
|
"unnamed field index with more than one field",
|
|
));
|
|
}
|
|
}
|
|
},
|
|
syn::Fields::Unnamed(fields) => {
|
|
if let VariantFieldIndex::Numeric(num) = &attr.index {
|
|
let num = num.base10_parse::<u8>()? as usize;
|
|
let field = fields
|
|
.unnamed
|
|
.iter()
|
|
.nth(num)
|
|
.ok_or(syn::Error::new(
|
|
attr_span,
|
|
"field index out of range",
|
|
))?
|
|
.clone();
|
|
let left = (0..num).map(|_| quote! {_});
|
|
let right = (num..fields.unnamed.len()).map(|_| quote! {_});
|
|
(
|
|
quote! {
|
|
#name::#field_ident(#(#left),*, val, #(#right),*) => &val,
|
|
},
|
|
field.ty.clone(),
|
|
)
|
|
} else if fields.unnamed.len() == 1 {
|
|
(
|
|
quote! {
|
|
#name::#field_ident(val) => &val,
|
|
},
|
|
fields.unnamed.iter().next().unwrap().ty.clone(),
|
|
)
|
|
} else {
|
|
return Err(syn::Error::new(
|
|
fields.span(),
|
|
"unnamed fields without numeric index",
|
|
));
|
|
}
|
|
}
|
|
syn::Fields::Unit => {
|
|
return Err(syn::Error::new(
|
|
field.span(),
|
|
"target cannot be a unit field",
|
|
));
|
|
}
|
|
};
|
|
let target = TargetVariant {
|
|
attr,
|
|
ty,
|
|
access,
|
|
variant: field.clone(),
|
|
};
|
|
// if targets.iter().any(|t| t.name() == target.name()) {
|
|
// return Err(syn::Error::new(attr_span, "duplicate target field"));
|
|
// }
|
|
// target.variant.ident.set_span(attr_span);
|
|
|
|
targets.push(target);
|
|
}
|
|
}
|
|
}
|
|
|
|
_ => todo!(),
|
|
};
|
|
|
|
Ok(Self {
|
|
name,
|
|
targets,
|
|
total_fields,
|
|
})
|
|
}
|
|
}
|
|
|
|
impl ToTokens for Targets {
|
|
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
|
|
// this is so horrid and haphazard but idc rn
|
|
// commenting to make the criticizing voice stfu
|
|
let mut by_target = {
|
|
let mut out = HashList::new();
|
|
for target in self.targets.iter() {
|
|
out.add(target.name().to_string(), target.clone());
|
|
}
|
|
out
|
|
};
|
|
let fns = by_target.decompose().into_iter().map(|(_, targets)| {
|
|
if targets.is_empty() {
|
|
return quote! {};
|
|
}
|
|
let all_variants = targets.len() == self.total_fields;
|
|
let fn_ret = {
|
|
let ty = targets.first().unwrap().ty.clone();
|
|
if all_variants {
|
|
quote! {&#ty}
|
|
} else {
|
|
quote! {Option<&#ty>}
|
|
}
|
|
};
|
|
let fn_name = {
|
|
let name = targets.first().unwrap().name();
|
|
syn::Ident::new(name.to_string().to_case(Case::Snake).as_str(), name.span())
|
|
};
|
|
let raw_accesses = targets.iter().map(|t| {
|
|
let access = &t.access;
|
|
|
|
quote_spanned! { t.attr.alias.as_ref().unwrap().span() =>
|
|
#access
|
|
}
|
|
});
|
|
let accesses = if all_variants {
|
|
quote! {
|
|
match self {
|
|
#(#raw_accesses)*
|
|
}
|
|
}
|
|
} else {
|
|
quote! {
|
|
Some(match self {
|
|
#(#raw_accesses)*
|
|
_ => return None,
|
|
})
|
|
}
|
|
};
|
|
|
|
quote! {
|
|
pub const fn #fn_name(&self) -> #fn_ret {
|
|
#accesses
|
|
}
|
|
}
|
|
});
|
|
let name = &self.name;
|
|
tokens.extend(quote! {
|
|
impl #name {
|
|
#(#fns)*
|
|
}
|
|
});
|
|
}
|
|
}
|