1717
1818use crate :: {
1919 trait_bounds,
20- utils:: { codec_crate_path, custom_mel_trait_bound, has_dumb_trait_bound, should_skip} ,
20+ utils:: { self , codec_crate_path, custom_mel_trait_bound, has_dumb_trait_bound, should_skip} ,
2121} ;
2222use quote:: { quote, quote_spanned} ;
23- use syn:: { parse_quote, spanned:: Spanned , Data , DeriveInput , Fields , Type } ;
23+ use syn:: { parse_quote, spanned:: Spanned , Data , DeriveInput , Field , Fields } ;
2424
2525/// impl for `#[derive(MaxEncodedLen)]`
2626pub fn derive_max_encoded_len ( input : proc_macro:: TokenStream ) -> proc_macro:: TokenStream {
@@ -43,13 +43,13 @@ pub fn derive_max_encoded_len(input: proc_macro::TokenStream) -> proc_macro::Tok
4343 parse_quote ! ( #crate_path:: MaxEncodedLen ) ,
4444 None ,
4545 has_dumb_trait_bound ( & input. attrs ) ,
46- & crate_path
46+ & crate_path,
4747 ) {
4848 return e. to_compile_error ( ) . into ( )
4949 }
5050 let ( impl_generics, ty_generics, where_clause) = input. generics . split_for_impl ( ) ;
5151
52- let data_expr = data_length_expr ( & input. data ) ;
52+ let data_expr = data_length_expr ( & input. data , & crate_path ) ;
5353
5454 quote:: quote!(
5555 const _: ( ) = {
@@ -64,22 +64,22 @@ pub fn derive_max_encoded_len(input: proc_macro::TokenStream) -> proc_macro::Tok
6464}
6565
6666/// generate an expression to sum up the max encoded length from several fields
67- fn fields_length_expr ( fields : & Fields ) -> proc_macro2:: TokenStream {
68- let type_iter : Box < dyn Iterator < Item = & Type > > = match fields {
69- Fields :: Named ( ref fields) => Box :: new (
70- fields . named . iter ( ) . filter_map ( |field| if should_skip ( & field. attrs ) {
67+ fn fields_length_expr ( fields : & Fields , crate_path : & syn :: Path ) -> proc_macro2:: TokenStream {
68+ let fields_iter : Box < dyn Iterator < Item = & Field > > = match fields {
69+ Fields :: Named ( ref fields) => Box :: new ( fields . named . iter ( ) . filter_map ( |field| {
70+ if should_skip ( & field. attrs ) {
7171 None
7272 } else {
73- Some ( & field. ty )
74- } )
75- ) ,
76- Fields :: Unnamed ( ref fields) => Box :: new (
77- fields . unnamed . iter ( ) . filter_map ( |field| if should_skip ( & field. attrs ) {
73+ Some ( field)
74+ }
75+ } ) ) ,
76+ Fields :: Unnamed ( ref fields) => Box :: new ( fields . unnamed . iter ( ) . filter_map ( |field| {
77+ if should_skip ( & field. attrs ) {
7878 None
7979 } else {
80- Some ( & field. ty )
81- } )
82- ) ,
80+ Some ( field)
81+ }
82+ } ) ) ,
8383 Fields :: Unit => Box :: new ( std:: iter:: empty ( ) ) ,
8484 } ;
8585 // expands to an expression like
@@ -92,9 +92,16 @@ fn fields_length_expr(fields: &Fields) -> proc_macro2::TokenStream {
9292 // `max_encoded_len` call. This way, if one field's type doesn't implement
9393 // `MaxEncodedLen`, the compiler's error message will underline which field
9494 // caused the issue.
95- let expansion = type_iter. map ( |ty| {
96- quote_spanned ! {
97- ty. span( ) => . saturating_add( <#ty>:: max_encoded_len( ) )
95+ let expansion = fields_iter. map ( |field| {
96+ let ty = & field. ty ;
97+ if utils:: is_compact ( & field) {
98+ quote_spanned ! {
99+ ty. span( ) => . saturating_add( <#crate_path:: Compact :: <#ty> as #crate_path:: MaxEncodedLen >:: max_encoded_len( ) )
100+ }
101+ } else {
102+ quote_spanned ! {
103+ ty. span( ) => . saturating_add( <#ty as #crate_path:: MaxEncodedLen >:: max_encoded_len( ) )
104+ }
98105 }
99106 } ) ;
100107 quote ! {
@@ -103,9 +110,9 @@ fn fields_length_expr(fields: &Fields) -> proc_macro2::TokenStream {
103110}
104111
105112// generate an expression to sum up the max encoded length of each field
106- fn data_length_expr ( data : & Data ) -> proc_macro2:: TokenStream {
113+ fn data_length_expr ( data : & Data , crate_path : & syn :: Path ) -> proc_macro2:: TokenStream {
107114 match * data {
108- Data :: Struct ( ref data) => fields_length_expr ( & data. fields ) ,
115+ Data :: Struct ( ref data) => fields_length_expr ( & data. fields , crate_path ) ,
109116 Data :: Enum ( ref data) => {
110117 // We need an expression expanded for each variant like
111118 //
@@ -121,7 +128,7 @@ fn data_length_expr(data: &Data) -> proc_macro2::TokenStream {
121128 // Each variant expression's sum is computed the way an equivalent struct's would be.
122129
123130 let expansion = data. variants . iter ( ) . map ( |variant| {
124- let variant_expression = fields_length_expr ( & variant. fields ) ;
131+ let variant_expression = fields_length_expr ( & variant. fields , crate_path ) ;
125132 quote ! {
126133 . max( #variant_expression)
127134 }
0 commit comments