diff --git a/apps/gateway/src/chat/chat.ts b/apps/gateway/src/chat/chat.ts index 52d59a276..adc6e5b28 100644 --- a/apps/gateway/src/chat/chat.ts +++ b/apps/gateway/src/chat/chat.ts @@ -254,18 +254,18 @@ function filterRegionsByAvailableKeys( }); } -function preferConcreteRegionalMappings( +function preferProviderRootMappings( providers: ProviderModelMapping[], ): ProviderModelMapping[] { - const providersWithRegions = new Set( + const providersWithRootMappings = new Set( providers - .filter((mapping) => mapping.region) + .filter((mapping) => !mapping.region) .map((mapping) => mapping.providerId), ); return providers.filter( (mapping) => - !providersWithRegions.has(mapping.providerId) || Boolean(mapping.region), + !providersWithRootMappings.has(mapping.providerId) || !mapping.region, ); } @@ -1691,7 +1691,7 @@ chat.openapi(completions, async (c) => { } const candidateAllowedProviders = candidateIam.allowedProviders; - const candidateProviders = preferConcreteRegionalMappings( + const candidateProviders = preferProviderRootMappings( project.mode === "credits" ? filterRegionsByAvailableKeys( expandAllProviderRegions( @@ -1834,6 +1834,7 @@ chat.openapi(completions, async (c) => { { metricsMap, isStreaming: stream, + includeProviderScoreRegions: false, promptTokens: routingPromptTokens, }, ); @@ -2108,9 +2109,7 @@ chat.openapi(completions, async (c) => { // Attempt to re-route to alternative providers (same pattern as low-uptime fallback) const providerIds = modelInfo.providers - .filter( - (p) => !(p.providerId === usedProvider && p.region === usedRegion), - ) + .filter((p) => p.providerId !== usedProvider) .map((p) => p.providerId); if (providerIds.length > 0) { @@ -2127,39 +2126,18 @@ chat.openapi(completions, async (c) => { .filter((p) => hasProviderEnvironmentToken(p.id as Provider)) .map((p) => p.id); - const availableModelProviders = preferConcreteRegionalMappings( - iamFilteredModelProviders, - ).filter((provider) => { - if (!availableProviders.includes(provider.providerId)) { - return false; - } - if ( - provider.providerId === usedProvider && - provider.region === usedRegion - ) { - return false; - } - if (webSearchTool && provider.webSearch !== true) { - return false; - } - if ( - response_format?.type === "json_object" || - response_format?.type === "json_schema" - ) { - if (provider.jsonOutput !== true) { - return false; - } - } - if (response_format?.type === "json_schema") { - if (provider.jsonOutputSchema !== true) { - return false; - } - } - if (hasImages && provider.vision !== true) { - return false; - } - return true; - }); + const availableModelProviders = filterEligibleModelProviders( + preferProviderRootMappings(expandedIamFilteredModelProviders), + { + allProviderVariants: modelInfo.providers, + availableProviders, + webSearchTool, + responseFormatType: response_format?.type, + hasImages, + maxTokens: max_tokens, + reasoningEffort: reasoning_effort, + }, + ).filter((provider) => provider.providerId !== usedProvider); // Also filter out rate-limited alternatives const rateLimitedAlternatives = await filterRateLimitedProviders( @@ -2206,6 +2184,7 @@ chat.openapi(completions, async (c) => { { metricsMap: allMetricsMap, isStreaming: stream, + includeProviderScoreRegions: false, promptTokens: routingPromptTokens, }, ); @@ -2285,9 +2264,7 @@ chat.openapi(completions, async (c) => { const currentUptime = metrics.uptime; // Get available providers for routing const providerIds = modelInfo.providers - .filter( - (p) => !(p.providerId === usedProvider && p.region === usedRegion), - ) // Exclude the exact low-uptime provider+region pair + .filter((p) => p.providerId !== usedProvider) .map((p) => p.providerId); if (providerIds.length > 0) { @@ -2308,7 +2285,7 @@ chat.openapi(completions, async (c) => { // If web search is requested, also filter to providers that support it // If JSON output is requested, also filter to providers that support it const availableModelProviders = filterEligibleModelProviders( - preferConcreteRegionalMappings(expandedIamFilteredModelProviders), + preferProviderRootMappings(expandedIamFilteredModelProviders), { allProviderVariants: modelInfo.providers, availableProviders, @@ -2318,13 +2295,7 @@ chat.openapi(completions, async (c) => { maxTokens: max_tokens, reasoningEffort: reasoning_effort, }, - ).filter( - (provider) => - !( - provider.providerId === usedProvider && - provider.region === usedRegion - ), - ); + ).filter((provider) => provider.providerId !== usedProvider); if (availableModelProviders.length > 0) { const rawModelForFallback = models.find((m) => m.id === baseModelId); @@ -2388,6 +2359,7 @@ chat.openapi(completions, async (c) => { { metricsMap: allMetricsMap, isStreaming: stream, + includeProviderScoreRegions: false, promptTokens: routingPromptTokens, }, ); @@ -2486,7 +2458,7 @@ chat.openapi(completions, async (c) => { // Filter model providers to only those eligible for this request const availableModelProviders = filterEligibleModelProviders( - preferConcreteRegionalMappings(expandedIamFilteredModelProviders), + preferProviderRootMappings(expandedIamFilteredModelProviders), { allProviderVariants: modelInfo.providers, availableProviders, @@ -2583,6 +2555,7 @@ chat.openapi(completions, async (c) => { { metricsMap, isStreaming: stream, + includeProviderScoreRegions: false, promptTokens: routingPromptTokens, }, ); @@ -2657,7 +2630,10 @@ chat.openapi(completions, async (c) => { selectionReason = "fallback-first-available"; } - let routingMetadataProviders = allModelProviders; + let routingMetadataProviders = + selectionReason === "direct-provider-specified" + ? allModelProviders + : preferProviderRootMappings(allModelProviders); let directProviderRegionWasExplicit = false; if ( @@ -2752,6 +2728,8 @@ chat.openapi(completions, async (c) => { { metricsMap, isStreaming: stream, + includeProviderScoreRegions: + selectionReason === "direct-provider-specified", promptTokens: routingPromptTokens, }, ); @@ -2781,13 +2759,18 @@ chat.openapi(completions, async (c) => { throughput: metrics?.throughput ?? 0, }; }); + const includeRoutingScoreRegions = + selectionReason === "direct-provider-specified"; routingMetadata = addContentFilterRoutingMetadata( { availableProviders: routingMetadataProviders.map((p) => p.providerId), selectedProvider: usedProvider, selectionReason, - providerScores: allProviderScores, + providerScores: allProviderScores.map((score) => ({ + ...score, + region: includeRoutingScoreRegions ? score.region : undefined, + })), ...getNoFallbackRoutingMetadata(noFallback, xNoFallbackHeaderSet), }, contentFilterMatched, @@ -2864,9 +2847,13 @@ chat.openapi(completions, async (c) => { // Create the model mapping values according to new schema let usedModelMapping = usedModel; // Store the original provider model name + const includeUsedModelRegion = + routingMetadata?.selectionReason === "direct-provider-specified"; let usedModelFormatted = formatUsedModelForDisplay( usedProvider, - usedRegion ? `${baseModelName}:${usedRegion}` : baseModelName, + includeUsedModelRegion && usedRegion + ? `${baseModelName}:${usedRegion}` + : baseModelName, customProviderName, ); // Store in LLMGateway format @@ -3396,7 +3383,7 @@ chat.openapi(completions, async (c) => { ); // If region is still unset but the provider supports regions, resolve the - // default region so it appears in logs and metadata. + // default region for request execution. if (!usedRegion) { const providerDef = providers.find((p) => p.id === usedProvider) as | { regionConfig?: { defaultRegion: string } } @@ -3407,7 +3394,7 @@ chat.openapi(completions, async (c) => { } // Re-compute usedModelFormatted now that region may have been resolved - if (usedRegion) { + if (includeUsedModelRegion && usedRegion) { usedModelFormatted = formatUsedModelForDisplay( usedProvider, `${baseModelName}:${usedRegion}`, diff --git a/apps/gateway/src/fallback.spec.ts b/apps/gateway/src/fallback.spec.ts index 507bd92e8..025c4edcb 100644 --- a/apps/gateway/src/fallback.spec.ts +++ b/apps/gateway/src/fallback.spec.ts @@ -1045,16 +1045,17 @@ describe("fallback and error status code handling", () => { const logs = await waitForLogs(1); expect(logs).toHaveLength(1); expect(logs[0].usedProvider).toBe("alibaba"); - expect(logs[0].usedModel).toBe("alibaba/glm-4.6:cn-beijing"); + expect(logs[0].usedModel).toBe("alibaba/glm-4.6"); expect(logs[0].routingMetadata?.selectedProvider).toBe("alibaba"); expect(logs[0].routingMetadata?.selectionReason).toBe( "low-uptime-fallback", ); - expect( - logs[0].routingMetadata?.providerScores?.some( - (score) => score.providerId === "alibaba" && !score.region, - ), - ).toBe(false); + const alibabaScores = + logs[0].routingMetadata?.providerScores?.filter( + (score) => score.providerId === "alibaba", + ) ?? []; + expect(alibabaScores).toHaveLength(1); + expect(alibabaScores[0]?.region).toBeUndefined(); expect( logs[0].routingMetadata?.providerScores?.some( (score) => @@ -1172,12 +1173,13 @@ describe("fallback and error status code handling", () => { logs.find((entry) => entry.requestedModel === "glm-4.6") ?? logs.at(-1); expect(log).toBeTruthy(); expect(log?.usedProvider).toBe("alibaba"); - expect(log?.usedModel).toBe("alibaba/glm-4.6:cn-beijing"); - expect( - log?.routingMetadata?.providerScores?.some( - (score) => score.providerId === "alibaba" && !score.region, - ), - ).toBe(false); + expect(log?.usedModel).toBe("alibaba/glm-4.6"); + const alibabaScores = + log?.routingMetadata?.providerScores?.filter( + (score) => score.providerId === "alibaba", + ) ?? []; + expect(alibabaScores).toHaveLength(1); + expect(alibabaScores[0]?.region).toBeUndefined(); expect( log?.routingMetadata?.providerScores?.some( (score) => @@ -1465,13 +1467,19 @@ describe("fallback and error status code handling", () => { expect(res.status).toBe(200); const logs = await waitForLogs(1); - expect(logs[0].routingMetadata?.providerScores).toContainEqual( + expect(logs[0].usedModel).toBe("alibaba/deepseek-v3.2"); + const alibabaScores = + logs[0].routingMetadata?.providerScores?.filter( + (score) => score.providerId === "alibaba", + ) ?? []; + expect(alibabaScores).toHaveLength(1); + expect(alibabaScores[0]).toEqual( expect.objectContaining({ providerId: "alibaba", - region: "cn-beijing", score: expect.any(Number), }), ); + expect(alibabaScores[0]?.region).toBeUndefined(); expect( logs[0].routingMetadata?.providerScores?.some( (score) => diff --git a/packages/actions/src/get-cheapest-from-available-providers.ts b/packages/actions/src/get-cheapest-from-available-providers.ts index f23dd3f4b..33d31c522 100644 --- a/packages/actions/src/get-cheapest-from-available-providers.ts +++ b/packages/actions/src/get-cheapest-from-available-providers.ts @@ -173,6 +173,7 @@ export interface ProviderSelectionOptions { metricsMap?: Map; isStreaming?: boolean; videoPricing?: VideoPricingContext; + includeProviderScoreRegions?: boolean; /** * Estimated prompt tokens for the request. When provided and at or above * CACHE_PROMPT_TOKEN_THRESHOLD, cache support is factored into the @@ -335,6 +336,8 @@ export function getCheapestFromAvailableProviders< const metricsMap = options?.metricsMap; const isStreaming = options?.isStreaming ?? false; const videoPricing = options?.videoPricing; + const includeProviderScoreRegions = + options?.includeProviderScoreRegions ?? true; const promptTokens = options?.promptTokens; // Use higher price weight for image generation models const isImageModel = modelWithPricing.output?.includes("image") ?? false; @@ -416,7 +419,12 @@ export function getCheapestFromAvailableProviders< // If no metrics provided, fall back to price-only selection if (!metricsMap || metricsMap.size === 0) { - return selectByPriceOnly(stableProviders, modelWithPricing, videoPricing); + return selectByPriceOnly( + stableProviders, + modelWithPricing, + videoPricing, + includeProviderScoreRegions, + ); } // Calculate scores for each provider @@ -559,7 +567,7 @@ export function getCheapestFromAvailableProviders< const priority = providerDef?.priority ?? 1; return { providerId: p.provider.providerId, - region: p.provider.region, + region: includeProviderScoreRegions ? p.provider.region : undefined, score: Number(p.score.toFixed(3)), uptime: p.uptime, latency: p.latency, @@ -584,6 +592,7 @@ function selectByPriceOnly( stableProviders: T[], modelWithPricing: ModelWithPricing & { id: string; output?: string[] }, videoPricing?: VideoPricingContext, + includeProviderScoreRegions = true, ): ProviderSelectionResult { let cheapestProvider = stableProviders[0]; let lowestEffectivePrice = Number.MAX_VALUE; @@ -628,7 +637,7 @@ function selectByPriceOnly( selectionReason: "price-only-no-metrics", providerScores: providerPrices.map((p) => ({ providerId: p.providerId, - region: p.region, + region: includeProviderScoreRegions ? p.region : undefined, score: 0, price: p.price, priority: p.priority, diff --git a/packages/actions/src/models.spec.ts b/packages/actions/src/models.spec.ts index 64634e31d..c848e8120 100644 --- a/packages/actions/src/models.spec.ts +++ b/packages/actions/src/models.spec.ts @@ -630,6 +630,50 @@ describe("getCheapestFromAvailableProviders", () => { expect(avalancheScore?.price).toBeCloseTo(2.56); }); + it("should omit provider score regions when disabled", () => { + const regionalProviders: ProviderModelMapping[] = [ + { + providerId: "alibaba", + modelName: "deepseek-v3.2", + region: "singapore", + inputPrice: 2, + outputPrice: 2, + streaming: true, + }, + { + providerId: "alibaba", + modelName: "deepseek-v3.2", + region: "cn-beijing", + inputPrice: 1, + outputPrice: 1, + streaming: true, + }, + ]; + + const result = getCheapestFromAvailableProviders( + regionalProviders, + { + id: "deepseek-v3.2", + output: ["text"], + providers: regionalProviders, + }, + { + includeProviderScoreRegions: false, + }, + ); + + expect(result?.metadata.providerScores).toEqual([ + expect.objectContaining({ + providerId: "alibaba", + region: undefined, + }), + expect.objectContaining({ + providerId: "alibaba", + region: undefined, + }), + ]); + }); + it("should disable random exploration for vitest processes", () => { const videoModel = models.find( (model) => model.id === "veo-3.1-generate-preview",