Skip to main content

dfir_pipes/pull/
accumulator.rs

1//! Accumulator trait and implementors.
2
3use core::future::Future;
4use core::pin::Pin;
5use core::task::Poll;
6use std::collections::hash_map::Entry;
7use std::hash::{BuildHasher, Hash};
8
9use pin_project_lite::pin_project;
10
11use crate::Context;
12use crate::pull::{Pull, PullStep};
13
14/// Generalization of fold, reduce, etc.
15pub trait Accumulator<ValAccum, ValIn> {
16    /// Accumulates a value into an either occupied or vacant table entry.
17    fn accumulate<Key>(&mut self, entry: Entry<'_, Key, ValAccum>, item: ValIn);
18}
19
20/// Fold with an initialization and fold function.
21#[derive(Clone, Debug)]
22pub struct Fold<InitFn, FoldFn> {
23    init_fn: InitFn,
24    fold_fn: FoldFn,
25}
26
27impl<InitFn, FoldFn> Fold<InitFn, FoldFn> {
28    /// Create a `Fold` [`Accumulator`] with the given `InitFn` and `FoldFn`.
29    pub const fn new<Accum, Item>(init_fn: InitFn, fold_fn: FoldFn) -> Self
30    where
31        Self: Accumulator<Accum, Item>,
32    {
33        Self { init_fn, fold_fn }
34    }
35}
36
37impl<InitFn, FoldFn, Accum, Item> Accumulator<Accum, Item> for Fold<InitFn, FoldFn>
38where
39    InitFn: Fn() -> Accum,
40    FoldFn: Fn(&mut Accum, Item),
41{
42    fn accumulate<Key>(&mut self, entry: Entry<'_, Key, Accum>, item: Item) {
43        let prev_item = entry.or_insert_with(|| (self.init_fn)());
44        let () = (self.fold_fn)(prev_item, item);
45    }
46}
47
48/// Reduce with a reduce function.
49#[derive(Clone, Debug)]
50pub struct Reduce<ReduceFn> {
51    reduce_fn: ReduceFn,
52}
53
54impl<ReduceFn> Reduce<ReduceFn> {
55    /// Create a `Reduce` [`Accumulator`] with the given `ReduceFn`.
56    pub const fn new<Item>(reduce_fn: ReduceFn) -> Self
57    where
58        Self: Accumulator<Item, Item>,
59    {
60        Self { reduce_fn }
61    }
62}
63
64impl<ReduceFn, Item> Accumulator<Item, Item> for Reduce<ReduceFn>
65where
66    ReduceFn: Fn(&mut Item, Item),
67{
68    fn accumulate<Key>(&mut self, entry: Entry<'_, Key, Item>, item: Item) {
69        match entry {
70            Entry::Vacant(entry) => {
71                entry.insert(item);
72            }
73            Entry::Occupied(mut entry) => {
74                let prev_item = entry.get_mut();
75                let () = (self.reduce_fn)(prev_item, item);
76            }
77        }
78    }
79}
80
81/// Fold but with initialization by converting the first received item.
82#[derive(Clone, Debug)]
83pub struct FoldFrom<InitFn, FoldFn> {
84    init_fn: InitFn,
85    fold_fn: FoldFn,
86}
87
88impl<InitFn, FoldFn> FoldFrom<InitFn, FoldFn> {
89    /// Create a `FoldFrom` [`Accumulator`] with the given `InitFn` and `FoldFn`.
90    pub const fn new<Accum, Item>(init_fn: InitFn, fold_fn: FoldFn) -> Self
91    where
92        Self: Accumulator<Accum, Item>,
93    {
94        Self { init_fn, fold_fn }
95    }
96}
97
98impl<InitFn, FoldFn, Accum, Item> Accumulator<Accum, Item> for FoldFrom<InitFn, FoldFn>
99where
100    InitFn: Fn(Item) -> Accum,
101    FoldFn: Fn(&mut Accum, Item),
102{
103    fn accumulate<Key>(&mut self, entry: Entry<'_, Key, Accum>, item: Item) {
104        match entry {
105            Entry::Vacant(entry) => {
106                entry.insert((self.init_fn)(item));
107            }
108            Entry::Occupied(mut entry) => {
109                let prev_item = entry.get_mut();
110                let () = (self.fold_fn)(prev_item, item);
111            }
112        }
113    }
114}
115
116pin_project! {
117    /// Future for [`accumulate_all`].
118    #[must_use = "futures do nothing unless polled"]
119        pub struct AccumulateAll<'a, Prev, Accum, Key, ValAccum, ValIn, S> {
120        #[pin]
121        prev: Prev,
122        accum: &'a mut Accum,
123        hash_map: &'a mut std::collections::HashMap<Key, ValAccum, S>,
124        _marker: core::marker::PhantomData<ValIn>,
125    }
126}
127
128impl<'a, Prev, Accum, Key, ValAccum, ValIn, S>
129    AccumulateAll<'a, Prev, Accum, Key, ValAccum, ValIn, S>
130where
131    Self: Future,
132{
133    pub(crate) const fn new(
134        prev: Prev,
135        accum: &'a mut Accum,
136        hash_map: &'a mut std::collections::HashMap<Key, ValAccum, S>,
137    ) -> Self {
138        Self {
139            prev,
140            accum,
141            hash_map,
142            _marker: core::marker::PhantomData,
143        }
144    }
145}
146
147impl<'a, Prev, Accum, Key, ValAccum, ValIn, S> Future
148    for AccumulateAll<'a, Prev, Accum, Key, ValAccum, ValIn, S>
149where
150    Prev: Pull<Item = (Key, ValIn)>,
151    Accum: Accumulator<ValAccum, ValIn>,
152    Key: Eq + Hash,
153    S: BuildHasher,
154    for<'ctx> Prev::Ctx<'ctx>: Context<'ctx>,
155{
156    type Output = ();
157
158    fn poll(self: Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
159        let mut this = self.project();
160        let ctx = <Prev::Ctx<'_> as Context<'_>>::from_task(cx);
161        loop {
162            return match this.prev.as_mut().pull(ctx) {
163                PullStep::Ready((key, item), _meta) => {
164                    this.accum.accumulate(this.hash_map.entry(key), item);
165                    continue;
166                }
167                PullStep::Pending(_) => Poll::Pending,
168                PullStep::Ended(_) => Poll::Ready(()),
169            };
170        }
171    }
172}
173
174/// Use the accumulator `accum` to accumulate all entries in the `Pull` `prev` into the `hash_map`.
175pub const fn accumulate_all<'a, Key, ValAccum, ValIn, Accum, S, Prev>(
176    accum: &'a mut Accum,
177    hash_map: &'a mut std::collections::HashMap<Key, ValAccum, S>,
178    prev: Prev,
179) -> AccumulateAll<'a, Prev, Accum, Key, ValAccum, ValIn, S>
180where
181    Key: Eq + Hash,
182    Accum: Accumulator<ValAccum, ValIn>,
183    Prev: Pull<Item = (Key, ValIn)>,
184    S: BuildHasher,
185{
186    AccumulateAll::new(prev, accum, hash_map)
187}