1#![warn(missing_docs)]
2
3extern crate proc_macro;
4
5use std::collections::{BTreeMap, BTreeSet, VecDeque};
6use std::fmt::Debug;
7use std::iter::FusedIterator;
8
9use itertools::Itertools;
10use proc_macro2::{Ident, Literal, Span, TokenStream};
11use quote::{ToTokens, TokenStreamExt, format_ident, quote, quote_spanned};
12use serde::{Deserialize, Serialize};
13use slotmap::{Key, SecondaryMap, SlotMap, SparseSecondaryMap};
14use syn::spanned::Spanned;
15
16use super::graph_write::{Dot, GraphWrite, Mermaid};
17use super::ops::{
18 DelayType, OPERATORS, OperatorWriteOutput, WriteContextArgs, find_op_op_constraints,
19 null_write_iterator_fn,
20};
21use super::{
22 CONTEXT, Color, DiMulGraph, GRAPH, GraphEdgeId, GraphLoopId, GraphNode, GraphNodeId,
23 GraphSubgraphId, HANDOFF_NODE_STR, MODULE_BOUNDARY_NODE_STR, OperatorInstance, PortIndexValue,
24 Varname, change_spans, get_operator_generics,
25};
26use crate::diagnostic::{Diagnostic, Diagnostics, Level};
27use crate::pretty_span::{PrettyRowCol, PrettySpan};
28use crate::process_singletons;
29
30#[derive(Default, Debug, Serialize, Deserialize)]
40pub struct DfirGraph {
41 nodes: SlotMap<GraphNodeId, GraphNode>,
43
44 #[serde(skip)]
47 operator_instances: SecondaryMap<GraphNodeId, OperatorInstance>,
48 operator_tag: SecondaryMap<GraphNodeId, String>,
50 graph: DiMulGraph<GraphNodeId, GraphEdgeId>,
52 ports: SecondaryMap<GraphEdgeId, (PortIndexValue, PortIndexValue)>,
54
55 node_loops: SecondaryMap<GraphNodeId, GraphLoopId>,
57 loop_nodes: SlotMap<GraphLoopId, Vec<GraphNodeId>>,
59 loop_parent: SparseSecondaryMap<GraphLoopId, GraphLoopId>,
61 root_loops: Vec<GraphLoopId>,
63 loop_children: SecondaryMap<GraphLoopId, Vec<GraphLoopId>>,
65
66 node_subgraph: SecondaryMap<GraphNodeId, GraphSubgraphId>,
68
69 subgraph_nodes: SlotMap<GraphSubgraphId, Vec<GraphNodeId>>,
71 subgraph_stratum: SecondaryMap<GraphSubgraphId, usize>,
73
74 node_singleton_references: SparseSecondaryMap<GraphNodeId, Vec<Option<GraphNodeId>>>,
76 node_varnames: SparseSecondaryMap<GraphNodeId, Varname>,
78
79 subgraph_laziness: SecondaryMap<GraphSubgraphId, bool>,
83}
84
85impl DfirGraph {
87 pub fn new() -> Self {
89 Default::default()
90 }
91}
92
93impl DfirGraph {
95 pub fn node(&self, node_id: GraphNodeId) -> &GraphNode {
97 self.nodes.get(node_id).expect("Node not found.")
98 }
99
100 pub fn node_op_inst(&self, node_id: GraphNodeId) -> Option<&OperatorInstance> {
105 self.operator_instances.get(node_id)
106 }
107
108 pub fn node_varname(&self, node_id: GraphNodeId) -> Option<&Varname> {
110 self.node_varnames.get(node_id)
111 }
112
113 pub fn node_subgraph(&self, node_id: GraphNodeId) -> Option<GraphSubgraphId> {
115 self.node_subgraph.get(node_id).copied()
116 }
117
118 pub fn node_degree_in(&self, node_id: GraphNodeId) -> usize {
120 self.graph.degree_in(node_id)
121 }
122
123 pub fn node_degree_out(&self, node_id: GraphNodeId) -> usize {
125 self.graph.degree_out(node_id)
126 }
127
128 pub fn node_successors(
130 &self,
131 src: GraphNodeId,
132 ) -> impl '_
133 + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
134 + ExactSizeIterator
135 + FusedIterator
136 + Clone
137 + Debug {
138 self.graph.successors(src)
139 }
140
141 pub fn node_predecessors(
143 &self,
144 dst: GraphNodeId,
145 ) -> impl '_
146 + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
147 + ExactSizeIterator
148 + FusedIterator
149 + Clone
150 + Debug {
151 self.graph.predecessors(dst)
152 }
153
154 pub fn node_successor_edges(
156 &self,
157 src: GraphNodeId,
158 ) -> impl '_
159 + DoubleEndedIterator<Item = GraphEdgeId>
160 + ExactSizeIterator
161 + FusedIterator
162 + Clone
163 + Debug {
164 self.graph.successor_edges(src)
165 }
166
167 pub fn node_predecessor_edges(
169 &self,
170 dst: GraphNodeId,
171 ) -> impl '_
172 + DoubleEndedIterator<Item = GraphEdgeId>
173 + ExactSizeIterator
174 + FusedIterator
175 + Clone
176 + Debug {
177 self.graph.predecessor_edges(dst)
178 }
179
180 pub fn node_successor_nodes(
182 &self,
183 src: GraphNodeId,
184 ) -> impl '_
185 + DoubleEndedIterator<Item = GraphNodeId>
186 + ExactSizeIterator
187 + FusedIterator
188 + Clone
189 + Debug {
190 self.graph.successor_vertices(src)
191 }
192
193 pub fn node_predecessor_nodes(
195 &self,
196 dst: GraphNodeId,
197 ) -> impl '_
198 + DoubleEndedIterator<Item = GraphNodeId>
199 + ExactSizeIterator
200 + FusedIterator
201 + Clone
202 + Debug {
203 self.graph.predecessor_vertices(dst)
204 }
205
206 pub fn node_ids(&self) -> slotmap::basic::Keys<'_, GraphNodeId, GraphNode> {
208 self.nodes.keys()
209 }
210
211 pub fn nodes(&self) -> slotmap::basic::Iter<'_, GraphNodeId, GraphNode> {
213 self.nodes.iter()
214 }
215
216 pub fn insert_node(
218 &mut self,
219 node: GraphNode,
220 varname_opt: Option<Ident>,
221 loop_opt: Option<GraphLoopId>,
222 ) -> GraphNodeId {
223 let node_id = self.nodes.insert(node);
224 if let Some(varname) = varname_opt {
225 self.node_varnames.insert(node_id, Varname(varname));
226 }
227 if let Some(loop_id) = loop_opt {
228 self.node_loops.insert(node_id, loop_id);
229 self.loop_nodes[loop_id].push(node_id);
230 }
231 node_id
232 }
233
234 pub fn insert_node_op_inst(&mut self, node_id: GraphNodeId, op_inst: OperatorInstance) {
236 assert!(matches!(
237 self.nodes.get(node_id),
238 Some(GraphNode::Operator(_))
239 ));
240 let old_inst = self.operator_instances.insert(node_id, op_inst);
241 assert!(old_inst.is_none());
242 }
243
244 pub fn insert_node_op_insts_all(&mut self, diagnostics: &mut Diagnostics) {
246 let mut op_insts = Vec::new();
247 for (node_id, node) in self.nodes() {
248 let GraphNode::Operator(operator) = node else {
249 continue;
250 };
251 if self.node_op_inst(node_id).is_some() {
252 continue;
253 };
254
255 let Some(op_constraints) = find_op_op_constraints(operator) else {
257 diagnostics.push(Diagnostic::spanned(
258 operator.path.span(),
259 Level::Error,
260 format!("Unknown operator `{}`", operator.name_string()),
261 ));
262 continue;
263 };
264
265 let (input_ports, output_ports) = {
267 let mut input_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
268 .node_predecessors(node_id)
269 .map(|(edge_id, pred_id)| (self.edge_ports(edge_id).1, pred_id))
270 .collect();
271 input_edges.sort();
273 let input_ports: Vec<PortIndexValue> = input_edges
274 .into_iter()
275 .map(|(port, _pred)| port)
276 .cloned()
277 .collect();
278
279 let mut output_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
281 .node_successors(node_id)
282 .map(|(edge_id, succ)| (self.edge_ports(edge_id).0, succ))
283 .collect();
284 output_edges.sort();
286 let output_ports: Vec<PortIndexValue> = output_edges
287 .into_iter()
288 .map(|(port, _succ)| port)
289 .cloned()
290 .collect();
291
292 (input_ports, output_ports)
293 };
294
295 let generics = get_operator_generics(diagnostics, operator);
297 {
299 let generics_span = generics
301 .generic_args
302 .as_ref()
303 .map(Spanned::span)
304 .unwrap_or_else(|| operator.path.span());
305
306 if !op_constraints
307 .persistence_args
308 .contains(&generics.persistence_args.len())
309 {
310 diagnostics.push(Diagnostic::spanned(
311 generics.persistence_args_span().unwrap_or(generics_span),
312 Level::Error,
313 format!(
314 "`{}` should have {} persistence lifetime arguments, actually has {}.",
315 op_constraints.name,
316 op_constraints.persistence_args.human_string(),
317 generics.persistence_args.len()
318 ),
319 ));
320 }
321 if !op_constraints.type_args.contains(&generics.type_args.len()) {
322 diagnostics.push(Diagnostic::spanned(
323 generics.type_args_span().unwrap_or(generics_span),
324 Level::Error,
325 format!(
326 "`{}` should have {} generic type arguments, actually has {}.",
327 op_constraints.name,
328 op_constraints.type_args.human_string(),
329 generics.type_args.len()
330 ),
331 ));
332 }
333 }
334
335 op_insts.push((
336 node_id,
337 OperatorInstance {
338 op_constraints,
339 input_ports,
340 output_ports,
341 singletons_referenced: operator.singletons_referenced.clone(),
342 generics,
343 arguments_pre: operator.args.clone(),
344 arguments_raw: operator.args_raw.clone(),
345 },
346 ));
347 }
348
349 for (node_id, op_inst) in op_insts {
350 self.insert_node_op_inst(node_id, op_inst);
351 }
352 }
353
354 pub fn insert_intermediate_node(
366 &mut self,
367 edge_id: GraphEdgeId,
368 new_node: GraphNode,
369 ) -> (GraphNodeId, GraphEdgeId) {
370 let span = Some(new_node.span());
371
372 let op_inst_opt = 'oc: {
374 let GraphNode::Operator(operator) = &new_node else {
375 break 'oc None;
376 };
377 let Some(op_constraints) = find_op_op_constraints(operator) else {
378 break 'oc None;
379 };
380 let (input_port, output_port) = self.ports.get(edge_id).cloned().unwrap();
381
382 let mut dummy_diagnostics = Diagnostics::new();
383 let generics = get_operator_generics(&mut dummy_diagnostics, operator);
384 assert!(dummy_diagnostics.is_empty());
385
386 Some(OperatorInstance {
387 op_constraints,
388 input_ports: vec![input_port],
389 output_ports: vec![output_port],
390 singletons_referenced: operator.singletons_referenced.clone(),
391 generics,
392 arguments_pre: operator.args.clone(),
393 arguments_raw: operator.args_raw.clone(),
394 })
395 };
396
397 let node_id = self.nodes.insert(new_node);
399 if let Some(op_inst) = op_inst_opt {
401 self.operator_instances.insert(node_id, op_inst);
402 }
403 let (e0, e1) = self
405 .graph
406 .insert_intermediate_vertex(node_id, edge_id)
407 .unwrap();
408
409 let (src_idx, dst_idx) = self.ports.remove(edge_id).unwrap();
411 self.ports
412 .insert(e0, (src_idx, PortIndexValue::Elided(span)));
413 self.ports
414 .insert(e1, (PortIndexValue::Elided(span), dst_idx));
415
416 (node_id, e1)
417 }
418
419 pub fn remove_intermediate_node(&mut self, node_id: GraphNodeId) {
422 assert_eq!(
423 1,
424 self.node_degree_in(node_id),
425 "Removed intermediate node must have one predecessor"
426 );
427 assert_eq!(
428 1,
429 self.node_degree_out(node_id),
430 "Removed intermediate node must have one successor"
431 );
432 assert!(
433 self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
434 "Should not remove intermediate node after subgraph partitioning"
435 );
436
437 assert!(self.nodes.remove(node_id).is_some());
438 let (new_edge_id, (pred_edge_id, succ_edge_id)) =
439 self.graph.remove_intermediate_vertex(node_id).unwrap();
440 self.operator_instances.remove(node_id);
441 self.node_varnames.remove(node_id);
442
443 let (src_port, _) = self.ports.remove(pred_edge_id).unwrap();
444 let (_, dst_port) = self.ports.remove(succ_edge_id).unwrap();
445 self.ports.insert(new_edge_id, (src_port, dst_port));
446 }
447
448 pub(crate) fn node_color(&self, node_id: GraphNodeId) -> Option<Color> {
454 if matches!(self.node(node_id), GraphNode::Handoff { .. }) {
455 return Some(Color::Hoff);
456 }
457
458 if let GraphNode::Operator(op) = self.node(node_id)
460 && (op.name_string() == "resolve_futures_blocking"
461 || op.name_string() == "resolve_futures_blocking_ordered")
462 {
463 return Some(Color::Push);
464 }
465
466 let inn_degree = self.node_predecessor_nodes(node_id).count();
468 let out_degree = self.node_successor_nodes(node_id).count();
470
471 match (inn_degree, out_degree) {
472 (0, 0) => None, (0, 1) => Some(Color::Pull),
474 (1, 0) => Some(Color::Push),
475 (1, 1) => None, (_many, 0 | 1) => Some(Color::Pull),
477 (0 | 1, _many) => Some(Color::Push),
478 (_many, _to_many) => Some(Color::Comp),
479 }
480 }
481
482 pub fn set_operator_tag(&mut self, node_id: GraphNodeId, tag: String) {
484 self.operator_tag.insert(node_id, tag);
485 }
486}
487
488impl DfirGraph {
490 pub fn set_node_singleton_references(
493 &mut self,
494 node_id: GraphNodeId,
495 singletons_referenced: Vec<Option<GraphNodeId>>,
496 ) -> Option<Vec<Option<GraphNodeId>>> {
497 self.node_singleton_references
498 .insert(node_id, singletons_referenced)
499 }
500
501 pub fn node_singleton_references(&self, node_id: GraphNodeId) -> &[Option<GraphNodeId>] {
504 self.node_singleton_references
505 .get(node_id)
506 .map(std::ops::Deref::deref)
507 .unwrap_or_default()
508 }
509}
510
511impl DfirGraph {
513 pub fn merge_modules(&mut self) -> Result<(), Diagnostic> {
521 let mod_bound_nodes = self
522 .nodes()
523 .filter(|(_nid, node)| matches!(node, GraphNode::ModuleBoundary { .. }))
524 .map(|(nid, _node)| nid)
525 .collect::<Vec<_>>();
526
527 for mod_bound_node in mod_bound_nodes {
528 self.remove_module_boundary(mod_bound_node)?;
529 }
530
531 Ok(())
532 }
533
534 fn remove_module_boundary(&mut self, mod_bound_node: GraphNodeId) -> Result<(), Diagnostic> {
538 assert!(
539 self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
540 "Should not remove intermediate node after subgraph partitioning"
541 );
542
543 let mut mod_pred_ports = BTreeMap::new();
544 let mut mod_succ_ports = BTreeMap::new();
545
546 for mod_out_edge in self.node_predecessor_edges(mod_bound_node) {
547 let (pred_port, succ_port) = self.edge_ports(mod_out_edge);
548 mod_pred_ports.insert(succ_port.clone(), (mod_out_edge, pred_port.clone()));
549 }
550
551 for mod_inn_edge in self.node_successor_edges(mod_bound_node) {
552 let (pred_port, succ_port) = self.edge_ports(mod_inn_edge);
553 mod_succ_ports.insert(pred_port.clone(), (mod_inn_edge, succ_port.clone()));
554 }
555
556 if mod_pred_ports.keys().collect::<BTreeSet<_>>()
557 != mod_succ_ports.keys().collect::<BTreeSet<_>>()
558 {
559 let GraphNode::ModuleBoundary { input, import_expr } = self.node(mod_bound_node) else {
561 panic!();
562 };
563
564 if *input {
565 return Err(Diagnostic {
566 span: *import_expr,
567 level: Level::Error,
568 message: format!(
569 "The ports into the module did not match. input: {:?}, expected: {:?}",
570 mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
571 mod_succ_ports.keys().map(|x| x.to_string()).join(", ")
572 ),
573 });
574 } else {
575 return Err(Diagnostic {
576 span: *import_expr,
577 level: Level::Error,
578 message: format!(
579 "The ports out of the module did not match. output: {:?}, expected: {:?}",
580 mod_succ_ports.keys().map(|x| x.to_string()).join(", "),
581 mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
582 ),
583 });
584 }
585 }
586
587 for (port, (pred_edge, pred_port)) in mod_pred_ports {
588 let (succ_edge, succ_port) = mod_succ_ports.remove(&port).unwrap();
589
590 let (src, _) = self.edge(pred_edge);
591 let (_, dst) = self.edge(succ_edge);
592 self.remove_edge(pred_edge);
593 self.remove_edge(succ_edge);
594
595 let new_edge_id = self.graph.insert_edge(src, dst);
596 self.ports.insert(new_edge_id, (pred_port, succ_port));
597 }
598
599 self.graph.remove_vertex(mod_bound_node);
600 self.nodes.remove(mod_bound_node);
601
602 Ok(())
603 }
604}
605
606impl DfirGraph {
608 pub fn edge(&self, edge_id: GraphEdgeId) -> (GraphNodeId, GraphNodeId) {
610 let (src, dst) = self.graph.edge(edge_id).expect("Edge not found.");
611 (src, dst)
612 }
613
614 pub fn edge_ports(&self, edge_id: GraphEdgeId) -> (&PortIndexValue, &PortIndexValue) {
616 let (src_port, dst_port) = self.ports.get(edge_id).expect("Edge not found.");
617 (src_port, dst_port)
618 }
619
620 pub fn edge_ids(&self) -> slotmap::basic::Keys<'_, GraphEdgeId, (GraphNodeId, GraphNodeId)> {
622 self.graph.edge_ids()
623 }
624
625 pub fn edges(
627 &self,
628 ) -> impl '_
629 + ExactSizeIterator<Item = (GraphEdgeId, (GraphNodeId, GraphNodeId))>
630 + FusedIterator
631 + Clone
632 + Debug {
633 self.graph.edges()
634 }
635
636 pub fn insert_edge(
638 &mut self,
639 src: GraphNodeId,
640 src_port: PortIndexValue,
641 dst: GraphNodeId,
642 dst_port: PortIndexValue,
643 ) -> GraphEdgeId {
644 let edge_id = self.graph.insert_edge(src, dst);
645 self.ports.insert(edge_id, (src_port, dst_port));
646 edge_id
647 }
648
649 pub fn remove_edge(&mut self, edge: GraphEdgeId) {
651 let (_src, _dst) = self.graph.remove_edge(edge).unwrap();
652 let (_src_port, _dst_port) = self.ports.remove(edge).unwrap();
653 }
654}
655
656impl DfirGraph {
658 pub fn subgraph(&self, subgraph_id: GraphSubgraphId) -> &Vec<GraphNodeId> {
660 self.subgraph_nodes
661 .get(subgraph_id)
662 .expect("Subgraph not found.")
663 }
664
665 pub fn subgraph_ids(&self) -> slotmap::basic::Keys<'_, GraphSubgraphId, Vec<GraphNodeId>> {
667 self.subgraph_nodes.keys()
668 }
669
670 pub fn subgraphs(&self) -> slotmap::basic::Iter<'_, GraphSubgraphId, Vec<GraphNodeId>> {
672 self.subgraph_nodes.iter()
673 }
674
675 pub fn insert_subgraph(
677 &mut self,
678 node_ids: Vec<GraphNodeId>,
679 ) -> Result<GraphSubgraphId, (GraphNodeId, GraphSubgraphId)> {
680 for &node_id in node_ids.iter() {
682 if let Some(&old_sg_id) = self.node_subgraph.get(node_id) {
683 return Err((node_id, old_sg_id));
684 }
685 }
686 let subgraph_id = self.subgraph_nodes.insert_with_key(|sg_id| {
687 for &node_id in node_ids.iter() {
688 self.node_subgraph.insert(node_id, sg_id);
689 }
690 node_ids
691 });
692
693 Ok(subgraph_id)
694 }
695
696 pub fn remove_from_subgraph(&mut self, node_id: GraphNodeId) -> bool {
698 if let Some(old_sg_id) = self.node_subgraph.remove(node_id) {
699 self.subgraph_nodes[old_sg_id].retain(|&other_node_id| other_node_id != node_id);
700 true
701 } else {
702 false
703 }
704 }
705
706 pub fn subgraph_stratum(&self, sg_id: GraphSubgraphId) -> Option<usize> {
708 self.subgraph_stratum.get(sg_id).copied()
709 }
710
711 pub fn set_subgraph_stratum(
713 &mut self,
714 sg_id: GraphSubgraphId,
715 stratum: usize,
716 ) -> Option<usize> {
717 self.subgraph_stratum.insert(sg_id, stratum)
718 }
719
720 fn subgraph_laziness(&self, sg_id: GraphSubgraphId) -> bool {
722 self.subgraph_laziness.get(sg_id).copied().unwrap_or(false)
723 }
724
725 pub fn set_subgraph_laziness(&mut self, sg_id: GraphSubgraphId, lazy: bool) -> bool {
727 self.subgraph_laziness.insert(sg_id, lazy).unwrap_or(false)
728 }
729
730 pub fn max_stratum(&self) -> Option<usize> {
732 self.subgraph_stratum.values().copied().max()
733 }
734
735 fn find_pull_to_push_idx(&self, subgraph_nodes: &[GraphNodeId]) -> usize {
737 subgraph_nodes
738 .iter()
739 .position(|&node_id| {
740 self.node_color(node_id)
741 .is_some_and(|color| Color::Pull != color)
742 })
743 .unwrap_or(subgraph_nodes.len())
744 }
745}
746
747impl DfirGraph {
749 fn node_as_ident(&self, node_id: GraphNodeId, is_pred: bool) -> Ident {
751 let name = match &self.nodes[node_id] {
752 GraphNode::Operator(_) => format!("op_{:?}", node_id.data()),
753 GraphNode::Handoff { .. } => format!(
754 "hoff_{:?}_{}",
755 node_id.data(),
756 if is_pred { "recv" } else { "send" }
757 ),
758 GraphNode::ModuleBoundary { .. } => panic!(),
759 };
760 let span = match (is_pred, &self.nodes[node_id]) {
761 (_, GraphNode::Operator(operator)) => operator.span(),
762 (true, &GraphNode::Handoff { src_span, .. }) => src_span,
763 (false, &GraphNode::Handoff { dst_span, .. }) => dst_span,
764 (_, GraphNode::ModuleBoundary { .. }) => panic!(),
765 };
766 Ident::new(&name, span)
767 }
768
769 fn node_as_singleton_ident(&self, node_id: GraphNodeId, span: Span) -> Ident {
771 Ident::new(&format!("singleton_op_{:?}", node_id.data()), span)
772 }
773
774 fn helper_resolve_singletons(&self, node_id: GraphNodeId, span: Span) -> Vec<Ident> {
776 self.node_singleton_references(node_id)
777 .iter()
778 .map(|singleton_node_id| {
779 self.node_as_singleton_ident(
781 singleton_node_id
782 .expect("Expected singleton to be resolved but was not, this is a bug."),
783 span,
784 )
785 })
786 .collect::<Vec<_>>()
787 }
788
789 fn helper_collect_subgraph_handoffs(
792 &self,
793 ) -> SecondaryMap<GraphSubgraphId, (Vec<GraphNodeId>, Vec<GraphNodeId>)> {
794 let mut subgraph_handoffs: SecondaryMap<
796 GraphSubgraphId,
797 (Vec<GraphNodeId>, Vec<GraphNodeId>),
798 > = self
799 .subgraph_nodes
800 .keys()
801 .map(|k| (k, Default::default()))
802 .collect();
803
804 for (hoff_id, node) in self.nodes() {
806 if !matches!(node, GraphNode::Handoff { .. }) {
807 continue;
808 }
809 for (_edge, succ_id) in self.node_successors(hoff_id) {
811 let succ_sg = self.node_subgraph(succ_id).unwrap();
812 subgraph_handoffs[succ_sg].0.push(hoff_id);
813 }
814 for (_edge, pred_id) in self.node_predecessors(hoff_id) {
816 let pred_sg = self.node_subgraph(pred_id).unwrap();
817 subgraph_handoffs[pred_sg].1.push(hoff_id);
818 }
819 }
820
821 subgraph_handoffs
822 }
823
824 fn codegen_nested_loops(&self, df: &Ident) -> TokenStream {
826 let mut out = TokenStream::new();
828 let mut queue = VecDeque::from_iter(self.root_loops.iter().copied());
829 while let Some(loop_id) = queue.pop_front() {
830 let parent_opt = self
831 .loop_parent(loop_id)
832 .map(|loop_id| loop_id.as_ident(Span::call_site()))
833 .map(|ident| quote! { Some(#ident) })
834 .unwrap_or_else(|| quote! { None });
835 let loop_name = loop_id.as_ident(Span::call_site());
836 out.append_all(quote! {
837 let #loop_name = #df.add_loop(#parent_opt);
838 });
839 queue.extend(self.loop_children.get(loop_id).into_iter().flatten());
840 }
841 out
842 }
843
844 pub fn as_code(
848 &self,
849 root: &TokenStream,
850 include_type_guards: bool,
851 prefix: TokenStream,
852 diagnostics: &mut Diagnostics,
853 ) -> Result<TokenStream, Diagnostics> {
854 let df = Ident::new(GRAPH, Span::call_site());
855 let context = Ident::new(CONTEXT, Span::call_site());
856
857 let handoff_code = self
859 .nodes
860 .iter()
861 .filter_map(|(node_id, node)| match node {
862 GraphNode::Operator(_) => None,
863 &GraphNode::Handoff { src_span, dst_span } => Some((node_id, (src_span, dst_span))),
864 GraphNode::ModuleBoundary { .. } => panic!(),
865 })
866 .map(|(node_id, (src_span, dst_span))| {
867 let ident_send = Ident::new(&format!("hoff_{:?}_send", node_id.data()), dst_span);
868 let ident_recv = Ident::new(&format!("hoff_{:?}_recv", node_id.data()), src_span);
869 let span = src_span.join(dst_span).unwrap_or(src_span);
870 let mut hoff_name = Literal::string(&format!("handoff {:?}", node_id));
871 hoff_name.set_span(span);
872 let hoff_type = quote_spanned! (span=> #root::scheduled::handoff::VecHandoff<_>);
873 quote_spanned! {span=>
874 let (#ident_send, #ident_recv) =
875 #df.make_edge::<_, #hoff_type>(#hoff_name);
876 }
877 });
878
879 let subgraph_handoffs = self.helper_collect_subgraph_handoffs();
880
881 let (subgraphs_without_preds, subgraphs_with_preds) = self
883 .subgraph_nodes
884 .iter()
885 .partition::<Vec<_>, _>(|(_, nodes)| {
886 nodes
887 .iter()
888 .any(|&node_id| self.node_degree_in(node_id) == 0)
889 });
890
891 let mut op_prologue_code = Vec::new();
892 let mut op_prologue_after_code = Vec::new();
893 let mut subgraphs = Vec::new();
894 {
895 for &(subgraph_id, subgraph_nodes) in subgraphs_without_preds
896 .iter()
897 .chain(subgraphs_with_preds.iter())
898 {
899 let (recv_hoffs, send_hoffs) = &subgraph_handoffs[subgraph_id];
900 let recv_ports: Vec<Ident> = recv_hoffs
901 .iter()
902 .map(|&hoff_id| self.node_as_ident(hoff_id, true))
903 .collect();
904 let send_ports: Vec<Ident> = send_hoffs
905 .iter()
906 .map(|&hoff_id| self.node_as_ident(hoff_id, false))
907 .collect();
908
909 let recv_port_code = recv_ports.iter().map(|ident| {
910 quote_spanned! {ident.span()=>
911 let mut #ident = #ident.borrow_mut_swap();
912 let #ident = #root::dfir_pipes::pull::iter(#ident.drain(..));
913 }
914 });
915 let send_port_code = send_ports.iter().map(|ident| {
916 quote_spanned! {ident.span()=>
917 let mut #ident = #ident.borrow_mut_give();
918 let #ident = #root::dfir_pipes::push::vec_push(&mut *#ident);
919 }
920 });
921
922 let loop_id = self
923 .node_loop(subgraph_nodes[0]);
925
926 let mut subgraph_op_iter_code = Vec::new();
927 let mut subgraph_op_iter_after_code = Vec::new();
928 {
929 let pull_to_push_idx = self.find_pull_to_push_idx(subgraph_nodes);
930
931 let (pull_half, push_half) = subgraph_nodes.split_at(pull_to_push_idx);
932 let nodes_iter = pull_half.iter().chain(push_half.iter().rev());
933
934 for (idx, &node_id) in nodes_iter.enumerate() {
935 let node = &self.nodes[node_id];
936 assert!(
937 matches!(node, GraphNode::Operator(_)),
938 "Handoffs are not part of subgraphs."
939 );
940 let op_inst = &self.operator_instances[node_id];
941
942 let op_span = node.span();
943 let op_name = op_inst.op_constraints.name;
944 let root = change_spans(root.clone(), op_span);
946 let op_constraints = OPERATORS
948 .iter()
949 .find(|op| op_name == op.name)
950 .unwrap_or_else(|| panic!("Failed to find op: {}", op_name));
951
952 let ident = self.node_as_ident(node_id, false);
953
954 {
955 let mut input_edges = self
958 .graph
959 .predecessor_edges(node_id)
960 .map(|edge_id| (self.edge_ports(edge_id).1, edge_id))
961 .collect::<Vec<_>>();
962 input_edges.sort();
964
965 let inputs = input_edges
966 .iter()
967 .map(|&(_port, edge_id)| {
968 let (pred, _) = self.edge(edge_id);
969 self.node_as_ident(pred, true)
970 })
971 .collect::<Vec<_>>();
972
973 let mut output_edges = self
975 .graph
976 .successor_edges(node_id)
977 .map(|edge_id| (&self.ports[edge_id].0, edge_id))
978 .collect::<Vec<_>>();
979 output_edges.sort();
981
982 let outputs = output_edges
983 .iter()
984 .map(|&(_port, edge_id)| {
985 let (_, succ) = self.edge(edge_id);
986 self.node_as_ident(succ, false)
987 })
988 .collect::<Vec<_>>();
989
990 let is_pull = idx < pull_to_push_idx;
991
992 let singleton_output_ident = &if op_constraints.has_singleton_output {
993 self.node_as_singleton_ident(node_id, op_span)
994 } else {
995 Ident::new(&format!("{}_has_no_singleton_output", op_name), op_span)
997 };
998
999 let df_local = &Ident::new(GRAPH, op_span.resolved_at(df.span()));
1008 let context = &Ident::new(CONTEXT, op_span.resolved_at(context.span()));
1009
1010 let singletons_resolved =
1011 self.helper_resolve_singletons(node_id, op_span);
1012 let arguments = &process_singletons::postprocess_singletons(
1013 op_inst.arguments_raw.clone(),
1014 singletons_resolved.clone(),
1015 context,
1016 );
1017 let arguments_handles =
1018 &process_singletons::postprocess_singletons_handles(
1019 op_inst.arguments_raw.clone(),
1020 singletons_resolved.clone(),
1021 );
1022
1023 let source_tag = 'a: {
1024 if let Some(tag) = self.operator_tag.get(node_id).cloned() {
1025 break 'a tag;
1026 }
1027
1028 #[cfg(nightly)]
1029 if proc_macro::is_available() {
1030 let op_span = op_span.unwrap();
1031 break 'a format!(
1032 "loc_{}_{}_{}_{}_{}",
1033 crate::pretty_span::make_source_path_relative(
1034 &op_span.file()
1035 )
1036 .display()
1037 .to_string()
1038 .replace(|x: char| !x.is_ascii_alphanumeric(), "_"),
1039 op_span.start().line(),
1040 op_span.start().column(),
1041 op_span.end().line(),
1042 op_span.end().column(),
1043 );
1044 }
1045
1046 format!(
1047 "loc_nopath_{}_{}_{}_{}",
1048 op_span.start().line,
1049 op_span.start().column,
1050 op_span.end().line,
1051 op_span.end().column
1052 )
1053 };
1054
1055 let work_fn = format_ident!(
1056 "{}__{}__{}",
1057 ident,
1058 op_name,
1059 source_tag,
1060 span = op_span
1061 );
1062 let work_fn_async = format_ident!("{}__async", work_fn, span = op_span);
1063
1064 let context_args = WriteContextArgs {
1065 root: &root,
1066 df_ident: df_local,
1067 context,
1068 subgraph_id,
1069 node_id,
1070 loop_id,
1071 op_span,
1072 op_tag: self.operator_tag.get(node_id).cloned(),
1073 work_fn: &work_fn,
1074 work_fn_async: &work_fn_async,
1075 ident: &ident,
1076 is_pull,
1077 inputs: &inputs,
1078 outputs: &outputs,
1079 singleton_output_ident,
1080 op_name,
1081 op_inst,
1082 arguments,
1083 arguments_handles,
1084 };
1085
1086 let write_result =
1087 (op_constraints.write_fn)(&context_args, diagnostics);
1088 let OperatorWriteOutput {
1089 write_prologue,
1090 write_prologue_after,
1091 write_iterator,
1092 write_iterator_after,
1093 } = write_result.unwrap_or_else(|()| {
1094 assert!(
1095 diagnostics.has_error(),
1096 "Operator `{}` returned `Err` but emitted no diagnostics, this is a bug.",
1097 op_name,
1098 );
1099 OperatorWriteOutput { write_iterator: null_write_iterator_fn(&context_args), ..Default::default() }
1100 });
1101
1102 op_prologue_code.push(syn::parse_quote! {
1103 #[allow(non_snake_case)]
1104 #[inline(always)]
1105 fn #work_fn<T>(thunk: impl ::std::ops::FnOnce() -> T) -> T {
1106 thunk()
1107 }
1108
1109 #[allow(non_snake_case)]
1110 #[inline(always)]
1111 async fn #work_fn_async<T>(thunk: impl ::std::future::Future<Output = T>) -> T {
1112 thunk.await
1113 }
1114 });
1115 op_prologue_code.push(write_prologue);
1116 op_prologue_after_code.push(write_prologue_after);
1117 subgraph_op_iter_code.push(write_iterator);
1118
1119 if include_type_guards {
1120 let type_guard = if is_pull {
1121 quote_spanned! {op_span=>
1122 let #ident = {
1123 #[allow(non_snake_case)]
1124 #[inline(always)]
1125 pub fn #work_fn<Item, Input>(input: Input)
1126 -> impl #root::dfir_pipes::pull::Pull<Item = Item, Meta = (), CanPend = Input::CanPend, CanEnd = Input::CanEnd>
1127 where
1128 Input: #root::dfir_pipes::pull::Pull<Item = Item, Meta = ()>,
1129 {
1130 #root::pin_project_lite::pin_project! {
1131 #[repr(transparent)]
1132 struct Pull<Item, Input: #root::dfir_pipes::pull::Pull<Item = Item>> {
1133 #[pin]
1134 inner: Input
1135 }
1136 }
1137
1138 impl<Item, Input> #root::dfir_pipes::pull::Pull for Pull<Item, Input>
1139 where
1140 Input: #root::dfir_pipes::pull::Pull<Item = Item>,
1141 {
1142 type Ctx<'ctx> = Input::Ctx<'ctx>;
1143
1144 type Item = Item;
1145 type Meta = Input::Meta;
1146 type CanPend = Input::CanPend;
1147 type CanEnd = Input::CanEnd;
1148
1149 #[inline(always)]
1150 fn pull(
1151 self: ::std::pin::Pin<&mut Self>,
1152 ctx: &mut Self::Ctx<'_>,
1153 ) -> #root::dfir_pipes::pull::PullStep<Self::Item, Self::Meta, Self::CanPend, Self::CanEnd> {
1154 #root::dfir_pipes::pull::Pull::pull(self.project().inner, ctx)
1155 }
1156
1157 #[inline(always)]
1158 fn size_hint(&self) -> (usize, Option<usize>) {
1159 #root::dfir_pipes::pull::Pull::size_hint(&self.inner)
1160 }
1161 }
1162
1163 Pull {
1164 inner: input
1165 }
1166 }
1167 #work_fn::<_, _>( #ident )
1168 };
1169 }
1170 } else {
1171 quote_spanned! {op_span=>
1172 let #ident = {
1173 #[allow(non_snake_case)]
1174 #[inline(always)]
1175 pub fn #work_fn<Item, Psh>(psh: Psh) -> impl #root::dfir_pipes::push::Push<Item, (), CanPend = Psh::CanPend>
1176 where
1177 Psh: #root::dfir_pipes::push::Push<Item, ()>
1178 {
1179 #root::pin_project_lite::pin_project! {
1180 #[repr(transparent)]
1181 struct PushGuard<Psh> {
1182 #[pin]
1183 inner: Psh,
1184 }
1185 }
1186
1187 impl<Item, Psh> #root::dfir_pipes::push::Push<Item, ()> for PushGuard<Psh>
1188 where
1189 Psh: #root::dfir_pipes::push::Push<Item, ()>,
1190 {
1191 type Ctx<'ctx> = Psh::Ctx<'ctx>;
1192
1193 type CanPend = Psh::CanPend;
1194
1195 #[inline(always)]
1196 fn poll_ready(
1197 self: ::std::pin::Pin<&mut Self>,
1198 ctx: &mut Self::Ctx<'_>,
1199 ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
1200 #root::dfir_pipes::push::Push::poll_ready(self.project().inner, ctx)
1201 }
1202
1203 #[inline(always)]
1204 fn start_send(
1205 self: ::std::pin::Pin<&mut Self>,
1206 item: Item,
1207 meta: (),
1208 ) {
1209 #root::dfir_pipes::push::Push::start_send(self.project().inner, item, meta)
1210 }
1211
1212 #[inline(always)]
1213 fn poll_flush(
1214 self: ::std::pin::Pin<&mut Self>,
1215 ctx: &mut Self::Ctx<'_>,
1216 ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
1217 #root::dfir_pipes::push::Push::poll_flush(self.project().inner, ctx)
1218 }
1219
1220 #[inline(always)]
1221 fn size_hint(
1222 self: ::std::pin::Pin<&mut Self>,
1223 hint: (usize, Option<usize>),
1224 ) {
1225 #root::dfir_pipes::push::Push::size_hint(self.project().inner, hint)
1226 }
1227 }
1228
1229 PushGuard {
1230 inner: psh
1231 }
1232 }
1233 #work_fn( #ident )
1234 };
1235 }
1236 };
1237 subgraph_op_iter_code.push(type_guard);
1238 }
1239 subgraph_op_iter_after_code.push(write_iterator_after);
1240 }
1241 }
1242
1243 {
1244 let pull_ident = if 0 < pull_to_push_idx {
1246 self.node_as_ident(subgraph_nodes[pull_to_push_idx - 1], false)
1247 } else {
1248 recv_ports[0].clone()
1250 };
1251
1252 #[rustfmt::skip]
1253 let push_ident = if let Some(&node_id) =
1254 subgraph_nodes.get(pull_to_push_idx)
1255 {
1256 self.node_as_ident(node_id, false)
1257 } else if 1 == send_ports.len() {
1258 send_ports[0].clone()
1260 } else {
1261 diagnostics.push(Diagnostic::spanned(
1262 pull_ident.span(),
1263 Level::Error,
1264 "Degenerate subgraph detected, is there a disconnected `null()` or other degenerate pipeline somewhere?",
1265 ));
1266 continue;
1267 };
1268
1269 let pivot_span = pull_ident
1271 .span()
1272 .join(push_ident.span())
1273 .unwrap_or_else(|| push_ident.span());
1274 let pivot_fn_ident =
1275 Ident::new(&format!("pivot_run_sg_{:?}", subgraph_id.0), pivot_span);
1276 let root = change_spans(root.clone(), pivot_span);
1277 subgraph_op_iter_code.push(quote_spanned! {pivot_span=>
1278 #[inline(always)]
1279 fn #pivot_fn_ident<Pul, Psh, Item>(pull: Pul, push: Psh)
1280 -> impl ::std::future::Future<Output = ()>
1281 where
1282 Pul: #root::dfir_pipes::pull::Pull<Item = Item>,
1283 Psh: #root::dfir_pipes::push::Push<Item, Pul::Meta>,
1284 {
1285 #root::dfir_pipes::pull::Pull::send_push(pull, push)
1286 }
1287 (#pivot_fn_ident)(#pull_ident, #push_ident).await;
1288 });
1289 }
1290 };
1291
1292 let subgraph_name = Literal::string(&format!("Subgraph {:?}", subgraph_id));
1293 let stratum = Literal::usize_unsuffixed(
1294 self.subgraph_stratum.get(subgraph_id).cloned().unwrap_or(0),
1295 );
1296 let laziness = self.subgraph_laziness(subgraph_id);
1297
1298 let loop_id_opt = loop_id
1300 .map(|loop_id| loop_id.as_ident(Span::call_site()))
1301 .map(|ident| quote! { Some(#ident) })
1302 .unwrap_or_else(|| quote! { None });
1303
1304 let sg_ident = subgraph_id.as_ident(Span::call_site());
1305
1306 subgraphs.push(quote! {
1307 let #sg_ident = #df.add_subgraph_full(
1308 #subgraph_name,
1309 #stratum,
1310 var_expr!( #( #recv_ports ),* ),
1311 var_expr!( #( #send_ports ),* ),
1312 #laziness,
1313 #loop_id_opt,
1314 async move |#context, var_args!( #( #recv_ports ),* ), var_args!( #( #send_ports ),* )| {
1315 #( #recv_port_code )*
1316 #( #send_port_code )*
1317 #( #subgraph_op_iter_code )*
1318 #( #subgraph_op_iter_after_code )*
1319 },
1320 );
1321 });
1322 }
1323 }
1324
1325 if diagnostics.has_error() {
1326 return Err(std::mem::take(diagnostics));
1327 }
1328 let _ = diagnostics; let loop_code = self.codegen_nested_loops(&df);
1331
1332 let code = quote! {
1337 #( #handoff_code )*
1338 #loop_code
1339 #( #op_prologue_code )*
1340 #( #subgraphs )*
1341 #( #op_prologue_after_code )*
1342 };
1343
1344 let meta_graph_json = serde_json::to_string(&self).unwrap();
1345 let meta_graph_json = Literal::string(&meta_graph_json);
1346
1347 let serde_diagnostics: Vec<_> = diagnostics.iter().map(Diagnostic::to_serde).collect();
1348 let diagnostics_json = serde_json::to_string(&*serde_diagnostics).unwrap();
1349 let diagnostics_json = Literal::string(&diagnostics_json);
1350
1351 Ok(quote! {
1352 {
1353 #[allow(unused_qualifications, clippy::await_holding_refcell_ref)]
1354 {
1355 #prefix
1356
1357 use #root::{var_expr, var_args};
1358
1359 let mut #df = #root::scheduled::graph::Dfir::new();
1360 #df.__assign_meta_graph(#meta_graph_json);
1361 #df.__assign_diagnostics(#diagnostics_json);
1362
1363 #code
1364
1365 #df
1366 }
1367 }
1368 })
1369 }
1370
1371 pub fn node_color_map(&self) -> SparseSecondaryMap<GraphNodeId, Color> {
1374 let mut node_color_map: SparseSecondaryMap<GraphNodeId, Color> = self
1375 .node_ids()
1376 .filter_map(|node_id| {
1377 let op_color = self.node_color(node_id)?;
1378 Some((node_id, op_color))
1379 })
1380 .collect();
1381
1382 for sg_nodes in self.subgraph_nodes.values() {
1384 let pull_to_push_idx = self.find_pull_to_push_idx(sg_nodes);
1385
1386 for (idx, node_id) in sg_nodes.iter().copied().enumerate() {
1387 let is_pull = idx < pull_to_push_idx;
1388 node_color_map.insert(node_id, if is_pull { Color::Pull } else { Color::Push });
1389 }
1390 }
1391
1392 node_color_map
1393 }
1394
1395 pub fn to_mermaid(&self, write_config: &WriteConfig) -> String {
1397 let mut output = String::new();
1398 self.write_mermaid(&mut output, write_config).unwrap();
1399 output
1400 }
1401
1402 pub fn write_mermaid(
1404 &self,
1405 output: impl std::fmt::Write,
1406 write_config: &WriteConfig,
1407 ) -> std::fmt::Result {
1408 let mut graph_write = Mermaid::new(output);
1409 self.write_graph(&mut graph_write, write_config)
1410 }
1411
1412 pub fn to_dot(&self, write_config: &WriteConfig) -> String {
1414 let mut output = String::new();
1415 let mut graph_write = Dot::new(&mut output);
1416 self.write_graph(&mut graph_write, write_config).unwrap();
1417 output
1418 }
1419
1420 pub fn write_dot(
1422 &self,
1423 output: impl std::fmt::Write,
1424 write_config: &WriteConfig,
1425 ) -> std::fmt::Result {
1426 let mut graph_write = Dot::new(output);
1427 self.write_graph(&mut graph_write, write_config)
1428 }
1429
1430 pub(crate) fn write_graph<W>(
1432 &self,
1433 mut graph_write: W,
1434 write_config: &WriteConfig,
1435 ) -> Result<(), W::Err>
1436 where
1437 W: GraphWrite,
1438 {
1439 fn helper_edge_label(
1440 src_port: &PortIndexValue,
1441 dst_port: &PortIndexValue,
1442 ) -> Option<String> {
1443 let src_label = match src_port {
1444 PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1445 PortIndexValue::Int(index) => Some(index.value.to_string()),
1446 _ => None,
1447 };
1448 let dst_label = match dst_port {
1449 PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1450 PortIndexValue::Int(index) => Some(index.value.to_string()),
1451 _ => None,
1452 };
1453 let label = match (src_label, dst_label) {
1454 (Some(l1), Some(l2)) => Some(format!("{}\n{}", l1, l2)),
1455 (Some(l1), None) => Some(l1),
1456 (None, Some(l2)) => Some(l2),
1457 (None, None) => None,
1458 };
1459 label
1460 }
1461
1462 let node_color_map = self.node_color_map();
1464
1465 graph_write.write_prologue()?;
1467
1468 let mut skipped_handoffs = BTreeSet::new();
1470 let mut subgraph_handoffs = <BTreeMap<GraphSubgraphId, Vec<GraphNodeId>>>::new();
1471 for (node_id, node) in self.nodes() {
1472 if matches!(node, GraphNode::Handoff { .. }) {
1473 if write_config.no_handoffs {
1474 skipped_handoffs.insert(node_id);
1475 continue;
1476 } else {
1477 let pred_node = self.node_predecessor_nodes(node_id).next().unwrap();
1478 let pred_sg = self.node_subgraph(pred_node);
1479 let succ_node = self.node_successor_nodes(node_id).next().unwrap();
1480 let succ_sg = self.node_subgraph(succ_node);
1481 if let Some((pred_sg, succ_sg)) = pred_sg.zip(succ_sg)
1482 && pred_sg == succ_sg
1483 {
1484 subgraph_handoffs.entry(pred_sg).or_default().push(node_id);
1485 }
1486 }
1487 }
1488 graph_write.write_node_definition(
1489 node_id,
1490 &if write_config.op_short_text {
1491 node.to_name_string()
1492 } else if write_config.op_text_no_imports {
1493 let full_text = node.to_pretty_string();
1495 let mut output = String::new();
1496 for sentence in full_text.split('\n') {
1497 if sentence.trim().starts_with("use") {
1498 continue;
1499 }
1500 output.push('\n');
1501 output.push_str(sentence);
1502 }
1503 output.into()
1504 } else {
1505 node.to_pretty_string()
1506 },
1507 if write_config.no_pull_push {
1508 None
1509 } else {
1510 node_color_map.get(node_id).copied()
1511 },
1512 )?;
1513 }
1514
1515 for (edge_id, (src_id, mut dst_id)) in self.edges() {
1517 if skipped_handoffs.contains(&src_id) {
1519 continue;
1520 }
1521
1522 let (src_port, mut dst_port) = self.edge_ports(edge_id);
1523 if skipped_handoffs.contains(&dst_id) {
1524 let mut handoff_succs = self.node_successors(dst_id);
1525 assert_eq!(1, handoff_succs.len());
1526 let (succ_edge, succ_node) = handoff_succs.next().unwrap();
1527 dst_id = succ_node;
1528 dst_port = self.edge_ports(succ_edge).1;
1529 }
1530
1531 let label = helper_edge_label(src_port, dst_port);
1532 let delay_type = self
1533 .node_op_inst(dst_id)
1534 .and_then(|op_inst| (op_inst.op_constraints.input_delaytype_fn)(dst_port));
1535 graph_write.write_edge(src_id, dst_id, delay_type, label.as_deref(), false)?;
1536 }
1537
1538 if !write_config.no_references {
1540 for dst_id in self.node_ids() {
1541 for src_ref_id in self
1542 .node_singleton_references(dst_id)
1543 .iter()
1544 .copied()
1545 .flatten()
1546 {
1547 let delay_type = Some(DelayType::Stratum);
1548 let label = None;
1549 graph_write.write_edge(src_ref_id, dst_id, delay_type, label, true)?;
1550 }
1551 }
1552 }
1553
1554 let loop_subgraphs = self.subgraph_ids().map(|sg_id| {
1565 let loop_id = if write_config.no_loops {
1566 None
1567 } else {
1568 self.subgraph_loop(sg_id)
1569 };
1570 (loop_id, sg_id)
1571 });
1572 let loop_subgraphs = into_group_map(loop_subgraphs);
1573 for (loop_id, subgraph_ids) in loop_subgraphs {
1574 if let Some(loop_id) = loop_id {
1575 graph_write.write_loop_start(loop_id)?;
1576 }
1577
1578 let subgraph_varnames_nodes = subgraph_ids.into_iter().flat_map(|sg_id| {
1580 self.subgraph(sg_id).iter().copied().map(move |node_id| {
1581 let opt_sg_id = if write_config.no_subgraphs {
1582 None
1583 } else {
1584 Some(sg_id)
1585 };
1586 (opt_sg_id, (self.node_varname(node_id), node_id))
1587 })
1588 });
1589 let subgraph_varnames_nodes = into_group_map(subgraph_varnames_nodes);
1590 for (sg_id, varnames) in subgraph_varnames_nodes {
1591 if let Some(sg_id) = sg_id {
1592 let stratum = self.subgraph_stratum(sg_id).unwrap();
1593 graph_write.write_subgraph_start(sg_id, stratum)?;
1594 }
1595
1596 let varname_nodes = varnames.into_iter().map(|(varname, node)| {
1598 let varname = if write_config.no_varnames {
1599 None
1600 } else {
1601 varname
1602 };
1603 (varname, node)
1604 });
1605 let varname_nodes = into_group_map(varname_nodes);
1606 for (varname, node_ids) in varname_nodes {
1607 if let Some(varname) = varname {
1608 graph_write.write_varname_start(&varname.0.to_string(), sg_id)?;
1609 }
1610
1611 for node_id in node_ids {
1613 graph_write.write_node(node_id)?;
1614 }
1615
1616 if varname.is_some() {
1617 graph_write.write_varname_end()?;
1618 }
1619 }
1620
1621 if sg_id.is_some() {
1622 graph_write.write_subgraph_end()?;
1623 }
1624 }
1625
1626 if loop_id.is_some() {
1627 graph_write.write_loop_end()?;
1628 }
1629 }
1630
1631 graph_write.write_epilogue()?;
1633
1634 Ok(())
1635 }
1636
1637 pub fn surface_syntax_string(&self) -> String {
1639 let mut string = String::new();
1640 self.write_surface_syntax(&mut string).unwrap();
1641 string
1642 }
1643
1644 pub fn write_surface_syntax(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1646 for (key, node) in self.nodes.iter() {
1647 match node {
1648 GraphNode::Operator(op) => {
1649 writeln!(write, "{:?} = {};", key.data(), op.to_token_stream())?;
1650 }
1651 GraphNode::Handoff { .. } => {
1652 writeln!(write, "// {:?} = <handoff>;", key.data())?;
1653 }
1654 GraphNode::ModuleBoundary { .. } => panic!(),
1655 }
1656 }
1657 writeln!(write)?;
1658 for (_e, (src_key, dst_key)) in self.graph.edges() {
1659 writeln!(write, "{:?} -> {:?};", src_key.data(), dst_key.data())?;
1660 }
1661 Ok(())
1662 }
1663
1664 pub fn mermaid_string_flat(&self) -> String {
1666 let mut string = String::new();
1667 self.write_mermaid_flat(&mut string).unwrap();
1668 string
1669 }
1670
1671 pub fn write_mermaid_flat(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1673 writeln!(write, "flowchart TB")?;
1674 for (key, node) in self.nodes.iter() {
1675 match node {
1676 GraphNode::Operator(operator) => writeln!(
1677 write,
1678 " %% {span}\n {id:?}[\"{row_col} <tt>{code}</tt>\"]",
1679 span = PrettySpan(node.span()),
1680 id = key.data(),
1681 row_col = PrettyRowCol(node.span()),
1682 code = operator
1683 .to_token_stream()
1684 .to_string()
1685 .replace('&', "&")
1686 .replace('<', "<")
1687 .replace('>', ">")
1688 .replace('"', """)
1689 .replace('\n', "<br>"),
1690 ),
1691 GraphNode::Handoff { .. } => {
1692 writeln!(write, r#" {:?}{{"{}"}}"#, key.data(), HANDOFF_NODE_STR)
1693 }
1694 GraphNode::ModuleBoundary { .. } => {
1695 writeln!(
1696 write,
1697 r#" {:?}{{"{}"}}"#,
1698 key.data(),
1699 MODULE_BOUNDARY_NODE_STR
1700 )
1701 }
1702 }?;
1703 }
1704 writeln!(write)?;
1705 for (_e, (src_key, dst_key)) in self.graph.edges() {
1706 writeln!(write, " {:?}-->{:?}", src_key.data(), dst_key.data())?;
1707 }
1708 Ok(())
1709 }
1710}
1711
1712impl DfirGraph {
1714 pub fn loop_ids(&self) -> slotmap::basic::Keys<'_, GraphLoopId, Vec<GraphNodeId>> {
1716 self.loop_nodes.keys()
1717 }
1718
1719 pub fn loops(&self) -> slotmap::basic::Iter<'_, GraphLoopId, Vec<GraphNodeId>> {
1721 self.loop_nodes.iter()
1722 }
1723
1724 pub fn insert_loop(&mut self, parent_loop: Option<GraphLoopId>) -> GraphLoopId {
1726 let loop_id = self.loop_nodes.insert(Vec::new());
1727 self.loop_children.insert(loop_id, Vec::new());
1728 if let Some(parent_loop) = parent_loop {
1729 self.loop_parent.insert(loop_id, parent_loop);
1730 self.loop_children
1731 .get_mut(parent_loop)
1732 .unwrap()
1733 .push(loop_id);
1734 } else {
1735 self.root_loops.push(loop_id);
1736 }
1737 loop_id
1738 }
1739
1740 pub fn node_loop(&self, node_id: GraphNodeId) -> Option<GraphLoopId> {
1742 self.node_loops.get(node_id).copied()
1743 }
1744
1745 pub fn subgraph_loop(&self, subgraph_id: GraphSubgraphId) -> Option<GraphLoopId> {
1747 let &node_id = self.subgraph(subgraph_id).first().unwrap();
1748 let out = self.node_loop(node_id);
1749 debug_assert!(
1750 self.subgraph(subgraph_id)
1751 .iter()
1752 .all(|&node_id| self.node_loop(node_id) == out),
1753 "Subgraph nodes should all have the same loop context."
1754 );
1755 out
1756 }
1757
1758 pub fn loop_parent(&self, loop_id: GraphLoopId) -> Option<GraphLoopId> {
1760 self.loop_parent.get(loop_id).copied()
1761 }
1762
1763 pub fn loop_children(&self, loop_id: GraphLoopId) -> &Vec<GraphLoopId> {
1765 self.loop_children.get(loop_id).unwrap()
1766 }
1767}
1768
1769#[derive(Clone, Debug, Default)]
1771#[cfg_attr(feature = "clap-derive", derive(clap::Args))]
1772pub struct WriteConfig {
1773 #[cfg_attr(feature = "clap-derive", arg(long))]
1775 pub no_subgraphs: bool,
1776 #[cfg_attr(feature = "clap-derive", arg(long))]
1778 pub no_varnames: bool,
1779 #[cfg_attr(feature = "clap-derive", arg(long))]
1781 pub no_pull_push: bool,
1782 #[cfg_attr(feature = "clap-derive", arg(long))]
1784 pub no_handoffs: bool,
1785 #[cfg_attr(feature = "clap-derive", arg(long))]
1787 pub no_references: bool,
1788 #[cfg_attr(feature = "clap-derive", arg(long))]
1790 pub no_loops: bool,
1791
1792 #[cfg_attr(feature = "clap-derive", arg(long))]
1794 pub op_short_text: bool,
1795 #[cfg_attr(feature = "clap-derive", arg(long))]
1797 pub op_text_no_imports: bool,
1798}
1799
1800#[derive(Copy, Clone, Debug)]
1802#[cfg_attr(feature = "clap-derive", derive(clap::Parser, clap::ValueEnum))]
1803pub enum WriteGraphType {
1804 Mermaid,
1806 Dot,
1808}
1809
1810fn into_group_map<K, V>(iter: impl IntoIterator<Item = (K, V)>) -> BTreeMap<K, Vec<V>>
1812where
1813 K: Ord,
1814{
1815 let mut out: BTreeMap<_, Vec<_>> = BTreeMap::new();
1816 for (k, v) in iter {
1817 out.entry(k).or_default().push(v);
1818 }
1819 out
1820}