1use std::borrow::{BorrowMut, Cow};
4use std::marker::PhantomData;
5use std::pin::Pin;
6
7use itertools::Either;
8use pin_project_lite::pin_project;
9use smallvec::SmallVec;
10
11use crate::pull::half_join_state::HalfJoinState;
12use crate::pull::{self, FusedPull, Pull, PullStep};
13use crate::{Context, Toggle};
14
15pin_project! {
16 #[must_use = "`Pull`s do nothing unless polled"]
21 #[derive(Clone, Debug, Default)]
22 pub struct SymmetricHashJoin<Lhs, Rhs, LhsState, RhsState, LhsStateInner, RhsStateInner> {
23 #[pin]
24 lhs: Lhs,
25 #[pin]
26 rhs: Rhs,
27
28 lhs_state: LhsState,
29 rhs_state: RhsState,
30
31 _phantom: PhantomData<(LhsStateInner, RhsStateInner)>,
32 }
33}
34
35impl<Lhs, Rhs, LhsState, RhsState, LhsStateInner, RhsStateInner>
36 SymmetricHashJoin<Lhs, Rhs, LhsState, RhsState, LhsStateInner, RhsStateInner>
37where
38 Self: Pull,
39{
40 pub(crate) const fn new(lhs: Lhs, rhs: Rhs, lhs_state: LhsState, rhs_state: RhsState) -> Self {
42 Self {
43 lhs,
44 rhs,
45 lhs_state,
46 rhs_state,
47 _phantom: PhantomData,
48 }
49 }
50}
51
52impl<Key, Lhs, V1, Rhs, V2, LhsState, RhsState, LhsStateInner, RhsStateInner> Pull
53 for SymmetricHashJoin<Lhs, Rhs, LhsState, RhsState, LhsStateInner, RhsStateInner>
54where
55 Key: Eq + std::hash::Hash + Clone,
56 V1: Clone,
57 V2: Clone,
58 Lhs: FusedPull<Item = (Key, V1), Meta = ()>,
59 Rhs: FusedPull<Item = (Key, V2), Meta = ()>,
60 LhsState: BorrowMut<LhsStateInner>,
61 RhsState: BorrowMut<RhsStateInner>,
62 LhsStateInner: HalfJoinState<Key, V1, V2>,
63 RhsStateInner: HalfJoinState<Key, V2, V1>,
64{
65 type Ctx<'ctx> = <Lhs::Ctx<'ctx> as Context<'ctx>>::Merged<Rhs::Ctx<'ctx>>;
66
67 type Item = (Key, (V1, V2));
68 type Meta = ();
69 type CanPend = <Lhs::CanPend as Toggle>::Or<Rhs::CanPend>;
70 type CanEnd = <Lhs::CanEnd as Toggle>::And<Rhs::CanEnd>;
71
72 fn pull(
73 self: Pin<&mut Self>,
74 ctx: &mut Self::Ctx<'_>,
75 ) -> PullStep<Self::Item, Self::Meta, Self::CanPend, Self::CanEnd> {
76 let mut this = self.project();
77 let lhs_state = this.lhs_state.borrow_mut();
78 let rhs_state = this.rhs_state.borrow_mut();
79
80 loop {
81 if let Some((k, v2, v1)) = lhs_state.pop_match() {
83 return PullStep::Ready((k, (v1, v2)), ());
84 }
85 if let Some((k, v1, v2)) = rhs_state.pop_match() {
86 return PullStep::Ready((k, (v1, v2)), ());
87 }
88
89 let lhs_step = this
91 .lhs
92 .as_mut()
93 .pull(<Lhs::Ctx<'_> as Context<'_>>::unmerge_self(ctx));
94 if let PullStep::Ready((k, v1), _meta) = lhs_step {
95 if lhs_state.build(k.clone(), Cow::Borrowed(&v1))
96 && let Some((k, v1, v2)) = rhs_state.probe(&k, &v1)
97 {
98 return PullStep::Ready((k, (v1, v2)), ());
99 }
100 continue;
101 }
102
103 let rhs_step = this
105 .rhs
106 .as_mut()
107 .pull(<Lhs::Ctx<'_> as Context<'_>>::unmerge_other(ctx));
108 if let PullStep::Ready((k, v2), _meta) = rhs_step {
109 if rhs_state.build(k.clone(), Cow::Borrowed(&v2))
110 && let Some((k, v2, v1)) = lhs_state.probe(&k, &v2)
111 {
112 return PullStep::Ready((k, (v1, v2)), ());
113 }
114 continue;
115 }
116
117 if lhs_step.is_pending() || rhs_step.is_pending() {
118 return PullStep::pending();
119 }
120
121 debug_assert!(lhs_step.is_ended());
123 debug_assert!(rhs_step.is_ended());
124 return PullStep::ended();
125 }
126 }
127
128 fn size_hint(&self) -> (usize, Option<usize>) {
129 (0, None)
131 }
132}
133
134pub struct NewTickJoinIter<'a, Key, V1, V2, LhsState, RhsState> {
136 lhs_state: &'a LhsState,
137 rhs_state: &'a RhsState,
138 lhs_smaller: bool,
139 outer_iter: Option<std::collections::hash_map::Iter<'a, Key, SmallVec<[V1; 1]>>>,
141 outer_iter_rhs: Option<std::collections::hash_map::Iter<'a, Key, SmallVec<[V2; 1]>>>,
142 current_key: Option<&'a Key>,
143 outer_val_iter: Option<std::slice::Iter<'a, V1>>,
144 outer_val_iter_rhs: Option<std::slice::Iter<'a, V2>>,
145 current_outer_val: Option<&'a V1>,
146 current_outer_val_rhs: Option<&'a V2>,
147 inner_val_iter: Option<std::slice::Iter<'a, V2>>,
148 inner_val_iter_rhs: Option<std::slice::Iter<'a, V1>>,
149}
150
151impl<'a, Key, V1, V2, LhsState, RhsState> NewTickJoinIter<'a, Key, V1, V2, LhsState, RhsState>
152where
153 Key: Eq + std::hash::Hash + Clone,
154 V1: Clone,
155 V2: Clone,
156 LhsState: HalfJoinState<Key, V1, V2>,
157 RhsState: HalfJoinState<Key, V2, V1>,
158{
159 fn new_lhs_smaller(lhs_state: &'a LhsState, rhs_state: &'a RhsState) -> Self {
160 Self {
161 lhs_state,
162 rhs_state,
163 lhs_smaller: true,
164 outer_iter: Some(lhs_state.iter()),
165 outer_iter_rhs: None,
166 current_key: None,
167 outer_val_iter: None,
168 outer_val_iter_rhs: None,
169 current_outer_val: None,
170 current_outer_val_rhs: None,
171 inner_val_iter: None,
172 inner_val_iter_rhs: None,
173 }
174 }
175
176 fn new_rhs_smaller(lhs_state: &'a LhsState, rhs_state: &'a RhsState) -> Self {
177 Self {
178 lhs_state,
179 rhs_state,
180 lhs_smaller: false,
181 outer_iter: None,
182 outer_iter_rhs: Some(rhs_state.iter()),
183 current_key: None,
184 outer_val_iter: None,
185 outer_val_iter_rhs: None,
186 current_outer_val: None,
187 current_outer_val_rhs: None,
188 inner_val_iter: None,
189 inner_val_iter_rhs: None,
190 }
191 }
192}
193
194impl<'a, Key, V1, V2, LhsState, RhsState> Iterator
195 for NewTickJoinIter<'a, Key, V1, V2, LhsState, RhsState>
196where
197 Key: Eq + std::hash::Hash + Clone,
198 V1: Clone,
199 V2: Clone,
200 LhsState: HalfJoinState<Key, V1, V2>,
201 RhsState: HalfJoinState<Key, V2, V1>,
202{
203 type Item = (Key, (V1, V2));
204
205 fn next(&mut self) -> Option<Self::Item> {
206 if self.lhs_smaller {
207 self.next_lhs_smaller()
208 } else {
209 self.next_rhs_smaller()
210 }
211 }
212
213 fn size_hint(&self) -> (usize, Option<usize>) {
214 (0, None)
216 }
217}
218
219impl<'a, Key, V1, V2, LhsState, RhsState> NewTickJoinIter<'a, Key, V1, V2, LhsState, RhsState>
220where
221 Key: Eq + std::hash::Hash + Clone,
222 V1: Clone,
223 V2: Clone,
224 LhsState: HalfJoinState<Key, V1, V2>,
225 RhsState: HalfJoinState<Key, V2, V1>,
226{
227 fn next_lhs_smaller(&mut self) -> Option<(Key, (V1, V2))> {
228 loop {
229 if let Some(ref mut v2_iter) = self.inner_val_iter {
231 if let Some(v2) = v2_iter.next() {
232 let key = self.current_key.unwrap();
233 let v1 = self.current_outer_val.unwrap();
234 return Some((key.clone(), (v1.clone(), v2.clone())));
235 }
236 self.inner_val_iter = None;
237 }
238
239 if let Some(ref mut v1_iter) = self.outer_val_iter {
241 if let Some(v1) = v1_iter.next() {
242 self.current_outer_val = Some(v1);
243 let key = self.current_key.unwrap();
244 self.inner_val_iter = Some(self.rhs_state.full_probe(key));
245 continue;
246 }
247 self.outer_val_iter = None;
248 self.current_key = None;
249 }
250
251 if let Some(ref mut lhs_iter) = self.outer_iter {
253 if let Some((key, values)) = lhs_iter.next() {
254 self.current_key = Some(key);
255 self.outer_val_iter = Some(values.iter());
256 continue;
257 }
258 self.outer_iter = None;
259 }
260
261 return None;
262 }
263 }
264
265 fn next_rhs_smaller(&mut self) -> Option<(Key, (V1, V2))> {
266 loop {
267 if let Some(ref mut v1_iter) = self.inner_val_iter_rhs {
269 if let Some(v1) = v1_iter.next() {
270 let key = self.current_key.unwrap();
271 let v2 = self.current_outer_val_rhs.unwrap();
272 return Some((key.clone(), (v1.clone(), v2.clone())));
273 }
274 self.inner_val_iter_rhs = None;
275 }
276
277 if let Some(ref mut v2_iter) = self.outer_val_iter_rhs {
279 if let Some(v2) = v2_iter.next() {
280 self.current_outer_val_rhs = Some(v2);
281 let key = self.current_key.unwrap();
282 self.inner_val_iter_rhs = Some(self.lhs_state.full_probe(key));
283 continue;
284 }
285 self.outer_val_iter_rhs = None;
286 self.current_key = None;
287 }
288
289 if let Some(ref mut rhs_iter) = self.outer_iter_rhs {
291 if let Some((key, values)) = rhs_iter.next() {
292 self.current_key = Some(key);
293 self.outer_val_iter_rhs = Some(values.iter());
294 continue;
295 }
296 self.outer_iter_rhs = None;
297 }
298
299 return None;
300 }
301 }
302}
303
304pub type SymmetricHashJoinEither<'a, Key, V1, V2, Lhs, Rhs, LhsState, RhsState> = Either<
306 pull::Iter<NewTickJoinIter<'a, Key, V1, V2, LhsState, RhsState>>,
307 SymmetricHashJoin<Lhs, Rhs, &'a mut LhsState, &'a mut RhsState, LhsState, RhsState>,
308>;
309
310pub async fn symmetric_hash_join<'a, Key, Lhs, V1, Rhs, V2, LhsState, RhsState>(
317 lhs: Lhs,
318 rhs: Rhs,
319 lhs_state: &'a mut LhsState,
320 rhs_state: &'a mut RhsState,
321 is_new_tick: bool,
322) -> SymmetricHashJoinEither<'a, Key, V1, V2, Lhs, Rhs, LhsState, RhsState>
323where
324 Key: 'a + Eq + std::hash::Hash + Clone,
325 V1: 'a + Clone,
326 V2: 'a + Clone,
327 Lhs: 'a + FusedPull<Item = (Key, V1), Meta = ()>,
328 Rhs: 'a + FusedPull<Item = (Key, V2), Meta = ()>,
329 LhsState: HalfJoinState<Key, V1, V2>,
330 RhsState: HalfJoinState<Key, V2, V1>,
331{
332 if is_new_tick {
333 drain_pull_into_state(std::pin::pin!(lhs), lhs_state).await;
335 drain_pull_into_state(std::pin::pin!(rhs), rhs_state).await;
336
337 let iter = if lhs_state.len() < rhs_state.len() {
338 NewTickJoinIter::new_lhs_smaller(lhs_state, rhs_state)
339 } else {
340 NewTickJoinIter::new_rhs_smaller(lhs_state, rhs_state)
341 };
342 SymmetricHashJoinEither::Left(pull::iter(iter))
343 } else {
344 SymmetricHashJoinEither::Right(SymmetricHashJoin::new(lhs, rhs, lhs_state, rhs_state))
345 }
346}
347
348fn drain_pull_into_state<Key, ValBuild, ValProbe, P, State>(
350 mut pull: Pin<&mut P>,
351 state: &mut State,
352) -> impl Future<Output = ()>
353where
354 Key: Eq + std::hash::Hash + Clone,
355 ValBuild: Clone,
356 P: Pull<Item = (Key, ValBuild)>,
357 State: HalfJoinState<Key, ValBuild, ValProbe>,
358{
359 std::future::poll_fn(move |ctx| {
360 let ctx = Context::from_task(ctx);
361 loop {
362 return match pull.as_mut().pull(ctx) {
363 PullStep::Ready((k, v), _meta) => {
364 state.build(k, Cow::Owned(v));
365 continue;
366 }
367 PullStep::Pending(_) => std::task::Poll::Pending,
368 PullStep::Ended(_) => std::task::Poll::Ready(()),
369 };
370 }
371 })
372}