Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions packages/app/src/components/provider-diagnostic-sheet.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -600,13 +600,22 @@ export function ProviderDiagnosticSheet({
const modelsRefreshing = isRefreshing || providerSnapshotRefreshing;

const stableDiscoveredRef = useRef<AgentModelDefinition[]>([]);
if (providerEntry?.models && providerEntry.models.length > 0) {
stableDiscoveredRef.current = providerEntry.models;
}
const discoveredModels =
providerEntry?.models && providerEntry.models.length > 0
? providerEntry.models
: stableDiscoveredRef.current;
const currentModels = providerEntry?.models;
useEffect(() => {
if (currentModels && currentModels.length > 0) {
stableDiscoveredRef.current = currentModels;
}
}, [currentModels]);

const discoveredModels = useMemo(() => {
if (currentModels && currentModels.length > 0) {
return currentModels;
}
if (providerSnapshotRefreshing) {
return stableDiscoveredRef.current;
}
return [];
}, [currentModels, providerSnapshotRefreshing]);

const [clockTick, setClockTick] = useState(0);
useEffect(() => {
Expand Down
16 changes: 16 additions & 0 deletions packages/server/src/server/agent/agent-sdk-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,16 @@ export interface ListModesOptions {
force: boolean;
}

export interface FetchCatalogOptions {
cwd: string;
force: boolean;
}

export interface ProviderCatalog {
models: AgentModelDefinition[];
modes: AgentMode[];
}

export interface AgentClient {
readonly provider: AgentProvider;
readonly capabilities: AgentCapabilityFlags;
Expand All @@ -661,6 +671,12 @@ export interface AgentClient {
): Promise<AgentSession>;
listModels(options: ListModelsOptions): Promise<AgentModelDefinition[]>;
listModes?(options: ListModesOptions): Promise<AgentMode[]>;
/**
* Discover models and modes together when the provider supports a single
* catalog probe. Implementations should spawn at most one runtime process.
* The registry is responsible for merging configured model overrides.
*/
fetchCatalog?(options: FetchCatalogOptions): Promise<ProviderCatalog>;
resolveCreateConfig?(input: ResolveAgentCreateConfigInput): ResolveAgentCreateConfigResult;
isCreateConfigUnattended?(input: AgentCreateConfigUnattendedInput): boolean;
listCommands?(config: AgentSessionConfig): Promise<AgentSlashCommand[]>;
Expand Down
116 changes: 115 additions & 1 deletion packages/server/src/server/agent/provider-registry.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { beforeEach, describe, expect, test, vi } from "vitest";

import { createTestLogger } from "../../test-utils/test-logger.js";
import type { AgentModelDefinition } from "./agent-sdk-types.js";
import type { AgentClient, AgentModelDefinition, AgentMode } from "./agent-sdk-types.js";

const mockState = vi.hoisted(() => {
interface ConstructorEntry {
Expand Down Expand Up @@ -1286,3 +1286,117 @@ describe("model merging", () => {
expect(models.find((model) => model.isDefault)?.id).toBe("MiniMax-M3");
});
});

describe("fetchCatalog", () => {
test("returns merged models and modes from listModels/listModes fallback", async () => {
mockState.runtimeModels.set("codex", [
{ provider: "codex", id: "codex-runtime", label: "Codex Runtime" },
]);

const registry = buildProviderRegistry(logger);
const catalog = await registry.codex.fetchCatalog({
cwd: "/tmp/catalog",
force: false,
});

expect(catalog.models.map((model) => model.id)).toEqual(["codex-runtime"]);
expect(catalog.modes).toEqual([]);
});

test("replacement models skip runtime model discovery but preserve additionalModels", async () => {
mockState.runtimeModels.set("codex", [
{ provider: "codex", id: "codex-runtime", label: "Codex Runtime" },
]);

const registry = buildProviderRegistry(logger, {
providerOverrides: {
codex: {
models: [{ id: "profile-model", label: "Profile Model" }],
additionalModels: [{ id: "extra-model", label: "Extra Model" }],
},
},
});

const catalog = await registry.codex.fetchCatalog({
cwd: "/tmp/catalog",
force: false,
});

expect(catalog.models.map((model) => model.id)).toEqual(["profile-model", "extra-model"]);
});

test("additionalModels can override replacement model fields", async () => {
const registry = buildProviderRegistry(logger, {
providerOverrides: {
codex: {
models: [{ id: "shared-model", label: "Profile Label" }],
additionalModels: [{ id: "shared-model", label: "Additional Label" }],
},
},
});

const catalog = await registry.codex.fetchCatalog({
cwd: "/tmp/catalog",
force: false,
});

expect(catalog.models).toEqual([
{
provider: "codex",
id: "shared-model",
label: "Additional Label",
},
]);
});

test("uses injected client instead of base client when provided", async () => {
const injectedModels: AgentModelDefinition[] = [
{ provider: "codex", id: "injected-model", label: "Injected Model" },
];
const injectedModes: AgentMode[] = [{ id: "agent", label: "Agent" }];
const injectedClient = {
provider: "codex",
capabilities: {},
listModels: vi.fn(async () => injectedModels),
listModes: vi.fn(async () => injectedModes),
isAvailable: vi.fn(async () => true),
} satisfies Partial<AgentClient> as AgentClient;

const registry = buildProviderRegistry(logger);
const catalog = await registry.codex.fetchCatalog(
{ cwd: "/tmp/catalog", force: false },
injectedClient,
);

expect(injectedClient.listModels).toHaveBeenCalledTimes(1);
expect(injectedClient.listModes).toHaveBeenCalledTimes(1);
expect(catalog.models.map((model) => model.id)).toEqual(["injected-model"]);
expect(catalog.modes).toEqual(injectedModes);
});

test("uses injected client fetchCatalog when available", async () => {
const injectedClient = {
provider: "codex",
capabilities: {},
fetchCatalog: vi.fn(async () => ({
models: [{ provider: "codex", id: "catalog-model", label: "Catalog Model" }],
modes: [{ id: "ask", label: "Ask" }],
})),
listModels: vi.fn(async () => []),
listModes: vi.fn(async () => []),
isAvailable: vi.fn(async () => true),
} satisfies Partial<AgentClient> as AgentClient;

const registry = buildProviderRegistry(logger);
const catalog = await registry.codex.fetchCatalog(
{ cwd: "/tmp/catalog", force: false },
injectedClient,
);

expect(injectedClient.fetchCatalog).toHaveBeenCalledTimes(1);
expect(injectedClient.listModels).not.toHaveBeenCalled();
expect(injectedClient.listModes).not.toHaveBeenCalled();
expect(catalog.models.map((model) => model.id)).toEqual(["catalog-model"]);
expect(catalog.modes.map((mode) => mode.id)).toEqual(["ask"]);
});
});
117 changes: 94 additions & 23 deletions packages/server/src/server/agent/provider-registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ import type {
AgentRuntimeInfo,
AgentSession,
AgentStreamEvent,
FetchCatalogOptions,
ListModelsOptions,
ListModesOptions,
ProviderCatalog,
ResolveAgentCreateConfigInput,
ResolveAgentCreateConfigResult,
} from "./agent-sdk-types.js";
Expand Down Expand Up @@ -66,6 +68,11 @@ export interface ProviderDefinition extends AgentProviderDefinition {
isCreateConfigUnattended: (input: AgentCreateConfigUnattendedInput) => boolean;
fetchModels: (options: ListModelsOptions) => Promise<AgentModelDefinition[]>;
fetchModes: (options: ListModesOptions) => Promise<AgentMode[]>;
/**
* Single catalog discovery call used by ProviderSnapshotManager. Should spawn
* at most one provider runtime process and return both models and modes.
*/
fetchCatalog: (options: FetchCatalogOptions, client?: AgentClient) => Promise<ProviderCatalog>;
}

export interface BuildProviderRegistryOptions {
Expand Down Expand Up @@ -429,6 +436,17 @@ function wrapClientProvider(
profileModelsAreAdditive,
}),
listModes: inner.listModes?.bind(inner),
fetchCatalog: inner.fetchCatalog
? async (options) => {
const catalog = await inner.fetchCatalog!(options);
return {
models: mergeModels(provider, profileModels, additionalModels, catalog.models, {
profileModelsAreAdditive,
}),
modes: catalog.modes,
};
}
: undefined,
resolveCreateConfig: inner.resolveCreateConfig?.bind(inner),
isCreateConfigUnattended: inner.isCreateConfigUnattended?.bind(inner),
listImportableSessions: listImportableSessions
Expand Down Expand Up @@ -473,6 +491,46 @@ function createRegistryEntry(
resolved: ResolvedProvider,
): ProviderDefinition {
const modelClient = resolved.createBaseClient(logger);
const hasReplacementModels =
resolved.profileModels.length > 0 && !resolved.profileModelsAreAdditive;
const replacementModels = hasReplacementModels
? resolved.profileModels.map((model) => mapModel(provider, model))
: [];

const decorateModes = (modes: AgentMode[]): AgentMode[] =>
modes.map((mode) => {
if (mode.icon && mode.colorTier) return mode;
const definitionMode = resolved.definition.modes.find((d) => d.id === mode.id);
if (!definitionMode) return mode;
return Object.assign({}, mode, {
icon: mode.icon ?? definitionMode.icon,
colorTier: mode.colorTier ?? definitionMode.colorTier,
});
});

const fetchModelsFromClient = async (
options: ListModelsOptions,
catalogClient: AgentClient = modelClient,
) =>
mergeModels(
provider,
resolved.profileModels,
resolved.additionalModels,
await catalogClient.listModels(options),
{
profileModelsAreAdditive: resolved.profileModelsAreAdditive,
},
);

const fetchModesFromClient = async (
options: ListModesOptions,
catalogClient: AgentClient = modelClient,
) => {
const modes = catalogClient.listModes
? await catalogClient.listModes(options)
: resolved.definition.modes;
return decorateModes(modes);
};

return {
...resolved.definition,
Expand All @@ -483,29 +541,42 @@ function createRegistryEntry(
resolveCreateConfig: modelClient.resolveCreateConfig ?? resolveDefaultAgentCreateConfig,
isCreateConfigUnattended:
modelClient.isCreateConfigUnattended ?? isDefaultAgentCreateConfigUnattended,
fetchModels: async (options: ListModelsOptions) =>
mergeModels(
provider,
resolved.profileModels,
resolved.additionalModels,
await modelClient.listModels(options),
{
profileModelsAreAdditive: resolved.profileModelsAreAdditive,
},
),
fetchModes: async (options: ListModesOptions) => {
const modes = modelClient.listModes
? await modelClient.listModes(options)
: resolved.definition.modes;
return modes.map((mode) => {
if (mode.icon && mode.colorTier) return mode;
const definitionMode = resolved.definition.modes.find((d) => d.id === mode.id);
if (!definitionMode) return mode;
return Object.assign({}, mode, {
icon: mode.icon ?? definitionMode.icon,
colorTier: mode.colorTier ?? definitionMode.colorTier,
});
});
fetchModels: fetchModelsFromClient,
fetchModes: fetchModesFromClient,
fetchCatalog: async (options: FetchCatalogOptions, client?: AgentClient) => {
const catalogClient = client ?? modelClient;
if (hasReplacementModels) {
// Replacement models skip runtime model discovery, but additionalModels
// must still be merged on top. If modes are dynamic, probe for modes only;
// otherwise use static/empty modes with no runtime.
const models = mergeModelAdditions(provider, replacementModels, resolved.additionalModels);
if (!catalogClient.listModes) {
return {
models,
modes: decorateModes(resolved.definition.modes),
};
}
return {
models,
modes: await fetchModesFromClient(options, catalogClient),
};
}

if (catalogClient.fetchCatalog) {
const catalog = await catalogClient.fetchCatalog(options);
return {
models: mergeModels(provider, [], resolved.additionalModels, catalog.models, {
profileModelsAreAdditive: true,
}),
modes: decorateModes(catalog.modes),
};
}
Comment thread
greptile-apps[bot] marked this conversation as resolved.

const [models, modes] = await Promise.all([
fetchModelsFromClient(options, catalogClient),
fetchModesFromClient(options, catalogClient),
]);
return { models, modes };
},
};
}
Expand Down
Loading
Loading