Skip to content

Commit 1f05cbc

Browse files
committed
perf: clarify and granularize access to async env ext state
1 parent 95e753c commit 1f05cbc

10 files changed

Lines changed: 76 additions & 96 deletions

File tree

src/Lean/Attributes.lean

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -206,15 +206,7 @@ def setTag [Monad m] [MonadError m] [MonadEnv m] (attr : TagAttribute) (decl :
206206
def hasTag (attr : TagAttribute) (env : Environment) (decl : Name) : Bool :=
207207
match env.getModuleIdxFor? decl with
208208
| some modIdx => (attr.ext.getModuleEntries env modIdx).binSearchContains decl Name.quickLt
209-
| none =>
210-
if attr.ext.toEnvExtension.asyncMode matches .async then
211-
-- It seems that the env extension API doesn't quite allow querying attributes in a way
212-
-- that works for realizable constants, but without waiting on proofs to finish.
213-
-- Until then, we use the following overapproximation, to be refined later:
214-
(attr.ext.findStateAsync env decl).contains decl ||
215-
(attr.ext.getState env (asyncMode := .local)).contains decl
216-
else
217-
(attr.ext.getState env).contains decl
209+
| none => (attr.ext.getState (asyncDecl := decl) env).contains decl
218210

219211
end TagAttribute
220212

@@ -303,7 +295,7 @@ def registerEnumAttributes (attrDescrs : List (Name × String × α))
303295
statsFn := fun s => "enumeration attribute extension" ++ Format.line ++ "number of local entries: " ++ format s.size
304296
-- We assume (and check below) that, if used asynchronously, enum attributes are set only in the
305297
-- same context in which the tagged declaration was created
306-
asyncMode := .async
298+
asyncMode := .async .mainEnv
307299
replay? := some fun _ newState consts st => consts.foldl (init := st) fun st c =>
308300
match newState.find? c with
309301
| some v => st.insert c v
@@ -335,15 +327,15 @@ def getValue [Inhabited α] (attr : EnumAttributes α) (env : Environment) (decl
335327
match (attr.ext.getModuleEntries env modIdx).binSearch (decl, default) (fun a b => Name.quickLt a.1 b.1) with
336328
| some (_, val) => some val
337329
| none => none
338-
| none => (attr.ext.findStateAsync env decl).find? decl
330+
| none => (attr.ext.getState (asyncDecl := decl) env).find? decl
339331

340332
def setValue (attrs : EnumAttributes α) (env : Environment) (decl : Name) (val : α) : Except String Environment := do
341333
let pfx := s!"Internal error calling `{attrs.ext.name}.setValue` for `{decl}`"
342334
if (env.getModuleIdxFor? decl).isSome then
343335
throw s!"{pfx}: Declaration is in an imported module"
344336
if !env.asyncMayContain decl then
345337
throw s!"{pfx}: Declaration is not from this async context `{env.asyncPrefix?}`"
346-
if ((attrs.ext.findStateAsync env decl).find? decl).isSome then
338+
if ((attrs.ext.getState (asyncDecl := decl) env).find? decl).isSome then
347339
throw s!"{pfx}: Attribute has already been set"
348340
return attrs.ext.addEntry env (decl, val)
349341

src/Lean/Compiler/MetaAttr.lean

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ public section
1212

1313
namespace Lean
1414

15-
builtin_initialize metaExt : TagDeclarationExtension ← mkTagDeclarationExtension (asyncMode := .async)
15+
builtin_initialize metaExt : TagDeclarationExtension ←
16+
mkTagDeclarationExtension (asyncMode := .async .mainEnv)
1617

1718
/-- Marks in the environment extension that the given declaration has been declared by the user as `meta`. -/
1819
def addMeta (env : Environment) (declName : Name) : Environment :=

src/Lean/DefEqAttrib.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ need to be unfolded to prove the theorem are exported and exposed.
6868
builtin_initialize defeqAttr : TagAttribute ←
6969
registerTagAttribute `defeq "mark theorem as a definitional equality, to be used by `dsimp`"
7070
(validate := validateDefEqAttr) (applicationTime := .afterTypeChecking)
71-
(asyncMode := .async)
71+
(asyncMode := .async .mainEnv)
7272

7373
private partial def isRflProofCore (type : Expr) (proof : Expr) : CoreM Bool := do
7474
match type with

src/Lean/EnvExtension.lean

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def getEntries {α σ : Type} [Inhabited σ] (ext : SimplePersistentEnvExtension
7171

7272
/-- Get the current state of the given `SimplePersistentEnvExtension`. -/
7373
def getState {α σ : Type} [Inhabited σ] (ext : SimplePersistentEnvExtension α σ) (env : Environment)
74-
(asyncMode := ext.toEnvExtension.asyncMode) : σ :=
75-
(PersistentEnvExtension.getState (asyncMode := asyncMode) ext env).2
74+
(asyncMode := ext.toEnvExtension.asyncMode) (asyncDecl : Name := .anonymous) : σ :=
75+
(PersistentEnvExtension.getState (asyncMode := asyncMode) (asyncDecl := asyncDecl) ext env).2
7676

7777
/-- Set the current state of the given `SimplePersistentEnvExtension`. This change is *not* persisted across files. -/
7878
def setState {α σ : Type} (ext : SimplePersistentEnvExtension α σ) (env : Environment) (s : σ) : Environment :=
@@ -82,11 +82,6 @@ def setState {α σ : Type} (ext : SimplePersistentEnvExtension α σ) (env : En
8282
def modifyState {α σ : Type} (ext : SimplePersistentEnvExtension α σ) (env : Environment) (f : σ → σ) : Environment :=
8383
PersistentEnvExtension.modifyState ext env (fun ⟨entries, s⟩ => (entries, f s))
8484

85-
@[inherit_doc PersistentEnvExtension.findStateAsync]
86-
def findStateAsync {α σ : Type} [Inhabited σ] (ext : SimplePersistentEnvExtension α σ)
87-
(env : Environment) (declPrefix : Name) : σ :=
88-
PersistentEnvExtension.findStateAsync ext env declPrefix |>.2
89-
9085
end SimplePersistentEnvExtension
9186

9287
/-- Environment extension for tagging declarations.
@@ -117,10 +112,7 @@ def tag (ext : TagDeclarationExtension) (env : Environment) (declName : Name) :
117112
def isTagged (ext : TagDeclarationExtension) (env : Environment) (declName : Name) : Bool :=
118113
match env.getModuleIdxFor? declName with
119114
| some modIdx => (ext.getModuleEntries env modIdx).binSearchContains declName Name.quickLt
120-
| none => if ext.toEnvExtension.asyncMode matches .async then
121-
(ext.findStateAsync env declName).contains declName
122-
else
123-
(ext.getState env).contains declName
115+
| none => (ext.getState (asyncDecl := declName) env).contains declName
124116

125117
end TagDeclarationExtension
126118

@@ -140,7 +132,7 @@ def mkMapDeclarationExtension (name : Name := by exact decl_name%)
140132
addImportedFn := fun _ => pure {}
141133
addEntryFn := fun s (n, v) => s.insert n v
142134
exportEntriesFnEx env s level := exportEntriesFn env s level
143-
asyncMode := .async
135+
asyncMode := .async .mainEnv
144136
replay? := some fun _ newState newConsts s =>
145137
newConsts.foldl (init := s) fun s c =>
146138
if let some a := newState.find? c then
@@ -165,11 +157,11 @@ def find? [Inhabited α] (ext : MapDeclarationExtension α) (env : Environment)
165157
match (ext.getModuleEntries (level := level) env modIdx).binSearch (declName, default) (fun a b => Name.quickLt a.1 b.1) with
166158
| some e => some e.2
167159
| none => none
168-
| none => (ext.findStateAsync env declName).find? declName
160+
| none => (ext.getState (asyncDecl := declName) env).find? declName
169161

170162
def contains [Inhabited α] (ext : MapDeclarationExtension α) (env : Environment) (declName : Name) : Bool :=
171163
match env.getModuleIdxFor? declName with
172164
| some modIdx => (ext.getModuleEntries env modIdx).binSearchContains (declName, default) (fun a b => Name.quickLt a.1 b.1)
173-
| none => (ext.findStateAsync env declName).contains declName
165+
| none => (ext.getState (asyncDecl := declName) env).contains declName
174166

175167
end MapDeclarationExtension

src/Lean/Environment.lean

Lines changed: 55 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,19 @@ private partial def AsyncConsts.findRecTask (aconsts : AsyncConsts) (declName :
500500
let some aconsts := aconsts.get? AsyncConsts | .pure none
501501
AsyncConsts.findRecTask aconsts declName
502502

503+
/-- Like `findRec?` but returns the constant that has `declName` in its `consts`, if any. -/
504+
private partial def AsyncConsts.findRecParent? (aconsts : AsyncConsts) (declName : Name) : Option AsyncConst :=
505+
go none aconsts
506+
where go parent? aconsts := do
507+
let c ← aconsts.findPrefix? declName
508+
if c.constInfo.name == declName then
509+
return (← parent?)
510+
-- If privacy is the only difference between `declName` and `findPrefix?` result, we can assume
511+
-- `declName` does not exist according to the `add` invariant
512+
guard <| privateToUserName c.constInfo.name != privateToUserName declName
513+
let aconsts ← c.consts.get.get? AsyncConsts
514+
go (some c) aconsts
515+
503516
/-- Accessibility levels of declarations in `Lean.Environment`. -/
504517
private inductive Visibility where
505518
/-- Information private to the module. -/
@@ -1190,6 +1203,11 @@ def instantiateValueLevelParams! (c : ConstantInfo) (ls : List Level) : Expr :=
11901203

11911204
end ConstantInfo
11921205

1206+
inductive AsyncBranch where
1207+
| mainEnv
1208+
| asyncEnv
1209+
deriving BEq
1210+
11931211
/--
11941212
Async access mode for environment extensions used in `EnvExtension.get/set/modifyState`.
11951213
When modified in concurrent contexts, extensions may need to switch to a different mode than the
@@ -1250,7 +1268,7 @@ inductive EnvExtension.AsyncMode where
12501268
own constant map works which asserts the same predicate on modification and provides `findAsync?`
12511269
for block-avoiding access.
12521270
-/
1253-
| async
1271+
| async (branch : AsyncBranch)
12541272
deriving Inhabited
12551273

12561274
abbrev ReplayFn (σ : Type) :=
@@ -1364,67 +1382,50 @@ def setState {σ : Type} (ext : EnvExtension σ) (env : Environment) (s : σ) :
13641382

13651383
-- `unsafe` fails to infer `Nonempty` here
13661384
private unsafe def getStateUnsafe {σ : Type} [Inhabited σ] (ext : EnvExtension σ)
1367-
(env : Environment) (asyncMode := ext.asyncMode) : σ :=
1385+
(env : Environment) (asyncMode := ext.asyncMode) (asyncDecl : Name := .anonymous) : σ := Id.run do
13681386
-- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
13691387
match asyncMode with
1370-
| .sync => ext.getStateImpl env.checked.get.extensions
1371-
| .async => panic! "called on `async` extension, use `findStateAsync` \
1372-
instead or pass `(asyncMode := .local)` to explicitly access local state"
1388+
| .sync => ext.getStateImpl env.checked.get.extensions
1389+
| .async branch =>
1390+
if asyncDecl.isAnonymous then
1391+
panic! "called on `async` extension, must set `asyncDecl` \
1392+
or pass `(asyncMode := .local)` to explicitly access local state"
1393+
-- analogous structure to `findAsync?`; see there
1394+
-- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
1395+
if env.base.get env |>.constants.contains asyncDecl then
1396+
return ext.getStateImpl env.base.private.extensions
1397+
if let some c := match branch with
1398+
| .asyncEnv => env.asyncConsts.findRec? asyncDecl
1399+
| .mainEnv => env.asyncConsts.findRecParent? asyncDecl then
1400+
if let some exts := c.exts? then
1401+
return ext.getStateImpl exts.get
1402+
-- NOTE: if `exts?` is `none`, we should *not* try the following, more expensive branches that
1403+
-- will just come to the same conclusion
1404+
else if let some c := env.allRealizations.get.find? asyncDecl then
1405+
if let some exts := c.exts? then
1406+
return ext.getStateImpl exts.get
1407+
-- fallback; we could enforce that `findStateAsync` is only used on existing constants but the
1408+
-- upside of doing is unclear
1409+
ext.getStateImpl env.base.private.extensions
13731410
| _ => ext.getStateImpl env.base.private.extensions
13741411

13751412
/--
13761413
Returns the current extension state. See `AsyncMode` for details on how modifications from
13771414
different environment branches are reconciled. Panics if the extension is marked as `async`; see its
13781415
documentation for more details. Overriding the extension's default `AsyncMode` is usually not
13791416
recommended and should be considered only for important optimizations.
1380-
-/
1381-
@[implemented_by getStateUnsafe]
1382-
opaque getState {σ : Type} [Inhabited σ] (ext : EnvExtension σ) (env : Environment)
1383-
(asyncMode := ext.asyncMode) : σ
1384-
1385-
-- `unsafe` fails to infer `Nonempty` here
1386-
private unsafe def findStateAsyncUnsafe {σ : Type} [Inhabited σ]
1387-
(ext : EnvExtension σ) (env : Environment) (declName : Name) : σ := Id.run do
1388-
-- analogous structure to `findAsync?`; see there
1389-
-- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
1390-
if env.base.get env |>.constants.contains declName then
1391-
return ext.getStateImpl env.base.private.extensions
1392-
if let some c := env.asyncConsts.find? declName then
1393-
if let some exts := c.exts? then
1394-
return ext.getStateImpl exts.get
1395-
-- NOTE: if `exts?` is `none`, we should *not* try the following, more expensive branches that
1396-
-- will just come to the same conclusion
1397-
else if let some exts := findRecExts? none env.asyncConsts declName then
1398-
return ext.getStateImpl exts.get
1399-
else if let some c := env.allRealizations.get.find? declName then
1400-
if let some exts := c.exts? then
1401-
return ext.getStateImpl exts.get
1402-
-- fallback; we could enforce that `findStateAsync` is only used on existing constants but the
1403-
-- upside of doing is unclear
1404-
ext.getStateImpl env.base.private.extensions
1405-
where
1406-
/--
1407-
Like `AsyncConsts.findRec?`, but if `AsyncConst.exts?` is `none`, returns the extension state of
1408-
the surrounding `AsyncConst` instead, which is where state for synchronously added constants is
1409-
stored.
1410-
-/
1411-
findRecExts? (parent? : Option AsyncConst) (aconsts : AsyncConsts) (declName : Name) :
1412-
Option (Task (Array EnvExtensionState)) := do
1413-
let c ← aconsts.findPrefix? declName
1414-
if c.constInfo.name == declName then
1415-
return (← c.exts?.or (parent?.bind (·.exts?)))
1416-
let aconsts ← c.consts.get.get? AsyncConsts
1417-
findRecExts? c aconsts declName
14181417
1418+
Returns the extension state on the environment branch corresponding to the passed declaration name,
1419+
if any, or otherwise the state on the current branch. In other words, at most one environment branch
1420+
will be blocked on.
14191421
1420-
/--
1421-
Returns the final extension state on the environment branch corresponding to the passed declaration
1422-
name, if any, or otherwise the state on the current branch. In other words, at most one environment
1423-
branch will be blocked on.
1422+
More specifically, if `afterAsync` is set to `true`, the retrieved state will be the one at the time
1423+
`AddConstAsyncResult.commitConst` was called; if it is `false`, it will instead be taken fro the
1424+
branch that that called `addConstAsync` in the first place.
14241425
-/
1425-
@[implemented_by findStateAsyncUnsafe]
1426-
opaque findStateAsync {σ : Type} [Inhabited σ] (ext : EnvExtension σ)
1427-
(env : Environment) (declName : Name) : σ
1426+
@[implemented_by getStateUnsafe]
1427+
opaque getState {σ : Type} [Inhabited σ] (ext : EnvExtension σ) (env : Environment)
1428+
(asyncMode := ext.asyncMode) (asyncDecl : Name := .anonymous) : σ
14281429

14291430
end EnvExtension
14301431

@@ -1593,8 +1594,8 @@ def addEntry {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : En
15931594

15941595
/-- Get the current state of the given extension in the given environment. -/
15951596
def getState {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ) (env : Environment)
1596-
(asyncMode := ext.toEnvExtension.asyncMode) : σ :=
1597-
(ext.toEnvExtension.getState (asyncMode := asyncMode) env).state
1597+
(asyncMode := ext.toEnvExtension.asyncMode) (asyncDecl : Name := .anonymous) : σ :=
1598+
(ext.toEnvExtension.getState (asyncMode := asyncMode) (asyncDecl := asyncDecl) env).state
15981599

15991600
/-- Set the current state of the given extension in the given environment. -/
16001601
def setState {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (s : σ) : Environment :=
@@ -1605,12 +1606,6 @@ def modifyState {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env :
16051606
(asyncMode := ext.toEnvExtension.asyncMode) : Environment :=
16061607
ext.toEnvExtension.modifyState (asyncMode := asyncMode) env fun ps => { ps with state := f (ps.state) }
16071608

1608-
@[inherit_doc EnvExtension.findStateAsync]
1609-
def findStateAsync {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ)
1610-
(env : Environment) (declPrefix : Name) : σ :=
1611-
ext.toEnvExtension.findStateAsync env declPrefix |>.state
1612-
1613-
16141609
end PersistentEnvExtension
16151610

16161611
builtin_initialize persistentEnvExtensionsRef : IO.Ref (Array (PersistentEnvExtension EnvExtensionEntry EnvExtensionEntry EnvExtensionState)) ← IO.mkRef #[]
@@ -1736,7 +1731,7 @@ def mkModuleData (env : Environment) (level : OLeanLevel := .private) : IO Modul
17361731
let entries := pExts.map fun pExt => Id.run do
17371732
-- get state from `checked` at the end if `async`; it would otherwise panic
17381733
let mut asyncMode := pExt.toEnvExtension.asyncMode
1739-
if asyncMode matches .async then
1734+
if asyncMode matches .async _ then
17401735
asyncMode := .sync
17411736
let state := pExt.getState (asyncMode := asyncMode) env
17421737
(pExt.name, pExt.exportEntriesFn env state level)
@@ -2266,7 +2261,7 @@ def displayStats (env : Environment) : IO Unit := do
22662261
IO.println ("extension '" ++ toString extDescr.name ++ "'")
22672262
-- get state from `checked` at the end if `async`; it would otherwise panic
22682263
let mut asyncMode := extDescr.toEnvExtension.asyncMode
2269-
if asyncMode matches .async then
2264+
if asyncMode matches .async _ then
22702265
asyncMode := .sync
22712266
let s := extDescr.toEnvExtension.getState (asyncMode := asyncMode) env
22722267
let fmt := extDescr.statsFn s.state

src/Lean/Meta/Eqns.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ This information is populated by the `PreDefinition` module, but the simplifier
4949
uses when unfolding declarations.
5050
-/
5151
builtin_initialize recExt : TagDeclarationExtension ←
52-
mkTagDeclarationExtension `recExt (asyncMode := .async)
52+
mkTagDeclarationExtension `recExt (asyncMode := .async .asyncEnv)
5353

5454
/--
5555
Marks the given declaration as recursive.

src/Lean/Meta/Match/MatchEqs.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ def getEquationsForImpl (matchDeclName : Name) : MetaM MatchEqns := do
746746
-- `realizeConst` as well as for looking up the resultant environment extension state via
747747
-- `findStateAsync`.
748748
realizeConst matchDeclName splitterName (go baseName splitterName)
749-
return matchEqnsExt.findStateAsync (← getEnv) splitterName |>.map.find! matchDeclName
749+
return matchEqnsExt.getState (asyncMode := .async .asyncEnv) (asyncDecl := splitterName) (← getEnv) |>.map.find! matchDeclName
750750
where go baseName splitterName := withConfig (fun c => { c with etaStruct := .none }) do
751751
let constInfo ← getConstInfo matchDeclName
752752
let us := constInfo.levelParams.map mkLevelParam
@@ -866,7 +866,7 @@ def genMatchCongrEqns (matchDeclName : Name) : MetaM (Array Name) := do
866866
let baseName := mkPrivateName (← getEnv) matchDeclName
867867
let firstEqnName := .str baseName congrEqn1ThmSuffix
868868
realizeConst matchDeclName firstEqnName (go baseName)
869-
return matchCongrEqnsExt.findStateAsync (← getEnv) firstEqnName |>.find! matchDeclName
869+
return matchCongrEqnsExt.getState (asyncMode := .async .asyncEnv) (asyncDecl := firstEqnName) (← getEnv) |>.find! matchDeclName
870870
where go baseName := withConfig (fun c => { c with etaStruct := .none }) do
871871
withConfig (fun c => { c with etaStruct := .none }) do
872872
let constInfo ← getConstInfo matchDeclName

src/Lean/Meta/Match/MatchEqsExt.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,6 @@ def isMatchEqnTheorem (env : Environment) (declName : Name) : Bool := Id.run do
5454
let .str _ s := declName.eraseMacroScopes | return false
5555
if !isEqnLikeSuffix s then
5656
return false
57-
(matchEqnsExt.findStateAsync env declName).eqns.contains declName
57+
(matchEqnsExt.getState (asyncMode := .async .asyncEnv) (asyncDecl := declName) env).eqns.contains declName
5858

5959
end Lean.Meta.Match

src/Lean/Meta/Match/MatcherInfo.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ builtin_initialize extension : SimplePersistentEnvExtension Entry State ←
8686
registerSimplePersistentEnvExtension {
8787
addEntryFn := State.addEntry
8888
addImportedFn := fun es => (mkStateFromImportedEntries State.addEntry {} es).switch
89-
asyncMode := .async
89+
asyncMode := .async .mainEnv
9090
}
9191

9292
def addMatcherInfo (env : Environment) (matcherName : Name) (info : MatcherInfo) : Environment :=
@@ -98,7 +98,7 @@ def getMatcherInfo? (env : Environment) (declName : Name) : Option MatcherInfo :
9898
-- avoid blocking on async decls whose names look nothing like matchers
9999
let .str _ s := declName.eraseMacroScopes | none
100100
guard <| s.startsWith "match_"
101-
(extension.findStateAsync env declName).map.find? declName
101+
(extension.getState (asyncDecl := declName) env).map.find? declName
102102

103103
end Extension
104104

0 commit comments

Comments
 (0)