Skip to main content

dfir_pipes/pull/
take.rs

1use core::pin::Pin;
2
3use pin_project_lite::pin_project;
4
5use crate::Yes;
6use crate::pull::{FusedPull, Pull, PullStep, fuse_self};
7
8pin_project! {
9    /// Pull combinator that yields the first `n` items.
10    #[must_use = "`Pull`s do nothing unless polled"]
11    #[derive(Clone, Debug, Default)]
12    pub struct Take<Prev> {
13        #[pin]
14        prev: Prev,
15        remaining: usize,
16    }
17}
18
19impl<Prev> Take<Prev>
20where
21    Self: Pull,
22{
23    pub(crate) const fn new(prev: Prev, n: usize) -> Self {
24        Self { prev, remaining: n }
25    }
26}
27
28impl<Prev> Pull for Take<Prev>
29where
30    Prev: Pull,
31{
32    type Ctx<'ctx> = Prev::Ctx<'ctx>;
33
34    type Item = Prev::Item;
35    type Meta = Prev::Meta;
36    type CanPend = Prev::CanPend;
37    type CanEnd = Yes;
38
39    fn pull(
40        self: Pin<&mut Self>,
41        ctx: &mut Self::Ctx<'_>,
42    ) -> PullStep<Self::Item, Self::Meta, Self::CanPend, Self::CanEnd> {
43        let this = self.project();
44
45        if 0 == *this.remaining {
46            return PullStep::Ended(Yes);
47        }
48
49        match this.prev.pull(ctx) {
50            PullStep::Ready(item, meta) => {
51                *this.remaining -= 1;
52                PullStep::Ready(item, meta)
53            }
54            PullStep::Pending(can_pend) => PullStep::Pending(can_pend),
55            PullStep::Ended(_) => {
56                *this.remaining = 0;
57                PullStep::Ended(Yes)
58            }
59        }
60    }
61
62    fn size_hint(&self) -> (usize, Option<usize>) {
63        let (lower, upper) = self.prev.size_hint();
64        let remaining = self.remaining;
65        (
66            lower.min(remaining),
67            upper.map(|u| u.min(remaining)).or(Some(remaining)),
68        )
69    }
70
71    fuse_self!();
72}
73
74impl<Prev> FusedPull for Take<Prev> where Prev: Pull {}
75
76#[cfg(test)]
77mod tests {
78    use core::pin::pin;
79
80    use crate::pull::Pull;
81    use crate::pull::test_utils::{TestPull, assert_fused_runtime};
82
83    #[test]
84    fn take_fused_shields_upstream() {
85        let p = pin!(TestPull::items(0..2).take(1));
86        assert_fused_runtime(p);
87    }
88}