Skip to content

Commit ff88fa4

Browse files
authored
inference: refine PartialStruct lattice tmerge (#44404)
* inference: fix tmerge lattice over issimpleenoughtype Previously we assumed only union type could have complexity that violated the tmerge lattice requirements, but other types can have that too. This lets us fix an issue with the PartialStruct comparison failing for undefined fields, mentioned in #43784. * inference: refine PartialStruct lattice tmerge Be more aggressive about merging fields to greatly accelerate convergence, but also compute anyrefine more correctly as we do now elsewhere (since #42831, a121721) Move the tmeet algorithm, without changes, since it is a precise lattice operation, not a heuristic limit like tmerge. Close #43784
1 parent ceec252 commit ff88fa4

3 files changed

Lines changed: 189 additions & 57 deletions

File tree

base/compiler/typelattice.jl

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ The non-strict partial order over the type inference lattice.
200200
end
201201
for i in 1:nfields(a.val)
202202
# XXX: let's handle varargs later
203-
isdefined(a.val, i) || return false
203+
isdefined(a.val, i) || continue # since ∀ T Union{} ⊑ T
204204
(Const(getfield(a.val, i)), b.fields[i]) || return false
205205
end
206206
return true
@@ -289,6 +289,48 @@ function is_lattice_equal(@nospecialize(a), @nospecialize(b))
289289
return a b && b a
290290
end
291291

292+
# compute typeintersect over the extended inference lattice,
293+
# as precisely as we can,
294+
# where v is in the extended lattice, and t is a Type.
295+
function tmeet(@nospecialize(v), @nospecialize(t))
296+
if isa(v, Const)
297+
if !has_free_typevars(t) && !isa(v.val, t)
298+
return Bottom
299+
end
300+
return v
301+
elseif isa(v, PartialStruct)
302+
has_free_typevars(t) && return v
303+
widev = widenconst(v)
304+
if widev <: t
305+
return v
306+
end
307+
ti = typeintersect(widev, t)
308+
valid_as_lattice(ti) || return Bottom
309+
@assert widev <: Tuple
310+
new_fields = Vector{Any}(undef, length(v.fields))
311+
for i = 1:length(new_fields)
312+
vfi = v.fields[i]
313+
if isvarargtype(vfi)
314+
new_fields[i] = vfi
315+
else
316+
new_fields[i] = tmeet(vfi, widenconst(getfield_tfunc(t, Const(i))))
317+
if new_fields[i] === Bottom
318+
return Bottom
319+
end
320+
end
321+
end
322+
return tuple_tfunc(new_fields)
323+
elseif isa(v, Conditional)
324+
if !(Bool <: t)
325+
return Bottom
326+
end
327+
return v
328+
end
329+
ti = typeintersect(widenconst(v), t)
330+
valid_as_lattice(ti) || return Bottom
331+
return ti
332+
end
333+
292334
widenconst(c::AnyConditional) = Bool
293335
widenconst((; val)::Const) = isa(val, Type) ? Type{val} : typeof(val)
294336
widenconst(m::MaybeUndef) = widenconst(m.typ)
@@ -427,3 +469,45 @@ function stupdate1!(state::VarTable, change::StateUpdate)
427469
end
428470
return false
429471
end
472+
473+
# compute typeintersect over the extended inference lattice,
474+
# as precisely as we can,
475+
# where v is in the extended lattice, and t is a Type.
476+
function tmeet(@nospecialize(v), @nospecialize(t))
477+
if isa(v, Const)
478+
if !has_free_typevars(t) && !isa(v.val, t)
479+
return Bottom
480+
end
481+
return v
482+
elseif isa(v, PartialStruct)
483+
has_free_typevars(t) && return v
484+
widev = widenconst(v)
485+
if widev <: t
486+
return v
487+
end
488+
ti = typeintersect(widev, t)
489+
valid_as_lattice(ti) || return Bottom
490+
@assert widev <: Tuple
491+
new_fields = Vector{Any}(undef, length(v.fields))
492+
for i = 1:length(new_fields)
493+
vfi = v.fields[i]
494+
if isvarargtype(vfi)
495+
new_fields[i] = vfi
496+
else
497+
new_fields[i] = tmeet(vfi, widenconst(getfield_tfunc(t, Const(i))))
498+
if new_fields[i] === Bottom
499+
return Bottom
500+
end
501+
end
502+
end
503+
return tuple_tfunc(new_fields)
504+
elseif isa(v, Conditional)
505+
if !(Bool <: t)
506+
return Bottom
507+
end
508+
return v
509+
end
510+
ti = typeintersect(widenconst(v), t)
511+
valid_as_lattice(ti) || return Bottom
512+
return ti
513+
end

base/compiler/typelimits.jl

Lines changed: 81 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -298,23 +298,71 @@ union_count_abstract(x::Union) = union_count_abstract(x.a) + union_count_abstrac
298298
union_count_abstract(@nospecialize(x)) = !isdispatchelem(x)
299299

300300
function issimpleenoughtype(@nospecialize t)
301-
t = ignorelimited(t)
302301
return unionlen(t) + union_count_abstract(t) <= MAX_TYPEUNION_LENGTH &&
303302
unioncomplexity(t) <= MAX_TYPEUNION_COMPLEXITY
304303
end
305304

305+
# A simplified type_more_complex query over the extended lattice
306+
# (assumes typeb ⊑ typea)
307+
function issimplertype(@nospecialize(typea), @nospecialize(typeb))
308+
typea = ignorelimited(typea)
309+
typeb = ignorelimited(typeb)
310+
typea isa MaybeUndef && (typea = typea.typ) # n.b. does not appear in inference
311+
typeb isa MaybeUndef && (typeb = typeb.typ) # n.b. does not appear in inference
312+
typea === typeb && return true
313+
if typea isa PartialStruct
314+
aty = widenconst(typea)
315+
for i = 1:length(typea.fields)
316+
ai = typea.fields[i]
317+
bi = fieldtype(aty, i)
318+
is_lattice_equal(ai, bi) && continue
319+
tni = _typename(widenconst(ai))
320+
if tni isa Const
321+
bi = (tni.val::Core.TypeName).wrapper
322+
is_lattice_equal(ai, bi) && continue
323+
end
324+
bi = getfield_tfunc(typeb, Const(i))
325+
is_lattice_equal(ai, bi) && continue
326+
# It is not enough for ai to be simpler than bi: it must exactly equal
327+
# (for this, an invariant struct field, by contrast to
328+
# type_more_complex above which handles covariant tuples).
329+
return false
330+
end
331+
elseif typea isa Type
332+
return issimpleenoughtype(typea)
333+
# elseif typea isa Const # fall-through good
334+
elseif typea isa Conditional # follow issubconditional query
335+
typeb isa Const && return true
336+
typeb isa Conditional || return false
337+
is_same_conditionals(typea, typeb) || return false
338+
issimplertype(typea.vtype, typeb.vtype) || return false
339+
issimplertype(typea.elsetype, typeb.elsetype) || return false
340+
elseif typea isa InterConditional # ibid
341+
typeb isa Const && return true
342+
typeb isa InterConditional || return false
343+
is_same_conditionals(typea, typeb) || return false
344+
issimplertype(typea.vtype, typeb.vtype) || return false
345+
issimplertype(typea.elsetype, typeb.elsetype) || return false
346+
elseif typea isa PartialOpaque
347+
# TODO
348+
end
349+
return true
350+
end
351+
306352
# pick a wider type that contains both typea and typeb,
307353
# with some limits on how "large" it can get,
308354
# but without losing too much precision in common cases
309355
# and also trying to be mostly associative and commutative
310356
function tmerge(@nospecialize(typea), @nospecialize(typeb))
311357
typea === Union{} && return typeb
312358
typeb === Union{} && return typea
359+
typea === typeb && return typea
360+
313361
suba = typea typeb
314-
suba && issimpleenoughtype(typeb) && return typeb
362+
suba && issimplertype(typeb, typea) && return typeb
315363
subb = typeb typea
316364
suba && subb && return typea
317-
subb && issimpleenoughtype(typea) && return typea
365+
subb && issimplertype(typea, typeb) && return typea
318366

319367
# type-lattice for LimitedAccuracy wrapper
320368
# the merge create a slightly narrower type than needed, but we can't
@@ -404,6 +452,7 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb))
404452
aty = widenconst(typea)
405453
bty = widenconst(typeb)
406454
if aty === bty
455+
# must have egal here, since we do not create PartialStruct for non-concrete types
407456
typea_nfields = nfields_tfunc(typea)
408457
typeb_nfields = nfields_tfunc(typeb)
409458
isa(typea_nfields, Const) || return aty
@@ -412,18 +461,40 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb))
412461
type_nfields === typeb_nfields.val::Int || return aty
413462
type_nfields == 0 && return aty
414463
fields = Vector{Any}(undef, type_nfields)
415-
anyconst = false
464+
anyrefine = false
416465
for i = 1:type_nfields
417466
ai = getfield_tfunc(typea, Const(i))
418467
bi = getfield_tfunc(typeb, Const(i))
419-
ity = tmerge(ai, bi)
420-
if ai === Union{} || bi === Union{}
421-
ity = widenconst(ity)
468+
ft = fieldtype(aty, i)
469+
if is_lattice_equal(ai, bi) || is_lattice_equal(ai, ft)
470+
# Since ai===bi, the given type has no restrictions on complexity.
471+
# and can be used to refine ft
472+
tyi = ai
473+
elseif is_lattice_equal(bi, ft)
474+
tyi = bi
475+
else
476+
# Otherwise choose between using the fieldtype or some other simple merged type.
477+
# The wrapper type never has restrictions on complexity,
478+
# so try to use that to refine the estimated type too.
479+
tni = _typename(widenconst(ai))
480+
if tni isa Const && tni === _typename(widenconst(bi))
481+
# A tmeet call may cause tyi to become complex, but since the inputs were
482+
# strictly limited to being egal, this has no restrictions on complexity.
483+
# (Otherwise, we would need to use <: and take the narrower one without
484+
# intersection. See the similar comment in abstract_call_method.)
485+
tyi = typeintersect(ft, (tni.val::Core.TypeName).wrapper)
486+
else
487+
# Since aty===bty, the fieldtype has no restrictions on complexity.
488+
tyi = ft
489+
end
490+
end
491+
fields[i] = tyi
492+
if !anyrefine
493+
anyrefine = has_nontrivial_const_info(tyi) || # constant information
494+
tyi ft # just a type-level information, but more precise than the declared type
422495
end
423-
fields[i] = ity
424-
anyconst |= has_nontrivial_const_info(ity)
425496
end
426-
return anyconst ? PartialStruct(aty, fields) : aty
497+
return anyrefine ? PartialStruct(aty, fields) : aty
427498
end
428499
end
429500
if isa(typea, PartialOpaque) && isa(typeb, PartialOpaque) && widenconst(typea) == widenconst(typeb)
@@ -610,44 +681,3 @@ function tuplemerge(a::DataType, b::DataType)
610681
end
611682
return Tuple{p...}
612683
end
613-
614-
# compute typeintersect over the extended inference lattice
615-
# where v is in the extended lattice, and t is a Type
616-
function tmeet(@nospecialize(v), @nospecialize(t))
617-
if isa(v, Const)
618-
if !has_free_typevars(t) && !isa(v.val, t)
619-
return Bottom
620-
end
621-
return v
622-
elseif isa(v, PartialStruct)
623-
has_free_typevars(t) && return v
624-
widev = widenconst(v)
625-
if widev <: t
626-
return v
627-
end
628-
ti = typeintersect(widev, t)
629-
valid_as_lattice(ti) || return Bottom
630-
@assert widev <: Tuple
631-
new_fields = Vector{Any}(undef, length(v.fields))
632-
for i = 1:length(new_fields)
633-
vfi = v.fields[i]
634-
if isvarargtype(vfi)
635-
new_fields[i] = vfi
636-
else
637-
new_fields[i] = tmeet(vfi, widenconst(getfield_tfunc(t, Const(i))))
638-
if new_fields[i] === Bottom
639-
return Bottom
640-
end
641-
end
642-
end
643-
return tuple_tfunc(new_fields)
644-
elseif isa(v, Conditional)
645-
if !(Bool <: t)
646-
return Bottom
647-
end
648-
return v
649-
end
650-
ti = typeintersect(widenconst(v), t)
651-
valid_as_lattice(ti) || return Bottom
652-
return ti
653-
end

