Skip to content

Commit 9f643a8

Browse files
committed
Remove huggingface_hub dependency — SwiftLM downloads models natively via HubApi
1 parent 391cb43 commit 9f643a8

1 file changed

Lines changed: 5 additions & 20 deletions

File tree

scripts/profiling/profile_runner.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -111,32 +111,17 @@ def extract_os_ram(log_path):
111111
except: pass
112112
return "N/A"
113113

114-
def download_model(repo_id, models_dir):
115-
try:
116-
from huggingface_hub import snapshot_download
117-
except ImportError:
118-
print("Error: huggingface_hub is not installed.")
119-
print("Please install it via: pip install huggingface_hub")
120-
sys.exit(1)
121-
122-
if "/" not in repo_id:
123-
repo_id = f"mlx-community/{repo_id}"
124-
125-
local_path = os.path.abspath(os.path.join(models_dir, repo_id))
126-
print(f"Downloading/verifying model '{repo_id}' to '{local_path}'...\n")
127-
snapshot_download(repo_id=repo_id, local_dir=local_path)
128-
return local_path
129-
130114
def main():
131115
parser = argparse.ArgumentParser(description="Aegis-AI Physical Model Profiler")
132116
parser.add_argument("--model", required=True, help="Model ID (e.g. gemma-4-26b-a4b-it-4bit)")
133117
parser.add_argument("--out", default="./profiling_results.md", help="Output markdown file path")
134118
parser.add_argument("--contexts", default="512", help="Comma-separated list of context lengths to test (e.g. 512,40000,100000)")
135-
parser.add_argument("--models-dir", default="./models", help="Local directory to store downloaded models")
136119
args = parser.parse_args()
137120

138-
# Ensure model is downloaded
139-
model_path = download_model(args.model, args.models_dir)
121+
# SwiftLM handles model downloading natively via HubApi.
122+
# Just pass the model ID directly — prepend mlx-community/ if no org is specified.
123+
model_id = args.model if "/" in args.model else f"mlx-community/{args.model}"
124+
140125

141126
context_sizes = [int(x.strip()) for x in args.contexts.split(",") if x.strip()]
142127
results = []
@@ -155,7 +140,7 @@ def main():
155140

156141
log_path = "./tmp/profile_server.log"
157142
os.makedirs(os.path.dirname(log_path), exist_ok=True)
158-
cmd = [SWIFTLM_PATH, "--model", model_path] + config["flags"]
143+
cmd = [SWIFTLM_PATH, "--model", model_id] + config["flags"]
159144

160145
with open(log_path, "w") as root_log:
161146
server_proc = subprocess.Popen(cmd, stdout=root_log, stderr=subprocess.STDOUT)

0 commit comments

Comments
 (0)