@@ -36,29 +36,28 @@ mul_prod(x::Real, y::Real)::Real = x * y
3636
3737# # foldl && mapfoldl
3838
39- mapfoldl_impl (f, op, nt, itr) = foldl_impl (op, nt, Generator (f, itr))
39+ function mapfoldl_impl (f, op, nt, itr)
40+ op′, itr′ = _xfadjoint (BottomRF (op), Generator (f, itr))
41+ return foldl_impl (op′, nt, itr′)
42+ end
4043
4144function foldl_impl (op, nt, itr)
42- op′, itr′ = _xfadjoint (BottomRF (op), itr)
43- return _foldl_impl (op′, nt, itr′)
45+ v = _foldl_impl (op, get (nt, :init , _InitialValue ()), itr)
46+ v isa _InitialValue && return reduce_empty_iter (op, itr)
47+ return v
4448end
4549
46- function _foldl_impl (op, nt, itr)
47- init = get (nt, :init , _InitialValue ())
50+ function _foldl_impl (op, init, itr)
4851 # Unroll the while loop once; if init is known, the call to op may
4952 # be evaluated at compile time
5053 y = iterate (itr)
51- if y === nothing
52- init isa _InitialValue && return reduce_empty_iter (op, itr)
53- return init
54- end
54+ y === nothing && return init
5555 v = op (init, y[1 ])
5656 while true
5757 y = iterate (itr, y[2 ])
5858 y === nothing && break
5959 v = op (v, y[1 ])
6060 end
61- v isa _InitialValue && return reduce_empty_iter (op, itr)
6261 return v
6362end
6463
102101
103102@inline (op:: FilteringRF )(acc, x) = op. f (x) ? op. rf (acc, x) : acc
104103
104+ """
105+ FlatteningRF(rf) -> rf′
106+
107+ Create a flattening reducing function that is roughly equivalent to
108+ `rf′(acc, x) = foldl(rf, x; init=acc)`.
109+ """
110+ struct FlatteningRF{T}
111+ rf:: T
112+ end
113+
114+ @inline (op:: FlatteningRF )(acc, x) = _foldl_impl (op. rf, acc, x)
115+
105116"""
106117 _xfadjoint(op, itr) -> op′, itr′
107118
@@ -130,6 +141,8 @@ _xfadjoint(op, itr::Generator) =
130141 end
131142_xfadjoint (op, itr:: Filter ) =
132143 _xfadjoint (FilteringRF (itr. flt, op), itr. itr)
144+ _xfadjoint (op, itr:: Flatten ) =
145+ _xfadjoint (FlatteningRF (op), itr. it)
133146
134147"""
135148 mapfoldl(f, op, itr; [init])
@@ -162,7 +175,7 @@ foldl(op, itr; kw...) = mapfoldl(identity, op, itr; kw...)
162175
163176function mapfoldr_impl (f, op, nt, itr)
164177 op′, itr′ = _xfadjoint (BottomRF (FlipArgs (op)), Generator (f, itr))
165- return _foldl_impl (op′, nt, Iterators. reverse (itr′))
178+ return foldl_impl (op′, nt, Iterators. reverse (itr′))
166179end
167180
168181struct FlipArgs{F}
0 commit comments