Skip to main content

dfir_pipes/push/
flatten_stream.rs

1//! [`FlattenStream`] push combinator.
2use core::pin::Pin;
3use core::task::{Context, Poll};
4
5use futures_core::Stream;
6use pin_project_lite::pin_project;
7
8use crate::Yes;
9use crate::push::{Push, PushStep, ready};
10
11pin_project! {
12    struct FlattenStreamBuffer<St, Meta> where St: Stream {
13        #[pin]
14        stream: St,
15        item: Option<St::Item>,
16        meta: Meta,
17    }
18}
19
20pin_project! {
21    /// Push combinator that flattens stream items by polling each stream and pushing elements downstream.
22    ///
23    /// When the inner stream yields `Poll::Pending`, this operator yields as well.
24    #[must_use = "`Push`es do nothing unless items are pushed into them"]
25    pub struct FlattenStream<Next, St, Meta>
26    where
27        St: Stream,
28    {
29        #[pin]
30        next: Next,
31        #[pin]
32        buffer: Option<FlattenStreamBuffer<St, Meta>>,
33    }
34}
35
36impl<Next, St, Meta> FlattenStream<Next, St, Meta>
37where
38    Next: Push<St::Item, Meta>,
39    St: Stream,
40    Meta: Copy,
41{
42    /// Creates with next `push`.
43    pub(crate) const fn new(next: Next) -> Self {
44        Self { next, buffer: None }
45    }
46}
47
48impl<Next, St, Meta> Push<St, Meta> for FlattenStream<Next, St, Meta>
49where
50    Next: Push<St::Item, Meta>,
51    St: Stream,
52    Meta: Copy,
53{
54    type Ctx<'ctx> = Context<'ctx>;
55
56    type CanPend = Yes;
57
58    fn size_hint(self: Pin<&mut Self>, _hint: (usize, Option<usize>)) {
59        let this = self.project();
60        let lower = this
61            .buffer
62            .as_pin_mut()
63            .map(|b| b.project().stream.size_hint().0)
64            .unwrap_or_default();
65        this.next.size_hint((lower, None));
66    }
67
68    fn poll_ready(self: Pin<&mut Self>, ctx: &mut Self::Ctx<'_>) -> PushStep<Self::CanPend> {
69        let mut this = self.project();
70
71        while let Some(buf) = this.buffer.as_mut().as_pin_mut().map(|buf| buf.project()) {
72            if buf.item.is_some() {
73                ready!(
74                    this.next
75                        .as_mut()
76                        .poll_ready(crate::Context::from_task(ctx))
77                );
78                let item = buf.item.take().unwrap();
79                this.next.as_mut().start_send(item, *buf.meta);
80            }
81            debug_assert!(buf.item.is_none());
82
83            match Stream::poll_next(buf.stream, ctx) {
84                Poll::Ready(Some(next_item)) => *buf.item = Some(next_item),
85                Poll::Ready(None) => this.buffer.as_mut().set(None),
86                Poll::Pending => return PushStep::Pending(Yes),
87            }
88        }
89        PushStep::Done
90    }
91
92    fn start_send(self: Pin<&mut Self>, stream: St, meta: Meta) {
93        let mut this = self.project();
94        assert!(
95            this.buffer.is_none(),
96            "FlattenStream: poll_ready must be called before start_send"
97        );
98        this.buffer.set(Some(FlattenStreamBuffer {
99            stream,
100            item: None,
101            meta,
102        }));
103    }
104
105    fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Self::Ctx<'_>) -> PushStep<Self::CanPend> {
106        ready!(self.as_mut().poll_ready(ctx));
107        self.project()
108            .next
109            .poll_flush(crate::Context::from_task(ctx))
110            .convert_into()
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use core::pin::Pin;
117    use core::task::{Context, Waker};
118
119    extern crate alloc;
120    use alloc::vec;
121
122    use futures_util::stream;
123
124    use crate::push::Push;
125    use crate::push::test_utils::TestPush;
126
127    #[test]
128    fn flatten_stream_readies_downstream_before_each_send() {
129        let waker = Waker::noop();
130        let mut cx = Context::from_waker(waker);
131
132        let mut tp = TestPush::no_pend();
133        let mut fs =
134            crate::push::flatten_stream::<stream::Iter<vec::IntoIter<i32>>, (), _>(&mut tp);
135        let mut fs = Pin::new(&mut fs);
136
137        let result = Push::<stream::Iter<vec::IntoIter<i32>>, ()>::poll_ready(fs.as_mut(), &mut cx);
138        assert!(result.is_done());
139
140        Push::<stream::Iter<vec::IntoIter<i32>>, ()>::start_send(
141            fs.as_mut(),
142            stream::iter(vec![1, 2]),
143            (),
144        );
145
146        let result = Push::<stream::Iter<vec::IntoIter<i32>>, ()>::poll_ready(fs.as_mut(), &mut cx);
147        assert!(result.is_done());
148
149        Push::<stream::Iter<vec::IntoIter<i32>>, ()>::start_send(
150            fs.as_mut(),
151            stream::iter(vec![3]),
152            (),
153        );
154
155        let result = Push::<stream::Iter<vec::IntoIter<i32>>, ()>::poll_flush(fs.as_mut(), &mut cx);
156        assert!(result.is_done());
157
158        assert_eq!(tp.items(), vec![1, 2, 3]);
159    }
160
161    #[test]
162    fn flatten_stream_pending_propagates() {
163        let waker = Waker::noop();
164        let mut cx = Context::from_waker(waker);
165
166        let mut tp: TestPush<i32, crate::No, true> = TestPush::new_fused([], []);
167        let mut fs = crate::push::flatten_stream::<stream::Pending<i32>, (), _>(&mut tp);
168        let mut fs = Pin::new(&mut fs);
169
170        // Ready initially (no stream buffered).
171        let result = Push::<stream::Pending<i32>, ()>::poll_ready(fs.as_mut(), &mut cx);
172        assert!(result.is_done());
173
174        // Send a stream that is always pending.
175        Push::<stream::Pending<i32>, ()>::start_send(fs.as_mut(), stream::pending(), ());
176
177        // poll_ready should return Pending since the stream pends.
178        let result = Push::<stream::Pending<i32>, ()>::poll_ready(fs.as_mut(), &mut cx);
179        assert!(result.is_pending());
180    }
181}