test/compiler/inference.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3999,15 +3999,33 @@ end
39993999
@test (a, c)
40004000
@test (b, c)
40014001

4002-
@test @eval Module() begin
4003-
const ginit = Base.ImmutableDict{Any,Any}()
4004-
Base.return_types() do
4005-
g = ginit
4002+
init = Base.ImmutableDict{Number,Number}()
4003+
a = Const(init)
4004+
b = Core.PartialStruct(typeof(init), Any[Const(init), Any, ComplexF64])
4005+
c = Core.Compiler.tmerge(a, b)
4006+
@test (a, c) && (b, c)
4007+
@test c === typeof(init)
4008+
4009+
a = Core.PartialStruct(typeof(init), Any[Const(init), ComplexF64, ComplexF64])
4010+
c = Core.Compiler.tmerge(a, b)
4011+
@test (a, c) && (b, c)
4012+
@test c.fields[2] === Any # or Number
4013+
@test c.fields[3] === ComplexF64
4014+
4015+
b = Core.PartialStruct(typeof(init), Any[Const(init), ComplexF32, Union{ComplexF32,ComplexF64}])
4016+
c = Core.Compiler.tmerge(a, b)
4017+
@test (a, c)
4018+
@test (b, c)
4019+
@test c.fields[2] === Complex
4020+
@test c.fields[3] === Complex
4021+
4022+
global const ginit43784 = Base.ImmutableDict{Any,Any}()
4023+
@test Base.return_types() do
4024+
g = ginit43784
40064025
while true
40074026
g = Base.ImmutableDict(g, 1=>2)
40084027
end
40094028
end |> only === Union{}
4010-
end
40114029
end
40124030

40134031
# Test that purity modeling doesn't accidentally introduce new world age issues

0 commit comments

Comments
 (0)