werewolves/werewolves-macros/src/targets.rs

295 lines
11 KiB
Rust

use convert_case::{Case, Casing};
use proc_macro2::TokenStream;
use quote::{ToTokens, quote, quote_spanned};
use syn::spanned::Spanned;
use crate::hashlist::HashList;
const TARGETS_PATH: &str = "extract";
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum VariantFieldIndex {
Ident(syn::Ident),
Numeric(syn::LitInt),
None,
}
impl syn::parse::Parse for VariantFieldIndex {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
if input.is_empty() {
Ok(VariantFieldIndex::None)
} else if input.peek(syn::LitInt) {
Ok(VariantFieldIndex::Numeric(input.parse()?))
} else {
Ok(VariantFieldIndex::Ident(input.parse()?))
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct VariantAttr {
index: VariantFieldIndex,
alias: Option<syn::Ident>,
}
impl syn::parse::Parse for VariantAttr {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let index = if input.peek(syn::token::As) {
VariantFieldIndex::None
} else {
input.parse::<VariantFieldIndex>()?
};
if !input.peek(syn::token::As) {
return Ok(Self { index, alias: None });
}
input.parse::<syn::Token![as]>()?;
Ok(Self {
index,
alias: Some(input.parse()?),
})
}
}
#[derive(Debug, Clone)]
struct TargetVariant {
attr: VariantAttr,
access: TokenStream,
ty: syn::Type,
variant: syn::Variant,
}
impl TargetVariant {
fn name(&self) -> syn::Ident {
if let Some(alias) = self.attr.alias.as_ref() {
return alias.clone();
}
match &self.attr.index {
VariantFieldIndex::Ident(ident) => ident.clone(),
VariantFieldIndex::Numeric(_) | VariantFieldIndex::None => self.variant.ident.clone(),
}
}
}
#[derive(Debug)]
pub struct Targets {
name: syn::Ident,
targets: Vec<TargetVariant>,
total_fields: usize,
}
impl Targets {
pub fn parse(input: syn::DeriveInput) -> Result<Self, syn::Error> {
let name = input.ident;
let mut targets: Vec<TargetVariant> = Vec::new();
let total_fields;
match &input.data {
syn::Data::Enum(data) => {
total_fields = data.variants.len();
for field in data.variants.iter() {
let field_ident = &field.ident;
for attr in field.attrs.iter().filter(|attr| {
attr.path()
.get_ident()
.map(|id| id.to_string().as_str() == TARGETS_PATH)
.unwrap_or_default()
}) {
let attr_span = attr.span();
let attr = attr.parse_args::<VariantAttr>()?;
let (access, ty) = match &field.fields {
syn::Fields::Named(fields) => match &attr.index {
VariantFieldIndex::Ident(ident) => {
let mut matching_val = None;
let mut accesses = Vec::new();
for field in fields.named.iter() {
let matching = field
.ident
.as_ref()
.map(|i| i == ident)
.unwrap_or_default();
if matching_val.is_some() && matching {
panic!("duplicate?")
}
let ident = field.ident.clone().unwrap();
if matching {
matching_val = Some((ident.clone(), field.ty.clone()));
accesses.push(quote! {
#ident
});
} else {
accesses.push(quote! {
#ident: _
});
}
}
let (mut matching_ident, matching_ty) = matching_val.ok_or(
syn::Error::new(ident.span(), "no such variant field"),
)?;
matching_ident.set_span(ident.span());
(
quote! {
#name::#field_ident { #(#accesses),* } => &#matching_ident,
},
matching_ty,
)
}
VariantFieldIndex::Numeric(num) => {
return Err(syn::Error::new(
num.span(),
"cannot used numeric index for a variant with named fields",
));
}
VariantFieldIndex::None => {
if fields.named.len() == 1 {
let field = fields.named.iter().next().unwrap();
let field_ident = field.ident.as_ref().unwrap();
(
quote! {
#name::#field_ident { #field_ident } => &#field_ident,
},
field.ty.clone(),
)
} else {
return Err(syn::Error::new(
fields.named.span(),
"unnamed field index with more than one field",
));
}
}
},
syn::Fields::Unnamed(fields) => {
if let VariantFieldIndex::Numeric(num) = &attr.index {
let num = num.base10_parse::<u8>()? as usize;
let field = fields
.unnamed
.iter()
.nth(num)
.ok_or(syn::Error::new(
attr_span,
"field index out of range",
))?
.clone();
let left = (0..num).map(|_| quote! {_});
let right = (num..fields.unnamed.len()).map(|_| quote! {_});
(
quote! {
#name::#field_ident(#(#left),*, val, #(#right),*) => &val,
},
field.ty.clone(),
)
} else if fields.unnamed.len() == 1 {
(
quote! {
#name::#field_ident(val) => &val,
},
fields.unnamed.iter().next().unwrap().ty.clone(),
)
} else {
return Err(syn::Error::new(
fields.span(),
"unnamed fields without numeric index",
));
}
}
syn::Fields::Unit => {
return Err(syn::Error::new(
field.span(),
"target cannot be a unit field",
));
}
};
let target = TargetVariant {
attr,
ty,
access,
variant: field.clone(),
};
// if targets.iter().any(|t| t.name() == target.name()) {
// return Err(syn::Error::new(attr_span, "duplicate target field"));
// }
// target.variant.ident.set_span(attr_span);
targets.push(target);
}
}
}
_ => todo!(),
};
Ok(Self {
name,
targets,
total_fields,
})
}
}
impl ToTokens for Targets {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
// this is so horrid and haphazard but idc rn
// commenting to make the criticizing voice stfu
let mut by_target = {
let mut out = HashList::new();
for target in self.targets.iter() {
out.add(target.name().to_string(), target.clone());
}
out
};
let fns = by_target.decompose().into_iter().map(|(_, targets)| {
if targets.is_empty() {
return quote! {};
}
let all_variants = targets.len() == self.total_fields;
let fn_ret = {
let ty = targets.first().unwrap().ty.clone();
if all_variants {
quote! {&#ty}
} else {
quote! {Option<&#ty>}
}
};
let fn_name = {
let name = targets.first().unwrap().name();
syn::Ident::new(name.to_string().to_case(Case::Snake).as_str(), name.span())
};
let raw_accesses = targets.iter().map(|t| {
let access = &t.access;
quote_spanned! { t.attr.alias.as_ref().unwrap().span() =>
#access
}
});
let accesses = if all_variants {
quote! {
match self {
#(#raw_accesses)*
}
}
} else {
quote! {
Some(match self {
#(#raw_accesses)*
_ => return None,
})
}
};
quote! {
pub const fn #fn_name(&self) -> #fn_ret {
#accesses
}
}
});
let name = &self.name;
tokens.extend(quote! {
impl #name {
#(#fns)*
}
});
}
}