Skip to main content

dfir_macro/
lib.rs

1#![cfg_attr(
2    nightly,
3    feature(proc_macro_diagnostic, proc_macro_span, proc_macro_def_site)
4)]
5
6use dfir_lang::diagnostic::Level;
7use dfir_lang::graph::{
8    BuildDfirCodeOutput, FlatGraphBuilder, FlatGraphBuilderOutput, build_dfir_code, partition_graph,
9};
10use dfir_lang::parse::DfirCode;
11use proc_macro2::{Ident, Literal, Span};
12use quote::{format_ident, quote, quote_spanned};
13use syn::spanned::Spanned;
14use syn::{
15    Attribute, Fields, GenericParam, ItemEnum, Variant, WherePredicate, parse_macro_input,
16    parse_quote,
17};
18
19/// Create a runnable graph instance using DFIR's custom syntax.
20///
21/// For example usage, take a look at the [`surface_*` tests in the `tests` folder](https://github.com/hydro-project/hydro/tree/main/dfir_rs/tests)
22/// or the [`examples` folder](https://github.com/hydro-project/hydro/tree/main/dfir_rs/examples)
23/// in the [Hydro repo](https://github.com/hydro-project/hydro).
24// TODO(mingwei): rustdoc examples inline.
25#[proc_macro]
26pub fn dfir_syntax(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
27    dfir_syntax_internal(input, Some(Level::Help))
28}
29
30/// [`dfir_syntax!`] but will not emit any diagnostics (errors, warnings, etc.).
31///
32/// Used for testing, users will want to use [`dfir_syntax!`] instead.
33#[proc_macro]
34pub fn dfir_syntax_noemit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
35    dfir_syntax_internal(input, None)
36}
37
38fn root() -> proc_macro2::TokenStream {
39    use std::env::{VarError, var as env_var};
40
41    let root_crate_name = format!(
42        "{}_rs",
43        env!("CARGO_PKG_NAME").strip_suffix("_macro").unwrap()
44    );
45    let root_crate_ident = root_crate_name.replace('-', "_");
46    let root_crate = proc_macro_crate::crate_name(&root_crate_name)
47        .unwrap_or_else(|_| panic!("{root_crate_name} should be present in `Cargo.toml`"));
48    match root_crate {
49        proc_macro_crate::FoundCrate::Itself => {
50            if Err(VarError::NotPresent) == env_var("CARGO_BIN_NAME")
51                && Err(VarError::NotPresent) != env_var("CARGO_PRIMARY_PACKAGE")
52                && Ok(&*root_crate_ident) == env_var("CARGO_CRATE_NAME").as_deref()
53            {
54                // In the crate itself, including unit tests.
55                quote! { crate }
56            } else {
57                // In an integration test, example, bench, etc.
58                let ident: Ident = Ident::new(&root_crate_ident, Span::call_site());
59                quote! { ::#ident }
60            }
61        }
62        proc_macro_crate::FoundCrate::Name(name) => {
63            let ident = Ident::new(&name, Span::call_site());
64            quote! { ::#ident }
65        }
66    }
67}
68
69fn dfir_syntax_internal(
70    input: proc_macro::TokenStream,
71    retain_diagnostic_level: Option<Level>,
72) -> proc_macro::TokenStream {
73    let input = parse_macro_input!(input as DfirCode);
74    let root = root();
75
76    let (code, mut diagnostics) = match build_dfir_code(input, &root) {
77        Ok(BuildDfirCodeOutput {
78            partitioned_graph: _,
79            code,
80            diagnostics,
81        }) => (code, diagnostics),
82        Err(diagnostics) => (quote! { #root::scheduled::graph::Dfir::new() }, diagnostics),
83    };
84
85    let diagnostic_tokens = retain_diagnostic_level.and_then(|level| {
86        diagnostics.retain_level(level);
87        diagnostics.try_emit_all().err()
88    });
89
90    quote! {
91        {
92            #diagnostic_tokens
93            #code
94        }
95    }
96    .into()
97}
98
99/// Parse DFIR syntax without emitting code.
100///
101/// Used for testing, users will want to use [`dfir_syntax!`] instead.
102#[proc_macro]
103pub fn dfir_parser(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
104    let input = parse_macro_input!(input as DfirCode);
105
106    let flat_graph_builder = FlatGraphBuilder::from_dfir(input);
107    let err_diagnostics = 'err: {
108        let (mut flat_graph, mut diagnostics) = match flat_graph_builder.build() {
109            Ok(FlatGraphBuilderOutput {
110                flat_graph,
111                uses: _,
112                diagnostics,
113            }) => (flat_graph, diagnostics),
114            Err(diagnostics) => {
115                break 'err diagnostics;
116            }
117        };
118
119        if let Err(diagnostic) = flat_graph.merge_modules() {
120            diagnostics.push(diagnostic);
121            break 'err diagnostics;
122        }
123
124        let flat_mermaid = flat_graph.mermaid_string_flat();
125
126        let part_graph = partition_graph(flat_graph).unwrap();
127        let part_mermaid = part_graph.to_mermaid(&Default::default());
128
129        let lit0 = Literal::string(&flat_mermaid);
130        let lit1 = Literal::string(&part_mermaid);
131
132        return quote! {
133            {
134                println!("{}\n\n{}\n", #lit0, #lit1);
135            }
136        }
137        .into();
138    };
139
140    err_diagnostics
141        .try_emit_all()
142        .err()
143        .unwrap_or_default()
144        .into()
145}
146
147fn wrap_localset(item: proc_macro::TokenStream, attribute: Attribute) -> proc_macro::TokenStream {
148    use quote::ToTokens;
149
150    let root = root();
151
152    let mut input: syn::ItemFn = match syn::parse(item) {
153        Ok(it) => it,
154        Err(e) => return e.into_compile_error().into(),
155    };
156
157    let statements = input.block.stmts;
158
159    input.block.stmts = parse_quote!(
160        #root::tokio::task::LocalSet::new().run_until(async {
161            #( #statements )*
162        }).await
163    );
164
165    input.attrs.push(attribute);
166
167    input.into_token_stream().into()
168}
169
170/// Checks that the given closure is a morphism. For now does nothing.
171#[proc_macro]
172pub fn morphism(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
173    // TODO(mingwei): some sort of code analysis?
174    item
175}
176
177/// Checks that the given closure is a monotonic function. For now does nothing.
178#[proc_macro]
179pub fn monotonic_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
180    // TODO(mingwei): some sort of code analysis?
181    item
182}
183
184#[proc_macro_attribute]
185pub fn dfir_test(
186    args: proc_macro::TokenStream,
187    item: proc_macro::TokenStream,
188) -> proc_macro::TokenStream {
189    let root = root();
190    let args_2: proc_macro2::TokenStream = args.into();
191
192    wrap_localset(
193        item,
194        parse_quote!(
195            #[#root::tokio::test(flavor = "current_thread", #args_2)]
196        ),
197    )
198}
199
200#[proc_macro_attribute]
201pub fn dfir_main(
202    _: proc_macro::TokenStream,
203    item: proc_macro::TokenStream,
204) -> proc_macro::TokenStream {
205    let root = root();
206
207    wrap_localset(
208        item,
209        parse_quote!(
210            #[#root::tokio::main(flavor = "current_thread")]
211        ),
212    )
213}
214
215#[proc_macro_derive(DemuxEnum)]
216pub fn derive_demux_enum(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
217    let root = root();
218
219    let ItemEnum {
220        ident: item_ident,
221        generics,
222        variants,
223        ..
224    } = parse_macro_input!(item as ItemEnum);
225
226    // Sort variants alphabetically.
227    let mut variants = variants.into_iter().collect::<Vec<_>>();
228    variants.sort_by(|a, b| a.ident.cmp(&b.ident));
229
230    // Return type for each variant.
231    let variant_output_types = variants
232        .iter()
233        .map(|variant| match &variant.fields {
234            Fields::Named(fields) => {
235                let field_types = fields.named.iter().map(|field| &field.ty);
236                quote! {
237                    ( #( #field_types, )* )
238                }
239            }
240            Fields::Unnamed(fields) => {
241                let field_types = fields.unnamed.iter().map(|field| &field.ty);
242                quote! {
243                    ( #( #field_types, )* )
244                }
245            }
246            Fields::Unit => quote!(()),
247        })
248        .collect::<Vec<_>>();
249
250    let variant_generics_sink = variants
251        .iter()
252        .map(|variant| format_ident!("__Sink{}", variant.ident))
253        .collect::<Vec<_>>();
254    let variant_generics_pinned_sink = variant_generics_sink.iter().map(|ident| {
255        quote_spanned! {ident.span()=>
256            ::std::pin::Pin::<&mut #ident>
257        }
258    });
259    let variant_generics_pinned_sink_all = quote! {
260        ( #( #variant_generics_pinned_sink, )* )
261    };
262    let variant_localvars_sink = variants
263        .iter()
264        .map(|variant| {
265            format_ident!(
266                "__sink_{}",
267                variant.ident.to_string().to_lowercase(),
268                span = variant.ident.span()
269            )
270        })
271        .collect::<Vec<_>>();
272
273    let mut full_generics_sink = generics.clone();
274    full_generics_sink.params.extend(
275        variant_generics_sink
276            .iter()
277            .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
278    );
279    full_generics_sink.make_where_clause().predicates.extend(
280        variant_generics_sink
281            .iter()
282            .zip(variant_output_types.iter())
283            .map::<WherePredicate, _>(|(sink_generic, output_type)| {
284                parse_quote! {
285                    // TODO(mingwei): generic error types?
286                    #sink_generic: #root::futures::sink::Sink<#output_type, Error = #root::Never>
287                }
288            }),
289    );
290
291    let variant_pats_sink_start_send = variants.iter().zip(variant_localvars_sink.iter()).map(
292        |(variant, sinkvar)| {
293            let Variant { ident, fields, .. } = variant;
294            let (fields_pat, push_item) = field_pattern_item(fields);
295            quote! {
296                Self::#ident #fields_pat => ::std::pin::Pin::as_mut(#sinkvar).start_send(#push_item)
297            }
298        },
299    );
300
301    let (impl_generics_item, ty_generics, where_clause_item) = generics.split_for_impl();
302    let (impl_generics_sink, _ty_generics_sink, where_clause_sink) =
303        full_generics_sink.split_for_impl();
304
305    let variant_generics_push = variants
306        .iter()
307        .map(|variant| format_ident!("__Push{}", variant.ident))
308        .collect::<Vec<_>>();
309    let variant_generics_pinned_push = variant_generics_push.iter().map(|ident| {
310        quote_spanned! {ident.span()=>
311            ::std::pin::Pin::<&mut #ident>
312        }
313    });
314    let variant_generics_pinned_push_all = quote! {
315        ( #( #variant_generics_pinned_push, )* )
316    };
317    let variant_localvars_push = variants
318        .iter()
319        .map(|variant| {
320            format_ident!(
321                "__push_{}",
322                variant.ident.to_string().to_lowercase(),
323                span = variant.ident.span()
324            )
325        })
326        .collect::<Vec<_>>();
327
328    let mut full_generics_push = generics.clone();
329    full_generics_push.params.extend(
330        variant_generics_push
331            .iter()
332            .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
333    );
334    // Each push just needs Push<Item = VariantOutput, Meta = ()>.
335    full_generics_push.make_where_clause().predicates.extend(
336        variant_generics_push
337            .iter()
338            .zip(variant_output_types.iter())
339            .map::<WherePredicate, _>(|(push_generic, output_type)| {
340                parse_quote! {
341                    #push_generic: #root::dfir_pipes::push::Push<#output_type, ()>
342                }
343            }),
344    );
345
346    // Build the recursive Merged Ctx type:
347    // For 0 pushes: `()
348    // For 1 push: `Push0::Ctx<'__ctx>`
349    // For 2 pushes: `<Push0::Ctx<'__ctx> as Context<'__ctx>>::Merged<Push1::Ctx<'__ctx>>`
350    // For 3 pushes: `<Push0::Ctx<'__ctx> as Context<'__ctx>>::Merged<<Push1::Ctx<'__ctx> as Context<'__ctx>>::Merged<Push2::Ctx<'__ctx>>>`
351    let ctx_type = variant_generics_push
352        .iter()
353        .zip(variant_output_types.iter())
354        .rev()
355        .map(|(push_generic, output_type)| {
356            quote_spanned! {push_generic.span()=>
357                <#push_generic as #root::dfir_pipes::push::Push<#output_type, ()>>::Ctx<'__ctx>
358            }
359        })
360        .reduce(|rest, next| {
361            quote_spanned! {next.span()=>
362                <#next as #root::dfir_pipes::Context<'__ctx>>::Merged<#rest>
363            }
364        })
365        .unwrap_or_else(|| quote!(()));
366
367    let can_pend = variant_generics_push
368        .iter()
369        .zip(variant_output_types.iter())
370        .rev()
371        .map(|(push_generic, output_type)| {
372            quote_spanned! {push_generic.span()=>
373                <#push_generic as #root::dfir_pipes::push::Push<#output_type, ()>>::CanPend
374            }
375        })
376        .reduce(|rest, next| {
377            quote_spanned! {next.span()=>
378                <#next as #root::dfir_pipes::Toggle>::Or<#rest>
379            }
380        })
381        .unwrap_or_else(|| quote!(#root::dfir_pipes::No));
382
383    // Generate `Ctx`: `unmerge_self` for each push, `unmerge_other` to get remaining `__ctx`.
384    // For the last push, just pass `__ctx` directly (no unmerge needed).
385    let push_poll_unwrap_context = |method_name: Ident| {
386        variant_localvars_push.split_last().map(|(lastvar, headvar)| {
387            // `#( ... )*` zips all iterators to shortest; `headvar` (all-but-last) is shortest, so
388            // `variant_generics_push` and `variant_output_types` are naturally truncated to match.
389            quote! {
390                #(
391                    let #headvar = {
392                        let __ctx = <<#variant_generics_push as #root::dfir_pipes::push::Push<#variant_output_types, ()>>::Ctx<'_> as #root::dfir_pipes::Context<'_>>::unmerge_self(__ctx);
393                        #root::dfir_pipes::push::Push::#method_name(::std::pin::Pin::as_mut(#headvar), __ctx)
394                    };
395                    let __ctx = <<#variant_generics_push as #root::dfir_pipes::push::Push<#variant_output_types, ()>>::Ctx<'_> as #root::dfir_pipes::Context<'_>>::unmerge_other(__ctx);
396                )*
397                let #lastvar = #root::dfir_pipes::push::Push::#method_name(::std::pin::Pin::as_mut(#lastvar), __ctx);
398                // If any are pending, return pending.
399                #(
400                    if #variant_localvars_push.is_pending() {
401                        return #root::dfir_pipes::push::PushStep::pending();
402                    }
403                )*
404            }
405        })
406    };
407    let push_poll_ready_body = (push_poll_unwrap_context)(format_ident!("poll_ready"));
408    let push_poll_flush_body = (push_poll_unwrap_context)(format_ident!("poll_flush"));
409
410    let variant_pats_push_send =
411        variants
412            .iter()
413            .zip(variant_localvars_push.iter())
414            .map(|(variant, pushvar)| {
415                let Variant { ident, fields, .. } = variant;
416                let (fields_pat, push_item) = field_pattern_item(fields);
417                quote! {
418                    Self::#ident #fields_pat => { #root::dfir_pipes::push::Push::start_send(#pushvar.as_mut(), #push_item, __meta); }
419                }
420            });
421
422    let (impl_generics_push, _ty_generics_push, where_clause_push) =
423        full_generics_push.split_for_impl();
424
425    let single_impl = (1 == variants.len()).then(|| {
426        let Variant { ident, fields, .. } = variants.first().unwrap();
427        let (fields_pat, push_item) = field_pattern_item(fields);
428        let out_type = variant_output_types.first().unwrap();
429        quote! {
430            impl #impl_generics_item #root::util::demux_enum::SingleVariant
431                for #item_ident #ty_generics #where_clause_item
432            {
433                type Output = #out_type;
434                fn single_variant(self) -> Self::Output {
435                    match self {
436                        Self::#ident #fields_pat => #push_item,
437                    }
438                }
439            }
440        }
441    });
442
443    quote! {
444        impl #impl_generics_sink #root::util::demux_enum::DemuxEnumSink<#variant_generics_pinned_sink_all>
445            for #item_ident #ty_generics #where_clause_sink
446        {
447            type Error = #root::Never;
448
449            fn poll_ready(
450                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
451                __cx: &mut ::std::task::Context<'_>,
452            ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
453                // Ready all sinks simultaneously.
454                #(
455                    let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_ready(__cx)?;
456                )*
457                #(
458                    ::std::task::ready!(#variant_localvars_sink);
459                )*
460                ::std::task::Poll::Ready(::std::result::Result::Ok(()))
461            }
462
463            fn start_send(
464                self,
465                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
466            ) -> ::std::result::Result<(), Self::Error> {
467                match self {
468                    #( #variant_pats_sink_start_send, )*
469                }
470            }
471
472            fn poll_flush(
473                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
474                __cx: &mut ::std::task::Context<'_>,
475            ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
476                // Flush all sinks simultaneously.
477                #(
478                    let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_flush(__cx)?;
479                )*
480                #(
481                    ::std::task::ready!(#variant_localvars_sink);
482                )*
483                ::std::task::Poll::Ready(::std::result::Result::Ok(()))
484            }
485
486            fn poll_close(
487                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
488                __cx: &mut ::std::task::Context<'_>,
489            ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
490                // Close all sinks simultaneously.
491                #(
492                    let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_close(__cx)?;
493                )*
494                #(
495                    ::std::task::ready!(#variant_localvars_sink);
496                )*
497                ::std::task::Poll::Ready(::std::result::Result::Ok(()))
498            }
499        }
500
501        impl #impl_generics_push #root::util::demux_enum::DemuxEnumPush<#variant_generics_pinned_push_all, ()>
502            for #item_ident #ty_generics #where_clause_push
503        {
504            type Ctx<'__ctx> = #ctx_type;
505            type CanPend = #can_pend;
506
507            fn poll_ready(
508                ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
509                __ctx: &mut Self::Ctx<'_>,
510            ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
511                #push_poll_ready_body
512                #root::dfir_pipes::push::PushStep::Done
513            }
514
515            fn start_send(
516                self,
517                __meta: (),
518                ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
519            ) {
520                match self {
521                    #( #variant_pats_push_send, )*
522                }
523            }
524
525            fn poll_flush(
526                ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
527                __ctx: &mut Self::Ctx<'_>,
528            ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
529                #push_poll_flush_body
530                #root::dfir_pipes::push::PushStep::Done
531            }
532
533            fn size_hint(
534                ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
535                __size_hint: (usize, ::std::option::Option<usize>),
536            ) {
537                #(
538                    #root::dfir_pipes::push::Push::size_hint(
539                        ::std::pin::Pin::as_mut(#variant_localvars_push),
540                        __size_hint,
541                    );
542                )*
543            }
544        }
545
546        impl #impl_generics_item #root::util::demux_enum::DemuxEnumBase
547            for #item_ident #ty_generics #where_clause_item {}
548
549        #single_impl
550    }
551    .into()
552}
553
554/// (fields pattern, push item expr)
555fn field_pattern_item(fields: &Fields) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
556    let idents = fields
557        .iter()
558        .enumerate()
559        .map(|(i, field)| {
560            field
561                .ident
562                .clone()
563                .unwrap_or_else(|| format_ident!("_{}", i))
564        })
565        .collect::<Vec<_>>();
566    let (fields_pat, push_item) = match fields {
567        Fields::Named(_) => (quote!( { #( #idents, )* } ), quote!( ( #( #idents, )* ) )),
568        Fields::Unnamed(_) => (quote!( ( #( #idents ),* ) ), quote!( ( #( #idents, )* ) )),
569        Fields::Unit => (quote!(), quote!(())),
570    };
571    (fields_pat, push_item)
572}