lib.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. //! Commonly used macros of libthreema.
  2. use convert_case::{Case, Casing as _};
  3. use proc_macro::TokenStream;
  4. use quote::{ToTokens as _, format_ident, quote};
  5. use syn::{
  6. self, Data, DeriveInput, Expr, Fields, Ident, ItemStruct, LitInt, LitStr, Variant,
  7. parse::{Parse, ParseStream},
  8. parse_macro_input, parse_quote,
  9. punctuated::Punctuated,
  10. };
  11. /// Provides the name of a `struct`, `enum` or `union`.
  12. ///
  13. /// # Examples
  14. ///
  15. /// Given the following:
  16. ///
  17. /// ```nobuild
  18. /// use libthreema_macros::Name;
  19. /// use crate::utils::debug::Name;
  20. ///
  21. /// #[derive(Name)]
  22. /// struct Something;
  23. /// ```
  24. ///
  25. /// the derive macro expands it to:
  26. ///
  27. /// ```nobuild
  28. /// struct Something;
  29. /// impl Name for Something {
  30. /// const NAME: &'static str = "Something";
  31. /// }
  32. /// ```
  33. #[proc_macro_derive(Name)]
  34. pub fn derive_name(input: TokenStream) -> TokenStream {
  35. // Parse the input tokens into a syntax tree.
  36. let input = parse_macro_input!(input as DeriveInput);
  37. // Implement `NAME`
  38. let name = input.ident;
  39. let literal_name = name.to_string();
  40. let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
  41. let expanded = quote! {
  42. impl #impl_generics crate::utils::debug::Name for #name #type_generics #where_clause {
  43. /// The name for debugging purposes
  44. const NAME: &'static str = #literal_name;
  45. }
  46. };
  47. // Generate code
  48. TokenStream::from(expanded)
  49. }
  50. /// Provides variant names for an `enum`.
  51. ///
  52. /// # Examples
  53. ///
  54. /// Given the following:
  55. ///
  56. /// ```
  57. /// use libthreema_macros::VariantNames;
  58. ///
  59. /// #[derive(VariantNames)]
  60. /// enum Something {
  61. /// SomeItem,
  62. /// SomeOtherItem(u64),
  63. /// }
  64. /// ```
  65. ///
  66. /// the derive macro expands it to:
  67. ///
  68. /// ```
  69. /// enum Something {
  70. /// SomeItem,
  71. /// SomeOtherItem(u64),
  72. /// }
  73. ///
  74. /// impl Something {
  75. /// pub const SOME_ITEM: &'static str = "SomeItem";
  76. /// pub const SOME_OTHER_ITEM: &'static str = "SomeOtherItem";
  77. ///
  78. /// pub const fn variant_name(&self) -> &'static str {
  79. /// match self {
  80. /// Self::SomeItem => Self::SOME_ITEM,
  81. /// Self::SomeOtherItem(..) => Self::SOME_OTHER_ITEM,
  82. /// }
  83. /// }
  84. /// }
  85. /// ```
  86. #[proc_macro_derive(VariantNames)]
  87. pub fn derive_variant_names(input: TokenStream) -> TokenStream {
  88. fn get_const_name(variant: &Variant) -> Ident {
  89. format_ident!(
  90. "{}",
  91. variant
  92. .ident
  93. .to_string()
  94. .from_case(Case::Pascal)
  95. .to_case(Case::UpperSnake)
  96. )
  97. }
  98. // Parse the input tokens into a syntax tree.
  99. let input = parse_macro_input!(input as DeriveInput);
  100. let enum_name = input.ident;
  101. let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
  102. // Map each variant to its literal identifier name
  103. let const_variants = match &input.data {
  104. Data::Enum(data) => data.variants.iter().map(|variant| {
  105. let docstring = format!(" Variant name of [`{}::{}`].", enum_name, variant.ident);
  106. let const_name = get_const_name(variant);
  107. let literal_name = variant.ident.to_string();
  108. quote! {
  109. #[doc = #docstring]
  110. pub const #const_name: &'static str = #literal_name;
  111. }
  112. }),
  113. #[expect(clippy::unimplemented, reason = "Only applicable to enums")]
  114. _ => unimplemented!(),
  115. };
  116. let mapped_variants = match &input.data {
  117. Data::Enum(data) => data.variants.iter().map(|variant| {
  118. let variant_name = &variant.ident;
  119. let parameters = match variant.fields {
  120. Fields::Unit => quote! {},
  121. Fields::Unnamed(..) => quote! { (..) },
  122. Fields::Named(..) => quote! { {..} },
  123. };
  124. let const_name = get_const_name(variant);
  125. quote! {
  126. Self::#variant_name #parameters => Self::#const_name
  127. }
  128. }),
  129. #[expect(clippy::unimplemented, reason = "Only applicable to enums")]
  130. _ => unimplemented!(),
  131. };
  132. // Implement for the enum
  133. let expanded = quote! {
  134. impl #impl_generics #enum_name #type_generics #where_clause {
  135. #(#const_variants)*
  136. /// Get the variant name of `self`.
  137. pub const fn variant_name(&self) -> &'static str {
  138. match self {
  139. #(#mapped_variants),*
  140. }
  141. }
  142. }
  143. };
  144. // Generate code
  145. TokenStream::from(expanded)
  146. }
  147. /// Implements [`Debug`] for the provided `enum`. Depends on [`VariantNames`].
  148. ///
  149. /// # Examples
  150. ///
  151. /// Given the following:
  152. ///
  153. /// ```
  154. /// use libthreema_macros::{DebugVariantNames, VariantNames};
  155. ///
  156. /// #[derive(DebugVariantNames, VariantNames)]
  157. /// enum Something {
  158. /// SomeItem,
  159. /// SomeOtherItem(u64),
  160. /// }
  161. /// ```
  162. ///
  163. /// the derive macro expands it to:
  164. ///
  165. /// ```
  166. /// # use libthreema_macros::VariantNames;
  167. /// #
  168. /// # #[derive(VariantNames)]
  169. /// enum Something {
  170. /// SomeItem,
  171. /// SomeOtherItem(u64),
  172. /// }
  173. ///
  174. /// // Omitting expansion of `VariantNames` here.
  175. ///
  176. /// impl std::fmt::Debug for Something {
  177. /// fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  178. /// write!(formatter, "{}::{}", "Something", self.variant_name())
  179. /// }
  180. /// }
  181. /// ```
  182. #[proc_macro_derive(DebugVariantNames)]
  183. pub fn derive_debug_variant_names(input: TokenStream) -> TokenStream {
  184. // Parse the input tokens into a syntax tree.
  185. let input = parse_macro_input!(input as DeriveInput);
  186. // Ensure it's an enum
  187. #[expect(clippy::unimplemented, reason = "Only applicable to enums")]
  188. if !matches!(input.data, Data::Enum(..)) {
  189. unimplemented!()
  190. }
  191. // Implement `Debug` for the enum
  192. let name = input.ident;
  193. let literal_name = name.to_string();
  194. let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
  195. let expanded = quote! {
  196. impl #impl_generics std::fmt::Debug for #name #type_generics #where_clause {
  197. fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  198. write!(formatter, "{}::{}", #literal_name, self.variant_name())
  199. }
  200. }
  201. };
  202. // Generate code
  203. TokenStream::from(expanded)
  204. }
  205. struct Arrays(Punctuated<Expr, syn::Token![,]>);
  206. impl Parse for Arrays {
  207. fn parse(input: ParseStream) -> syn::parse::Result<Arrays> {
  208. let punctuated = Punctuated::parse_terminated(input)?;
  209. Ok(Arrays(punctuated))
  210. }
  211. }
  212. /// Concatenates fixed-size byte arrays into a single large byte array.
  213. ///
  214. /// # Examples
  215. ///
  216. /// ```
  217. /// use libthreema_macros::concat_fixed_bytes;
  218. ///
  219. /// let a = [1u8; 4];
  220. /// let b = [2u8; 3];
  221. /// let c = [3u8; 3];
  222. ///
  223. /// let concatenated: [u8; 10] = concat_fixed_bytes!(a, b, c);
  224. /// assert_eq!(concatenated, [1, 1, 1, 1, 2, 2, 2, 3, 3, 3]);
  225. /// ```
  226. #[proc_macro]
  227. pub fn concat_fixed_bytes(tokens: TokenStream) -> TokenStream {
  228. let input = parse_macro_input!(tokens as Arrays).0.into_iter();
  229. let indices = input.clone().enumerate();
  230. let arrays: Vec<Expr> = input.collect();
  231. let field_length_parameters: Vec<Ident> = indices
  232. .clone()
  233. .map(|(index, _)| format_ident!("T{index}"))
  234. .collect();
  235. let field_names: Vec<Ident> = indices.map(|(index, _)| format_ident!("t{index}")).collect();
  236. let expanded = quote! {{
  237. #[repr(C)]
  238. struct ConcatenatedArrays<#(const #field_length_parameters: usize,)*> {
  239. #(#field_names: [u8; #field_length_parameters],)*
  240. }
  241. let concatenated_arrays = ConcatenatedArrays {
  242. #(#field_names: #arrays,)*
  243. };
  244. unsafe {
  245. let concatenated_bytes = core::mem::transmute(concatenated_arrays);
  246. concatenated_bytes
  247. }
  248. }};
  249. // Generate code
  250. TokenStream::from(expanded)
  251. }
  252. /// Annotates a `prost::Message` in the following way:
  253. ///
  254. /// ## `padding` fields
  255. ///
  256. /// These fields will be marked as deprecated to discourage direct usage of it. Furthermore, the padding tag
  257. /// will be extracted and made available on the message as a `PADDING_TAG` const. See the
  258. /// `ProtobufPaddedMessage` trait.
  259. #[proc_macro_attribute]
  260. pub fn protobuf_annotations(_attribute: TokenStream, input: TokenStream) -> TokenStream {
  261. fn annotate_protobuf_message(mut message: ItemStruct) -> syn::Result<TokenStream> {
  262. let mut padding_tag: Option<LitInt> = None;
  263. for field in &mut message.fields {
  264. let Some(name) = field.ident.as_ref() else {
  265. continue;
  266. };
  267. // Process `padding` fields so that we can use `ProtobufPaddedMessage` on them easily
  268. if name == "padding" {
  269. // Look for the tag value in `#[prost(..., tag = "<tag-value>")]`
  270. for attribute in &field.attrs {
  271. if !attribute.path().is_ident("prost") {
  272. continue;
  273. }
  274. attribute.parse_nested_meta(|meta| {
  275. let value: LitStr = meta.value()?.parse()?;
  276. if meta.path.is_ident("tag") {
  277. padding_tag = Some(value.parse()?);
  278. }
  279. Ok(())
  280. })?;
  281. }
  282. // Ensure nobody uses the field directly by deprecating it
  283. field.attrs.push(parse_quote! {
  284. #[deprecated(note = "Use ProtobufPaddedMessage trait to generate padding")]
  285. });
  286. }
  287. }
  288. let message_name = message.ident.clone();
  289. let mut output = message.into_token_stream();
  290. // Add any padding tag value as a const to the message
  291. if let Some(padding_tag) = padding_tag {
  292. output.extend(quote! {
  293. impl #message_name {
  294. /// Tag value of the padding of this message.
  295. pub const PADDING_TAG: u32 = #padding_tag;
  296. }
  297. });
  298. }
  299. Ok(output.into())
  300. }
  301. annotate_protobuf_message(parse_macro_input!(input as ItemStruct))
  302. .unwrap_or_else(|error| error.into_compile_error().into())
  303. }
  304. /// Implements [`subtle::ConstantTimeEq`] for named and unnamed structs.
  305. ///
  306. /// Moreover, this derives [`PartialEq`] and [`Eq`] using constant time comparison.
  307. ///
  308. /// Note: All fields must implement [`subtle::ConstantTimeEq`].
  309. ///
  310. /// The proc macro was adapted from <https://github.com/dalek-cryptography/subtle/pull/111>
  311. ///
  312. /// # Examples
  313. ///
  314. /// Given the following:
  315. ///
  316. /// ```
  317. /// use libthreema_macros::ConstantTimeEq;
  318. ///
  319. /// #[derive(ConstantTimeEq)]
  320. /// struct MyStruct {
  321. /// first_field: [u8; 32],
  322. /// second_field: u64,
  323. /// }
  324. /// ```
  325. ///
  326. /// the derive macro expands it to:
  327. ///
  328. /// ```
  329. /// struct MyStruct {
  330. /// first_field: [u8; 32],
  331. /// second_field: u64,
  332. /// }
  333. ///
  334. /// impl ::subtle::ConstantTimeEq for MyStruct {
  335. /// #[inline]
  336. /// fn ct_eq(&self, other: &Self) -> ::subtle::Choice {
  337. /// use ::subtle::ConstantTimeEq as _;
  338. /// return { self.first_field }.ct_eq(&{ other.first_field })
  339. /// & { self.second_field }.ct_eq(&{ other.second_field });
  340. /// }
  341. /// }
  342. /// impl PartialEq<Self> for MyStruct {
  343. /// #[inline]
  344. /// fn eq(&self, other: &Self) -> bool {
  345. /// use ::subtle::ConstantTimeEq as _;
  346. /// bool::from(self.ct_eq(other))
  347. /// }
  348. /// }
  349. /// impl Eq for MyStruct {}
  350. /// ```
  351. #[proc_macro_derive(ConstantTimeEq)]
  352. pub fn constant_time_eq(input: TokenStream) -> TokenStream {
  353. let input = parse_macro_input!(input as DeriveInput);
  354. #[expect(
  355. clippy::unimplemented,
  356. reason = "Only applicable to named and unnamed structs"
  357. )]
  358. let Data::Struct(data_struct) = input.data else {
  359. unimplemented!()
  360. };
  361. let constant_time_eq_stream = match &data_struct.fields {
  362. Fields::Named(fields_named) => {
  363. let mut token_stream = quote! {};
  364. let mut fields = fields_named.named.iter().peekable();
  365. while let Some(field) = fields.next() {
  366. let ident = &field.ident;
  367. token_stream.extend(quote! { {self.#ident}.ct_eq(&{other.#ident}) });
  368. if fields.peek().is_some() {
  369. token_stream.extend(quote! { & });
  370. }
  371. }
  372. token_stream
  373. },
  374. Fields::Unnamed(unnamed_fields) => {
  375. let mut token_stream = quote! {};
  376. let mut fields = unnamed_fields.unnamed.iter().enumerate().peekable();
  377. while let Some(field) = fields.next() {
  378. let index = syn::Index::from(field.0);
  379. token_stream.extend(quote! { {self.#index}.ct_eq(&{other.#index}) });
  380. if fields.peek().is_some() {
  381. token_stream.extend(quote! { & });
  382. }
  383. }
  384. token_stream
  385. },
  386. #[expect(clippy::unimplemented, reason = "Not applicable to unit-like structs")]
  387. Fields::Unit => unimplemented!(),
  388. };
  389. let name = &input.ident;
  390. let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
  391. let expanded = quote! {
  392. impl #impl_generics ::subtle::ConstantTimeEq for #name #type_generics #where_clause {
  393. #[inline]
  394. fn ct_eq(&self, other: &Self) -> ::subtle::Choice {
  395. use ::subtle::ConstantTimeEq as _;
  396. return #constant_time_eq_stream
  397. }
  398. }
  399. impl #impl_generics PartialEq<Self> for #name #type_generics #where_clause {
  400. #[inline]
  401. fn eq(&self, other: &Self) -> bool{
  402. use ::subtle::ConstantTimeEq as _;
  403. bool::from(self.ct_eq(other))
  404. }
  405. }
  406. impl #impl_generics Eq for #name #type_generics #where_clause {}
  407. };
  408. TokenStream::from(expanded)
  409. }
  410. // Avoids dependencies to be picked up by the linter.
  411. mod external_crate_false_positives {
  412. use subtle as _;
  413. }
  414. // Avoids test dependencies to be picked up by the linter.
  415. #[cfg(test)]
  416. mod external_crate_false_positives_test_feature {
  417. use rstest as _;
  418. use trybuild as _;
  419. }