@@ -26,7 +26,7 @@ register_builtin_option synthInstance.maxSize : Nat := {
2626}
2727
2828register_builtin_option synthInstance.etaExperiment : Bool := {
29- defValue := false
29+ defValue := true
3030 descr := "[DO NOT USE EXCEPT FOR TESTING] enable structure eta for type-classes during type-class search"
3131}
3232
@@ -48,11 +48,15 @@ structure GeneratorNode where
4848 currInstanceIdx : Nat
4949 deriving Inhabited
5050
51+ inductive Subgoal where
52+ | tc (cls : Expr)
53+ | defEq (a b : Expr)
54+
5155structure ConsumerNode where
5256 mvar : Expr
5357 key : Expr
5458 mctx : MetavarContext
55- subgoals : List Expr
59+ subgoals : List Subgoal
5660 size : Nat -- instance size so far
5761 deriving Inhabited
5862
@@ -328,10 +332,40 @@ def preprocessOutParam (type : Expr) : MetaM Expr :=
328332 Given `getSubgoalsAux args j subgoals instVal type`,
329333 we have that `type.instantiateRevRange j args.size args` does not have loose bound variables. -/
330334structure SubgoalsResult where
331- subgoals : List Expr
335+ subgoals : Array Subgoal
332336 instVal : Expr
333337 instTypeBody : Expr
334338
339+ partial def generalizeTypeClassInstances (e : Expr) (outerLCtx : LocalContext) (outerLocalInsts : LocalInstances) :
340+ MetaM (Expr × Array Subgoal) := do
341+ let subgoals ← IO.mkRef #[]
342+ let rec
343+ go (e : Expr) : MonadCacheT ExprStructEq Expr MetaM Expr :=
344+ checkCache (ExprStructEq.mk e) fun _ => withIncRecDepth do
345+ let e := e.headBeta -- (fun x => t) 1
346+ if let some e := e.etaExpandedStrict? then go e else
347+ if (← isClass? (← inferType e)).isSome then
348+ if !e.getAppFn.isMVar && e.hasExprMVar then
349+ return ← mkMVarForExpr e
350+ match e with
351+ | .lam .. | .letE .. => lambdaLetTelescope e fun xs e' => do mkLambdaFVars xs (← go e')
352+ | .forallE .. => forallTelescope e fun xs e' => do mkForallFVars xs (← go e')
353+ | .mdata d e => return .mdata d (← go e)
354+ | _ => e.withApp fun f as => do
355+ return mkAppN f (← as.mapM go),
356+ mkMVarForExpr (e : Expr) := do
357+ let (_, usedFVars) ← e.collectFVars.run {}
358+ let usedFVars ← usedFVars.addDependencies
359+ let mut args := #[]
360+ for ldecl in ← getLCtx do
361+ if !outerLCtx.contains ldecl.fvarId && usedFVars.fvarSet.contains ldecl.fvarId then
362+ args := args.push (.fvar ldecl.fvarId)
363+ let mvarType ← mkForallFVars args (← go (← inferType e))
364+ let mvar ← mkFreshExprMVarAt outerLCtx outerLocalInsts mvarType
365+ subgoals.modify (·.push (.defEq (← mkLambdaFVars args e) mvar))
366+ return mkAppN mvar args
367+ return (← (go e).run, ← subgoals.get)
368+
335369/--
336370 `getSubgoals lctx localInsts xs inst` creates the subgoals for the instance `inst`.
337371 The subgoals are in the context of the free variables `xs`, and
@@ -365,29 +399,29 @@ def getSubgoals (lctx : LocalContext) (localInsts : LocalInstances) (xs : Array
365399 instVal := instVal.instantiateRev subst
366400 subst := #[]
367401 unless instType.isForall do break
402+ let (instType', defEqSubgoals) ← generalizeTypeClassInstances (instType.instantiateRev subst) lctx localInsts
368403 return {
369404 instVal := instVal.instantiateRev subst
370- instTypeBody := instType.instantiateRev subst
371- subgoals := inst.synthOrder.map (mvars[·]!) |>.toList
405+ instTypeBody := instType'
406+ subgoals := inst.synthOrder.map (Subgoal.tc mvars[·]!) ++ defEqSubgoals
372407 }
373408
374409/--
375410 Try to synthesize metavariable `mvar` using the instance `inst`.
376411 Remark: `mctx` is set using `withMCtx`.
377412 If it succeeds, the result is a new updated metavariable context and a new list of subgoals.
378413 A subgoal is created for each instance implicit parameter of `inst`. -/
379- def tryResolve (mvar : Expr) (inst : Instance) : MetaM (Option (MetavarContext × List Expr )) := do
414+ def tryResolve (mvar : Expr) (inst : Instance) : MetaM (Option (MetavarContext × List Subgoal )) := do
380415 let mvarType ← inferType mvar
381416 let lctx ← getLCtx
382417 let localInsts ← getLocalInstances
383418 forallTelescopeReducing mvarType fun xs mvarTypeBody => do
384- let ⟨ subgoals, instVal, instTypeBody⟩ ← getSubgoals lctx localInsts xs inst
419+ let { subgoals, instTypeBody, instVal} ← getSubgoals lctx localInsts xs inst
385420 withTraceNode `Meta.synthInstance.tryResolve (withMCtx (← getMCtx) do
386421 return m! "{ exceptOptionEmoji ·} { ← instantiateMVars mvarTypeBody} ≟ { ← instantiateMVars instTypeBody} " ) do
387422 if (← isDefEq mvarTypeBody instTypeBody) then
388423 let instVal ← mkLambdaFVars xs instVal
389- if (← isDefEq mvar instVal) then
390- return some ((← getMCtx), subgoals)
424+ return some ((← getMCtx), (subgoals.push (.defEq mvar instVal)).toList)
391425 return none
392426
393427/--
@@ -463,11 +497,11 @@ def inprocessOutParams (mvar : Expr) : MetaM Expr := do
463497 let type ← instantiateMVars (← inferType mvar)
464498 let outerLCtx ← getLCtx
465499 let outerLocalInsts ← getLocalInstances
466- forallTelescopeReducing type fun xs typeBody => do
467- let typeBody ← whnf typeBody
500+ forallTelescopeReducing type fun xs typeBody0 => do
501+ let typeBody ← whnf typeBody0
468502 let c@(.const className _) := typeBody.getAppFn | return mvar
469- let some outParamsPos := getOutParamPositions? (← getEnv) className | return mvar
470- if outParamsPos.isEmpty then return mvar
503+ let outParamsPos := ( getOutParamPositions? (← getEnv) className).getD #[]
504+ if outParamsPos.isEmpty && typeBody == typeBody0 then return mvar
471505 let args := typeBody.getAppArgs
472506 let cType ← inferType c
473507 let args ← preprocessArgs cType 0 args outParamsPos
@@ -488,7 +522,15 @@ def removeUnusedArguments (mvar : Expr) : MetaM Expr := do
488522 return mvar'
489523
490524/-- Process the next subgoal in the given consumer node. -/
491- def consume (cNode : ConsumerNode) : SynthM Unit := do
525+ partial def consume (cNode : ConsumerNode) : SynthM Unit := do
526+ match cNode.subgoals with
527+ | [] => addAnswer cNode
528+ | .defEq a b :: subgoals =>
529+ if let (true , mctx) ← withMCtx cNode.mctx do return (← withDefault (isDefEq a b), ← getMCtx) then
530+ consume { cNode with subgoals, mctx }
531+ else
532+ pure ()
533+ | .tc mvar :: mvars =>
492534 /- Filter out subgoals that have already been assigned when solving typing constraints.
493535 This may happen when a local instance type depends on other local instances.
494536 For example, in Mathlib, we have
@@ -500,18 +542,12 @@ def consume (cNode : ConsumerNode) : SynthM Unit := do
500542 SetLike (@Submodule R M _inst_1 _inst_2 _inst_3) M
501543 ```
502544 -/
503- let cNode := { cNode with
504- subgoals := ← withMCtx cNode.mctx do
505- cNode.subgoals.filterM (not <$> ·.mvarId!.isAssigned)
506- }
507- match cNode.subgoals with
508- | [] => addAnswer cNode
509- | mvar::mvars =>
545+ if ← mvar.mvarId!.isAssigned then consume { cNode with subgoals := mvars } else
510546 let (mvar, mvar', mctx) ← withMCtx cNode.mctx do
511547 let mvar ← removeUnusedArguments mvar
512548 let mvar' ← inprocessOutParams mvar
513549 return (mvar, mvar', ← getMCtx)
514- let cNode := { cNode with subgoals := mvar:: mvars, mctx }
550+ let cNode := { cNode with subgoals := .tc mvar :: mvars, mctx }
515551 let waiter := Waiter.consumerNode cNode
516552 let key ← mkTableKeyFor cNode.mctx mvar'
517553 match (← findEntry? key) with
@@ -559,7 +595,8 @@ def resume : SynthM Unit := do
559595 let (cNode, answer) ← getNextToResume
560596 match cNode.subgoals with
561597 | [] => panic! "resume found no remaining subgoals"
562- | mvar::rest =>
598+ | .defEq .. :: _ => panic! "resume found defeq constraint"
599+ | .tc mvar :: rest =>
563600 match (← tryAnswer cNode.mctx mvar answer) with
564601 | none => return ()
565602 | some mctx =>
0 commit comments