1+ /-
2+ Copyright (c) 2022 Microsoft Corporation. All rights reserved.
3+ Released under Apache 2.0 license as described in the file LICENSE.
4+ Authors: Leonardo de Moura
5+ -/
6+ import Lean.Compiler.Specialize
7+ import Lean.Compiler.LCNF.Simp
8+ import Lean.Compiler.LCNF.SpecInfo
9+
10+ namespace Lean.Compiler.LCNF
11+ namespace Specialize
12+
13+ structure Context where
14+ scope : FVarIdSet := {}
15+
16+ structure State where
17+ decls : Array Decl := #[]
18+
19+ abbrev SpecializeM := ReaderT Context $ StateRefT State CompilerM
20+
21+ @[inline] def withParams (ps : Array Param) (x : SpecializeM α) : SpecializeM α :=
22+ withReader (fun ctx => { ctx with scope := ps.foldl (init := ctx.scope) fun s p => s.insert p.fvarId }) x
23+
24+ @[inline] def withFVar (fvarId : FVarId) (x : SpecializeM α) : SpecializeM α :=
25+ withReader (fun ctx => { ctx with scope := ctx.scope.insert fvarId }) x
26+
27+ def specializeApp? (letDecl : LetDecl) (_k : Code) : SpecializeM (Option Code) := do
28+ unless letDecl.value.isApp do return none
29+ let .const declName _us := letDecl.value.getAppFn | return none
30+ let some paramsInfo ← getSpecParamInfo? declName | return none
31+ let some _decl ← getStage1Decl? declName | return none
32+ trace[Compiler.specialize.candidate] "{ letDecl.value} , { paramsInfo} "
33+ -- TODO
34+ return none
35+
36+ mutual
37+ partial def visitFunDecl (funDecl : FunDecl) : SpecializeM FunDecl := do
38+ let value ← withParams funDecl.params <| visitCode funDecl.value
39+ funDecl.update' funDecl.type value
40+
41+ partial def visitCode (code : Code) : SpecializeM Code := do
42+ match code with
43+ | .let decl k =>
44+ if let some k ← specializeApp? decl k then
45+ visitCode k
46+ else
47+ let k ← withFVar decl.fvarId <| visitCode k
48+ return code.updateLet! decl k
49+ | .fun decl k | .jp decl k =>
50+ let decl ← visitFunDecl decl
51+ let k ← withFVar decl.fvarId <| visitCode k
52+ return code.updateFun! decl k
53+ | .cases c =>
54+ let alts ← c.alts.mapMonoM fun alt =>
55+ match alt with
56+ | .default k => return alt.updateCode (← visitCode k)
57+ | .alt _ ps k => withParams ps do return alt.updateCode (← visitCode k)
58+ return code.updateAlts! alts
59+ | .unreach .. | .jmp .. | .return .. => return code
60+
61+ end
62+
63+ def main (decl : Decl) : SpecializeM Decl := do
64+ if (← isTemplateLike decl) then
65+ return decl
66+ else
67+ let value ← withParams decl.params <| visitCode decl.value
68+ return { decl with value }
69+
70+ end Specialize
71+
72+ partial def Decl.specialize (decl : Decl) : CompilerM (Array Decl) := do
73+ let (decl, s) ← Specialize.main decl |>.run {} |>.run {}
74+ return s.decls.push decl
75+
76+ def specialize : Pass where
77+ phase := .base
78+ name := `specialize
79+ run := fun decls => do
80+ saveSpecParamInfo decls
81+ decls.foldlM (init := #[]) fun decls decl => return decls ++ (← decl.specialize)
82+
83+ builtin_initialize
84+ registerTraceClass `Compiler.specialize (inherited := true )
85+ registerTraceClass `Compiler.specialize.candidate
86+
87+ end Lean.Compiler.LCNF
0 commit comments