Skip to main content

dfir_pipes/pull/
symmetric_hash_join.rs

1//! Symmetric hash join combinator for Pull-based streams.
2
3use std::borrow::{BorrowMut, Cow};
4use std::marker::PhantomData;
5use std::pin::Pin;
6
7use itertools::Either;
8use pin_project_lite::pin_project;
9use smallvec::SmallVec;
10
11use crate::pull::half_join_state::HalfJoinState;
12use crate::pull::{self, FusedPull, Pull, PullStep};
13use crate::{Context, Toggle};
14
15pin_project! {
16    /// Pull combinator for symmetric hash join operations.
17    ///
18    /// Joins two pulls on a common key, producing tuples of matched values.
19    /// Items are processed as they arrive, with matches emitted immediately.
20    #[must_use = "`Pull`s do nothing unless polled"]
21    #[derive(Clone, Debug, Default)]
22    pub struct SymmetricHashJoin<Lhs, Rhs, LhsState, RhsState, LhsStateInner, RhsStateInner> {
23        #[pin]
24        lhs: Lhs,
25        #[pin]
26        rhs: Rhs,
27
28        lhs_state: LhsState,
29        rhs_state: RhsState,
30
31        _phantom: PhantomData<(LhsStateInner, RhsStateInner)>,
32    }
33}
34
35impl<Lhs, Rhs, LhsState, RhsState, LhsStateInner, RhsStateInner>
36    SymmetricHashJoin<Lhs, Rhs, LhsState, RhsState, LhsStateInner, RhsStateInner>
37where
38    Self: Pull,
39{
40    /// Creates a new symmetric hash join Pull from two input Pulls and their join states.
41    pub(crate) const fn new(lhs: Lhs, rhs: Rhs, lhs_state: LhsState, rhs_state: RhsState) -> Self {
42        Self {
43            lhs,
44            rhs,
45            lhs_state,
46            rhs_state,
47            _phantom: PhantomData,
48        }
49    }
50}
51
52impl<Key, Lhs, V1, Rhs, V2, LhsState, RhsState, LhsStateInner, RhsStateInner> Pull
53    for SymmetricHashJoin<Lhs, Rhs, LhsState, RhsState, LhsStateInner, RhsStateInner>
54where
55    Key: Eq + std::hash::Hash + Clone,
56    V1: Clone,
57    V2: Clone,
58    Lhs: FusedPull<Item = (Key, V1), Meta = ()>,
59    Rhs: FusedPull<Item = (Key, V2), Meta = ()>,
60    LhsState: BorrowMut<LhsStateInner>,
61    RhsState: BorrowMut<RhsStateInner>,
62    LhsStateInner: HalfJoinState<Key, V1, V2>,
63    RhsStateInner: HalfJoinState<Key, V2, V1>,
64{
65    type Ctx<'ctx> = <Lhs::Ctx<'ctx> as Context<'ctx>>::Merged<Rhs::Ctx<'ctx>>;
66
67    type Item = (Key, (V1, V2));
68    type Meta = ();
69    type CanPend = <Lhs::CanPend as Toggle>::Or<Rhs::CanPend>;
70    type CanEnd = <Lhs::CanEnd as Toggle>::And<Rhs::CanEnd>;
71
72    fn pull(
73        self: Pin<&mut Self>,
74        ctx: &mut Self::Ctx<'_>,
75    ) -> PullStep<Self::Item, Self::Meta, Self::CanPend, Self::CanEnd> {
76        let mut this = self.project();
77        let lhs_state = this.lhs_state.borrow_mut();
78        let rhs_state = this.rhs_state.borrow_mut();
79
80        loop {
81            // First check for any pending matches from previous probes
82            if let Some((k, v2, v1)) = lhs_state.pop_match() {
83                return PullStep::Ready((k, (v1, v2)), ());
84            }
85            if let Some((k, v1, v2)) = rhs_state.pop_match() {
86                return PullStep::Ready((k, (v1, v2)), ());
87            }
88
89            // Try to pull from lhs
90            let lhs_step = this
91                .lhs
92                .as_mut()
93                .pull(<Lhs::Ctx<'_> as Context<'_>>::unmerge_self(ctx));
94            if let PullStep::Ready((k, v1), _meta) = lhs_step {
95                if lhs_state.build(k.clone(), Cow::Borrowed(&v1))
96                    && let Some((k, v1, v2)) = rhs_state.probe(&k, &v1)
97                {
98                    return PullStep::Ready((k, (v1, v2)), ());
99                }
100                continue;
101            }
102
103            // Try to pull from rhs
104            let rhs_step = this
105                .rhs
106                .as_mut()
107                .pull(<Lhs::Ctx<'_> as Context<'_>>::unmerge_other(ctx));
108            if let PullStep::Ready((k, v2), _meta) = rhs_step {
109                if rhs_state.build(k.clone(), Cow::Borrowed(&v2))
110                    && let Some((k, v2, v1)) = lhs_state.probe(&k, &v2)
111                {
112                    return PullStep::Ready((k, (v1, v2)), ());
113                }
114                continue;
115            }
116
117            if lhs_step.is_pending() || rhs_step.is_pending() {
118                return PullStep::pending();
119            }
120
121            // If we get here, both sides have ended.
122            debug_assert!(lhs_step.is_ended());
123            debug_assert!(rhs_step.is_ended());
124            return PullStep::ended();
125        }
126    }
127
128    fn size_hint(&self) -> (usize, Option<usize>) {
129        // TODO(mingwei): actual estimate
130        (0, None)
131    }
132}
133
134/// Iterator for new tick - iterates over all matches after both sides are drained.
135pub struct NewTickJoinIter<'a, Key, V1, V2, LhsState, RhsState> {
136    lhs_state: &'a LhsState,
137    rhs_state: &'a RhsState,
138    lhs_smaller: bool,
139    // State for iteration
140    outer_iter: Option<std::collections::hash_map::Iter<'a, Key, SmallVec<[V1; 1]>>>,
141    outer_iter_rhs: Option<std::collections::hash_map::Iter<'a, Key, SmallVec<[V2; 1]>>>,
142    current_key: Option<&'a Key>,
143    outer_val_iter: Option<std::slice::Iter<'a, V1>>,
144    outer_val_iter_rhs: Option<std::slice::Iter<'a, V2>>,
145    current_outer_val: Option<&'a V1>,
146    current_outer_val_rhs: Option<&'a V2>,
147    inner_val_iter: Option<std::slice::Iter<'a, V2>>,
148    inner_val_iter_rhs: Option<std::slice::Iter<'a, V1>>,
149}
150
151impl<'a, Key, V1, V2, LhsState, RhsState> NewTickJoinIter<'a, Key, V1, V2, LhsState, RhsState>
152where
153    Key: Eq + std::hash::Hash + Clone,
154    V1: Clone,
155    V2: Clone,
156    LhsState: HalfJoinState<Key, V1, V2>,
157    RhsState: HalfJoinState<Key, V2, V1>,
158{
159    fn new_lhs_smaller(lhs_state: &'a LhsState, rhs_state: &'a RhsState) -> Self {
160        Self {
161            lhs_state,
162            rhs_state,
163            lhs_smaller: true,
164            outer_iter: Some(lhs_state.iter()),
165            outer_iter_rhs: None,
166            current_key: None,
167            outer_val_iter: None,
168            outer_val_iter_rhs: None,
169            current_outer_val: None,
170            current_outer_val_rhs: None,
171            inner_val_iter: None,
172            inner_val_iter_rhs: None,
173        }
174    }
175
176    fn new_rhs_smaller(lhs_state: &'a LhsState, rhs_state: &'a RhsState) -> Self {
177        Self {
178            lhs_state,
179            rhs_state,
180            lhs_smaller: false,
181            outer_iter: None,
182            outer_iter_rhs: Some(rhs_state.iter()),
183            current_key: None,
184            outer_val_iter: None,
185            outer_val_iter_rhs: None,
186            current_outer_val: None,
187            current_outer_val_rhs: None,
188            inner_val_iter: None,
189            inner_val_iter_rhs: None,
190        }
191    }
192}
193
194impl<'a, Key, V1, V2, LhsState, RhsState> Iterator
195    for NewTickJoinIter<'a, Key, V1, V2, LhsState, RhsState>
196where
197    Key: Eq + std::hash::Hash + Clone,
198    V1: Clone,
199    V2: Clone,
200    LhsState: HalfJoinState<Key, V1, V2>,
201    RhsState: HalfJoinState<Key, V2, V1>,
202{
203    type Item = (Key, (V1, V2));
204
205    fn next(&mut self) -> Option<Self::Item> {
206        if self.lhs_smaller {
207            self.next_lhs_smaller()
208        } else {
209            self.next_rhs_smaller()
210        }
211    }
212
213    fn size_hint(&self) -> (usize, Option<usize>) {
214        // TODO(mingwei): proper size hint estimate
215        (0, None)
216    }
217}
218
219impl<'a, Key, V1, V2, LhsState, RhsState> NewTickJoinIter<'a, Key, V1, V2, LhsState, RhsState>
220where
221    Key: Eq + std::hash::Hash + Clone,
222    V1: Clone,
223    V2: Clone,
224    LhsState: HalfJoinState<Key, V1, V2>,
225    RhsState: HalfJoinState<Key, V2, V1>,
226{
227    fn next_lhs_smaller(&mut self) -> Option<(Key, (V1, V2))> {
228        loop {
229            // Try to get next v2 for current v1
230            if let Some(ref mut v2_iter) = self.inner_val_iter {
231                if let Some(v2) = v2_iter.next() {
232                    let key = self.current_key.unwrap();
233                    let v1 = self.current_outer_val.unwrap();
234                    return Some((key.clone(), (v1.clone(), v2.clone())));
235                }
236                self.inner_val_iter = None;
237            }
238
239            // Try to get next v1 for current key
240            if let Some(ref mut v1_iter) = self.outer_val_iter {
241                if let Some(v1) = v1_iter.next() {
242                    self.current_outer_val = Some(v1);
243                    let key = self.current_key.unwrap();
244                    self.inner_val_iter = Some(self.rhs_state.full_probe(key));
245                    continue;
246                }
247                self.outer_val_iter = None;
248                self.current_key = None;
249            }
250
251            // Try to get next key from lhs
252            if let Some(ref mut lhs_iter) = self.outer_iter {
253                if let Some((key, values)) = lhs_iter.next() {
254                    self.current_key = Some(key);
255                    self.outer_val_iter = Some(values.iter());
256                    continue;
257                }
258                self.outer_iter = None;
259            }
260
261            return None;
262        }
263    }
264
265    fn next_rhs_smaller(&mut self) -> Option<(Key, (V1, V2))> {
266        loop {
267            // Try to get next v1 for current v2
268            if let Some(ref mut v1_iter) = self.inner_val_iter_rhs {
269                if let Some(v1) = v1_iter.next() {
270                    let key = self.current_key.unwrap();
271                    let v2 = self.current_outer_val_rhs.unwrap();
272                    return Some((key.clone(), (v1.clone(), v2.clone())));
273                }
274                self.inner_val_iter_rhs = None;
275            }
276
277            // Try to get next v2 for current key
278            if let Some(ref mut v2_iter) = self.outer_val_iter_rhs {
279                if let Some(v2) = v2_iter.next() {
280                    self.current_outer_val_rhs = Some(v2);
281                    let key = self.current_key.unwrap();
282                    self.inner_val_iter_rhs = Some(self.lhs_state.full_probe(key));
283                    continue;
284                }
285                self.outer_val_iter_rhs = None;
286                self.current_key = None;
287            }
288
289            // Try to get next key from rhs
290            if let Some(ref mut rhs_iter) = self.outer_iter_rhs {
291                if let Some((key, values)) = rhs_iter.next() {
292                    self.current_key = Some(key);
293                    self.outer_val_iter_rhs = Some(values.iter());
294                    continue;
295                }
296                self.outer_iter_rhs = None;
297            }
298
299            return None;
300        }
301    }
302}
303
304/// Type alias for the `Either` pull returned by [`symmetric_hash_join`].
305pub type SymmetricHashJoinEither<'a, Key, V1, V2, Lhs, Rhs, LhsState, RhsState> = Either<
306    pull::Iter<NewTickJoinIter<'a, Key, V1, V2, LhsState, RhsState>>,
307    SymmetricHashJoin<Lhs, Rhs, &'a mut LhsState, &'a mut RhsState, LhsState, RhsState>,
308>;
309
310/// Creates a symmetric hash join Pull from two input Pulls and their join states.
311///
312/// For `is_new_tick = true`, this first drains both inputs into their respective states,
313/// then returns an iterator over all matches.
314///
315/// For `is_new_tick = false`, this returns a streaming join that processes items as they arrive.
316pub async fn symmetric_hash_join<'a, Key, Lhs, V1, Rhs, V2, LhsState, RhsState>(
317    lhs: Lhs,
318    rhs: Rhs,
319    lhs_state: &'a mut LhsState,
320    rhs_state: &'a mut RhsState,
321    is_new_tick: bool,
322) -> SymmetricHashJoinEither<'a, Key, V1, V2, Lhs, Rhs, LhsState, RhsState>
323where
324    Key: 'a + Eq + std::hash::Hash + Clone,
325    V1: 'a + Clone,
326    V2: 'a + Clone,
327    Lhs: 'a + FusedPull<Item = (Key, V1), Meta = ()>,
328    Rhs: 'a + FusedPull<Item = (Key, V2), Meta = ()>,
329    LhsState: HalfJoinState<Key, V1, V2>,
330    RhsState: HalfJoinState<Key, V2, V1>,
331{
332    if is_new_tick {
333        // Drain both inputs first
334        drain_pull_into_state(std::pin::pin!(lhs), lhs_state).await;
335        drain_pull_into_state(std::pin::pin!(rhs), rhs_state).await;
336
337        let iter = if lhs_state.len() < rhs_state.len() {
338            NewTickJoinIter::new_lhs_smaller(lhs_state, rhs_state)
339        } else {
340            NewTickJoinIter::new_rhs_smaller(lhs_state, rhs_state)
341        };
342        SymmetricHashJoinEither::Left(pull::iter(iter))
343    } else {
344        SymmetricHashJoinEither::Right(SymmetricHashJoin::new(lhs, rhs, lhs_state, rhs_state))
345    }
346}
347
348/// Helper to drain a Pull into state.
349fn drain_pull_into_state<Key, ValBuild, ValProbe, P, State>(
350    mut pull: Pin<&mut P>,
351    state: &mut State,
352) -> impl Future<Output = ()>
353where
354    Key: Eq + std::hash::Hash + Clone,
355    ValBuild: Clone,
356    P: Pull<Item = (Key, ValBuild)>,
357    State: HalfJoinState<Key, ValBuild, ValProbe>,
358{
359    std::future::poll_fn(move |ctx| {
360        let ctx = Context::from_task(ctx);
361        loop {
362            return match pull.as_mut().pull(ctx) {
363                PullStep::Ready((k, v), _meta) => {
364                    state.build(k, Cow::Owned(v));
365                    continue;
366                }
367                PullStep::Pending(_) => std::task::Poll::Pending,
368                PullStep::Ended(_) => std::task::Poll::Ready(()),
369            };
370        }
371    })
372}