Skip to content

Commit ec2372e

Browse files
committed
feat: add Specialize.lean skeleton
1 parent 44c67f7 commit ec2372e

2 files changed

Lines changed: 90 additions & 1 deletion

File tree

src/Lean/Compiler/LCNF/Passes.lean

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import Lean.Compiler.LCNF.Simp
1010
import Lean.Compiler.LCNF.PullFunDecls
1111
import Lean.Compiler.LCNF.ReduceJpArity
1212
import Lean.Compiler.LCNF.JoinPoints
13+
import Lean.Compiler.LCNF.Specialize
1314

1415
namespace Lean.Compiler.LCNF
1516

@@ -21,7 +22,8 @@ namespace Lean.Compiler.LCNF
2122
pullFunDecls,
2223
findJoinPoints,
2324
reduceJpArity,
24-
simp { etaPoly := true, inlinePartial := true, implementedBy := true } (occurence := 1)
25+
simp { etaPoly := true, inlinePartial := true, implementedBy := true } (occurence := 1),
26+
specialize
2527
]
2628

2729
end Lean.Compiler.LCNF
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/-
2+
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Leonardo de Moura
5+
-/
6+
import Lean.Compiler.Specialize
7+
import Lean.Compiler.LCNF.Simp
8+
import Lean.Compiler.LCNF.SpecInfo
9+
10+
namespace Lean.Compiler.LCNF
11+
namespace Specialize
12+
13+
structure Context where
14+
scope : FVarIdSet := {}
15+
16+
structure State where
17+
decls : Array Decl := #[]
18+
19+
abbrev SpecializeM := ReaderT Context $ StateRefT State CompilerM
20+
21+
@[inline] def withParams (ps : Array Param) (x : SpecializeM α) : SpecializeM α :=
22+
withReader (fun ctx => { ctx with scope := ps.foldl (init := ctx.scope) fun s p => s.insert p.fvarId }) x
23+
24+
@[inline] def withFVar (fvarId : FVarId) (x : SpecializeM α) : SpecializeM α :=
25+
withReader (fun ctx => { ctx with scope := ctx.scope.insert fvarId }) x
26+
27+
def specializeApp? (letDecl : LetDecl) (_k : Code) : SpecializeM (Option Code) := do
28+
unless letDecl.value.isApp do return none
29+
let .const declName _us := letDecl.value.getAppFn | return none
30+
let some paramsInfo ← getSpecParamInfo? declName | return none
31+
let some _decl ← getStage1Decl? declName | return none
32+
trace[Compiler.specialize.candidate] "{letDecl.value}, {paramsInfo}"
33+
-- TODO
34+
return none
35+
36+
mutual
37+
partial def visitFunDecl (funDecl : FunDecl) : SpecializeM FunDecl := do
38+
let value ← withParams funDecl.params <| visitCode funDecl.value
39+
funDecl.update' funDecl.type value
40+
41+
partial def visitCode (code : Code) : SpecializeM Code := do
42+
match code with
43+
| .let decl k =>
44+
if let some k ← specializeApp? decl k then
45+
visitCode k
46+
else
47+
let k ← withFVar decl.fvarId <| visitCode k
48+
return code.updateLet! decl k
49+
| .fun decl k | .jp decl k =>
50+
let decl ← visitFunDecl decl
51+
let k ← withFVar decl.fvarId <| visitCode k
52+
return code.updateFun! decl k
53+
| .cases c =>
54+
let alts ← c.alts.mapMonoM fun alt =>
55+
match alt with
56+
| .default k => return alt.updateCode (← visitCode k)
57+
| .alt _ ps k => withParams ps do return alt.updateCode (← visitCode k)
58+
return code.updateAlts! alts
59+
| .unreach .. | .jmp .. | .return .. => return code
60+
61+
end
62+
63+
def main (decl : Decl) : SpecializeM Decl := do
64+
if (← isTemplateLike decl) then
65+
return decl
66+
else
67+
let value ← withParams decl.params <| visitCode decl.value
68+
return { decl with value }
69+
70+
end Specialize
71+
72+
partial def Decl.specialize (decl : Decl) : CompilerM (Array Decl) := do
73+
let (decl, s) ← Specialize.main decl |>.run {} |>.run {}
74+
return s.decls.push decl
75+
76+
def specialize : Pass where
77+
phase := .base
78+
name := `specialize
79+
run := fun decls => do
80+
saveSpecParamInfo decls
81+
decls.foldlM (init := #[]) fun decls decl => return decls ++ (← decl.specialize)
82+
83+
builtin_initialize
84+
registerTraceClass `Compiler.specialize (inherited := true)
85+
registerTraceClass `Compiler.specialize.candidate
86+
87+
end Lean.Compiler.LCNF

0 commit comments

Comments
 (0)