Skip to content

Commit 707a181

Browse files
committed
some refactors
1 parent eb2a86a commit 707a181

4 files changed

Lines changed: 80 additions & 50 deletions

File tree

TypedSyntax/src/node.jl

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,9 @@ const no_default_value = NoDefaultValue()
2424
# These are TypedSyntaxNode constructor helpers
2525
# Call these directly if you want both the TypedSyntaxNode and the `mappings` list,
2626
# where `mappings[i]` corresponds to the list of nodes matching `(src::CodeInfo).code[i]`.
27-
function tsn_and_mappings(@nospecialize(f), @nospecialize(t); kwargs...)
28-
m = which(f, t)
29-
src, rt, mi = getsrc(f, t)
30-
tsn_and_mappings(mi, src, rt; kwargs...)
27+
function tsn_and_mappings(@nospecialize(f), @nospecialize(tt=Base.default_tt(f)); kwargs...)
28+
inferred_result = get_inferred_result(f, tt)
29+
return tsn_and_mappings(inferred_result.mi, inferred_result.src, inferred_result.rt; kwargs...)
3130
end
3231

3332
function tsn_and_mappings(mi::MethodInstance, src::CodeInfo, @nospecialize(rt); warn::Bool=true, strip_macros::Bool=false, kwargs...)
@@ -58,13 +57,18 @@ function tsn_and_mappings(mi::MethodInstance, src::CodeInfo, @nospecialize(rt),
5857
return node, mappings
5958
end
6059

61-
TypedSyntaxNode(@nospecialize(f), @nospecialize(t); kwargs...) = tsn_and_mappings(f, t; kwargs...)[1]
60+
TypedSyntaxNode(@nospecialize(f), @nospecialize(tt=Base.default_tt(f)); kwargs...) = tsn_and_mappings(f, tt; kwargs...)[1]
6261

6362
function TypedSyntaxNode(mi::MethodInstance; kwargs...)
64-
src, rt = getsrc(mi)
63+
src, rt = code_typed1_tsn(mi)
6564
tsn_and_mappings(mi, src, rt; kwargs...)[1]
6665
end
6766

67+
function TypedSyntaxNode(rootnode::SyntaxNode, @nospecialize(f), @nospecialize(tt=Base.default_tt(f)); kwargs...)
68+
inferred_result = get_inferred_result(f, tt)
69+
TypedSyntaxNode(rootnode, inferred_result.src, inferred_result.mi; kwargs...)
70+
end
71+
6872
TypedSyntaxNode(rootnode::SyntaxNode, src::CodeInfo, mi::MethodInstance, Δline::Integer=0) =
6973
TypedSyntaxNode(rootnode, src, map_ssas_to_source(src, mi, rootnode, Δline)...)
7074

@@ -305,24 +309,57 @@ function sparam_name(mi::MethodInstance, i::Int)
305309
return sig.var.name
306310
end
307311

308-
function getsrc(@nospecialize(f), @nospecialize(t))
309-
srcrts = code_typed(f, t; debuginfo=:source, optimize=false)
310-
src, rt = only(srcrts)
311-
if src.parent !== nothing
312-
mi = src.parent
313-
else
314-
mi = Base.method_instance(f, t)
315-
end
316-
return src, rt, mi
317-
end
318-
319-
function getsrc(mi::MethodInstance)
320-
cis = Base.code_typed_by_type(mi.specTypes; debuginfo=:source, optimize=false)
321-
isempty(cis) && error("no applicable type-inferred code found for ", mi)
322-
length(cis) == 1 || error("got $(length(cis)) possible type-inferred results for ", mi,
323-
", you may need a more specialized signature")
324-
src::CodeInfo, rt = cis[1]
325-
return src, rt
312+
@static if isdefined(Base, :method_instances)
313+
using Base: method_instances
314+
else
315+
function method_instances(@nospecialize(f), @nospecialize(t), world::UInt)
316+
tt = Base.signature_type(f, t)
317+
results = Core.MethodInstance[]
318+
# this make a better error message than the typeassert that follows
319+
world == typemax(UInt) && error("code reflection cannot be used from generated functions")
320+
for match in Base._methods_by_ftype(tt, -1, world)::Vector
321+
instance = Core.Compiler.specialize_method(match)
322+
push!(results, instance)
323+
end
324+
return results
325+
end
326+
end
327+
328+
struct InferredResult
329+
mi::MethodInstance
330+
src::CodeInfo
331+
rt
332+
InferredResult(mi::MethodInstance, src::CodeInfo, @nospecialize(rt)) = new(mi, src, rt)
333+
end
334+
function get_inferred_result(@nospecialize(f), @nospecialize(tt=Base.default_tt(f)),
335+
world::UInt=Base.get_world_counter())
336+
mis = method_instances(f, tt, world)
337+
if isempty(mis)
338+
sig = sprint(Base.show_tuple_as_call, Symbol(""), Base.signature_type(f, tt))
339+
error("no applicable type-inferred code found for ", sig)
340+
elseif length(mis) 1
341+
sig = sprint(Base.show_tuple_as_call, Symbol(""), Base.signature_type(f, tt))
342+
error("got $(length(mis)) possible type-inferred results for ", sig,
343+
", you may need a more specialized signature")
344+
end
345+
mi = only(mis)
346+
return InferredResult(mi, code_typed1_tsn(mi)...)
347+
end
348+
349+
code_typed1_tsn(mi::MethodInstance) = code_typed1_by_method_instance(mi; optimize=false, debuginfo=:source)
350+
351+
function code_typed1_by_method_instance(mi::MethodInstance;
352+
optimize::Bool=true,
353+
debuginfo::Symbol=:default,
354+
world::UInt=Base.get_world_counter(),
355+
interp::Core.Compiler.AbstractInterpreter=Core.Compiler.NativeInterpreter(world))
356+
(ccall(:jl_is_in_pure_context, Bool, ()) || world == typemax(UInt)) &&
357+
error("code reflection should not be used from generated functions")
358+
debuginfo = Base.IRShow.debuginfo(debuginfo)
359+
code, rt = Core.Compiler.typeinf_code(interp, mi.def::Method, mi.specTypes, mi.sparam_vals, optimize)
360+
code isa CodeInfo || error("no code is available for ", mi)
361+
debuginfo === :none && Base.remove_linenums!(code)
362+
return Pair{CodeInfo,Any}(code, rt)
326363
end
327364

328365
function is_function_def(node) # this is not `Base.is_function_def`
@@ -435,7 +472,7 @@ function map_ssas_to_source(src::CodeInfo, mi::MethodInstance, rootnode::SyntaxN
435472
# (Essentially `copy!(mapped, filter(predicate, targets))`)
436473
function append_targets_for_line!(mapped#=::Vector{nodes}=#, i::Int, targets#=::Vector{nodes}=#)
437474
j = src.codelocs[i]
438-
lt = src.linetable::Vector{Any}
475+
lt = src.linetable::Vector
439476
start = getline(lt, j) + Δline
440477
stop = getnextline(lt, j, Δline) - 1
441478
linerange = start : stop

TypedSyntax/test/runtests.jl

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using JuliaSyntax: JuliaSyntax, SyntaxNode, children, child, sourcetext, kind, @K_str
2-
using TypedSyntax: TypedSyntax, TypedSyntaxNode, getsrc
2+
using TypedSyntax: TypedSyntax, TypedSyntaxNode
33
using Dates, InteractiveUtils, Test
44

55
has_name_typ(node, name::Symbol, @nospecialize(T)) = kind(node) == K"Identifier" && node.val === name && node.typ === T
@@ -15,8 +15,7 @@ include("test_module.jl")
1515
"""
1616
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN1.jl")
1717
TSN.eval(Expr(rootnode))
18-
src, _, mi = getsrc(TSN.f, (Float32, Int, Float64))
19-
tsn = TypedSyntaxNode(rootnode, src, mi)
18+
tsn = TypedSyntaxNode(rootnode, TSN.f, (Float32, Int, Float64))
2019
sig, body = children(tsn)
2120
@test children(sig)[2].typ === Float32
2221
@test children(sig)[3].typ === Int
@@ -33,8 +32,7 @@ include("test_module.jl")
3332
"""
3433
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN2.jl")
3534
TSN.eval(Expr(rootnode))
36-
src, _, mi = getsrc(TSN.g, (Int16, Int16, Int32))
37-
tsn = TypedSyntaxNode(rootnode, src, mi)
35+
tsn = TypedSyntaxNode(rootnode, TSN.g, (Int16, Int16, Int32))
3836
sig, body = children(tsn)
3937
@test length(children(sig)) == 4
4038
@test children(body)[2].typ === Int32
@@ -46,8 +44,7 @@ include("test_module.jl")
4644
st = "math(x) = x + sin(x + π / 4)"
4745
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN2.jl")
4846
TSN.eval(Expr(rootnode))
49-
src, _, mi = getsrc(TSN.math, (Int,))
50-
tsn = TypedSyntaxNode(rootnode, src, mi)
47+
tsn = TypedSyntaxNode(rootnode, TSN.math, (Int,))
5148
sig, body = children(tsn)
5249
@test has_name_typ(child(body, 1), :x, Int)
5350
@test has_name_typ(child(body, 3, 2, 1), :x, Int)
@@ -70,8 +67,7 @@ include("test_module.jl")
7067
st = "math2(x) = sin(x) + sin(x)"
7168
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN2.jl")
7269
TSN.eval(Expr(rootnode))
73-
src, _, mi = getsrc(TSN.math2, (Int,))
74-
tsn = TypedSyntaxNode(rootnode, src, mi)
70+
tsn = TypedSyntaxNode(rootnode, TSN.math2, (Int,))
7571
sig, body = children(tsn)
7672
@test body.typ === Float64
7773
@test_broken child(body, 1).typ === Float64
@@ -91,8 +87,7 @@ include("test_module.jl")
9187
)
9288
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN3.jl")
9389
TSN.eval(Expr(rootnode))
94-
src, _, mi = getsrc(TSN.firstfirst, (Vector{Vector{Real}},))
95-
tsn = TypedSyntaxNode(rootnode, src, mi)
90+
tsn = TypedSyntaxNode(rootnode, TSN.firstfirst, (Vector{Vector{Real}},))
9691
sig, body = children(tsn)
9792
@test child(body, idxsinner...).typ === nothing
9893
@test child(body, idxsouter...).typ === Vector{Real}
@@ -150,8 +145,7 @@ include("test_module.jl")
150145
"""
151146
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN4.jl")
152147
TSN.eval(Expr(rootnode))
153-
src, rt, mi = getsrc(TSN.setlist!, (Vector{Vector{Float32}}, Vector{Vector{UInt8}}, Int, Int))
154-
tsn = TypedSyntaxNode(rootnode, src, mi)
148+
tsn = TypedSyntaxNode(rootnode, TSN.setlist!, (Vector{Vector{Float32}}, Vector{Vector{UInt8}}, Int, Int))
155149
sig, body = children(tsn)
156150
nodelist = child(body, 1, 2, 1, 1) # `listget`
157151
@test sourcetext(nodelist) == "listget" && nodelist.typ === Vector{Vector{UInt8}}
@@ -175,8 +169,7 @@ include("test_module.jl")
175169
"""
176170
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN5.jl")
177171
TSN.eval(Expr(rootnode))
178-
src, rt, mi = getsrc(TSN.callfindmin, (Vector{Float64},))
179-
tsn = TypedSyntaxNode(rootnode, src, mi)
172+
tsn = TypedSyntaxNode(rootnode, TSN.callfindmin, (Vector{Float64},))
180173
sig, body = children(tsn)
181174
t = child(body, 1, 1)
182175
@test kind(t) == K"tuple"
@@ -280,7 +273,8 @@ include("test_module.jl")
280273
"""
281274
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN6.jl")
282275
TSN.eval(Expr(rootnode))
283-
src, rt, mi = getsrc(TSN.avoidzero, (Int,))
276+
inferred_result = TypedSyntax.get_inferred_result(TSN.avoidzero, (Int,))
277+
src, rt, mi = inferred_result.src, inferred_result.rt, inferred_result.mi
284278
# src looks like this:
285279
# %1 = Main.TSN.:(var"#avoidzero#6")(true, #self#, x)::Float64
286280
# return %1
@@ -290,8 +284,7 @@ include("test_module.jl")
290284
@test rt === Float64
291285
# Try the kwbodyfunc
292286
m = which(TSN.avoidzero, (Int,))
293-
src, rt, mi = getsrc(Base.bodyfunction(m), (Bool, typeof(TSN.avoidzero), Int,))
294-
tsn = TypedSyntaxNode(rootnode, src, mi)
287+
tsn = TypedSyntaxNode(rootnode, Base.bodyfunction(m), (Bool, typeof(TSN.avoidzero), Int,))
295288
sig, body = children(tsn)
296289
isz = child(body, 2, 1, 1)
297290
@test kind(isz) == K"call" && child(isz, 1).val == :iszero
@@ -520,8 +513,7 @@ include("test_module.jl")
520513
@test_broken body.typ == Int
521514

522515
# Construction from MethodInstance
523-
src, rt, mi = TypedSyntax.getsrc(TSN.myoftype, (Float64, Int))
524-
tsn = TypedSyntaxNode(mi)
516+
tsn = TypedSyntaxNode(TSN.myoftype, (Float64, Int))
525517
sig, body = children(tsn)
526518
node = child(body, 1)
527519
@test node.typ === Type{Float64}
@@ -641,7 +633,8 @@ include("test_module.jl")
641633
@test isa(tsnc, TypedSyntaxNode)
642634

643635
# issue 487
644-
src, rt, mi = getsrc(TSN.f487, (Int,))
636+
inferred_result = TypedSyntax.get_inferred_result(TSN.f487, (Int,))
637+
src, mi = inferred_result.src, inferred_result.mi
645638
rt = Core.Const(1)
646639
tsn, _ = TypedSyntax.tsn_and_mappings(mi, src, rt)
647640
@test_nowarn str = sprint(tsn; context=:color=>false) do io, obj

src/reflection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ function get_typed_sourcetext(mi::MethodInstance, src::CodeInfo, @nospecialize(r
355355
end
356356

357357
function get_typed_sourcetext(mi::MethodInstance, ::IRCode, @nospecialize(rt); kwargs...)
358-
src, rt = TypedSyntax.getsrc(mi)
358+
src, rt = TypedSyntax.code_typed1_tsn(mi)
359359
return get_typed_sourcetext(mi, src, rt; kwargs...)
360360
end
361361

test/test_Cthulhu.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ end
5252
# Callsite handling in source-view mode: for kwarg functions, strip the body, and use "typed" callsites
5353
for m in (@which(anykwargs("animals")), @which(anykwargs("animals"; cat=1, dog=2)))
5454
mi = first_specialization(m)
55-
src, rt = Cthulhu.TypedSyntax.getsrc(mi)
55+
src, rt = Cthulhu.TypedSyntax.code_typed1_tsn(mi)
5656
tsn, mappings = Cthulhu.get_typed_sourcetext(mi, src, rt; warn=false)
5757
str = sprint(printstyled, tsn)
5858
@test occursin("anykwargs", str) && occursin("kwargs...", str) && !occursin("println", str)
@@ -61,15 +61,15 @@ end
6161
# Likewise for methods that fill in default positional arguments
6262
m = @which hasdefaultargs(1)
6363
mi = first_specialization(m)
64-
src, rt = Cthulhu.TypedSyntax.getsrc(mi)
64+
src, rt = Cthulhu.TypedSyntax.code_typed1_tsn(mi)
6565
tsn, mappings = Cthulhu.get_typed_sourcetext(mi, src, rt; warn=false)
6666
str = sprint(printstyled, tsn)
6767
@test occursin("hasdefaultargs(a, b=2)", str)
6868
@test !occursin("a + b", str)
6969
@test isempty(mappings)
7070
m = @which hasdefaultargs(1, 5)
7171
mi = first_specialization(m)
72-
src, rt = Cthulhu.TypedSyntax.getsrc(mi)
72+
src, rt = Cthulhu.TypedSyntax.code_typed1_tsn(mi)
7373
tsn, mappings = Cthulhu.get_typed_sourcetext(mi, src, rt; warn=false)
7474
str = sprint(printstyled, tsn)
7575
@test occursin("hasdefaultargs(a, b=2)", str)

0 commit comments

Comments
 (0)