Skip to content

Commit 182e343

Browse files
committed
feat: tc: generalize instances when applying
1 parent 7dba372 commit 182e343

1 file changed

Lines changed: 60 additions & 23 deletions

File tree

src/Lean/Meta/SynthInstance.lean

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ register_builtin_option synthInstance.maxSize : Nat := {
2626
}
2727

2828
register_builtin_option synthInstance.etaExperiment : Bool := {
29-
defValue := false
29+
defValue := true
3030
descr := "[DO NOT USE EXCEPT FOR TESTING] enable structure eta for type-classes during type-class search"
3131
}
3232

@@ -48,11 +48,15 @@ structure GeneratorNode where
4848
currInstanceIdx : Nat
4949
deriving Inhabited
5050

51+
inductive Subgoal where
52+
| tc (cls : Expr)
53+
| defEq (a b : Expr)
54+
5155
structure ConsumerNode where
5256
mvar : Expr
5357
key : Expr
5458
mctx : MetavarContext
55-
subgoals : List Expr
59+
subgoals : List Subgoal
5660
size : Nat -- instance size so far
5761
deriving Inhabited
5862

@@ -328,10 +332,40 @@ def preprocessOutParam (type : Expr) : MetaM Expr :=
328332
Given `getSubgoalsAux args j subgoals instVal type`,
329333
we have that `type.instantiateRevRange j args.size args` does not have loose bound variables. -/
330334
structure SubgoalsResult where
331-
subgoals : List Expr
335+
subgoals : Array Subgoal
332336
instVal : Expr
333337
instTypeBody : Expr
334338

