Skip to main content

dfir_pipes/pull/
flatten_stream.rs

1//! [`FlattenStream`] pull combinator.
2use core::pin::Pin;
3use core::task::Context;
4
5use futures_core::Stream;
6use pin_project_lite::pin_project;
7
8use crate::Yes;
9use crate::pull::{FusedPull, Pull, PullStep};
10
11pin_project! {
12    /// Pull combinator that flattens items that are streams by polling each inner stream.
13    ///
14    /// When the inner stream yields `Poll::Pending`, this operator yields `Pending` as well.
15    #[must_use = "`Pull`s do nothing unless polled"]
16    pub struct FlattenStream<Prev, St, Meta> where St: Stream {
17        #[pin]
18        prev: Prev,
19        #[pin]
20        current: Option<FlattenStreamCurrent<St, Meta>>,
21    }
22}
23
24pin_project! {
25    struct FlattenStreamCurrent<St, Meta> where St: Stream {
26        #[pin]
27        stream: St,
28        meta: Meta,
29    }
30}
31
32impl<Prev, St, Meta> FlattenStream<Prev, St, Meta>
33where
34    Self: Pull,
35    St: Stream,
36{
37    pub(crate) const fn new(prev: Prev) -> Self {
38        Self {
39            prev,
40            current: None,
41        }
42    }
43}
44
45impl<Prev> Pull for FlattenStream<Prev, Prev::Item, Prev::Meta>
46where
47    Prev: Pull,
48    Prev::Item: Stream,
49{
50    type Ctx<'ctx> = Context<'ctx>;
51
52    type Item = <Prev::Item as Stream>::Item;
53    type Meta = Prev::Meta;
54    type CanPend = Yes;
55    type CanEnd = Prev::CanEnd;
56
57    fn size_hint(&self) -> (usize, Option<usize>) {
58        let current_lower = self
59            .current
60            .as_ref()
61            .map(|c| c.stream.size_hint().0)
62            .unwrap_or_default();
63        (current_lower, None)
64    }
65
66    fn pull(
67        self: Pin<&mut Self>,
68        ctx: &mut Self::Ctx<'_>,
69    ) -> PullStep<Self::Item, Self::Meta, Self::CanPend, Self::CanEnd> {
70        let mut this = self.project();
71        loop {
72            if let Some(cur) = this.current.as_mut().as_pin_mut().map(|c| c.project()) {
73                match Stream::poll_next(cur.stream, ctx) {
74                    core::task::Poll::Ready(Some(item)) => {
75                        return PullStep::Ready(item, *cur.meta);
76                    }
77                    core::task::Poll::Ready(None) => {
78                        this.current.as_mut().set(None);
79                    }
80                    core::task::Poll::Pending => {
81                        return PullStep::Pending(Yes);
82                    }
83                }
84            }
85            debug_assert!(this.current.is_none());
86
87            match this.prev.as_mut().pull(crate::Context::from_task(ctx)) {
88                PullStep::Ready(stream, meta) => {
89                    this.current
90                        .as_mut()
91                        .set(Some(FlattenStreamCurrent { stream, meta }));
92                }
93                PullStep::Pending(_) => {
94                    return PullStep::Pending(Yes);
95                }
96                PullStep::Ended(can_end) => {
97                    return PullStep::Ended(can_end);
98                }
99            }
100        }
101    }
102}
103
104impl<Prev> FusedPull for FlattenStream<Prev, Prev::Item, Prev::Meta>
105where
106    Prev: FusedPull,
107    Prev::Item: Stream,
108{
109}
110
111#[cfg(test)]
112mod tests {
113    use core::pin::Pin;
114    use core::task::{Context, Waker};
115
116    extern crate alloc;
117    use alloc::vec;
118
119    use futures_util::stream;
120
121    use crate::Yes;
122    use crate::pull::{Pull, PullStep};
123
124    #[test]
125    fn flatten_stream_basic() {
126        let waker = Waker::noop();
127        let mut cx = Context::from_waker(waker);
128
129        let mut p = crate::pull::iter(vec![stream::iter(vec![1, 2]), stream::iter(vec![3])])
130            .flatten_stream();
131        let mut p = Pin::new(&mut p);
132
133        assert_eq!(PullStep::Ready(1, ()), p.as_mut().pull(&mut cx));
134        assert_eq!(PullStep::Ready(2, ()), p.as_mut().pull(&mut cx));
135        assert_eq!(PullStep::Ready(3, ()), p.as_mut().pull(&mut cx));
136
137        let step: PullStep<i32, (), Yes, Yes> = p.as_mut().pull(&mut cx);
138        assert!(step.is_ended());
139    }
140
141    #[test]
142    fn flatten_stream_pending_propagates() {
143        let waker = Waker::noop();
144        let mut cx = Context::from_waker(waker);
145
146        let mut p = crate::pull::iter(vec![stream::pending::<i32>()]).flatten_stream();
147        let mut p = Pin::new(&mut p);
148
149        for _ in 0..10 {
150            let step: PullStep<i32, (), Yes, Yes> = p.as_mut().pull(&mut cx);
151            assert!(step.is_pending());
152        }
153    }
154}