diff --git a/apps/gateway/src/api.spec.ts b/apps/gateway/src/api.spec.ts index 097eaeb701..dc92829b7b 100644 --- a/apps/gateway/src/api.spec.ts +++ b/apps/gateway/src/api.spec.ts @@ -2783,11 +2783,11 @@ describe("api", () => { expect(logs[0].hasError).toBe(true); expect(logs[0].errorDetails?.statusCode).toBe(401); - expect(isTrackedKeyHealthy("provider-key-id-stream-auth-error")).toBe( - false, - ); expect( - getTrackedKeyMetrics("provider-key-id-stream-auth-error"), + isTrackedKeyHealthy("provider-key-id-stream-auth-error", "custom"), + ).toBe(false); + expect( + getTrackedKeyMetrics("provider-key-id-stream-auth-error", "custom"), ).toMatchObject({ permanentlyBlacklisted: true, totalRequests: 1, diff --git a/apps/gateway/src/chat/chat.ts b/apps/gateway/src/chat/chat.ts index 68ad2d2b83..8df77dcf7c 100644 --- a/apps/gateway/src/chat/chat.ts +++ b/apps/gateway/src/chat/chat.ts @@ -99,7 +99,7 @@ import { type WebSearchTool, expandAllProviderRegions, getProviderDefinition, - getRegionSpecificEnvValue, + getRegionSpecificEnvVarName, stripRegionFromModelName, } from "@llmgateway/models"; @@ -156,6 +156,7 @@ import { MAX_RETRIES, providerRetryKey, selectNextProvider, + shouldRetryAlternateKey, shouldRetryRequest, } from "./tools/retry-with-fallback.js"; import { @@ -1585,7 +1586,6 @@ chat.openapi(completions, async (c) => { const customProviderKey = await findCustomProviderKey( project.organizationId, customProviderName, - requestId, ); if (!customProviderKey) { throw new HTTPException(400, { @@ -1968,7 +1968,7 @@ chat.openapi(completions, async (c) => { const providerKey = await findProviderKey( project.organizationId, usedProvider, - requestId, + modelInfo.id || stripRegionFromModelName(usedModel, usedRegion), ); lockedRegion = providerKey ? resolveExplicitRegionFromProviderKey(providerKey) @@ -2668,7 +2668,7 @@ chat.openapi(completions, async (c) => { const providerKey = await findProviderKey( project.organizationId, requestedProvider, - requestId, + modelInfo.id || stripRegionFromModelName(usedModel, usedRegion), ); explicitDirectRegion = providerKey ? resolveExplicitRegionFromProviderKey(providerKey) @@ -2897,6 +2897,12 @@ chat.openapi(completions, async (c) => { let usedApiKeyHash: string | undefined; let configIndex = 0; // Index for round-robin environment variables let envVarName: string | undefined; // Environment variable name for health tracking + // ID for tracked-key health attribution. Equal to providerKey.id when the + // DB-provided key is what's actually sent. Cleared when a region-specific + // env var override replaces the token, so health failures route to the env + // credential via envVarName instead of blaming an unused DB key. Endpoint + // and option resolution still use providerKey for BYOK base URLs/options. + let trackedKeyHealthId: string | undefined; if ( project.mode === "credits" && (usedProvider === "custom" || usedProvider === "llmgateway") @@ -2913,13 +2919,13 @@ chat.openapi(completions, async (c) => { providerKey = await findCustomProviderKey( project.organizationId, customProviderName, - requestId, + baseModelName, ); } else { providerKey = await findProviderKey( project.organizationId, usedProvider, - requestId, + baseModelName, ); } @@ -2934,12 +2940,25 @@ chat.openapi(completions, async (c) => { } usedToken = providerKey.token; + trackedKeyHealthId = providerKey.id; usedRegion ??= resolveRegionFromProviderKey(providerKey); - // Override with region-specific env var if the DB key doesn't match the requested region + // Override with region-specific env var if the DB key doesn't match the requested region. + // When we do override, route health attribution to the regional env credential. + // providerKey stays set so endpoint/options/baseUrl construction keeps the BYOK context; + // only trackedKeyHealthId is cleared so reportTrackedKey* doesn't blame the unused DB key. if (usedRegion) { - const regionToken = getRegionSpecificEnvValue(usedProvider, usedRegion); - if (regionToken && regionToken !== usedToken) { - usedToken = regionToken; + const regionEnvVarName = getRegionSpecificEnvVarName( + usedProvider, + usedRegion, + ); + if (regionEnvVarName) { + const regionToken = process.env[regionEnvVarName]; + if (regionToken && regionToken !== usedToken) { + usedToken = regionToken; + envVarName = regionEnvVarName; + configIndex = 0; + trackedKeyHealthId = undefined; + } } } } else if (project.mode === "credits") { @@ -2979,16 +2998,27 @@ chat.openapi(completions, async (c) => { }); } - const envResult = getProviderEnv(usedProvider); + const envResult = getProviderEnv(usedProvider, { + selectionScope: baseModelName, + }); usedToken = envResult.token; configIndex = envResult.configIndex; envVarName = envResult.envVarName; - // Override with region-specific env var if a non-default region is selected + // Override with region-specific env var if a non-default region is selected. + // Health attribution must follow the credential we actually send. if (usedRegion) { - const regionToken = getRegionSpecificEnvValue(usedProvider, usedRegion); - if (regionToken) { - usedToken = regionToken; + const regionEnvVarName = getRegionSpecificEnvVarName( + usedProvider, + usedRegion, + ); + if (regionEnvVarName) { + const regionToken = process.env[regionEnvVarName]; + if (regionToken) { + usedToken = regionToken; + envVarName = regionEnvVarName; + configIndex = 0; + } } } } else if (project.mode === "hybrid") { @@ -2997,24 +3027,36 @@ chat.openapi(completions, async (c) => { providerKey = await findCustomProviderKey( project.organizationId, customProviderName, - requestId, + baseModelName, ); } else { providerKey = await findProviderKey( project.organizationId, usedProvider, - requestId, + baseModelName, ); } if (providerKey) { usedToken = providerKey.token; + trackedKeyHealthId = providerKey.id; usedRegion ??= resolveRegionFromProviderKey(providerKey); - // Override with region-specific env var if the DB key doesn't match the requested region + // Override with region-specific env var if the DB key doesn't match the requested region. + // Route health attribution to the env credential while keeping providerKey for + // endpoint/options resolution (BYOK base URLs and provider options). if (usedRegion) { - const regionToken = getRegionSpecificEnvValue(usedProvider, usedRegion); - if (regionToken && regionToken !== usedToken) { - usedToken = regionToken; + const regionEnvVarName = getRegionSpecificEnvVarName( + usedProvider, + usedRegion, + ); + if (regionEnvVarName) { + const regionToken = process.env[regionEnvVarName]; + if (regionToken && regionToken !== usedToken) { + usedToken = regionToken; + envVarName = regionEnvVarName; + configIndex = 0; + trackedKeyHealthId = undefined; + } } } } else { @@ -3053,16 +3095,27 @@ chat.openapi(completions, async (c) => { }); } - const envResult = getProviderEnv(usedProvider); + const envResult = getProviderEnv(usedProvider, { + selectionScope: baseModelName, + }); usedToken = envResult.token; configIndex = envResult.configIndex; envVarName = envResult.envVarName; - // Override with region-specific env var if a non-default region is selected + // Override with region-specific env var if a non-default region is selected. + // Health attribution must follow the credential we actually send. if (usedRegion) { - const regionToken = getRegionSpecificEnvValue(usedProvider, usedRegion); - if (regionToken) { - usedToken = regionToken; + const regionEnvVarName = getRegionSpecificEnvVarName( + usedProvider, + usedRegion, + ); + if (regionEnvVarName) { + const regionToken = process.env[regionEnvVarName]; + if (regionToken) { + usedToken = regionToken; + envVarName = regionEnvVarName; + configIndex = 0; + } } } } @@ -4184,6 +4237,7 @@ chat.openapi(completions, async (c) => { usedToken = ctx.usedToken; usedApiKeyHash = ctx.usedApiKeyHash; providerKey = ctx.providerKey; + trackedKeyHealthId = ctx.trackedKeyHealthId; configIndex = ctx.configIndex; envVarName = ctx.envVarName; url = ctx.url; @@ -5047,10 +5101,21 @@ chat.openapi(completions, async (c) => { // Report key health for the selected token source if (envVarName !== undefined) { - reportKeyError(envVarName, configIndex, 0); + reportKeyError( + envVarName, + configIndex, + 0, + undefined, + baseModelName, + ); } - if (providerKey?.id) { - reportTrackedKeyError(providerKey.id, 0); + if (trackedKeyHealthId) { + reportTrackedKeyError( + trackedKeyHealthId, + 0, + undefined, + baseModelName, + ); } if (willRetrySameProvider && sameProviderRetryContext) { @@ -5159,7 +5224,13 @@ chat.openapi(completions, async (c) => { let sameProviderRetryContext: Awaited< ReturnType > | null = null; - if (isRetryableErrorType(finishReason)) { + if ( + shouldRetryAlternateKey( + finishReason, + res.status, + errorResponseText, + ) + ) { rememberFailedKey(usedProvider, usedRegion, { envVarName, configIndex, @@ -5283,13 +5354,15 @@ chat.openapi(completions, async (c) => { configIndex, res.status, errorResponseText, + baseModelName, ); } - if (providerKey?.id && finishReason !== "content_filter") { + if (trackedKeyHealthId && finishReason !== "content_filter") { reportTrackedKeyError( - providerKey.id, + trackedKeyHealthId, res.status, errorResponseText, + baseModelName, ); } @@ -5428,7 +5501,13 @@ chat.openapi(completions, async (c) => { let sameProviderRetryContext: Awaited< ReturnType > | null = null; - if (isRetryableErrorType(errorType)) { + if ( + shouldRetryAlternateKey( + errorType, + inferredStatusCode, + errorResponseText, + ) + ) { rememberFailedKey(usedProvider, usedRegion, { envVarName, configIndex, @@ -5537,13 +5616,15 @@ chat.openapi(completions, async (c) => { configIndex, inferredStatusCode, errorResponseText, + baseModelName, ); } - if (providerKey?.id && errorType !== "content_filter") { + if (trackedKeyHealthId && errorType !== "content_filter") { reportTrackedKeyError( - providerKey.id, + trackedKeyHealthId, inferredStatusCode, errorResponseText, + baseModelName, ); } @@ -7937,16 +8018,27 @@ chat.openapi(completions, async (c) => { // Report key health for the selected token source if (envVarName !== undefined) { if (streamingError !== null) { - reportKeyError(envVarName, configIndex, streamingErrorStatusCode); + reportKeyError( + envVarName, + configIndex, + streamingErrorStatusCode, + undefined, + baseModelName, + ); } else { - reportKeySuccess(envVarName, configIndex); + reportKeySuccess(envVarName, configIndex, baseModelName); } } - if (providerKey?.id) { + if (trackedKeyHealthId) { if (streamingError !== null) { - reportTrackedKeyError(providerKey.id, streamingErrorStatusCode); + reportTrackedKeyError( + trackedKeyHealthId, + streamingErrorStatusCode, + undefined, + baseModelName, + ); } else { - reportTrackedKeySuccess(providerKey.id); + reportTrackedKeySuccess(trackedKeyHealthId, baseModelName); } } @@ -8290,10 +8382,10 @@ chat.openapi(completions, async (c) => { // Report key health for the selected token source if (envVarName !== undefined) { - reportKeyError(envVarName, configIndex, 0); + reportKeyError(envVarName, configIndex, 0, undefined, baseModelName); } - if (providerKey?.id) { - reportTrackedKeyError(providerKey.id, 0); + if (trackedKeyHealthId) { + reportTrackedKeyError(trackedKeyHealthId, 0, undefined, baseModelName); } if (willRetrySameProvider && sameProviderRetryContext) { @@ -8649,7 +8741,9 @@ chat.openapi(completions, async (c) => { let sameProviderRetryContext: Awaited< ReturnType > | null = null; - if (isRetryableErrorType(finishReason)) { + if ( + shouldRetryAlternateKey(finishReason, res.status, errorResponseText) + ) { rememberFailedKey(usedProvider, usedRegion, { envVarName, configIndex, @@ -8774,10 +8868,21 @@ chat.openapi(completions, async (c) => { // Report key health for the selected token source // Don't report content_filter as a key error - it's intentional provider behavior if (envVarName !== undefined && finishReason !== "content_filter") { - reportKeyError(envVarName, configIndex, res.status, errorResponseText); + reportKeyError( + envVarName, + configIndex, + res.status, + errorResponseText, + baseModelName, + ); } - if (providerKey?.id && finishReason !== "content_filter") { - reportTrackedKeyError(providerKey.id, res.status, errorResponseText); + if (trackedKeyHealthId && finishReason !== "content_filter") { + reportTrackedKeyError( + trackedKeyHealthId, + res.status, + errorResponseText, + baseModelName, + ); } if (willRetrySameProvider && sameProviderRetryContext) { @@ -9668,10 +9773,10 @@ chat.openapi(completions, async (c) => { // Report key health for the selected token source // Note: We don't report empty responses as key errors since they're not upstream errors if (envVarName !== undefined) { - reportKeySuccess(envVarName, configIndex); + reportKeySuccess(envVarName, configIndex, baseModelName); } - if (providerKey?.id) { - reportTrackedKeySuccess(providerKey.id); + if (trackedKeyHealthId) { + reportTrackedKeySuccess(trackedKeyHealthId, baseModelName); } if (cachingEnabled && cacheKey && !stream && !hasEmptyNonStreamingResponse) { diff --git a/apps/gateway/src/chat/tools/get-finish-reason-from-error.spec.ts b/apps/gateway/src/chat/tools/get-finish-reason-from-error.spec.ts index 280bfc9baf..92a90274ab 100644 --- a/apps/gateway/src/chat/tools/get-finish-reason-from-error.spec.ts +++ b/apps/gateway/src/chat/tools/get-finish-reason-from-error.spec.ts @@ -140,6 +140,30 @@ describe("getFinishReasonFromError", () => { expect(getFinishReasonFromError(403)).toBe("gateway_error"); }); + it("returns gateway_error for 400 invalid API key payloads", () => { + expect( + getFinishReasonFromError( + 400, + '{"error":{"message":"API key not valid. Please pass a valid API key.","type":"authentication_error","code":"invalid_api_key"}}', + ), + ).toBe("gateway_error"); + }); + + it("returns gateway_error for invalid_api_key code only", () => { + expect( + getFinishReasonFromError( + 400, + '{"error":{"message":"Some unfamiliar wording","code":"invalid_api_key"}}', + ), + ).toBe("gateway_error"); + }); + + it("returns gateway_error for 'Incorrect API key provided' wording", () => { + expect( + getFinishReasonFromError(401, "Incorrect API key provided: sk-test***"), + ).toBe("gateway_error"); + }); + it("returns client_error when no error text provided for other 4xx", () => { expect(getFinishReasonFromError(400)).toBe("client_error"); expect(getFinishReasonFromError(422)).toBe("client_error"); diff --git a/apps/gateway/src/chat/tools/get-finish-reason-from-error.ts b/apps/gateway/src/chat/tools/get-finish-reason-from-error.ts index 5cf2608a5d..ae63d95e42 100644 --- a/apps/gateway/src/chat/tools/get-finish-reason-from-error.ts +++ b/apps/gateway/src/chat/tools/get-finish-reason-from-error.ts @@ -1,3 +1,5 @@ +import { hasInvalidProviderCredentialError } from "@/lib/provider-auth-errors.js"; + /** * Determines the appropriate finish reason based on HTTP status code and error message * 5xx status codes indicate upstream provider errors @@ -66,8 +68,12 @@ export function getFinishReasonFromError( return "content_filter"; } - // 401/403 usually indicate invalid or unauthorized provider credentials - if (statusCode === 401 || statusCode === 403) { + // 401/403 and known provider credential payloads indicate bad provider keys. + if ( + statusCode === 401 || + statusCode === 403 || + hasInvalidProviderCredentialError(errorText) + ) { return "gateway_error"; } diff --git a/apps/gateway/src/chat/tools/get-provider-env.spec.ts b/apps/gateway/src/chat/tools/get-provider-env.spec.ts index 8e0d5a9dbd..3ed9428887 100644 --- a/apps/gateway/src/chat/tools/get-provider-env.spec.ts +++ b/apps/gateway/src/chat/tools/get-provider-env.spec.ts @@ -1,5 +1,6 @@ import { afterEach, beforeEach, describe, expect, it } from "vitest"; +import { reportKeyError, resetKeyHealth } from "@/lib/api-key-health.js"; import { resetRoundRobinCounters } from "@/lib/round-robin-env.js"; import { getProviderEnv } from "./get-provider-env.js"; @@ -9,6 +10,7 @@ describe("getProviderEnv", () => { beforeEach(() => { resetRoundRobinCounters(); + resetKeyHealth(); process.env.LLM_OPENAI_API_KEY = "sk-openai-a,sk-openai-b,sk-openai-c"; }); @@ -55,4 +57,20 @@ describe("getProviderEnv", () => { expect(thirdKey.token).toBe("sk-openai-c"); expect(thirdKey.configIndex).toBe(2); }); + + it("passes selection scope through to env key health", () => { + reportKeyError("LLM_OPENAI_API_KEY", 0, 500, undefined, "gpt-4"); + reportKeyError("LLM_OPENAI_API_KEY", 0, 500, undefined, "gpt-4"); + reportKeyError("LLM_OPENAI_API_KEY", 0, 500, undefined, "gpt-4"); + + const gpt4Selection = getProviderEnv("openai", { + selectionScope: "gpt-4", + }); + const claudeSelection = getProviderEnv("openai", { + selectionScope: "claude-3-5-sonnet", + }); + + expect(gpt4Selection.configIndex).toBe(1); + expect(claudeSelection.configIndex).toBe(0); + }); }); diff --git a/apps/gateway/src/chat/tools/get-provider-env.ts b/apps/gateway/src/chat/tools/get-provider-env.ts index 9bda2c3069..08e67e18bd 100644 --- a/apps/gateway/src/chat/tools/get-provider-env.ts +++ b/apps/gateway/src/chat/tools/get-provider-env.ts @@ -20,6 +20,7 @@ export interface ProviderEnvResult { interface GetProviderEnvOptions { advanceRoundRobin?: boolean; excludedIndices?: ReadonlySet; + selectionScope?: string; } /** @@ -62,9 +63,10 @@ export function getProviderEnv( const advanceRoundRobin = options.advanceRoundRobin ?? true; const excludedIndices = options.excludedIndices; + const selectionScope = options.selectionScope; const result = advanceRoundRobin - ? getRoundRobinValue(envVar, envValue, excludedIndices) - : peekRoundRobinValue(envVar, envValue, excludedIndices); + ? getRoundRobinValue(envVar, envValue, selectionScope, excludedIndices) + : peekRoundRobinValue(envVar, envValue, selectionScope, excludedIndices); return { token: result.value, configIndex: result.index, envVarName: envVar }; } diff --git a/apps/gateway/src/chat/tools/resolve-provider-context.ts b/apps/gateway/src/chat/tools/resolve-provider-context.ts index 707cec8440..d9c50edd58 100644 --- a/apps/gateway/src/chat/tools/resolve-provider-context.ts +++ b/apps/gateway/src/chat/tools/resolve-provider-context.ts @@ -40,6 +40,13 @@ export interface ProviderContext { usedToken: string; usedApiKeyHash: string; providerKey: InferSelectModel | undefined; + /** + * Provider-key id to attribute health failures to via reportTrackedKey*. + * Equal to `providerKey.id` when the BYOK key is the credential actually + * sent, undefined when a regional env-var override replaces the token + * (in which case `envVarName` carries the health attribution). + */ + trackedKeyHealthId: string | undefined; configIndex: number; envVarName: string | undefined; url: string; @@ -201,14 +208,14 @@ export async function resolveProviderContext( providerKey = await findCustomProviderKey( project.organizationId, options.customProviderName, - options.requestId, + baseModelName, options.excludedProviderKeyIds, ); } else { providerKey = await findProviderKey( project.organizationId, usedProvider, - options.requestId, + baseModelName, options.excludedProviderKeyIds, ); } @@ -224,6 +231,7 @@ export async function resolveProviderContext( assertOrganizationHasCreditsForEnvFallback(organization, modelInfo); const envResult = getProviderEnv(usedProvider as Provider, { excludedIndices: options.excludedEnvKeyIndices, + selectionScope: baseModelName, }); usedToken = envResult.token; configIndex = envResult.configIndex; @@ -233,14 +241,14 @@ export async function resolveProviderContext( providerKey = await findCustomProviderKey( project.organizationId, options.customProviderName, - options.requestId, + baseModelName, options.excludedProviderKeyIds, ); } else { providerKey = await findProviderKey( project.organizationId, usedProvider, - options.requestId, + baseModelName, options.excludedProviderKeyIds, ); } @@ -251,6 +259,7 @@ export async function resolveProviderContext( assertOrganizationHasCreditsForEnvFallback(organization, modelInfo); const envResult = getProviderEnv(usedProvider as Provider, { excludedIndices: options.excludedEnvKeyIndices, + selectionScope: baseModelName, }); usedToken = envResult.token; configIndex = envResult.configIndex; @@ -492,6 +501,7 @@ export async function resolveProviderContext( usedToken, usedApiKeyHash, providerKey, + trackedKeyHealthId: providerKey?.id, configIndex, envVarName, url, diff --git a/apps/gateway/src/chat/tools/retry-with-fallback.spec.ts b/apps/gateway/src/chat/tools/retry-with-fallback.spec.ts index e72050f479..2f894b9679 100644 --- a/apps/gateway/src/chat/tools/retry-with-fallback.spec.ts +++ b/apps/gateway/src/chat/tools/retry-with-fallback.spec.ts @@ -2,6 +2,7 @@ import { describe, it, expect } from "vitest"; import { isRetryableErrorType, + shouldRetryAlternateKey, shouldRetryRequest, selectNextProvider, getErrorType, @@ -125,6 +126,33 @@ describe("shouldRetryRequest", () => { }); }); +describe("shouldRetryAlternateKey", () => { + it("retries alternate keys for retryable upstream failures", () => { + expect(shouldRetryAlternateKey("upstream_error", 500)).toBe(true); + expect(shouldRetryAlternateKey("network_error", 0)).toBe(true); + }); + + it("retries alternate keys for auth failures on the current provider", () => { + expect(shouldRetryAlternateKey("gateway_error", 401)).toBe(true); + expect(shouldRetryAlternateKey("gateway_error", 403)).toBe(true); + }); + + it("retries alternate keys for invalid API key payloads without 401/403", () => { + expect( + shouldRetryAlternateKey( + "gateway_error", + 400, + "API key not valid. Please pass a valid API key.", + ), + ).toBe(true); + }); + + it("does not retry alternate keys for non-retryable failure types", () => { + expect(shouldRetryAlternateKey("client_error", 400)).toBe(false); + expect(shouldRetryAlternateKey("content_filter", 403)).toBe(false); + }); +}); + describe("selectNextProvider", () => { const modelProviders = [ { providerId: "openai", modelName: "gpt-4o" }, diff --git a/apps/gateway/src/chat/tools/retry-with-fallback.ts b/apps/gateway/src/chat/tools/retry-with-fallback.ts index 5055469f08..cbbb7ea772 100644 --- a/apps/gateway/src/chat/tools/retry-with-fallback.ts +++ b/apps/gateway/src/chat/tools/retry-with-fallback.ts @@ -1,3 +1,5 @@ +import { hasInvalidProviderCredentialError } from "@/lib/provider-auth-errors.js"; + export const MAX_RETRIES = 2; export type RetryableErrorType = @@ -33,6 +35,28 @@ export function isRetryableErrorType(errorType: string): boolean { ); } +/** + * Determines whether a failed request should be retried against another key + * for the same provider. + * + * Auth failures (401/403) are not eligible for cross-provider fallback, but + * they should still rotate to another configured key for the current provider + * because the failure is often isolated to a single credential. + */ +export function shouldRetryAlternateKey( + errorType: string, + statusCode?: number, + errorText?: string, +): boolean { + return ( + isRetryableErrorType(errorType) || + (errorType === "gateway_error" && + ((statusCode !== undefined && + (statusCode === 401 || statusCode === 403)) || + hasInvalidProviderCredentialError(errorText))) + ); +} + /** * Determines whether a failed request should be retried with a different provider. * Only retries when no specific provider was requested, the error is retryable, diff --git a/apps/gateway/src/fallback.spec.ts b/apps/gateway/src/fallback.spec.ts index 507bd92e82..7f10af3fba 100644 --- a/apps/gateway/src/fallback.spec.ts +++ b/apps/gateway/src/fallback.spec.ts @@ -14,6 +14,11 @@ import { getProviderDefinition } from "@llmgateway/models"; import { app } from "./app.js"; import { getApiKeyFingerprint } from "./lib/api-key-fingerprint.js"; +import { + isTrackedKeyHealthy, + reportTrackedKeyError, + resetKeyHealth, +} from "./lib/api-key-health.js"; import { startMockServer, stopMockServer, @@ -89,6 +94,7 @@ describe("fallback and error status code handling", () => { async function resetTestState() { resetFailOnceCounter(); + resetKeyHealth(); await clearCache(); await db.update(tables.modelProviderMapping).set({ routingUptime: null, @@ -238,6 +244,41 @@ describe("fallback and error status code handling", () => { ]); } + async function setupSingleProviderWithRegionalKeys(provider = "alibaba") { + await ensureBaseFixtures(); + + await db.insert(tables.apiKey).values({ + id: "token-id", + token: "real-token", + projectId: "project-id", + description: "Test API Key", + createdBy: "user-id", + }); + + await db.insert(tables.providerKey).values([ + { + id: `${provider}-key-singapore`, + token: `${provider}-singapore-token`, + provider, + organizationId: "org-id", + baseUrl: mockServerUrl, + options: { + alibaba_region: "singapore", + }, + }, + { + id: `${provider}-key-beijing`, + token: `${provider}-beijing-token`, + provider, + organizationId: "org-id", + baseUrl: mockServerUrl, + options: { + alibaba_region: "cn-beijing", + }, + }, + ]); + } + async function setRoutingMetrics( modelId: string, providerId: string, @@ -1432,6 +1473,57 @@ describe("fallback and error status code handling", () => { ]); }); + test("direct provider selection follows the scoped key region after failover", async () => { + await setupSingleProviderWithRegionalKeys("alibaba"); + await ensureRegionalMapping("deepseek-v3.2", "alibaba", "singapore"); + await ensureRegionalMapping("deepseek-v3.2", "alibaba", "cn-beijing"); + + reportTrackedKeyError( + "alibaba-key-singapore", + 500, + undefined, + "deepseek-v3.2", + ); + reportTrackedKeyError( + "alibaba-key-singapore", + 500, + undefined, + "deepseek-v3.2", + ); + reportTrackedKeyError( + "alibaba-key-singapore", + 500, + undefined, + "deepseek-v3.2", + ); + + const res = await app.request("/v1/chat/completions", { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: "Bearer real-token", + }, + body: JSON.stringify({ + model: "alibaba/deepseek-v3.2", + messages: [{ role: "user", content: "Hello!" }], + }), + }); + + expect(res.status).toBe(200); + + const logs = await waitForLogs(1); + expect(logs[0].usedModel).toBe("alibaba/deepseek-v3.2:cn-beijing"); + expect(logs[0].routingMetadata?.routing).toEqual([ + expect.objectContaining({ + provider: "alibaba", + model: "deepseek-v3.2", + region: "cn-beijing", + status_code: 200, + succeeded: true, + }), + ]); + }); + test("provider-agnostic routing keeps regional mappings aggregated", async () => { await setupKeys("alibaba"); @@ -2234,6 +2326,87 @@ describe("fallback and error status code handling", () => { }); }); + test("non-streaming: retries another key for auth failures on the same explicit provider", async () => { + await setupSingleProviderWithMultipleKeys("together-ai"); + + const res = await app.request("/v1/chat/completions", { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: "Bearer real-token", + }, + body: JSON.stringify({ + model: "together-ai/glm-4.7", + messages: [{ role: "user", content: "TRIGGER_STATUS_401" }], + }), + }); + + expect(res.status).toBe(500); + const json = await res.json(); + expect(json.error.type).toBe("gateway_error"); + + const logs = await waitForLogs(2); + const authLogs = logs.filter( + (log: Log) => log.errorDetails?.statusCode === 401, + ); + expect(authLogs).toHaveLength(2); + expect(authLogs.some((log: Log) => log.retried)).toBe(true); + expect(isTrackedKeyHealthy("together-ai-key-primary", "glm-4.7")).toBe( + false, + ); + expect(isTrackedKeyHealthy("together-ai-key-secondary", "glm-4.7")).toBe( + false, + ); + }); + + test("non-streaming: retries another key for invalid API key payloads", async () => { + await setupSingleProviderWithMultipleKeys("together-ai"); + + const res = await app.request("/v1/chat/completions", { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: "Bearer real-token", + }, + body: JSON.stringify({ + model: "together-ai/glm-4.7", + messages: [ + { role: "user", content: "TRIGGER_FAIL_ONCE_INVALID_KEY" }, + ], + }), + }); + + expect(res.status).toBe(200); + const json = await res.json(); + expect(json.metadata.routing).toHaveLength(2); + expect(json.metadata.routing[0]).toMatchObject({ + provider: "together-ai", + status_code: 400, + succeeded: false, + }); + expect(json.metadata.routing[1]).toMatchObject({ + provider: "together-ai", + succeeded: true, + }); + + const logs = await waitForLogs(2); + const failedLog = logs.find( + (log: Log) => log.errorDetails?.statusCode === 400, + ); + const successLog = logs.find( + (log: Log) => log.finishReason === "stop" || !log.hasError, + ); + expect(failedLog?.finishReason).toBe("gateway_error"); + expect(failedLog?.retried).toBe(true); + expect(successLog?.routingMetadata?.routing).toHaveLength(2); + expect(isTrackedKeyHealthy("together-ai-key-primary", "glm-4.7")).toBe( + false, + ); + expect(isTrackedKeyHealthy("together-ai-key-secondary", "glm-4.7")).toBe( + true, + ); + }); + test("streaming: retries on 500 and delivers response on fallback provider", async () => { await setupMultiProviderKeys(); @@ -2352,7 +2525,7 @@ describe("fallback and error status code handling", () => { "together-ai-secondary-token", ); const originalStreamingTimeout = process.env.AI_STREAMING_TIMEOUT_MS; - process.env.AI_STREAMING_TIMEOUT_MS = "10"; + process.env.AI_STREAMING_TIMEOUT_MS = "75"; try { const res = await app.request("/v1/chat/completions", { diff --git a/apps/gateway/src/graceful-shutdown.spec.ts b/apps/gateway/src/graceful-shutdown.spec.ts index 2fe17d3adc..f4de206f27 100644 --- a/apps/gateway/src/graceful-shutdown.spec.ts +++ b/apps/gateway/src/graceful-shutdown.spec.ts @@ -28,6 +28,18 @@ const closeServer = (server: ServerType): Promise => { }); }; +const waitForServerListening = (server: ServerType): Promise => { + const httpServer = server as Server; + + if (httpServer.listening) { + return Promise.resolve(); + } + + return new Promise((resolve) => { + httpServer.once("listening", () => resolve()); + }); +}; + // Recommended implementation: server.close() + periodic idle connection drain + timeout const closeServerWithTimeout = ( server: ServerType, @@ -109,6 +121,7 @@ describe("graceful shutdown", () => { }); const server = serve({ fetch: app.fetch, port: 0 }); + await waitForServerListening(server); const address = server.address(); const port = typeof address === "object" && address !== null ? address.port : 0; @@ -157,6 +170,7 @@ describe("graceful shutdown", () => { }); const server = serve({ fetch: app.fetch, port: 0 }); + await waitForServerListening(server); const address = server.address(); const port = typeof address === "object" && address !== null ? address.port : 0; @@ -200,6 +214,7 @@ describe("graceful shutdown", () => { app.get("/fast", (c) => c.json({ message: "ok" })); const server = serve({ fetch: app.fetch, port: 0 }); + await waitForServerListening(server); const address = server.address(); const port = typeof address === "object" && address !== null ? address.port : 0; @@ -238,6 +253,7 @@ describe("graceful shutdown", () => { }); const server = serve({ fetch: app.fetch, port: 0 }); + await waitForServerListening(server); const address = server.address(); const port = typeof address === "object" && address !== null ? address.port : 0; @@ -287,6 +303,7 @@ describe("graceful shutdown", () => { }); const server = serve({ fetch: app.fetch, port: 0 }); + await waitForServerListening(server); const address = server.address(); const port = typeof address === "object" && address !== null ? address.port : 0; @@ -333,6 +350,7 @@ describe("graceful shutdown", () => { }); const server = serve({ fetch: app.fetch, port: 0 }); + await waitForServerListening(server); const address = server.address(); const port = typeof address === "object" && address !== null ? address.port : 0; @@ -386,6 +404,7 @@ describe("graceful shutdown", () => { }); const server = serve({ fetch: app.fetch, port: 0 }); + await waitForServerListening(server); const address = server.address(); const port = typeof address === "object" && address !== null ? address.port : 0; diff --git a/apps/gateway/src/lib/api-key-health.spec.ts b/apps/gateway/src/lib/api-key-health.spec.ts index c7e98f28e6..6b8c587ab5 100644 --- a/apps/gateway/src/lib/api-key-health.spec.ts +++ b/apps/gateway/src/lib/api-key-health.spec.ts @@ -2,11 +2,14 @@ import { describe, it, expect, beforeEach } from "vitest"; import { isKeyHealthy, + isTrackedKeyHealthy, reportKeySuccess, reportKeyError, + reportTrackedKeyError, getKeyHealth, getKeyMetrics, getAllKeyMetrics, + getTrackedKeyMetrics, calculateUptimePenalty, resetKeyHealth, UPTIME_PENALTY_THRESHOLD, @@ -59,6 +62,85 @@ describe("api-key-health", () => { expect(isKeyHealthy("LLM_OPENAI_API_KEY", 1)).toBe(true); expect(isKeyHealthy("LLM_ANTHROPIC_API_KEY", 0)).toBe(true); }); + + it("should propagate auth blacklist across selection scopes", () => { + reportKeyError("LLM_OPENAI_API_KEY", 0, 401, undefined, "gpt-4"); + expect(isKeyHealthy("LLM_OPENAI_API_KEY", 0, "gpt-4")).toBe(false); + expect(isKeyHealthy("LLM_OPENAI_API_KEY", 0, "claude-3-5-sonnet")).toBe( + false, + ); + expect(isKeyHealthy("LLM_OPENAI_API_KEY", 0)).toBe(false); + expect(isKeyHealthy("LLM_OPENAI_API_KEY", 1, "gpt-4")).toBe(true); + }); + + it("should propagate invalid-key payload blacklist across scopes", () => { + reportKeyError( + "LLM_OPENAI_API_KEY", + 0, + 400, + "API key not valid. Please pass a valid API key.", + "gpt-4", + ); + expect(isKeyHealthy("LLM_OPENAI_API_KEY", 0, "gpt-4")).toBe(false); + expect(isKeyHealthy("LLM_OPENAI_API_KEY", 0, "claude-3-5-sonnet")).toBe( + false, + ); + }); + + it("should keep non-auth uptime penalties scoped per model", () => { + reportKeyError("LLM_OPENAI_API_KEY", 0, 500, undefined, "gpt-4"); + reportKeyError("LLM_OPENAI_API_KEY", 0, 500, undefined, "gpt-4"); + reportKeyError("LLM_OPENAI_API_KEY", 0, 500, undefined, "gpt-4"); + + expect(isKeyHealthy("LLM_OPENAI_API_KEY", 0, "gpt-4")).toBe(false); + expect(isKeyHealthy("LLM_OPENAI_API_KEY", 0, "claude-3-5-sonnet")).toBe( + true, + ); + }); + + it("should propagate tracked-key auth blacklist across scopes", () => { + reportTrackedKeyError("provider-key-1", 401, undefined, "gpt-4"); + expect(isTrackedKeyHealthy("provider-key-1", "gpt-4")).toBe(false); + expect(isTrackedKeyHealthy("provider-key-1", "claude-3-5-sonnet")).toBe( + false, + ); + expect(isTrackedKeyHealthy("provider-key-1")).toBe(false); + expect(isTrackedKeyHealthy("provider-key-2", "gpt-4")).toBe(true); + }); + + it("should surface unscoped blacklist in metrics for new scopes", () => { + reportTrackedKeyError("provider-key-1", 401, undefined, "gpt-4"); + const metrics = getTrackedKeyMetrics( + "provider-key-1", + "claude-3-5-sonnet", + ); + expect(metrics.permanentlyBlacklisted).toBe(true); + expect(metrics.uptime).toBe(0); + }); + + it("should report zero uptime for blacklisted keys even with prior successes in another scope", () => { + reportKeySuccess("LLM_OPENAI_API_KEY", 0, "gpt-4"); + reportKeySuccess("LLM_OPENAI_API_KEY", 0, "gpt-4"); + reportKeyError( + "LLM_OPENAI_API_KEY", + 0, + 401, + undefined, + "claude-3-5-sonnet", + ); + + const successScopeMetrics = getKeyMetrics( + "LLM_OPENAI_API_KEY", + 0, + "gpt-4", + ); + expect(successScopeMetrics.permanentlyBlacklisted).toBe(true); + expect(successScopeMetrics.uptime).toBe(0); + + const unscopedMetrics = getKeyMetrics("LLM_OPENAI_API_KEY", 0); + expect(unscopedMetrics.permanentlyBlacklisted).toBe(true); + expect(unscopedMetrics.uptime).toBe(0); + }); }); describe("reportKeySuccess", () => { @@ -161,6 +243,22 @@ describe("api-key-health", () => { permanentlyBlacklisted: true, }); }); + + it("should permanently blacklist ignored 4xx with invalid key text", () => { + reportKeyError( + "LLM_OPENAI_API_KEY", + 0, + 400, + "API key not valid. Please pass a valid API key.", + ); + + expect(getKeyMetrics("LLM_OPENAI_API_KEY", 0)).toMatchObject({ + uptime: 0, + totalRequests: 1, + consecutiveErrors: 0, + permanentlyBlacklisted: true, + }); + }); }); describe("getKeyHealth", () => { diff --git a/apps/gateway/src/lib/api-key-health.ts b/apps/gateway/src/lib/api-key-health.ts index baed6e4b40..c03a4f0a84 100644 --- a/apps/gateway/src/lib/api-key-health.ts +++ b/apps/gateway/src/lib/api-key-health.ts @@ -1,3 +1,5 @@ +import { hasInvalidProviderCredentialError } from "./provider-auth-errors.js"; + /** * In-memory API key health tracking for uptime-aware routing * Tracks historical error rates per API key using a sliding window approach @@ -77,27 +79,59 @@ const PERMANENT_ERROR_CODES = [401, 403]; */ const UPTIME_RELEVANT_4XX_CODES = new Set([...PERMANENT_ERROR_CODES, 404, 429]); -/** - * Error messages that indicate permanent key issues - */ -const PERMANENT_ERROR_MESSAGES = [ - "API Key not found. Please pass a valid API key.", -]; - /** * Uptime threshold below which exponential penalty kicks in */ export const UPTIME_PENALTY_THRESHOLD = 95; +function appendSelectionScope( + baseKey: string, + selectionScope?: string, +): string { + return selectionScope ? `${baseKey}:${selectionScope}` : baseKey; +} + /** * Get the health key identifier for a specific API key */ -function getHealthKey(envVarName: string, keyIndex: number): string { - return `${envVarName}:${keyIndex}`; +function getHealthKey( + envVarName: string, + keyIndex: number, + selectionScope?: string, +): string { + return appendSelectionScope(`${envVarName}:${keyIndex}`, selectionScope); +} + +function getTrackedHealthKey(keyId: string, selectionScope?: string): string { + return appendSelectionScope(`tracked:${keyId}`, selectionScope); +} + +/** + * Auth validity is provider-wide, not scope-specific. A 401/403 or + * invalid-key payload on one model does not become valid on another. + * The permanent blacklist is therefore stored on an unscoped record so + * health checks for any future scope still see the key as bad. + */ +function markUnscopedPermanentBlacklist( + unscopedKey: string, + now: number, +): void { + let unscoped = keyHealthMap.get(unscopedKey); + if (!unscoped) { + unscoped = { + consecutiveErrors: 0, + lastErrorTime: 0, + permanentlyBlacklisted: false, + history: [], + }; + keyHealthMap.set(unscopedKey, unscoped); + } + unscoped.permanentlyBlacklisted = true; + unscoped.lastErrorTime = now; } -function getTrackedHealthKey(keyId: string): string { - return `tracked:${keyId}`; +function isUnscopedPermanentlyBlacklisted(unscopedKey: string): boolean { + return keyHealthMap.get(unscopedKey)?.permanentlyBlacklisted === true; } /** @@ -157,8 +191,16 @@ export function calculateUptimePenalty(uptime: number): number { * @param keyIndex The index of the key in the comma-separated list * @returns true if the key is healthy, false if it should be skipped */ -export function isKeyHealthy(envVarName: string, keyIndex: number): boolean { - const healthKey = getHealthKey(envVarName, keyIndex); +export function isKeyHealthy( + envVarName: string, + keyIndex: number, + selectionScope?: string, +): boolean { + if (isUnscopedPermanentlyBlacklisted(getHealthKey(envVarName, keyIndex))) { + return false; + } + + const healthKey = getHealthKey(envVarName, keyIndex, selectionScope); const health = keyHealthMap.get(healthKey); if (!health) { @@ -182,8 +224,15 @@ export function isKeyHealthy(envVarName: string, keyIndex: number): boolean { return true; } -export function isTrackedKeyHealthy(keyId: string): boolean { - const healthKey = getTrackedHealthKey(keyId); +export function isTrackedKeyHealthy( + keyId: string, + selectionScope?: string, +): boolean { + if (isUnscopedPermanentlyBlacklisted(getTrackedHealthKey(keyId))) { + return false; + } + + const healthKey = getTrackedHealthKey(keyId, selectionScope); const health = keyHealthMap.get(healthKey); if (!health) { @@ -215,51 +264,65 @@ export function isTrackedKeyHealthy(keyId: string): boolean { export function getKeyMetrics( envVarName: string, keyIndex: number, + selectionScope?: string, ): KeyMetrics { - const healthKey = getHealthKey(envVarName, keyIndex); + const unscopedBlacklisted = isUnscopedPermanentlyBlacklisted( + getHealthKey(envVarName, keyIndex), + ); + const healthKey = getHealthKey(envVarName, keyIndex, selectionScope); const health = keyHealthMap.get(healthKey); if (!health) { return { - uptime: 100, + uptime: unscopedBlacklisted ? 0 : 100, totalRequests: 0, consecutiveErrors: 0, - permanentlyBlacklisted: false, + permanentlyBlacklisted: unscopedBlacklisted, }; } const now = Date.now(); pruneHistory(health, now); + const permanentlyBlacklisted = + health.permanentlyBlacklisted || unscopedBlacklisted; return { - uptime: calculateUptime(health, now), + uptime: permanentlyBlacklisted ? 0 : calculateUptime(health, now), totalRequests: health.history.length, consecutiveErrors: health.consecutiveErrors, - permanentlyBlacklisted: health.permanentlyBlacklisted, + permanentlyBlacklisted, }; } -export function getTrackedKeyMetrics(keyId: string): KeyMetrics { - const healthKey = getTrackedHealthKey(keyId); +export function getTrackedKeyMetrics( + keyId: string, + selectionScope?: string, +): KeyMetrics { + const unscopedBlacklisted = isUnscopedPermanentlyBlacklisted( + getTrackedHealthKey(keyId), + ); + const healthKey = getTrackedHealthKey(keyId, selectionScope); const health = keyHealthMap.get(healthKey); if (!health) { return { - uptime: 100, + uptime: unscopedBlacklisted ? 0 : 100, totalRequests: 0, consecutiveErrors: 0, - permanentlyBlacklisted: false, + permanentlyBlacklisted: unscopedBlacklisted, }; } const now = Date.now(); pruneHistory(health, now); + const permanentlyBlacklisted = + health.permanentlyBlacklisted || unscopedBlacklisted; return { - uptime: calculateUptime(health, now), + uptime: permanentlyBlacklisted ? 0 : calculateUptime(health, now), totalRequests: health.history.length, consecutiveErrors: health.consecutiveErrors, - permanentlyBlacklisted: health.permanentlyBlacklisted, + permanentlyBlacklisted, }; } @@ -272,10 +335,11 @@ export function getTrackedKeyMetrics(keyId: string): KeyMetrics { export function getAllKeyMetrics( envVarName: string, keyCount: number, + selectionScope?: string, ): KeyMetrics[] { const metrics: KeyMetrics[] = []; for (let i = 0; i < keyCount; i++) { - metrics.push(getKeyMetrics(envVarName, i)); + metrics.push(getKeyMetrics(envVarName, i, selectionScope)); } return metrics; } @@ -284,8 +348,12 @@ export function getAllKeyMetrics( * Report a successful request for an API key * Resets the consecutive error counter and adds to history */ -export function reportKeySuccess(envVarName: string, keyIndex: number): void { - const healthKey = getHealthKey(envVarName, keyIndex); +export function reportKeySuccess( + envVarName: string, + keyIndex: number, + selectionScope?: string, +): void { + const healthKey = getHealthKey(envVarName, keyIndex, selectionScope); let health = keyHealthMap.get(healthKey); const now = Date.now(); @@ -309,8 +377,11 @@ export function reportKeySuccess(envVarName: string, keyIndex: number): void { pruneHistory(health, now); } -export function reportTrackedKeySuccess(keyId: string): void { - const healthKey = getTrackedHealthKey(keyId); +export function reportTrackedKeySuccess( + keyId: string, + selectionScope?: string, +): void { + const healthKey = getTrackedHealthKey(keyId, selectionScope); let health = keyHealthMap.get(healthKey); const now = Date.now(); @@ -344,8 +415,9 @@ export function reportKeyError( keyIndex: number, statusCode?: number, errorText?: string, + selectionScope?: string, ): void { - const healthKey = getHealthKey(envVarName, keyIndex); + const healthKey = getHealthKey(envVarName, keyIndex, selectionScope); let health = keyHealthMap.get(healthKey); const now = Date.now(); @@ -360,9 +432,7 @@ export function reportKeyError( keyHealthMap.set(healthKey, health); } - const isPermanentErrorMessage = - errorText !== undefined && - PERMANENT_ERROR_MESSAGES.some((msg) => errorText.includes(msg)); + const isPermanentErrorMessage = hasInvalidProviderCredentialError(errorText); // Most upstream 4xx responses are client-side request issues and should not // degrade provider uptime or influence routing decisions. @@ -376,19 +446,15 @@ export function reportKeyError( return; } - // Check for permanent auth errors by status code - if (statusCode && PERMANENT_ERROR_CODES.includes(statusCode)) { - health.permanentlyBlacklisted = true; - // Still add to history for metrics visibility - health.history.push({ timestamp: now, success: false }); - pruneHistory(health, now); - return; - } - - // Check for permanent auth errors by error message - if (isPermanentErrorMessage) { + // Check for permanent auth errors by status code or payload. Auth validity + // is provider-wide, so the blacklist is recorded on the unscoped record so + // future scopes (other models) also skip this key. + if ( + (statusCode && PERMANENT_ERROR_CODES.includes(statusCode)) || + isPermanentErrorMessage + ) { + markUnscopedPermanentBlacklist(getHealthKey(envVarName, keyIndex), now); health.permanentlyBlacklisted = true; - // Still add to history for metrics visibility health.history.push({ timestamp: now, success: false }); pruneHistory(health, now); return; @@ -406,8 +472,9 @@ export function reportTrackedKeyError( keyId: string, statusCode?: number, errorText?: string, + selectionScope?: string, ): void { - const healthKey = getTrackedHealthKey(keyId); + const healthKey = getTrackedHealthKey(keyId, selectionScope); let health = keyHealthMap.get(healthKey); const now = Date.now(); @@ -422,9 +489,7 @@ export function reportTrackedKeyError( keyHealthMap.set(healthKey, health); } - const isPermanentErrorMessage = - errorText !== undefined && - PERMANENT_ERROR_MESSAGES.some((msg) => errorText.includes(msg)); + const isPermanentErrorMessage = hasInvalidProviderCredentialError(errorText); if ( statusCode !== undefined && @@ -436,14 +501,11 @@ export function reportTrackedKeyError( return; } - if (statusCode && PERMANENT_ERROR_CODES.includes(statusCode)) { - health.permanentlyBlacklisted = true; - health.history.push({ timestamp: now, success: false }); - pruneHistory(health, now); - return; - } - - if (isPermanentErrorMessage) { + if ( + (statusCode && PERMANENT_ERROR_CODES.includes(statusCode)) || + isPermanentErrorMessage + ) { + markUnscopedPermanentBlacklist(getTrackedHealthKey(keyId), now); health.permanentlyBlacklisted = true; health.history.push({ timestamp: now, success: false }); pruneHistory(health, now); @@ -462,8 +524,9 @@ export function reportTrackedKeyError( export function getKeyHealth( envVarName: string, keyIndex: number, + selectionScope?: string, ): KeyHealth | undefined { - return keyHealthMap.get(getHealthKey(envVarName, keyIndex)); + return keyHealthMap.get(getHealthKey(envVarName, keyIndex, selectionScope)); } /** diff --git a/apps/gateway/src/lib/cached-queries.spec.ts b/apps/gateway/src/lib/cached-queries.spec.ts index 897ffa6106..460733c787 100644 --- a/apps/gateway/src/lib/cached-queries.spec.ts +++ b/apps/gateway/src/lib/cached-queries.spec.ts @@ -273,6 +273,22 @@ describe("Cached Queries - Gateway Database Access", () => { expect(result?.id).toBe("test-provider-key-cached-queries-2"); }); + it("should keep tracked key health isolated per model scope", async () => { + reportTrackedKeyError(testProviderKeyId, 500, undefined, "gpt-4"); + reportTrackedKeyError(testProviderKeyId, 500, undefined, "gpt-4"); + reportTrackedKeyError(testProviderKeyId, 500, undefined, "gpt-4"); + + const gpt4Selection = await findProviderKey(testOrgId, "openai", "gpt-4"); + const claudeSelection = await findProviderKey( + testOrgId, + "openai", + "claude-3-5-sonnet", + ); + + expect(gpt4Selection?.id).toBe("test-provider-key-cached-queries-2"); + expect(claudeSelection?.id).toBe(testProviderKeyId); + }); + it("should select the next provider key when the current one is excluded", async () => { const result = await findProviderKey( testOrgId, diff --git a/apps/gateway/src/lib/cached-queries.ts b/apps/gateway/src/lib/cached-queries.ts index 2f64941798..039efacd5f 100644 --- a/apps/gateway/src/lib/cached-queries.ts +++ b/apps/gateway/src/lib/cached-queries.ts @@ -65,6 +65,7 @@ const userOrganizationTableName = getTableName(userOrganizationTable); function selectProviderKeyWithFailover( items: T[], + selectionScope?: string, excludedKeyIds: ReadonlySet = new Set(), ): T | undefined { const availableItems = items.filter((item) => !excludedKeyIds.has(item.id)); @@ -81,9 +82,9 @@ function selectProviderKeyWithFailover( .map((item, index) => ({ item, index, - metrics: getTrackedKeyMetrics(item.id), + metrics: getTrackedKeyMetrics(item.id, selectionScope), })) - .filter(({ item }) => isTrackedKeyHealthy(item.id)); + .filter(({ item }) => isTrackedKeyHealthy(item.id, selectionScope)); if (healthyItems.length === 0) { return availableItems[0]; @@ -207,7 +208,7 @@ export async function findOrganizationById( export async function findCustomProviderKey( organizationId: string, customProviderName: string, - _selectionKey?: string, + selectionScope?: string, excludedKeyIds?: ReadonlySet, ): Promise { const results = await swrWrap( @@ -227,7 +228,7 @@ export async function findCustomProviderKey( ) .orderBy(asc(providerKeyTable.createdAt), asc(providerKeyTable.id)), ); - return selectProviderKeyWithFailover(results, excludedKeyIds); + return selectProviderKeyWithFailover(results, selectionScope, excludedKeyIds); } /** @@ -236,7 +237,7 @@ export async function findCustomProviderKey( export async function findProviderKey( organizationId: string, provider: string, - _selectionKey?: string, + selectionScope?: string, excludedKeyIds?: ReadonlySet, filter?: (key: ProviderKey) => boolean, ): Promise { @@ -257,7 +258,11 @@ export async function findProviderKey( .orderBy(asc(providerKeyTable.createdAt), asc(providerKeyTable.id)), ); const filtered = filter ? results.filter(filter) : results; - return selectProviderKeyWithFailover(filtered, excludedKeyIds); + return selectProviderKeyWithFailover( + filtered, + selectionScope, + excludedKeyIds, + ); } /** diff --git a/apps/gateway/src/lib/provider-auth-errors.ts b/apps/gateway/src/lib/provider-auth-errors.ts new file mode 100644 index 0000000000..36e77da10d --- /dev/null +++ b/apps/gateway/src/lib/provider-auth-errors.ts @@ -0,0 +1,17 @@ +const INVALID_PROVIDER_CREDENTIAL_PATTERNS = [ + /\binvalid_api_key\b/i, + /incorrect api key provided/i, + /api key not valid/i, + /api key not found/i, + /please pass a valid api key/i, +]; + +export function hasInvalidProviderCredentialError(errorText?: string): boolean { + if (!errorText) { + return false; + } + + return INVALID_PROVIDER_CREDENTIAL_PATTERNS.some((pattern) => + pattern.test(errorText), + ); +} diff --git a/apps/gateway/src/lib/round-robin-env.spec.ts b/apps/gateway/src/lib/round-robin-env.spec.ts index 5cfe84f981..acccbb8bf8 100644 --- a/apps/gateway/src/lib/round-robin-env.spec.ts +++ b/apps/gateway/src/lib/round-robin-env.spec.ts @@ -188,6 +188,26 @@ describe("round-robin-env", () => { expect(result2.index).toBe(0); }); + it("should keep env key health isolated per model scope", () => { + reportKeyError("TEST_VAR", 0, 500, undefined, "gpt-4"); + reportKeyError("TEST_VAR", 0, 500, undefined, "gpt-4"); + reportKeyError("TEST_VAR", 0, 500, undefined, "gpt-4"); + + const gpt4Selection = getRoundRobinValue( + "TEST_VAR", + "value1,value2", + "gpt-4", + ); + const claudeSelection = getRoundRobinValue( + "TEST_VAR", + "value1,value2", + "claude-3-5-sonnet", + ); + + expect(gpt4Selection.index).toBe(1); + expect(claudeSelection.index).toBe(0); + }); + it("should keep using the primary key when uptimes are identical", () => { // Key 0 and 1 both have 100% uptime for (let i = 0; i < 5; i++) { diff --git a/apps/gateway/src/lib/round-robin-env.ts b/apps/gateway/src/lib/round-robin-env.ts index 418ea7497f..322bb8b926 100644 --- a/apps/gateway/src/lib/round-robin-env.ts +++ b/apps/gateway/src/lib/round-robin-env.ts @@ -38,6 +38,7 @@ function selectRoundRobinValue( envVarName: string, value: string, _advanceCounter: boolean, + selectionScope?: string, excludedIndices: ReadonlySet = new Set(), ): RoundRobinResult { const values = parseCommaSeparatedEnv(value); @@ -69,7 +70,7 @@ function selectRoundRobinValue( continue; } - const metrics = getKeyMetrics(envVarName, i); + const metrics = getKeyMetrics(envVarName, i, selectionScope); // Skip permanently blacklisted keys entirely if (metrics.permanentlyBlacklisted) { @@ -77,7 +78,7 @@ function selectRoundRobinValue( } // Check if temporarily unhealthy (consecutive errors threshold) - if (!isKeyHealthy(envVarName, i)) { + if (!isKeyHealthy(envVarName, i, selectionScope)) { continue; } @@ -128,9 +129,16 @@ function selectRoundRobinValue( export function getRoundRobinValue( envVarName: string, value: string, + selectionScope?: string, excludedIndices?: ReadonlySet, ): RoundRobinResult { - return selectRoundRobinValue(envVarName, value, true, excludedIndices); + return selectRoundRobinValue( + envVarName, + value, + true, + selectionScope, + excludedIndices, + ); } /** @@ -140,9 +148,16 @@ export function getRoundRobinValue( export function peekRoundRobinValue( envVarName: string, value: string, + selectionScope?: string, excludedIndices?: ReadonlySet, ): RoundRobinResult { - return selectRoundRobinValue(envVarName, value, false, excludedIndices); + return selectRoundRobinValue( + envVarName, + value, + false, + selectionScope, + excludedIndices, + ); } /** diff --git a/apps/gateway/src/moderations/moderations.ts b/apps/gateway/src/moderations/moderations.ts index baa9f66839..8f05c843c1 100644 --- a/apps/gateway/src/moderations/moderations.ts +++ b/apps/gateway/src/moderations/moderations.ts @@ -386,7 +386,7 @@ moderations.openapi(createModeration, async (c): Promise => { providerKey = await findProviderKey( project.organizationId, "openai", - requestId, + upstreamModel, ); if (!providerKey) { throw new HTTPException(400, { @@ -396,7 +396,9 @@ moderations.openapi(createModeration, async (c): Promise => { } usedToken = providerKey.token; } else if (project.mode === "credits") { - const envResult = getProviderEnv("openai"); + const envResult = getProviderEnv("openai", { + selectionScope: upstreamModel, + }); usedToken = envResult.token; configIndex = envResult.configIndex; envVarName = envResult.envVarName; @@ -404,12 +406,14 @@ moderations.openapi(createModeration, async (c): Promise => { providerKey = await findProviderKey( project.organizationId, "openai", - requestId, + upstreamModel, ); if (providerKey) { usedToken = providerKey.token; } else { - const envResult = getProviderEnv("openai"); + const envResult = getProviderEnv("openai", { + selectionScope: upstreamModel, + }); usedToken = envResult.token; configIndex = envResult.configIndex; envVarName = envResult.envVarName; diff --git a/apps/gateway/src/test-utils/mock-openai-server.ts b/apps/gateway/src/test-utils/mock-openai-server.ts index 429ec885f6..ce991489bc 100644 --- a/apps/gateway/src/test-utils/mock-openai-server.ts +++ b/apps/gateway/src/test-utils/mock-openai-server.ts @@ -767,6 +767,20 @@ mockOpenAIServer.post("/v1/chat/completions", async (c) => { }); } // Subsequent requests succeed - fall through to normal response + } else if (userMessage.includes("TRIGGER_FAIL_ONCE_INVALID_KEY")) { + failOnceCounter++; + if (failOnceCounter === 1) { + c.status(400); + return c.json({ + error: { + message: "API key not valid. Please pass a valid API key.", + type: "authentication_error", + param: null, + code: "invalid_api_key", + }, + }); + } + // Subsequent requests succeed - fall through to normal response } else if (userMessage.includes("TRIGGER_FAIL_ONCE_403")) { failOnceCounter++; if (failOnceCounter === 1) { diff --git a/apps/gateway/src/videos/videos.ts b/apps/gateway/src/videos/videos.ts index 4b8ccb2a08..ac3909ae60 100644 --- a/apps/gateway/src/videos/videos.ts +++ b/apps/gateway/src/videos/videos.ts @@ -1055,7 +1055,8 @@ async function resolveProviderContext( providerId: Provider, project: InferSelectModel, organizationId: string, - selectionKey: string, + requestId: string, + selectionScope: string, ): Promise { const defaultBaseUrl = getDefaultVideoProviderBaseUrl(providerId); const sharedVertexProjectId = isGoogleVertexVideoProvider(providerId) @@ -1070,7 +1071,7 @@ async function resolveProviderContext( const providerKey = await findProviderKey( organizationId, providerId, - selectionKey, + selectionScope, undefined, getVideoProviderKeyFilter(providerId), ); @@ -1100,7 +1101,7 @@ async function resolveProviderContext( providerId, baseUrl, token: providerKey.token, - requestId: selectionKey, + requestId, usedMode: "api-keys", configIndex: null, vertexProjectId: sharedVertexProjectId, @@ -1117,6 +1118,7 @@ async function resolveProviderContext( if (project.mode === "credits") { const env = getProviderEnv(providerId, { excludedIndices: getVideoExcludedConfigIndices(providerId), + selectionScope, }); const baseUrl = getProviderEnvValue(providerId, "baseUrl", env.configIndex) ?? @@ -1149,7 +1151,7 @@ async function resolveProviderContext( providerId, baseUrl, token: env.token, - requestId: selectionKey, + requestId, usedMode: "credits", configIndex: env.configIndex, vertexProjectId, @@ -1170,7 +1172,7 @@ async function resolveProviderContext( const providerKey = await findProviderKey( organizationId, providerId, - selectionKey, + selectionScope, undefined, getVideoProviderKeyFilter(providerId), ); @@ -1195,7 +1197,7 @@ async function resolveProviderContext( providerId, baseUrl, token: providerKey.token, - requestId: selectionKey, + requestId, usedMode: "api-keys", configIndex: null, vertexProjectId: sharedVertexProjectId, @@ -1249,7 +1251,7 @@ async function resolveProviderContext( providerId, baseUrl, token: env.token, - requestId: selectionKey, + requestId, usedMode: "credits", configIndex: env.configIndex, vertexProjectId, @@ -1609,6 +1611,7 @@ async function resolveVideoExecution( project, organizationId, requestId, + modelInfo.id, ); return { providerMapping, @@ -2131,7 +2134,7 @@ async function resolveVideoJobProviderContext(job: VideoJobRecord): Promise<{ const providerKey = await findProviderKey( job.organizationId, providerId, - job.requestId, + job.usedModel, undefined, getVideoProviderKeyFilter(providerId), ); @@ -2161,6 +2164,7 @@ async function resolveVideoJobProviderContext(job: VideoJobRecord): Promise<{ const env = getProviderEnv(providerId, { excludedIndices: getVideoExcludedConfigIndices(providerId), + selectionScope: job.usedModel, }); const baseUrl = getProviderEnvValue(providerId, "baseUrl", env.configIndex) ?? @@ -3164,6 +3168,7 @@ videos.openapi(createVideo, async (c) => { project, organization.id, requestId, + modelInfo.id, ); selectedUpstreamModelName = getVideoUpstreamModelName( nextMapping.providerId as Provider, @@ -3219,6 +3224,7 @@ videos.openapi(createVideo, async (c) => { project, organization.id, requestId, + modelInfo.id, ); selectedUpstreamModelName = getVideoUpstreamModelName( nextMapping.providerId as Provider, @@ -3315,6 +3321,7 @@ videos.openapi(createVideo, async (c) => { project, organization.id, requestId, + modelInfo.id, ); selectedUpstreamModelName = getVideoUpstreamModelName( nextMapping.providerId as Provider, diff --git a/apps/worker/src/services/stats-calculator.spec.ts b/apps/worker/src/services/stats-calculator.spec.ts index a4e9831f79..a34a095684 100644 --- a/apps/worker/src/services/stats-calculator.spec.ts +++ b/apps/worker/src/services/stats-calculator.spec.ts @@ -403,6 +403,360 @@ describe("stats-calculator", () => { expect(beijingHistory?.totalCost).toBeCloseTo(0.21); }); + it("should exclude same-provider recovered retries from health stats", async () => { + const previousMinuteStart = new Date("2024-01-01T12:29:00.000Z"); + + await db.insert(modelProviderMapping).values({ + id: "mapping-3", + modelId: "gpt-4", + providerId: "anthropic", + modelName: "gpt-4-on-anthropic", + status: "active", + }); + + await db.insert(log).values([ + { + id: "log-same-provider-failed", + requestId: "req-same-provider", + organizationId: "org-1", + projectId: "proj-1", + apiKeyId: "key-1", + duration: 600, + requestedModel: "gpt-4", + requestedProvider: "openai", + usedModel: "openai/gpt-4", + usedProvider: "openai", + responseSize: 0, + hasError: true, + unifiedFinishReason: "upstream_error", + mode: "api-keys", + usedMode: "api-keys", + retried: true, + retriedByLogId: "log-same-provider-success", + createdAt: new Date(previousMinuteStart.getTime() + 5000), + }, + { + id: "log-same-provider-success", + requestId: "req-same-provider", + organizationId: "org-1", + projectId: "proj-1", + apiKeyId: "key-1", + duration: 1000, + requestedModel: "gpt-4", + requestedProvider: "openai", + usedModel: "openai/gpt-4", + usedProvider: "openai", + responseSize: 120, + hasError: false, + promptTokens: "80", + completionTokens: "100", + totalTokens: "180", + unifiedFinishReason: "completed", + mode: "api-keys", + usedMode: "api-keys", + createdAt: new Date(previousMinuteStart.getTime() + 10000), + }, + { + id: "log-provider-fallback-failed", + requestId: "req-provider-fallback", + organizationId: "org-1", + projectId: "proj-1", + apiKeyId: "key-1", + duration: 700, + requestedModel: "gpt-4", + requestedProvider: "openai", + usedModel: "openai/gpt-4", + usedProvider: "openai", + responseSize: 0, + hasError: true, + unifiedFinishReason: "upstream_error", + mode: "api-keys", + usedMode: "api-keys", + retried: true, + retriedByLogId: "log-provider-fallback-success", + createdAt: new Date(previousMinuteStart.getTime() + 15000), + }, + { + id: "log-provider-fallback-success", + requestId: "req-provider-fallback", + organizationId: "org-1", + projectId: "proj-1", + apiKeyId: "key-1", + duration: 900, + requestedModel: "gpt-4", + requestedProvider: "openai", + usedModel: "anthropic/gpt-4", + usedProvider: "anthropic", + responseSize: 140, + hasError: false, + promptTokens: "70", + completionTokens: "120", + totalTokens: "190", + unifiedFinishReason: "completed", + mode: "api-keys", + usedMode: "api-keys", + createdAt: new Date(previousMinuteStart.getTime() + 20000), + }, + ]); + + await calculateMinutelyHistory(); + await calculateAggregatedStatistics(); + + const mappingHistoryRecords = await db + .select() + .from(modelProviderMappingHistory) + .where( + eq(modelProviderMappingHistory.minuteTimestamp, previousMinuteStart), + ); + const openaiHistory = mappingHistoryRecords.find( + (record) => record.modelProviderMappingId === "mapping-1", + ); + const anthropicHistory = mappingHistoryRecords.find( + (record) => record.modelProviderMappingId === "mapping-3", + ); + + expect(openaiHistory?.logsCount).toBe(2); + expect(openaiHistory?.errorsCount).toBe(1); + expect(anthropicHistory?.logsCount).toBe(1); + expect(anthropicHistory?.errorsCount).toBe(0); + + const gpt4ModelHistory = (await db.select().from(modelHistory)).find( + (record) => + record.modelId === "gpt-4" && + record.minuteTimestamp.getTime() === previousMinuteStart.getTime(), + ); + expect(gpt4ModelHistory?.logsCount).toBe(3); + expect(gpt4ModelHistory?.errorsCount).toBe(1); + + const mappings = await db + .select() + .from(modelProviderMapping) + .where(eq(modelProviderMapping.modelId, "gpt-4")); + const openaiMapping = mappings.find( + (mapping) => mapping.id === "mapping-1", + ); + const anthropicMapping = mappings.find( + (mapping) => mapping.id === "mapping-3", + ); + + expect(openaiMapping?.routingUptime).toBeCloseTo(50); + expect(openaiMapping?.routingTotalRequests).toBe(2); + expect(anthropicMapping?.routingUptime).toBeCloseTo(100); + expect(anthropicMapping?.routingTotalRequests).toBe(1); + }); + + it("should keep failed regional attempts in mapping stats when recovery switches regions", async () => { + const previousMinuteStart = new Date("2024-01-01T12:29:00.000Z"); + + await db.insert(provider).values({ + id: "alibaba", + name: "Alibaba", + description: "Alibaba provider", + streaming: true, + cancellation: false, + color: "#ff6a00", + website: "https://www.alibabacloud.com", + status: "active", + }); + + await db.insert(model).values({ + id: "deepseek-v3.2", + name: "DeepSeek V3.2", + family: "deepseek", + status: "active", + }); + + await db.insert(modelProviderMapping).values([ + { + id: "mapping-aggregate-region-retry", + modelId: "deepseek-v3.2", + providerId: "alibaba", + modelName: "deepseek-v3.2", + status: "active", + }, + { + id: "mapping-region-singapore", + modelId: "deepseek-v3.2", + providerId: "alibaba", + modelName: "deepseek-v3.2:singapore", + region: "singapore", + status: "active", + }, + { + id: "mapping-region-beijing", + modelId: "deepseek-v3.2", + providerId: "alibaba", + modelName: "deepseek-v3.2:cn-beijing", + region: "cn-beijing", + status: "active", + }, + ]); + + await db.insert(log).values([ + { + id: "log-region-retry-failed", + requestId: "req-region-retry", + organizationId: "org-1", + projectId: "proj-1", + apiKeyId: "key-1", + duration: 600, + requestedModel: "alibaba/deepseek-v3.2", + requestedProvider: "alibaba", + usedModel: "alibaba/deepseek-v3.2:singapore", + usedProvider: "alibaba", + responseSize: 0, + hasError: true, + unifiedFinishReason: "upstream_error", + mode: "api-keys", + usedMode: "api-keys", + retried: true, + retriedByLogId: "log-region-retry-success", + createdAt: new Date(previousMinuteStart.getTime() + 5000), + }, + { + id: "log-region-retry-success", + requestId: "req-region-retry", + organizationId: "org-1", + projectId: "proj-1", + apiKeyId: "key-1", + duration: 900, + requestedModel: "alibaba/deepseek-v3.2", + requestedProvider: "alibaba", + usedModel: "alibaba/deepseek-v3.2:cn-beijing", + usedProvider: "alibaba", + responseSize: 140, + hasError: false, + promptTokens: "70", + completionTokens: "120", + totalTokens: "190", + unifiedFinishReason: "completed", + mode: "api-keys", + usedMode: "api-keys", + createdAt: new Date(previousMinuteStart.getTime() + 10000), + }, + ]); + + await calculateMinutelyHistory(); + await calculateAggregatedStatistics(); + + const mappingHistoryRecords = await db + .select() + .from(modelProviderMappingHistory) + .where( + eq(modelProviderMappingHistory.minuteTimestamp, previousMinuteStart), + ); + const aggregateHistory = mappingHistoryRecords.find( + (record) => + record.modelProviderMappingId === "mapping-aggregate-region-retry", + ); + const singaporeHistory = mappingHistoryRecords.find( + (record) => + record.modelProviderMappingId === "mapping-region-singapore", + ); + const beijingHistory = mappingHistoryRecords.find( + (record) => record.modelProviderMappingId === "mapping-region-beijing", + ); + + expect(aggregateHistory?.logsCount).toBe(2); + expect(aggregateHistory?.errorsCount).toBe(1); + expect(singaporeHistory?.logsCount).toBe(1); + expect(singaporeHistory?.errorsCount).toBe(1); + expect(beijingHistory?.logsCount).toBe(1); + expect(beijingHistory?.errorsCount).toBe(0); + + const regionalMappings = await db + .select() + .from(modelProviderMapping) + .where(eq(modelProviderMapping.modelId, "deepseek-v3.2")); + const singaporeMapping = regionalMappings.find( + (mapping) => mapping.id === "mapping-region-singapore", + ); + const beijingMapping = regionalMappings.find( + (mapping) => mapping.id === "mapping-region-beijing", + ); + + expect(singaporeMapping?.routingUptime).toBeCloseTo(0); + expect(singaporeMapping?.routingTotalRequests).toBe(1); + expect(beijingMapping?.routingUptime).toBeCloseTo(100); + expect(beijingMapping?.routingTotalRequests).toBe(1); + + const deepseekModelHistory = (await db.select().from(modelHistory)).find( + (record) => + record.modelId === "deepseek-v3.2" && + record.minuteTimestamp.getTime() === previousMinuteStart.getTime(), + ); + expect(deepseekModelHistory?.logsCount).toBe(2); + expect(deepseekModelHistory?.errorsCount).toBe(1); + }); + + it("should keep failed attempts when the same-provider retry also failed", async () => { + const previousMinuteStart = new Date("2024-01-01T12:29:00.000Z"); + + await db.insert(log).values([ + { + id: "log-same-provider-failed-1", + requestId: "req-same-provider-all-failed", + organizationId: "org-1", + projectId: "proj-1", + apiKeyId: "key-1", + duration: 600, + requestedModel: "gpt-4", + requestedProvider: "openai", + usedModel: "openai/gpt-4", + usedProvider: "openai", + responseSize: 0, + hasError: true, + unifiedFinishReason: "upstream_error", + mode: "api-keys", + usedMode: "api-keys", + retried: true, + retriedByLogId: "log-same-provider-failed-2", + createdAt: new Date(previousMinuteStart.getTime() + 5000), + }, + { + id: "log-same-provider-failed-2", + requestId: "req-same-provider-all-failed", + organizationId: "org-1", + projectId: "proj-1", + apiKeyId: "key-1", + duration: 700, + requestedModel: "gpt-4", + requestedProvider: "openai", + usedModel: "openai/gpt-4", + usedProvider: "openai", + responseSize: 0, + hasError: true, + unifiedFinishReason: "upstream_error", + mode: "api-keys", + usedMode: "api-keys", + createdAt: new Date(previousMinuteStart.getTime() + 10000), + }, + ]); + + await calculateMinutelyHistory(); + + const mappingHistoryRecords = await db + .select() + .from(modelProviderMappingHistory) + .where( + eq(modelProviderMappingHistory.minuteTimestamp, previousMinuteStart), + ); + const openaiHistory = mappingHistoryRecords.find( + (record) => record.modelProviderMappingId === "mapping-1", + ); + + expect(openaiHistory?.logsCount).toBe(2); + expect(openaiHistory?.errorsCount).toBe(2); + + const gpt4ModelHistory = (await db.select().from(modelHistory)).find( + (record) => + record.modelId === "gpt-4" && + record.minuteTimestamp.getTime() === previousMinuteStart.getTime(), + ); + expect(gpt4ModelHistory?.logsCount).toBe(2); + expect(gpt4ModelHistory?.errorsCount).toBe(2); + }); + it("should handle cached requests correctly by ignoring tokens but counting requests", async () => { const previousMinuteStart = new Date("2024-01-01T12:29:00.000Z"); diff --git a/apps/worker/src/services/stats-calculator.ts b/apps/worker/src/services/stats-calculator.ts index 65115ca2fe..0b7f05adca 100644 --- a/apps/worker/src/services/stats-calculator.ts +++ b/apps/worker/src/services/stats-calculator.ts @@ -26,6 +26,24 @@ const usedRegionSql = sql< string | null >`nullif(split_part(${usedModelWithRegionSql}, ':', 2), '')`; +function excludeRecoveredSameProviderRegionRetry() { + return sql`not ( + coalesce(${log.hasError}, false) = true + and coalesce(${log.retried}, false) = true + and exists ( + select 1 + from "log" as final_retry_log + where final_retry_log.id = ${log.retriedByLogId} + and final_retry_log.used_provider = ${log.usedProvider} + and coalesce(final_retry_log.has_error, false) = false + and nullif( + split_part(split_part(final_retry_log.used_model, '/', 2), ':', 2), + '' + ) is not distinct from ${usedRegionSql} + ) + )`; +} + interface MappingMinuteStats { modelId: string | null; providerId: string | null; @@ -202,6 +220,7 @@ async function calculateModelHistoryForMinute(targetMinute: Date) { and( gte(log.createdAt, roundedTargetMinute), lt(log.createdAt, minuteEnd), + excludeRecoveredSameProviderRegionRetry(), ), ) .groupBy(usedBaseModelSql); @@ -380,6 +399,7 @@ async function calculateHistoryForMinute(targetMinute: Date) { and( gte(log.createdAt, roundedTargetMinute), lt(log.createdAt, minuteEnd), + excludeRecoveredSameProviderRegionRetry(), ), ) .groupBy(usedBaseModelSql, log.usedProvider, usedRegionSql); diff --git a/packages/models/src/provider.ts b/packages/models/src/provider.ts index c9b955156f..9948e6082d 100644 --- a/packages/models/src/provider.ts +++ b/packages/models/src/provider.ts @@ -126,6 +126,25 @@ export function getRegionSpecificEnvValue( ); } +/** + * Get the region-specific env var name only when that var is actually set. + * Returns `{BASE_ENV_VAR}__{REGION}` when the regional override exists, else + * undefined. Use this when you need to attribute health to the regional + * credential rather than the base env var. + */ +export function getRegionSpecificEnvVarName( + provider: Provider, + region: string, +): string | undefined { + const baseEnvVar = getProviderEnvVar(provider); + if (!baseEnvVar) { + return undefined; + } + const regionSuffix = region.toUpperCase().replace(/-/g, "_"); + const regionalName = `${baseEnvVar}__${regionSuffix}`; + return process.env[regionalName] ? regionalName : undefined; +} + /** * Check whether an env var exists for a specific region. * Returns true if a region-specific env var (`{BASE_ENV_VAR}__{REGION}`) exists,