@@ -500,6 +500,19 @@ private partial def AsyncConsts.findRecTask (aconsts : AsyncConsts) (declName :
500500 let some aconsts := aconsts.get? AsyncConsts | .pure none
501501 AsyncConsts.findRecTask aconsts declName
502502
503+ /-- Like `findRec?` but returns the constant that has `declName` in its `consts`, if any. -/
504+ private partial def AsyncConsts.findRecParent? (aconsts : AsyncConsts) (declName : Name) : Option AsyncConst :=
505+ go none aconsts
506+ where go parent? aconsts := do
507+ let c ← aconsts.findPrefix? declName
508+ if c.constInfo.name == declName then
509+ return (← parent?)
510+ -- If privacy is the only difference between `declName` and `findPrefix?` result, we can assume
511+ -- `declName` does not exist according to the `add` invariant
512+ guard <| privateToUserName c.constInfo.name != privateToUserName declName
513+ let aconsts ← c.consts.get.get? AsyncConsts
514+ go (some c) aconsts
515+
503516/-- Accessibility levels of declarations in `Lean.Environment`. -/
504517private inductive Visibility where
505518 /-- Information private to the module. -/
@@ -1190,6 +1203,11 @@ def instantiateValueLevelParams! (c : ConstantInfo) (ls : List Level) : Expr :=
11901203
11911204end ConstantInfo
11921205
1206+ inductive AsyncBranch where
1207+ | mainEnv
1208+ | asyncEnv
1209+ deriving BEq
1210+
11931211/--
11941212Async access mode for environment extensions used in `EnvExtension.get/set/modifyState`.
11951213When modified in concurrent contexts, extensions may need to switch to a different mode than the
@@ -1250,7 +1268,7 @@ inductive EnvExtension.AsyncMode where
12501268 own constant map works which asserts the same predicate on modification and provides `findAsync?`
12511269 for block-avoiding access.
12521270 -/
1253- | async
1271+ | async (branch : AsyncBranch)
12541272 deriving Inhabited
12551273
12561274abbrev ReplayFn (σ : Type ) :=
@@ -1364,67 +1382,50 @@ def setState {σ : Type} (ext : EnvExtension σ) (env : Environment) (s : σ) :
13641382
13651383-- `unsafe` fails to infer `Nonempty` here
13661384private unsafe def getStateUnsafe {σ : Type } [Inhabited σ] (ext : EnvExtension σ)
1367- (env : Environment) (asyncMode := ext.asyncMode) : σ :=
1385+ (env : Environment) (asyncMode := ext.asyncMode) (asyncDecl : Name := .anonymous) : σ := Id.run do
13681386 -- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
13691387 match asyncMode with
1370- | .sync => ext.getStateImpl env.checked.get.extensions
1371- | .async => panic! "called on `async` extension, use `findStateAsync` \
1372- instead or pass `(asyncMode := .local)` to explicitly access local state"
1388+ | .sync => ext.getStateImpl env.checked.get.extensions
1389+ | .async branch =>
1390+ if asyncDecl.isAnonymous then
1391+ panic! "called on `async` extension, must set `asyncDecl` \
1392+ or pass `(asyncMode := .local)` to explicitly access local state"
1393+ -- analogous structure to `findAsync?`; see there
1394+ -- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
1395+ if env.base.get env |>.constants.contains asyncDecl then
1396+ return ext.getStateImpl env.base.private.extensions
1397+ if let some c := match branch with
1398+ | .asyncEnv => env.asyncConsts.findRec? asyncDecl
1399+ | .mainEnv => env.asyncConsts.findRecParent? asyncDecl then
1400+ if let some exts := c.exts? then
1401+ return ext.getStateImpl exts.get
1402+ -- NOTE: if `exts?` is `none`, we should *not* try the following, more expensive branches that
1403+ -- will just come to the same conclusion
1404+ else if let some c := env.allRealizations.get.find? asyncDecl then
1405+ if let some exts := c.exts? then
1406+ return ext.getStateImpl exts.get
1407+ -- fallback; we could enforce that `findStateAsync` is only used on existing constants but the
1408+ -- upside of doing is unclear
1409+ ext.getStateImpl env.base.private.extensions
13731410 | _ => ext.getStateImpl env.base.private.extensions
13741411
13751412/--
13761413Returns the current extension state. See `AsyncMode` for details on how modifications from
13771414different environment branches are reconciled. Panics if the extension is marked as `async`; see its
13781415documentation for more details. Overriding the extension's default `AsyncMode` is usually not
13791416recommended and should be considered only for important optimizations.
1380- -/
1381- @ [implemented_by getStateUnsafe]
1382- opaque getState {σ : Type } [Inhabited σ] (ext : EnvExtension σ) (env : Environment)
1383- (asyncMode := ext.asyncMode) : σ
1384-
1385- -- `unsafe` fails to infer `Nonempty` here
1386- private unsafe def findStateAsyncUnsafe {σ : Type } [Inhabited σ]
1387- (ext : EnvExtension σ) (env : Environment) (declName : Name) : σ := Id.run do
1388- -- analogous structure to `findAsync?`; see there
1389- -- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
1390- if env.base.get env |>.constants.contains declName then
1391- return ext.getStateImpl env.base.private.extensions
1392- if let some c := env.asyncConsts.find? declName then
1393- if let some exts := c.exts? then
1394- return ext.getStateImpl exts.get
1395- -- NOTE: if `exts?` is `none`, we should *not* try the following, more expensive branches that
1396- -- will just come to the same conclusion
1397- else if let some exts := findRecExts? none env.asyncConsts declName then
1398- return ext.getStateImpl exts.get
1399- else if let some c := env.allRealizations.get.find? declName then
1400- if let some exts := c.exts? then
1401- return ext.getStateImpl exts.get
1402- -- fallback; we could enforce that `findStateAsync` is only used on existing constants but the
1403- -- upside of doing is unclear
1404- ext.getStateImpl env.base.private.extensions
1405- where
1406- /--
1407- Like `AsyncConsts.findRec?`, but if `AsyncConst.exts?` is `none`, returns the extension state of
1408- the surrounding `AsyncConst` instead, which is where state for synchronously added constants is
1409- stored.
1410- -/
1411- findRecExts? (parent? : Option AsyncConst) (aconsts : AsyncConsts) (declName : Name) :
1412- Option (Task (Array EnvExtensionState)) := do
1413- let c ← aconsts.findPrefix? declName
1414- if c.constInfo.name == declName then
1415- return (← c.exts?.or (parent?.bind (·.exts?)))
1416- let aconsts ← c.consts.get.get? AsyncConsts
1417- findRecExts? c aconsts declName
14181417
1418+ Returns the extension state on the environment branch corresponding to the passed declaration name,
1419+ if any, or otherwise the state on the current branch. In other words, at most one environment branch
1420+ will be blocked on.
14191421
1420- /--
1421- Returns the final extension state on the environment branch corresponding to the passed declaration
1422- name, if any, or otherwise the state on the current branch. In other words, at most one environment
1423- branch will be blocked on.
1422+ More specifically, if `afterAsync` is set to `true`, the retrieved state will be the one at the time
1423+ `AddConstAsyncResult.commitConst` was called; if it is `false`, it will instead be taken fro the
1424+ branch that that called `addConstAsync` in the first place.
14241425-/
1425- @ [implemented_by findStateAsyncUnsafe ]
1426- opaque findStateAsync {σ : Type } [Inhabited σ] (ext : EnvExtension σ)
1427- (env : Environment ) (declName : Name) : σ
1426+ @ [implemented_by getStateUnsafe ]
1427+ opaque getState {σ : Type } [Inhabited σ] (ext : EnvExtension σ) (env : Environment )
1428+ (asyncMode := ext.asyncMode ) (asyncDecl : Name := .anonymous ) : σ
14281429
14291430end EnvExtension
14301431
@@ -1593,8 +1594,8 @@ def addEntry {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : En
15931594
15941595/-- Get the current state of the given extension in the given environment. -/
15951596def getState {α β σ : Type } [Inhabited σ] (ext : PersistentEnvExtension α β σ) (env : Environment)
1596- (asyncMode := ext.toEnvExtension.asyncMode) : σ :=
1597- (ext.toEnvExtension.getState (asyncMode := asyncMode) env).state
1597+ (asyncMode := ext.toEnvExtension.asyncMode) (asyncDecl : Name := .anonymous) : σ :=
1598+ (ext.toEnvExtension.getState (asyncMode := asyncMode) (asyncDecl := asyncDecl) env).state
15981599
15991600/-- Set the current state of the given extension in the given environment. -/
16001601def setState {α β σ : Type } (ext : PersistentEnvExtension α β σ) (env : Environment) (s : σ) : Environment :=
@@ -1605,12 +1606,6 @@ def modifyState {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env :
16051606 (asyncMode := ext.toEnvExtension.asyncMode) : Environment :=
16061607 ext.toEnvExtension.modifyState (asyncMode := asyncMode) env fun ps => { ps with state := f (ps.state) }
16071608
1608- @ [inherit_doc EnvExtension.findStateAsync]
1609- def findStateAsync {α β σ : Type } [Inhabited σ] (ext : PersistentEnvExtension α β σ)
1610- (env : Environment) (declPrefix : Name) : σ :=
1611- ext.toEnvExtension.findStateAsync env declPrefix |>.state
1612-
1613-
16141609end PersistentEnvExtension
16151610
16161611builtin_initialize persistentEnvExtensionsRef : IO.Ref (Array (PersistentEnvExtension EnvExtensionEntry EnvExtensionEntry EnvExtensionState)) ← IO.mkRef #[]
@@ -1736,7 +1731,7 @@ def mkModuleData (env : Environment) (level : OLeanLevel := .private) : IO Modul
17361731 let entries := pExts.map fun pExt => Id.run do
17371732 -- get state from `checked` at the end if `async`; it would otherwise panic
17381733 let mut asyncMode := pExt.toEnvExtension.asyncMode
1739- if asyncMode matches .async then
1734+ if asyncMode matches .async _ then
17401735 asyncMode := .sync
17411736 let state := pExt.getState (asyncMode := asyncMode) env
17421737 (pExt.name, pExt.exportEntriesFn env state level)
@@ -2266,7 +2261,7 @@ def displayStats (env : Environment) : IO Unit := do
22662261 IO.println ("extension '" ++ toString extDescr.name ++ "'" )
22672262 -- get state from `checked` at the end if `async`; it would otherwise panic
22682263 let mut asyncMode := extDescr.toEnvExtension.asyncMode
2269- if asyncMode matches .async then
2264+ if asyncMode matches .async _ then
22702265 asyncMode := .sync
22712266 let s := extDescr.toEnvExtension.getState (asyncMode := asyncMode) env
22722267 let fmt := extDescr.statsFn s.state
0 commit comments