use std::collections::HashMap; use convert_case::{Case, Casing}; use proc_macro2::TokenStream; use quote::{ToTokens, quote, quote_spanned}; use syn::spanned::Spanned; use crate::hashlist::HashList; const TARGETS_PATH: &str = "extract"; #[derive(Debug, Clone, PartialEq, Eq, Hash)] enum VariantFieldIndex { Ident(syn::Ident), Numeric(syn::LitInt), None, } impl VariantFieldIndex { const fn is_numeric(&self) -> bool { match self { Self::Numeric(_) => true, _ => false, } } } impl syn::parse::Parse for VariantFieldIndex { fn parse(input: syn::parse::ParseStream) -> syn::Result { if input.is_empty() { Ok(VariantFieldIndex::None) } else if input.peek(syn::LitInt) { Ok(VariantFieldIndex::Numeric(input.parse()?)) } else { Ok(VariantFieldIndex::Ident(input.parse()?)) } } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct VariantAttr { index: VariantFieldIndex, alias: Option, } impl syn::parse::Parse for VariantAttr { fn parse(input: syn::parse::ParseStream) -> syn::Result { let index = if input.peek(syn::token::As) { VariantFieldIndex::None } else { input.parse::()? }; if !input.peek(syn::token::As) { return Ok(Self { index, alias: None }); } input.parse::()?; Ok(Self { index, alias: Some(input.parse()?), }) } } #[derive(Debug, Clone)] struct TargetVariant { attr: VariantAttr, access: TokenStream, ty: syn::Type, variant: syn::Variant, } impl TargetVariant { fn name(&self) -> syn::Ident { if let Some(alias) = self.attr.alias.as_ref() { return alias.clone(); } match &self.attr.index { VariantFieldIndex::Ident(ident) => ident.clone(), VariantFieldIndex::Numeric(_) | VariantFieldIndex::None => self.variant.ident.clone(), } } } #[derive(Debug)] pub struct Targets { name: syn::Ident, targets: Vec, total_fields: usize, } impl Targets { pub fn parse(input: syn::DeriveInput) -> Result { let name = input.ident; let mut targets: Vec = Vec::new(); let total_fields; match &input.data { syn::Data::Enum(data) => { total_fields = data.variants.len(); for field in data.variants.iter() { let field_ident = &field.ident; for attr in field.attrs.iter().filter(|attr| { attr.path() .get_ident() .map(|id| id.to_string().as_str() == TARGETS_PATH) .unwrap_or_default() }) { let attr_span = attr.span(); let attr = attr.parse_args::()?; let (access, ty) = match &field.fields { syn::Fields::Named(fields) => match &attr.index { VariantFieldIndex::Ident(ident) => { let mut matching_val = None; let mut accesses = Vec::new(); for field in fields.named.iter() { let matching = field .ident .as_ref() .map(|i| i.to_string() == ident.to_string()) .unwrap_or_default(); if matching_val.is_some() && matching { panic!("duplicate?") } let ident = field.ident.clone().unwrap(); if matching { matching_val = Some((ident.clone(), field.ty.clone())); accesses.push(quote! { #ident }); } else { accesses.push(quote! { #ident: _ }); } } let (mut matching_ident, matching_ty) = matching_val.ok_or( syn::Error::new(ident.span(), "no such variant field"), )?; matching_ident.set_span(ident.span()); ( quote! { #name::#field_ident { #(#accesses),* } => &#matching_ident, }, matching_ty, ) } VariantFieldIndex::Numeric(num) => { return Err(syn::Error::new( num.span(), "cannot used numeric index for a variant with named fields", )); } VariantFieldIndex::None => { if fields.named.len() == 1 { let field = fields.named.iter().next().unwrap(); let field_ident = field.ident.as_ref().unwrap(); ( quote! { #name::#field_ident { #field_ident } => &#field_ident, }, field.ty.clone(), ) } else { return Err(syn::Error::new( fields.named.span(), "unnamed field index with more than one field", )); } } }, syn::Fields::Unnamed(fields) => { if let VariantFieldIndex::Numeric(num) = &attr.index { let num = num.base10_parse::()? as usize; let field = fields .unnamed .iter() .nth(num) .ok_or(syn::Error::new( attr_span, "field index out of range", ))? .clone(); let left = (0..num).map(|_| quote! {_}); let right = (num..fields.unnamed.len()).map(|_| quote! {_}); ( quote! { #name::#field_ident(#(#left),*, val, #(#right),*) => &val, }, field.ty.clone(), ) } else if fields.unnamed.len() == 1 { ( quote! { #name::#field_ident(val) => &val, }, fields.unnamed.iter().next().unwrap().ty.clone(), ) } else { return Err(syn::Error::new( fields.span(), "unnamed fields without numeric index", )); } } syn::Fields::Unit => { return Err(syn::Error::new( field.span(), "target cannot be a unit field", )); } }; let target = TargetVariant { attr, ty, access, variant: field.clone(), }; // if targets.iter().any(|t| t.name() == target.name()) { // return Err(syn::Error::new(attr_span, "duplicate target field")); // } // target.variant.ident.set_span(attr_span); targets.push(target); } } } _ => todo!(), }; Ok(Self { name, targets, total_fields, }) } } impl ToTokens for Targets { fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { // this is so horrid and haphazard but idc rn // commenting to make the criticizing voice stfu let mut by_target = { let mut out = HashList::new(); for target in self.targets.iter() { out.add(target.name().to_string(), target.clone()); } out }; let fns = by_target.decompose().into_iter().map(|(_, targets)| { if targets.is_empty() { return quote! {}; } let all_variants = targets.len() == self.total_fields; let fn_ret = { let ty = targets.first().unwrap().ty.clone(); if all_variants { quote! {&#ty} } else { quote! {Option<&#ty>} } }; let fn_name = { let name = targets.first().unwrap().name(); syn::Ident::new(name.to_string().to_case(Case::Snake).as_str(), name.span()) }; let raw_accesses = targets.iter().map(|t| { let access = &t.access; quote_spanned! { t.attr.alias.as_ref().unwrap().span() => #access } }); let accesses = if all_variants { quote! { match self { #(#raw_accesses)* } } } else { quote! { Some(match self { #(#raw_accesses)* _ => return None, }) } }; quote! { pub const fn #fn_name(&self) -> #fn_ret { #accesses } } }); let name = &self.name; tokens.extend(quote! { impl #name { #(#fns)* } }); } }