1use core::future::Future;
4use core::pin::Pin;
5use core::task::Poll;
6use std::collections::hash_map::Entry;
7use std::hash::{BuildHasher, Hash};
8
9use pin_project_lite::pin_project;
10
11use crate::Context;
12use crate::pull::{Pull, PullStep};
13
14pub trait Accumulator<ValAccum, ValIn> {
16 fn accumulate<Key>(&mut self, entry: Entry<'_, Key, ValAccum>, item: ValIn);
18}
19
20#[derive(Clone, Debug)]
22pub struct Fold<InitFn, FoldFn> {
23 init_fn: InitFn,
24 fold_fn: FoldFn,
25}
26
27impl<InitFn, FoldFn> Fold<InitFn, FoldFn> {
28 pub const fn new<Accum, Item>(init_fn: InitFn, fold_fn: FoldFn) -> Self
30 where
31 Self: Accumulator<Accum, Item>,
32 {
33 Self { init_fn, fold_fn }
34 }
35}
36
37impl<InitFn, FoldFn, Accum, Item> Accumulator<Accum, Item> for Fold<InitFn, FoldFn>
38where
39 InitFn: Fn() -> Accum,
40 FoldFn: Fn(&mut Accum, Item),
41{
42 fn accumulate<Key>(&mut self, entry: Entry<'_, Key, Accum>, item: Item) {
43 let prev_item = entry.or_insert_with(|| (self.init_fn)());
44 let () = (self.fold_fn)(prev_item, item);
45 }
46}
47
48#[derive(Clone, Debug)]
50pub struct Reduce<ReduceFn> {
51 reduce_fn: ReduceFn,
52}
53
54impl<ReduceFn> Reduce<ReduceFn> {
55 pub const fn new<Item>(reduce_fn: ReduceFn) -> Self
57 where
58 Self: Accumulator<Item, Item>,
59 {
60 Self { reduce_fn }
61 }
62}
63
64impl<ReduceFn, Item> Accumulator<Item, Item> for Reduce<ReduceFn>
65where
66 ReduceFn: Fn(&mut Item, Item),
67{
68 fn accumulate<Key>(&mut self, entry: Entry<'_, Key, Item>, item: Item) {
69 match entry {
70 Entry::Vacant(entry) => {
71 entry.insert(item);
72 }
73 Entry::Occupied(mut entry) => {
74 let prev_item = entry.get_mut();
75 let () = (self.reduce_fn)(prev_item, item);
76 }
77 }
78 }
79}
80
81#[derive(Clone, Debug)]
83pub struct FoldFrom<InitFn, FoldFn> {
84 init_fn: InitFn,
85 fold_fn: FoldFn,
86}
87
88impl<InitFn, FoldFn> FoldFrom<InitFn, FoldFn> {
89 pub const fn new<Accum, Item>(init_fn: InitFn, fold_fn: FoldFn) -> Self
91 where
92 Self: Accumulator<Accum, Item>,
93 {
94 Self { init_fn, fold_fn }
95 }
96}
97
98impl<InitFn, FoldFn, Accum, Item> Accumulator<Accum, Item> for FoldFrom<InitFn, FoldFn>
99where
100 InitFn: Fn(Item) -> Accum,
101 FoldFn: Fn(&mut Accum, Item),
102{
103 fn accumulate<Key>(&mut self, entry: Entry<'_, Key, Accum>, item: Item) {
104 match entry {
105 Entry::Vacant(entry) => {
106 entry.insert((self.init_fn)(item));
107 }
108 Entry::Occupied(mut entry) => {
109 let prev_item = entry.get_mut();
110 let () = (self.fold_fn)(prev_item, item);
111 }
112 }
113 }
114}
115
116pin_project! {
117 #[must_use = "futures do nothing unless polled"]
119 pub struct AccumulateAll<'a, Prev, Accum, Key, ValAccum, ValIn, S> {
120 #[pin]
121 prev: Prev,
122 accum: &'a mut Accum,
123 hash_map: &'a mut std::collections::HashMap<Key, ValAccum, S>,
124 _marker: core::marker::PhantomData<ValIn>,
125 }
126}
127
128impl<'a, Prev, Accum, Key, ValAccum, ValIn, S>
129 AccumulateAll<'a, Prev, Accum, Key, ValAccum, ValIn, S>
130where
131 Self: Future,
132{
133 pub(crate) const fn new(
134 prev: Prev,
135 accum: &'a mut Accum,
136 hash_map: &'a mut std::collections::HashMap<Key, ValAccum, S>,
137 ) -> Self {
138 Self {
139 prev,
140 accum,
141 hash_map,
142 _marker: core::marker::PhantomData,
143 }
144 }
145}
146
147impl<'a, Prev, Accum, Key, ValAccum, ValIn, S> Future
148 for AccumulateAll<'a, Prev, Accum, Key, ValAccum, ValIn, S>
149where
150 Prev: Pull<Item = (Key, ValIn)>,
151 Accum: Accumulator<ValAccum, ValIn>,
152 Key: Eq + Hash,
153 S: BuildHasher,
154 for<'ctx> Prev::Ctx<'ctx>: Context<'ctx>,
155{
156 type Output = ();
157
158 fn poll(self: Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
159 let mut this = self.project();
160 let ctx = <Prev::Ctx<'_> as Context<'_>>::from_task(cx);
161 loop {
162 return match this.prev.as_mut().pull(ctx) {
163 PullStep::Ready((key, item), _meta) => {
164 this.accum.accumulate(this.hash_map.entry(key), item);
165 continue;
166 }
167 PullStep::Pending(_) => Poll::Pending,
168 PullStep::Ended(_) => Poll::Ready(()),
169 };
170 }
171 }
172}
173
174pub const fn accumulate_all<'a, Key, ValAccum, ValIn, Accum, S, Prev>(
176 accum: &'a mut Accum,
177 hash_map: &'a mut std::collections::HashMap<Key, ValAccum, S>,
178 prev: Prev,
179) -> AccumulateAll<'a, Prev, Accum, Key, ValAccum, ValIn, S>
180where
181 Key: Eq + Hash,
182 Accum: Accumulator<ValAccum, ValIn>,
183 Prev: Pull<Item = (Key, ValIn)>,
184 S: BuildHasher,
185{
186 AccumulateAll::new(prev, accum, hash_map)
187}