339+
partial def generalizeTypeClassInstances (e : Expr) (outerLCtx : LocalContext) (outerLocalInsts : LocalInstances) :
340+
MetaM (Expr × Array Subgoal) := do
341+
let subgoals ← IO.mkRef #[]
342+
let rec
343+
go (e : Expr) : MonadCacheT ExprStructEq Expr MetaM Expr :=
344+
checkCache (ExprStructEq.mk e) fun _ => withIncRecDepth do
345+
let e := e.headBeta -- (fun x => t) 1
346+
if let some e := e.etaExpandedStrict? then go e else
347+
if (← isClass? (← inferType e)).isSome then
348+
if !e.getAppFn.isMVar && e.hasExprMVar then
349+
return ← mkMVarForExpr e
350+
match e with
351+
| .lam .. | .letE .. => lambdaLetTelescope e fun xs e' => do mkLambdaFVars xs (← go e')
352+
| .forallE .. => forallTelescope e fun xs e' => do mkForallFVars xs (← go e')
353+
| .mdata d e => return .mdata d (← go e)
354+
| _ => e.withApp fun f as => do
355+
return mkAppN f (← as.mapM go),
356+
mkMVarForExpr (e : Expr) := do
357+
let (_, usedFVars) ← e.collectFVars.run {}
358+
let usedFVars ← usedFVars.addDependencies
359+
let mut args := #[]
360+
for ldecl in ← getLCtx do
361+
if !outerLCtx.contains ldecl.fvarId && usedFVars.fvarSet.contains ldecl.fvarId then
362+
args := args.push (.fvar ldecl.fvarId)
363+
let mvarType ← mkForallFVars args (← go (← inferType e))
364+
let mvar ← mkFreshExprMVarAt outerLCtx outerLocalInsts mvarType
365+
subgoals.modify (·.push (.defEq (← mkLambdaFVars args e) mvar))
366+
return mkAppN mvar args
367+
return (← (go e).run, ← subgoals.get)
368+
335369
/--
336370
`getSubgoals lctx localInsts xs inst` creates the subgoals for the instance `inst`.
337371
The subgoals are in the context of the free variables `xs`, and
@@ -365,29 +399,29 @@ def getSubgoals (lctx : LocalContext) (localInsts : LocalInstances) (xs : Array
365399
instVal := instVal.instantiateRev subst
366400
subst := #[]
367401
unless instType.isForall do break
402+
let (instType', defEqSubgoals) ← generalizeTypeClassInstances (instType.instantiateRev subst) lctx localInsts
368403
return {
369404
instVal := instVal.instantiateRev subst
370-
instTypeBody := instType.instantiateRev subst
371-
subgoals := inst.synthOrder.map (mvars[·]!) |>.toList
405+
instTypeBody := instType'
406+
subgoals := inst.synthOrder.map (Subgoal.tc mvars[·]!) ++ defEqSubgoals
372407
}
373408

374409
/--
375410
Try to synthesize metavariable `mvar` using the instance `inst`.
376411
Remark: `mctx` is set using `withMCtx`.
377412
If it succeeds, the result is a new updated metavariable context and a new list of subgoals.
378413
A subgoal is created for each instance implicit parameter of `inst`. -/
379-
def tryResolve (mvar : Expr) (inst : Instance) : MetaM (Option (MetavarContext × List Expr)) := do
414+
def tryResolve (mvar : Expr) (inst : Instance) : MetaM (Option (MetavarContext × List Subgoal)) := do
380415
let mvarType ← inferType mvar
381416
let lctx ← getLCtx
382417
let localInsts ← getLocalInstances
383418
forallTelescopeReducing mvarType fun xs mvarTypeBody => do
384-
let subgoals, instVal, instTypeBody⟩ ← getSubgoals lctx localInsts xs inst
419+
let {subgoals, instTypeBody, instVal} ← getSubgoals lctx localInsts xs inst
385420
withTraceNode `Meta.synthInstance.tryResolve (withMCtx (← getMCtx) do
386421
return m!"{exceptOptionEmoji ·} {← instantiateMVars mvarTypeBody}{← instantiateMVars instTypeBody}") do
387422
if (← isDefEq mvarTypeBody instTypeBody) then
388423
let instVal ← mkLambdaFVars xs instVal
389-
if (← isDefEq mvar instVal) then
390-
return some ((← getMCtx), subgoals)
424+
return some ((← getMCtx), (subgoals.push (.defEq mvar instVal)).toList)
391425
return none
392426

393427
/--
@@ -463,11 +497,11 @@ def inprocessOutParams (mvar : Expr) : MetaM Expr := do
463497
let type ← instantiateMVars (← inferType mvar)
464498
let outerLCtx ← getLCtx
465499
let outerLocalInsts ← getLocalInstances
466-
forallTelescopeReducing type fun xs typeBody => do
467-
let typeBody ← whnf typeBody
500+
forallTelescopeReducing type fun xs typeBody0 => do
501+
let typeBody ← whnf typeBody0
468502
let c@(.const className _) := typeBody.getAppFn | return mvar
469-
let some outParamsPos := getOutParamPositions? (← getEnv) className | return mvar
470-
if outParamsPos.isEmpty then return mvar
503+
let outParamsPos := (getOutParamPositions? (← getEnv) className).getD #[]
504+
if outParamsPos.isEmpty && typeBody == typeBody0 then return mvar
471505
let args := typeBody.getAppArgs
472506
let cType ← inferType c
473507
let args ← preprocessArgs cType 0 args outParamsPos
@@ -488,7 +522,15 @@ def removeUnusedArguments (mvar : Expr) : MetaM Expr := do
488522
return mvar'
489523

490524
/-- Process the next subgoal in the given consumer node. -/
491-
def consume (cNode : ConsumerNode) : SynthM Unit := do
525+
partial def consume (cNode : ConsumerNode) : SynthM Unit := do
526+
match cNode.subgoals with
527+
| [] => addAnswer cNode
528+
| .defEq a b :: subgoals =>
529+
if let (true, mctx) ← withMCtx cNode.mctx do return (← withDefault (isDefEq a b), ← getMCtx) then
530+
consume { cNode with subgoals, mctx }
531+
else
532+
pure ()
533+
| .tc mvar :: mvars =>
492534
/- Filter out subgoals that have already been assigned when solving typing constraints.
493535
This may happen when a local instance type depends on other local instances.
494536
For example, in Mathlib, we have
@@ -500,18 +542,12 @@ def consume (cNode : ConsumerNode) : SynthM Unit := do
500542
SetLike (@Submodule R M _inst_1 _inst_2 _inst_3) M
501543
```
502544
-/
503-
let cNode := { cNode with
504-
subgoals := ← withMCtx cNode.mctx do
505-
cNode.subgoals.filterM (not <$> ·.mvarId!.isAssigned)
506-
}
507-
match cNode.subgoals with
508-
| [] => addAnswer cNode
509-
| mvar::mvars =>
545+
if ← mvar.mvarId!.isAssigned then consume { cNode with subgoals := mvars } else
510546
let (mvar, mvar', mctx) ← withMCtx cNode.mctx do
511547
let mvar ← removeUnusedArguments mvar
512548
let mvar' ← inprocessOutParams mvar
513549
return (mvar, mvar', ← getMCtx)
514-
let cNode := { cNode with subgoals := mvar::mvars, mctx }
550+
let cNode := { cNode with subgoals := .tc mvar :: mvars, mctx }
515551
let waiter := Waiter.consumerNode cNode
516552
let key ← mkTableKeyFor cNode.mctx mvar'
517553
match (← findEntry? key) with
@@ -559,7 +595,8 @@ def resume : SynthM Unit := do
559595
let (cNode, answer) ← getNextToResume
560596
match cNode.subgoals with
561597
| [] => panic! "resume found no remaining subgoals"
562-
| mvar::rest =>
598+
| .defEq .. :: _ => panic! "resume found defeq constraint"
599+
| .tc mvar :: rest =>
563600
match (← tryAnswer cNode.mctx mvar answer) with
564601
| none => return ()
565602
| some mctx =>

0 commit comments

Comments
 (0)