Skip to content

Commit dfcd792

Browse files
committed
Add a transducer for Flatten
1 parent 3a266b3 commit dfcd792

2 files changed

Lines changed: 25 additions & 16 deletions

File tree

base/reduce.jl

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,29 +36,28 @@ mul_prod(x::Real, y::Real)::Real = x * y
3636

3737
## foldl && mapfoldl
3838

39-
mapfoldl_impl(f, op, nt, itr) = foldl_impl(op, nt, Generator(f, itr))
39+
function mapfoldl_impl(f, op, nt, itr)
40+
op′, itr′ = _xfadjoint(BottomRF(op), Generator(f, itr))
41+
return foldl_impl(op′, nt, itr′)
42+
end
4043

4144
function foldl_impl(op, nt, itr)
42-
op′, itr′ = _xfadjoint(BottomRF(op), itr)
43-
return _foldl_impl(op′, nt, itr′)
45+
v = _foldl_impl(op, get(nt, :init, _InitialValue()), itr)
46+
v isa _InitialValue && return reduce_empty_iter(op, itr)
47+
return v
4448
end
4549

46-
function _foldl_impl(op, nt, itr)
47-
init = get(nt, :init, _InitialValue())
50+
function _foldl_impl(op, init, itr)
4851
# Unroll the while loop once; if init is known, the call to op may
4952
# be evaluated at compile time
5053
y = iterate(itr)
51-
if y === nothing
52-
init isa _InitialValue && return reduce_empty_iter(op, itr)
53-
return init
54-
end
54+
y === nothing && return init
5555
v = op(init, y[1])
5656
while true
5757
y = iterate(itr, y[2])
5858
y === nothing && break
5959
v = op(v, y[1])
6060
end
61-
v isa _InitialValue && return reduce_empty_iter(op, itr)
6261
return v
6362
end
6463

@@ -102,6 +101,18 @@ end
102101

103102
@inline (op::FilteringRF)(acc, x) = op.f(x) ? op.rf(acc, x) : acc
104103

104+
"""
105+
FlatteningRF(rf) -> rf′
106+
107+
Create a flattening reducing function that is roughly equivalent to
108+
`rf′(acc, x) = foldl(rf, x; init=acc)`.
109+
"""
110+
struct FlatteningRF{T}
111+
rf::T
112+
end
113+
114+
@inline (op::FlatteningRF)(acc, x) = _foldl_impl(op.rf, acc, x)
115+
105116
"""
106117
_xfadjoint(op, itr) -> op′, itr′
107118
@@ -130,6 +141,8 @@ _xfadjoint(op, itr::Generator) =
130141
end
131142
_xfadjoint(op, itr::Filter) =
132143
_xfadjoint(FilteringRF(itr.flt, op), itr.itr)
144+
_xfadjoint(op, itr::Flatten) =
145+
_xfadjoint(FlatteningRF(op), itr.it)
133146

134147
"""
135148
mapfoldl(f, op, itr; [init])
@@ -162,7 +175,7 @@ foldl(op, itr; kw...) = mapfoldl(identity, op, itr; kw...)
162175

163176
function mapfoldr_impl(f, op, nt, itr)
164177
op′, itr′ = _xfadjoint(BottomRF(FlipArgs(op)), Generator(f, itr))
165-
return _foldl_impl(op′, nt, Iterators.reverse(itr′))
178+
return foldl_impl(op′, nt, Iterators.reverse(itr′))
166179
end
167180

168181
struct FlipArgs{F}

base/tuple.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,7 @@ function map(f, t1::Any16, t2::Any16, ts::Any16...)
204204
(A...,)
205205
end
206206

207-
function _foldl_impl(op, nt, itr::Tuple)
208-
init = get(nt, :init, _InitialValue())
209-
y = afoldl(op, init, itr...)
210-
return y isa _InitialValue ? reduce_empty_iter(op, itr) : y
211-
end
207+
_foldl_impl(op, init, itr::Tuple) = afoldl(op, init, itr...)
212208

213209
# type-stable padding
214210
fill_to_length(t::NTuple{N,Any}, val, ::Val{N}) where {N} = t

0 commit comments

Comments
 (0)