Skip to main content

dfir_pipes/pull/
fuse.rs

1use core::pin::Pin;
2
3use pin_project_lite::pin_project;
4
5use crate::pull::{FusedPull, Pull, PullStep, fuse_self};
6
7pin_project! {
8    /// Pull combinator that guarantees [`PullStep::Ended`] is returned forever after the first end.
9    #[must_use = "`Pull`s do nothing unless polled"]
10    #[derive(Clone, Debug, Default)]
11    #[project_replace = FuseReplace]
12    pub struct Fuse<Prev> {
13        #[pin]
14        prev: Option<Prev>,
15    }
16}
17
18impl<Prev> Fuse<Prev>
19where
20    Self: Pull,
21{
22    pub(crate) const fn new(prev: Prev) -> Self {
23        Self { prev: Some(prev) }
24    }
25}
26
27impl<Prev> Pull for Fuse<Prev>
28where
29    Prev: Pull,
30{
31    type Ctx<'ctx> = Prev::Ctx<'ctx>;
32
33    type Item = Prev::Item;
34    type Meta = Prev::Meta;
35    type CanPend = Prev::CanPend;
36    type CanEnd = Prev::CanEnd;
37
38    fn pull(
39        mut self: Pin<&mut Self>,
40        ctx: &mut Self::Ctx<'_>,
41    ) -> PullStep<Self::Item, Self::Meta, Self::CanPend, Self::CanEnd> {
42        let this = self.as_mut().project();
43
44        if let Some(prev) = this.prev.as_pin_mut() {
45            match prev.pull(ctx) {
46                PullStep::Ready(item, meta) => PullStep::Ready(item, meta),
47                PullStep::Pending(can_pend) => PullStep::Pending(can_pend),
48                PullStep::Ended(_) => {
49                    let _ = self.project_replace(Self { prev: None });
50                    PullStep::ended()
51                }
52            }
53        } else {
54            PullStep::ended()
55        }
56    }
57
58    fn size_hint(&self) -> (usize, Option<usize>) {
59        if let Some(prev) = self.prev.as_ref() {
60            prev.size_hint()
61        } else {
62            (0, Some(0))
63        }
64    }
65
66    fuse_self!();
67}
68
69impl<Prev> FusedPull for Fuse<Prev> where Prev: Pull {}
70
71#[cfg(test)]
72mod tests {
73    use core::pin::pin;
74
75    use crate::pull::Pull;
76    use crate::pull::test_utils::{TestPull, assert_fused_runtime};
77
78    #[test]
79    fn fuse_shields_upstream() {
80        let p = pin!(TestPull::items(0..2).fuse());
81        assert_fused_runtime(p);
82    }
83}