-
Notifications
You must be signed in to change notification settings - Fork 528
Expand file tree
/
Copy pathgoodput.py
More file actions
128 lines (109 loc) · 4.52 KB
/
goodput.py
File metadata and controls
128 lines (109 loc) · 4.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities for monitoring and recording job's goodput.
This module provides methods to monitor and record goodput metrics
to various logging platforms, including cloud logging and TensorBoard.
"""
import contextlib
import jax
from enum import Enum
from maxtext.utils import max_logging
from maxtext.common.gcloud_stub import goodput_modules
goodput, monitoring, _GOODPUT_STUB = goodput_modules()
class GoodputEvent(Enum):
JOB = "job"
TPU_INIT = "tpu_init"
TRAINING_PREPARATION = "training_preparation"
DATA_LOADING = "data_loading"
STEP = "step"
# Recorder method name constants for explicit job start/end recording.
# Derived from the enum so they stay in sync if the value ever changes.
RECORD_JOB_START_TIME = f"record_{GoodputEvent.JOB.value}_start_time"
RECORD_JOB_END_TIME = f"record_{GoodputEvent.JOB.value}_end_time"
@contextlib.contextmanager
def maybe_monitor_goodput(config):
"""Monitor cumulative goodput if enabled on the lead host.
When the goodput module is stubbed or monitoring is disabled, this
becomes a lightweight no-op context manager.
"""
if _GOODPUT_STUB:
if config.monitor_goodput and jax.process_index() == 0:
max_logging.log("[GOODPUT NO-OP] monitoring disabled (decoupled stub).")
yield
return
if not config.monitor_goodput or jax.process_index() != 0:
yield
return
goodput_monitor = None
try:
if config.report_performance_metric_for_gcp_monitoring:
config.enable_gcp_step_deviation_metrics = False
gcp_options = monitoring.GCPOptions(
enable_gcp_goodput_metrics=config.enable_gcp_goodput_metrics,
enable_gcp_step_deviation_metrics=config.enable_gcp_step_deviation_metrics,
)
goodput_monitor = monitoring.GoodputMonitor(
job_name=config.run_name,
logger_name=f"goodput_{config.run_name}",
tensorboard_dir=config.tensorboard_dir,
upload_interval=config.goodput_upload_interval_seconds,
monitoring_enabled=True,
pathway_enabled=config.enable_pathways_goodput,
include_badput_breakdown=True,
include_step_deviation=config.monitor_step_time_deviation,
step_deviation_interval_seconds=config.step_deviation_interval_seconds,
gcp_options=gcp_options,
)
goodput_monitor.start_goodput_uploader()
max_logging.log("Started Goodput upload to Tensorboard & GCM in the background!")
yield
finally:
if goodput_monitor:
goodput_monitor.stop_goodput_uploader()
max_logging.log("Flushed final metrics and safe exited from Goodput monitoring.")
@contextlib.contextmanager
def maybe_record_goodput(recorder, event_name, *args):
"""Record goodput if `enable_goodput_recording=True`.
The end-time event is only recorded when the wrapped block exits without
raising an exception (i.e. the event truly completed). Callers that need
explicit end-time control — e.g. GoodputEvent.JOB under elastic training
where the elastic manager may suppress the JAX exception internally —
should call record_goodput directly rather than using this context manager.
"""
record_goodput(recorder, f"record_{event_name.value}_start_time", *args)
completed = False
try:
yield
completed = True
finally:
if completed:
record_goodput(recorder, f"record_{event_name.value}_end_time", *args)
def record_goodput(recorder, event_name, *args):
"""Record goodput to cloud logging."""
if recorder:
record_func = getattr(recorder, event_name, None)
if record_func:
record_func(*args)
def create_goodput_recorder(config):
"""Create goodput recorder if `enable_goodput_recording=True`."""
if _GOODPUT_STUB:
if config.enable_goodput_recording and jax.process_index() == 0:
max_logging.log("[GOODPUT NO-OP] recorder skipped (decoupled stub).")
return None
if config.enable_goodput_recording:
logger_name = f"goodput_{config.run_name}"
recorder = goodput.GoodputRecorder(config.run_name, logger_name, jax.process_index() == 0)
return recorder
return None