Skip to content

Commit 698beed

Browse files
authored
Precompile correct invoke-targets (#46907)
This fixes backedge-based invalidation when a precompiled `invoke` is followed by loading a package that adds new specializations for the `invoke`d method. An example is LowRankApprox.jl, where FillArrays adds a specialization to `unique`.
1 parent 6fdcfd4 commit 698beed

3 files changed

Lines changed: 80 additions & 33 deletions

File tree

src/dump.c

Lines changed: 58 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,6 +1366,7 @@ static void jl_collect_backedges(jl_array_t *edges, jl_array_t *ext_targets)
13661366
jl_value_t *invokeTypes;
13671367
jl_method_instance_t *c;
13681368
size_t i;
1369+
size_t world = jl_get_world_counter();
13691370
void **table = edges_map.table; // edges is caller => callees
13701371
size_t table_size = edges_map.size;
13711372
for (i = 0; i < table_size; i += 2) {
@@ -1408,15 +1409,28 @@ static void jl_collect_backedges(jl_array_t *edges, jl_array_t *ext_targets)
14081409
size_t min_valid = 0;
14091410
size_t max_valid = ~(size_t)0;
14101411
int ambig = 0;
1411-
jl_value_t *matches = jl_matching_methods((jl_tupletype_t*)sig, jl_nothing, -1, 0, jl_atomic_load_acquire(&jl_world_counter), &min_valid, &max_valid, &ambig);
1412-
if (matches == jl_false) {
1413-
valid = 0;
1414-
break;
1415-
}
1416-
size_t k;
1417-
for (k = 0; k < jl_array_len(matches); k++) {
1418-
jl_method_match_t *match = (jl_method_match_t *)jl_array_ptr_ref(matches, k);
1419-
jl_array_ptr_set(matches, k, match->method);
1412+
jl_value_t *matches;
1413+
if (mode == 2 && callee && jl_is_method_instance(callee) && jl_is_type(sig)) {
1414+
// invoke, use subtyping
1415+
jl_methtable_t *mt = jl_method_get_table(((jl_method_instance_t*)callee)->def.method);
1416+
size_t min_world, max_world;
1417+
matches = jl_gf_invoke_lookup_worlds(sig, (jl_value_t*)mt, world, &min_world, &max_world);
1418+
if (matches == jl_nothing) {
1419+
valid = 0;
1420+
break;
1421+
}
1422+
matches = (jl_value_t*)((jl_method_match_t*)matches)->method;
1423+
} else {
1424+
matches = jl_matching_methods((jl_tupletype_t*)sig, jl_nothing, -1, 0, jl_atomic_load_acquire(&jl_world_counter), &min_valid, &max_valid, &ambig);
1425+
if (matches == jl_false) {
1426+
valid = 0;
1427+
break;
1428+
}
1429+
size_t k;
1430+
for (k = 0; k < jl_array_len(matches); k++) {
1431+
jl_method_match_t *match = (jl_method_match_t *)jl_array_ptr_ref(matches, k);
1432+
jl_array_ptr_set(matches, k, match->method);
1433+
}
14201434
}
14211435
jl_array_ptr_1d_push(ext_targets, mode == 1 ? NULL : sig);
14221436
jl_array_ptr_1d_push(ext_targets, callee);
@@ -2542,8 +2556,10 @@ static void jl_verify_edges(jl_array_t *targets, jl_array_t **pvalids)
25422556
jl_array_t *valids = jl_alloc_array_1d(jl_array_uint8_type, l);
25432557
memset(jl_array_data(valids), 1, l);
25442558
jl_value_t *loctag = NULL, *matches = NULL;
2545-
JL_GC_PUSH2(&loctag, &matches);
2559+
jl_methtable_t *mt = NULL;
2560+
JL_GC_PUSH3(&loctag, &matches, &mt);
25462561
*pvalids = valids;
2562+
size_t world = jl_get_world_counter();
25472563
for (i = 0; i < l; i++) {
25482564
jl_value_t *invokesig = jl_array_ptr_ref(targets, i * 3);
25492565
jl_value_t *callee = jl_array_ptr_ref(targets, i * 3 + 1);
@@ -2555,33 +2571,43 @@ static void jl_verify_edges(jl_array_t *targets, jl_array_t **pvalids)
25552571
else {
25562572
sig = callee == NULL ? invokesig : callee;
25572573
}
2558-
jl_array_t *expected = (jl_array_t*)jl_array_ptr_ref(targets, i * 3 + 2);
2559-
assert(jl_is_array(expected));
2574+
jl_value_t *expected = jl_array_ptr_ref(targets, i * 3 + 2);
25602575
int valid = 1;
25612576
size_t min_valid = 0;
25622577
size_t max_valid = ~(size_t)0;
25632578
int ambig = 0;
2564-
// TODO: possibly need to included ambiguities too (for the optimizer correctness)?
2565-
matches = jl_matching_methods((jl_tupletype_t*)sig, jl_nothing, -1, 0, jl_atomic_load_acquire(&jl_world_counter), &min_valid, &max_valid, &ambig);
2566-
if (matches == jl_false || jl_array_len(matches) != jl_array_len(expected)) {
2567-
valid = 0;
2568-
}
2569-
else {
2570-
size_t j, k, l = jl_array_len(expected);
2571-
for (k = 0; k < jl_array_len(matches); k++) {
2572-
jl_method_match_t *match = (jl_method_match_t*)jl_array_ptr_ref(matches, k);
2573-
jl_method_t *m = match->method;
2574-
for (j = 0; j < l; j++) {
2575-
if (m == (jl_method_t*)jl_array_ptr_ref(expected, j))
2579+
int use_invoke = invokesig == NULL || callee == NULL ? 0 : 1;
2580+
if (!use_invoke) {
2581+
// TODO: possibly need to included ambiguities too (for the optimizer correctness)?
2582+
matches = jl_matching_methods((jl_tupletype_t*)sig, jl_nothing, -1, 0, jl_atomic_load_acquire(&jl_world_counter), &min_valid, &max_valid, &ambig);
2583+
if (matches == jl_false || jl_array_len(matches) != jl_array_len(expected)) {
2584+
valid = 0;
2585+
}
2586+
else {
2587+
assert(jl_is_array(expected));
2588+
size_t j, k, l = jl_array_len(expected);
2589+
for (k = 0; k < jl_array_len(matches); k++) {
2590+
jl_method_match_t *match = (jl_method_match_t*)jl_array_ptr_ref(matches, k);
2591+
jl_method_t *m = match->method;
2592+
for (j = 0; j < l; j++) {
2593+
if (m == (jl_method_t*)jl_array_ptr_ref(expected, j))
2594+
break;
2595+
}
2596+
if (j == l) {
2597+
// intersection has a new method or a method was
2598+
// deleted--this is now probably no good, just invalidate
2599+
// everything about it now
2600+
valid = 0;
25762601
break;
2602+
}
25772603
}
2578-
if (j == l) {
2579-
// intersection has a new method or a method was
2580-
// deleted--this is now probably no good, just invalidate
2581-
// everything about it now
2582-
valid = 0;
2583-
break;
2584-
}
2604+
}
2605+
} else {
2606+
mt = jl_method_get_table(((jl_method_instance_t*)callee)->def.method);
2607+
size_t min_world, max_world;
2608+
matches = jl_gf_invoke_lookup_worlds(invokesig, (jl_value_t*)mt, world, &min_world, &max_world);
2609+
if (matches == jl_nothing || expected != (jl_value_t*)((jl_method_match_t*)matches)->method) {
2610+
valid = 0;
25852611
}
25862612
}
25872613
jl_array_uint8_set(valids, i, valid);
@@ -2593,7 +2619,7 @@ static void jl_verify_edges(jl_array_t *targets, jl_array_t **pvalids)
25932619
jl_array_ptr_1d_push(_jl_debug_method_invalidation, loctag);
25942620
loctag = jl_box_uint64(jl_worklist_key(serializer_worklist));
25952621
jl_array_ptr_1d_push(_jl_debug_method_invalidation, loctag);
2596-
if (matches != jl_false) {
2622+
if (!use_invoke && matches != jl_false) {
25972623
// setdiff!(matches, expected)
25982624
size_t j, k, ins = 0;
25992625
for (j = 0; j < jl_array_len(matches); j++) {

src/julia_internal.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -716,13 +716,14 @@ jl_value_t *jl_gf_invoke_by_method(jl_method_t *method, jl_value_t *gf, jl_value
716716
jl_value_t *jl_gf_invoke(jl_value_t *types, jl_value_t *f, jl_value_t **args, size_t nargs);
717717
JL_DLLEXPORT jl_value_t *jl_matching_methods(jl_tupletype_t *types, jl_value_t *mt, int lim, int include_ambiguous,
718718
size_t world, size_t *min_valid, size_t *max_valid, int *ambig);
719+
JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup_worlds(jl_value_t *types, jl_value_t *mt, size_t world, size_t *min_world, size_t *max_world);
719720

720721
JL_DLLEXPORT jl_datatype_t *jl_first_argument_datatype(jl_value_t *argtypes JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT;
721722
JL_DLLEXPORT jl_value_t *jl_argument_datatype(jl_value_t *argt JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT;
722723
JL_DLLEXPORT jl_methtable_t *jl_method_table_for(
723724
jl_value_t *argtypes JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT;
724725
JL_DLLEXPORT jl_methtable_t *jl_method_get_table(
725-
jl_method_t *method) JL_NOTSAFEPOINT;
726+
jl_method_t *method JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT;
726727
jl_methtable_t *jl_argument_method_table(jl_value_t *argt JL_PROPAGATES_ROOT);
727728

728729
JL_DLLEXPORT int jl_pointer_egal(jl_value_t *t);

test/precompile.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,7 @@ precompile_test_harness("invoke") do dir
931931
module $InvokeModule
932932
export f, g, h, q, fnc, gnc, hnc, qnc # nc variants do not infer to a Const
933933
export f44320, g44320
934+
export getlast
934935
# f is for testing invoke that occurs within a dependency
935936
f(x::Real) = 0
936937
f(x::Int) = x < 5 ? 1 : invoke(f, Tuple{Real}, x)
@@ -954,6 +955,16 @@ precompile_test_harness("invoke") do dir
954955
f44320(::Any) = 2
955956
g44320() = invoke(f44320, Tuple{Any}, 0)
956957
g44320()
958+
959+
# Adding new specializations should not invalidate `invoke`s
960+
function getlast(itr)
961+
x = nothing
962+
for y in itr
963+
x = y
964+
end
965+
return x
966+
end
967+
getlast(a::AbstractArray) = invoke(getlast, Tuple{Any}, a)
957968
end
958969
""")
959970
write(joinpath(dir, "$CallerModule.jl"),
@@ -981,6 +992,8 @@ precompile_test_harness("invoke") do dir
981992
# Issue #44320
982993
f44320(::Real) = 3
983994
995+
call_getlast(x) = getlast(x)
996+
984997
# force precompilation
985998
begin
986999
Base.Experimental.@force_compile
@@ -996,6 +1009,7 @@ precompile_test_harness("invoke") do dir
9961009
callqnci(3)
9971010
internal(3)
9981011
internalnc(3)
1012+
call_getlast([1,2,3])
9991013
end
10001014
10011015
# Now that we've precompiled, invalidate with a new method that overrides the `invoke` dispatch
@@ -1007,6 +1021,9 @@ precompile_test_harness("invoke") do dir
10071021
end
10081022
""")
10091023
Base.compilecache(Base.PkgId(string(CallerModule)))
1024+
@eval using $InvokeModule: $InvokeModule
1025+
MI = getfield(@__MODULE__, InvokeModule)
1026+
@eval $MI.getlast(a::UnitRange) = a.stop
10101027
@eval using $CallerModule
10111028
M = getfield(@__MODULE__, CallerModule)
10121029

@@ -1060,6 +1077,9 @@ precompile_test_harness("invoke") do dir
10601077
m = only(methods(M.g44320))
10611078
@test m.specializations[1].cache.max_world == typemax(UInt)
10621079

1080+
m = which(MI.getlast, (Any,))
1081+
@test m.specializations[1].cache.max_world == typemax(UInt)
1082+
10631083
# Precompile specific methods for arbitrary arg types
10641084
invokeme(x) = 1
10651085
invokeme(::Int) = 2

0 commit comments

Comments
 (0)