Skip to main content

dfir_pipes/push/
flat_map_stream.rs

1//! [`FlatMapStream`] push combinator.
2use core::pin::Pin;
3use core::task::{Context, Poll};
4
5use futures_core::Stream;
6use pin_project_lite::pin_project;
7
8use crate::Yes;
9use crate::push::{Push, PushStep, ready};
10
11pin_project! {
12    struct FlatMapStreamBuffer<St, Meta> where St: Stream {
13        #[pin]
14        stream: St,
15        item: Option<St::Item>,
16        meta: Meta,
17    }
18}
19
20pin_project! {
21    /// Push combinator that maps each item to a stream and pushes each element downstream.
22    ///
23    /// When the inner stream yields `Poll::Pending`, this operator yields as well.
24    #[must_use = "`Push`es do nothing unless items are pushed into them"]
25    pub struct FlatMapStream<Next, Func, St, Meta>
26    where
27        St: Stream,
28    {
29        #[pin]
30        next: Next,
31        func: Func,
32        #[pin]
33        buffer: Option<FlatMapStreamBuffer<St, Meta>>,
34    }
35}
36
37impl<Next, Func, St, Meta> FlatMapStream<Next, Func, St, Meta>
38where
39    Next: Push<St::Item, Meta>,
40    St: Stream,
41    Meta: Copy,
42{
43    /// Creates with flat-mapping `func` and next `push`.
44    pub(crate) const fn new<In>(func: Func, next: Next) -> Self
45    where
46        Func: FnMut(In) -> St,
47    {
48        Self {
49            next,
50            func,
51            buffer: None,
52        }
53    }
54}
55
56impl<Next, Func, St, In, Meta> Push<In, Meta> for FlatMapStream<Next, Func, St, Meta>
57where
58    Next: Push<St::Item, Meta>,
59    Func: FnMut(In) -> St,
60    St: Stream,
61    Meta: Copy,
62{
63    type Ctx<'ctx> = Context<'ctx>;
64
65    type CanPend = Yes;
66
67    fn size_hint(self: Pin<&mut Self>, _hint: (usize, Option<usize>)) {
68        let this = self.project();
69        let lower = this
70            .buffer
71            .as_pin_mut()
72            .map(|b| b.project().stream.size_hint().0)
73            .unwrap_or_default();
74        this.next.size_hint((lower, None));
75    }
76
77    fn poll_ready(self: Pin<&mut Self>, ctx: &mut Self::Ctx<'_>) -> PushStep<Self::CanPend> {
78        let mut this = self.project();
79
80        while let Some(buf) = this.buffer.as_mut().as_pin_mut().map(|buf| buf.project()) {
81            if buf.item.is_some() {
82                ready!(
83                    this.next
84                        .as_mut()
85                        .poll_ready(crate::Context::from_task(ctx))
86                );
87                let item = buf.item.take().unwrap();
88                this.next.as_mut().start_send(item, *buf.meta);
89            }
90            debug_assert!(buf.item.is_none());
91
92            match Stream::poll_next(buf.stream, ctx) {
93                Poll::Ready(Some(next_item)) => *buf.item = Some(next_item),
94                Poll::Ready(None) => this.buffer.as_mut().set(None),
95                Poll::Pending => return PushStep::Pending(Yes),
96            }
97        }
98        PushStep::Done
99    }
100
101    fn start_send(self: Pin<&mut Self>, item: In, meta: Meta) {
102        let mut this = self.project();
103        assert!(
104            this.buffer.is_none(),
105            "FlatMapStream: poll_ready must be called before start_send"
106        );
107        let stream = (this.func)(item);
108        this.buffer.set(Some(FlatMapStreamBuffer {
109            stream,
110            item: None,
111            meta,
112        }));
113    }
114
115    fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Self::Ctx<'_>) -> PushStep<Self::CanPend> {
116        ready!(self.as_mut().poll_ready(ctx));
117        self.project()
118            .next
119            .poll_flush(crate::Context::from_task(ctx))
120            .convert_into()
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use core::pin::Pin;
127    use core::task::{Context, Waker};
128
129    extern crate alloc;
130    use alloc::vec;
131
132    use futures_util::stream;
133
134    use crate::push::Push;
135    use crate::push::test_utils::TestPush;
136
137    #[test]
138    fn flat_map_stream_readies_downstream_before_each_send() {
139        let waker = Waker::noop();
140        let mut cx = Context::from_waker(waker);
141
142        let mut tp = TestPush::no_pend();
143        let mut fms = crate::push::flat_map_stream::<_, _, stream::Iter<vec::IntoIter<i32>>, (), _>(
144            |x: i32| stream::iter(vec![x, x + 10]),
145            &mut tp,
146        );
147        let mut fms = Pin::new(&mut fms);
148
149        let result = Push::<i32, ()>::poll_ready(fms.as_mut(), &mut cx);
150        assert!(result.is_done());
151
152        Push::<i32, ()>::start_send(fms.as_mut(), 1, ());
153
154        let result = Push::<i32, ()>::poll_ready(fms.as_mut(), &mut cx);
155        assert!(result.is_done());
156
157        Push::<i32, ()>::start_send(fms.as_mut(), 2, ());
158
159        let result = Push::<i32, ()>::poll_flush(fms.as_mut(), &mut cx);
160        assert!(result.is_done());
161
162        assert_eq!(tp.items(), vec![1, 11, 2, 12]);
163    }
164
165    #[test]
166    fn flat_map_stream_pending_propagates() {
167        let waker = Waker::noop();
168        let mut cx = Context::from_waker(waker);
169
170        let mut tp: TestPush<i32, crate::No, true> = TestPush::new_fused([], []);
171        let mut fms = crate::push::flat_map_stream::<_, _, stream::Pending<i32>, (), _>(
172            |_: i32| stream::pending(),
173            &mut tp,
174        );
175        let mut fms = Pin::new(&mut fms);
176
177        let result = Push::<i32, ()>::poll_ready(fms.as_mut(), &mut cx);
178        assert!(result.is_done());
179
180        Push::<i32, ()>::start_send(fms.as_mut(), 42, ());
181
182        let result = Push::<i32, ()>::poll_ready(fms.as_mut(), &mut cx);
183        assert!(result.is_pending());
184    }
185}