Skip to content

Commit 170501a

Browse files
committed
feat: extend profile_runner.py parameterization to test extreme contexts
- Add --contexts flag to seamlessly loop through scale factors - Refactor script to output extended markdown matrix encompassing context depths - Enables sequential TTFT scaling tests up to 100k prompts
1 parent 2cd373f commit 170501a

1 file changed

Lines changed: 39 additions & 25 deletions

File tree

scripts/profiling/profile_runner.py

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def poll_health(port=5413, timeout=30):
3131
return False
3232

3333
def make_request_stream(prompt_len, max_tokens, port=5413):
34+
# To prevent blowing up python memory when generating 100k prompts, build efficiently
3435
prompt = "apple " * int(prompt_len * 0.75)
3536
data = json.dumps({
3637
"messages": [{"role": "user", "content": prompt}],
@@ -49,7 +50,8 @@ def make_request_stream(prompt_len, max_tokens, port=5413):
4950
start = time.time()
5051
tokens = 0
5152
try:
52-
with urllib.request.urlopen(req, timeout=120) as response:
53+
# Extreme context testing requires a very large socket timeout
54+
with urllib.request.urlopen(req, timeout=900) as response:
5355
for line in response:
5456
line = line.decode('utf-8').strip()
5557
if line.startswith("data: ") and line != "data: [DONE]":
@@ -87,54 +89,66 @@ def main():
8789
parser = argparse.ArgumentParser(description="Aegis-AI Physical Model Profiler")
8890
parser.add_argument("--model", required=True, help="Model ID (e.g. gemma-4-26b-a4b-it-4bit)")
8991
parser.add_argument("--out", default="./profiling_results.md", help="Output markdown file path")
92+
parser.add_argument("--contexts", default="512", help="Comma-separated list of context lengths to test (e.g. 512,40000,100000)")
9093
args = parser.parse_args()
9194

95+
context_sizes = [int(x.strip()) for x in args.contexts.split(",") if x.strip()]
9296
results = []
97+
9398
subprocess.run(["killall", "SwiftLM"], stderr=subprocess.DEVNULL)
9499

95100
for config in CONFIGS:
96-
print(f"\n--- Profiling {args.model} [{config['name']}] ---")
97-
model_path = f"/Users/simba/.aegis-ai/models/mlx_models/mlx-community/{args.model}"
101+
print(f"\n==============================================")
102+
print(f"--- Profiling {args.model} [{config['name']}] ---")
103+
print(f"==============================================")
98104

105+
model_path = f"/Users/simba/.aegis-ai/models/mlx_models/mlx-community/{args.model}"
99106
log_path = "./tmp/profile_server.log"
100107
cmd = [SWIFTLM_PATH, "--model", model_path] + config["flags"]
101108

102109
with open(log_path, "w") as root_log:
103110
server_proc = subprocess.Popen(cmd, stdout=root_log, stderr=subprocess.STDOUT)
104111

105-
if not poll_health():
112+
if not poll_health(timeout=60):
106113
print("Server failed to start.")
107114
server_proc.terminate()
108115
continue
109116

110117
static_mem = extract_base_memory(log_path)
111118

112-
print("Running 20-token test (prefill ~512, max ~20)...")
113-
ok, ttft, tps = make_request_stream(prompt_len=512, max_tokens=20)
114-
119+
for ctx_size in context_sizes:
120+
print(f"\n>> Running {ctx_size}-token context test (max generation ~20)...")
121+
ok, ttft, tps = make_request_stream(prompt_len=ctx_size, max_tokens=20)
122+
123+
real_mem = extract_real_memory(log_path)
124+
125+
if ok:
126+
results.append({
127+
"config": config["name"],
128+
"context": ctx_size,
129+
"ttft_20": f"{ttft:.2f}",
130+
"tps_20": f"{tps:.2f}",
131+
"static_mem": static_mem,
132+
"real_mem": real_mem
133+
})
134+
print(f"Result [{config['name']} | Ctx: {ctx_size}]: TTFT={ttft:.2f}s TPS={tps:.2f} BaseRAM={static_mem} PhysRAM={real_mem}")
135+
else:
136+
print(f"Result [{config['name']} | Ctx: {ctx_size}]: FAILED / OOM")
137+
138+
# Teardown after finishing all context sizes for this config
115139
server_proc.send_signal(subprocess.signal.SIGTERM)
116-
server_proc.wait(timeout=10)
117-
118-
real_mem = extract_real_memory(log_path)
119-
120-
if ok:
121-
results.append({
122-
"config": config["name"],
123-
"ttft_20": f"{ttft:.2f}",
124-
"tps_20": f"{tps:.2f}",
125-
"static_mem": static_mem,
126-
"real_mem": real_mem
127-
})
128-
print(f"Result [{config['name']}]: TTFT={ttft:.2f}s TPS={tps:.2f} BaseRAM={static_mem} PhysRAM={real_mem}")
140+
server_proc.wait(timeout=20)
141+
time.sleep(2) # Give OS memory manager a breather to reap active wires
129142

130143
with open(args.out, "w") as f:
131-
f.write(f"### `{args.model}` - Throughput & OS Memory Profile\n\n")
132-
f.write("| Configuration | Time To First Token | Generation Speed | Theoretical Reservation | Physical OS Footprint (RAM) |\n")
133-
f.write("|---|---|---|---|---|\n")
144+
f.write(f"### `{args.model}` - Extreme Context & Footprint Profile\n\n")
145+
f.write(f"Tested Context Lengths: {args.contexts}\n\n")
146+
f.write("| Configuration | Context Size | Time To First Token | Generation Speed | Theoretical Reservation | Peak OS Footprint (Active RAM) |\n")
147+
f.write("|---|---|---|---|---|---|\n")
134148
for r in results:
135-
f.write(f"| {r['config']} | {r['ttft_20']}s | {r['tps_20']} tok/s | {r['static_mem']} | {r['real_mem']} |\n")
149+
f.write(f"| {r['config']} | {r['context']} | {r['ttft_20']}s | {r['tps_20']} tok/s | {r['static_mem']} | {r['real_mem']} |\n")
136150

137-
print(f"\nDone. Results saved to {args.out}")
151+
print(f"\nDone. Matrix saved to {args.out}")
138152

139153
if __name__ == "__main__":
140154
main()

0 commit comments

Comments
 (0)