Skip to content

Commit 6587832

Browse files
committed
test(live): read gateway provider models
1 parent d47497c commit 6587832

1 file changed

Lines changed: 92 additions & 1 deletion

File tree

src/gateway/gateway-models.profiles.live.test.ts

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1410,6 +1410,83 @@ type LiveModelRegistry = {
14101410
getAll(): Array<Model<Api>>;
14111411
};
14121412

1413+
function toGatewayLiveModel(params: {
1414+
provider: string;
1415+
providerConfig: ModelProviderConfig;
1416+
modelConfig: NonNullable<ModelProviderConfig["models"]>[number];
1417+
}): Model<Api> | null {
1418+
const id = params.modelConfig.id?.trim();
1419+
const api = params.modelConfig.api ?? params.providerConfig.api;
1420+
const baseUrl = params.modelConfig.baseUrl ?? params.providerConfig.baseUrl;
1421+
if (!id || !api || !baseUrl) {
1422+
return null;
1423+
}
1424+
const input = params.modelConfig.input.filter(
1425+
(value): value is "text" | "image" => value === "text" || value === "image",
1426+
);
1427+
return {
1428+
id,
1429+
name: params.modelConfig.name ?? id,
1430+
api: api as Api,
1431+
provider: params.provider,
1432+
baseUrl,
1433+
reasoning: params.modelConfig.reasoning ?? false,
1434+
thinkingLevelMap: params.modelConfig.thinkingLevelMap,
1435+
input: input.length > 0 ? input : ["text"],
1436+
cost: params.modelConfig.cost ?? {
1437+
input: 0,
1438+
output: 0,
1439+
cacheRead: 0,
1440+
cacheWrite: 0,
1441+
},
1442+
contextWindow: params.modelConfig.contextWindow ?? 128_000,
1443+
maxTokens: params.modelConfig.maxTokens ?? 16_384,
1444+
compat: params.modelConfig.compat ?? params.providerConfig.compat,
1445+
};
1446+
}
1447+
1448+
async function loadProviderScopedConfiguredModels(params: {
1449+
agentDir: string;
1450+
providerList: readonly string[];
1451+
}): Promise<Array<Model<Api>>> {
1452+
const modelsPath = path.join(params.agentDir, "models.json");
1453+
let parsed: { providers?: Record<string, ModelProviderConfig> };
1454+
try {
1455+
parsed = JSON.parse(await fs.readFile(modelsPath, "utf8")) as {
1456+
providers?: Record<string, ModelProviderConfig>;
1457+
};
1458+
} catch {
1459+
return [];
1460+
}
1461+
1462+
const providers = parsed.providers ?? {};
1463+
const models: Array<Model<Api>> = [];
1464+
const seen = new Set<string>();
1465+
for (const rawProvider of params.providerList) {
1466+
const normalizedProvider = normalizeProviderId(rawProvider);
1467+
const entry = Object.entries(providers).find(
1468+
([provider]) => normalizeProviderId(provider) === normalizedProvider,
1469+
);
1470+
if (!entry) {
1471+
continue;
1472+
}
1473+
const [provider, providerConfig] = entry;
1474+
for (const modelConfig of providerConfig.models ?? []) {
1475+
const model = toGatewayLiveModel({ provider, providerConfig, modelConfig });
1476+
if (!model) {
1477+
continue;
1478+
}
1479+
const key = `${normalizeProviderId(model.provider)}/${model.id.toLowerCase()}`;
1480+
if (seen.has(key)) {
1481+
continue;
1482+
}
1483+
seen.add(key);
1484+
models.push(model);
1485+
}
1486+
}
1487+
return models;
1488+
}
1489+
14131490
function loadProviderScopedBuiltInModels(providerList: readonly string[]): Array<Model<Api>> {
14141491
const models: Array<Model<Api>> = [];
14151492
const seen = new Set<string>();
@@ -1430,6 +1507,17 @@ function loadProviderScopedBuiltInModels(providerList: readonly string[]): Array
14301507
return models;
14311508
}
14321509

1510+
async function loadProviderScopedModels(params: {
1511+
agentDir: string;
1512+
providerList: readonly string[];
1513+
}): Promise<Array<Model<Api>>> {
1514+
const configured = await loadProviderScopedConfiguredModels(params);
1515+
if (configured.length > 0) {
1516+
return configured;
1517+
}
1518+
return loadProviderScopedBuiltInModels(params.providerList);
1519+
}
1520+
14331521
function createStaticLiveModelRegistry(models: Array<Model<Api>>): LiveModelRegistry {
14341522
return {
14351523
find(provider, modelId) {
@@ -2459,7 +2547,10 @@ describeLive("gateway live (dev agent, profile keys)", () => {
24592547
let all: Array<Model<Api>>;
24602548
if (useProviderScopedBuiltIns) {
24612549
logProgress("[all-models] loading provider-scoped model refs");
2462-
all = loadProviderScopedBuiltInModels(providerList);
2550+
all = await withGatewayLiveSetupTimeout(
2551+
loadProviderScopedModels({ agentDir, providerList }),
2552+
"[all-models] load provider-scoped model refs",
2553+
);
24632554
modelRegistry = createStaticLiveModelRegistry(all);
24642555
} else {
24652556
logProgress("[all-models] loading auth profiles");

0 commit comments

Comments
 (0)