diff --git a/apps/gateway/src/api.spec.ts b/apps/gateway/src/api.spec.ts index e57ef00607..804141c9b0 100644 --- a/apps/gateway/src/api.spec.ts +++ b/apps/gateway/src/api.spec.ts @@ -1441,6 +1441,671 @@ describe("api", () => { } }); + test("/v1/chat/completions blocks with custom content filter method", async () => { + await db.insert(tables.apiKey).values({ + id: "token-id", + token: "real-token", + projectId: "project-id", + description: "Test API Key", + createdBy: "user-id", + }); + + await db.insert(tables.providerKey).values({ + id: "provider-key-id", + token: "sk-test-key", + provider: "llmgateway", + organizationId: "org-id", + baseUrl: mockServerUrl, + }); + + const previousContentFilterMode = process.env.LLM_CONTENT_FILTER_MODE; + const previousContentFilterMethod = process.env.LLM_CONTENT_FILTER_METHOD; + const previousContentFilterModels = process.env.LLM_CONTENT_FILTER_MODELS; + const previousCustomApiKey = process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY; + const previousCustomBaseUrl = + process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL; + const previousCustomModel = process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL; + const previousGatewayUrl = process.env.GATEWAY_URL; + const requestId = "chat-custom-content-filter-request-id"; + const debugSpy = vi.spyOn(logger, "debug").mockImplementation(() => {}); + const originalFetch = globalThis.fetch; + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockImplementation(async (input, init) => { + const url = + typeof input === "string" + ? input + : input instanceof URL + ? input.toString() + : input.url; + const headers = new Headers(init?.headers); + + if ( + url === `${mockServerUrl}/v1/chat/completions` && + headers.get("authorization") === "Bearer llmgateway-moderation-key" + ) { + expect(headers.get("x-client-request-id")).toBe(requestId); + + const body = JSON.parse(String(init?.body ?? "{}")); + expect(body.model).toBe("openai/gpt-5-mini"); + expect(body.response_format).toEqual({ + type: "json_schema", + json_schema: expect.objectContaining({ + name: "gateway_content_filter", + strict: true, + }), + }); + expect(body.plugins).toEqual([{ id: "response-healing" }]); + expect(body.messages[1]?.content).toContain( + "I want to attack someone.", + ); + + return new Response( + JSON.stringify({ + id: "chatcmpl-moderation", + model: "openai/gpt-5-mini", + choices: [ + { + message: { + content: JSON.stringify({ + flagged: true, + categories: { + violence: true, + }, + category_scores: { + violence: 0.98, + }, + reason: "Explicit violent intent.", + }), + }, + }, + ], + }), + { + status: 200, + headers: { + "Content-Type": "application/json", + "x-request-id": "upstream-custom-request-id", + }, + }, + ); + } + + return await originalFetch(input as RequestInfo | URL, init); + }); + + try { + process.env.LLM_CONTENT_FILTER_MODE = "enabled"; + process.env.LLM_CONTENT_FILTER_METHOD = "custom"; + process.env.LLM_CONTENT_FILTER_MODELS = "custom"; + process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY = + "llmgateway-moderation-key"; + process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL = mockServerUrl; + process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL = "openai/gpt-5-mini"; + process.env.GATEWAY_URL = mockServerUrl; + + const res = await app.request("/v1/chat/completions", { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: "Bearer real-token", + "x-request-id": requestId, + }, + body: JSON.stringify({ + model: "llmgateway/custom", + messages: [ + { + role: "user", + content: "I want to attack someone.", + }, + ], + }), + }); + + expect(res.status).toBe(200); + + const json = await res.json(); + expect(json.choices[0].message.content).toBeNull(); + expect(json.choices[0].finish_reason).toBe("content_filter"); + + expect(debugSpy).toHaveBeenCalledWith( + "gateway_content_filter", + expect.objectContaining({ + durationMs: expect.any(Number), + mode: "custom", + requestId, + organizationId: "org-id", + projectId: "project-id", + apiKeyId: "token-id", + flagged: true, + model: "openai/gpt-5-mini", + upstreamRequestId: "upstream-custom-request-id", + }), + ); + + const logs = await waitForLogs(1); + const blockedLog = logs.find((log) => log.requestId === requestId); + + expect(blockedLog).toBeTruthy(); + expect(blockedLog?.finishReason).toBe("llmgateway_content_filter"); + expect(blockedLog?.unifiedFinishReason).toBe("content_filter"); + expect(blockedLog?.gatewayContentFilterResponse).toEqual([ + { + id: "chatcmpl-moderation", + model: "openai/gpt-5-mini", + results: [ + { + flagged: true, + categories: { + violence: true, + }, + category_scores: { + violence: 0.98, + }, + reason: "Explicit violent intent.", + }, + ], + }, + ]); + } finally { + fetchSpy.mockRestore(); + debugSpy.mockRestore(); + if (previousContentFilterMode === undefined) { + delete process.env.LLM_CONTENT_FILTER_MODE; + } else { + process.env.LLM_CONTENT_FILTER_MODE = previousContentFilterMode; + } + if (previousContentFilterMethod === undefined) { + delete process.env.LLM_CONTENT_FILTER_METHOD; + } else { + process.env.LLM_CONTENT_FILTER_METHOD = previousContentFilterMethod; + } + if (previousContentFilterModels === undefined) { + delete process.env.LLM_CONTENT_FILTER_MODELS; + } else { + process.env.LLM_CONTENT_FILTER_MODELS = previousContentFilterModels; + } + if (previousCustomApiKey === undefined) { + delete process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY; + } else { + process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY = previousCustomApiKey; + } + if (previousCustomBaseUrl === undefined) { + delete process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL; + } else { + process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL = previousCustomBaseUrl; + } + if (previousCustomModel === undefined) { + delete process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL; + } else { + process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL = previousCustomModel; + } + if (previousGatewayUrl === undefined) { + delete process.env.GATEWAY_URL; + } else { + process.env.GATEWAY_URL = previousGatewayUrl; + } + } + }); + + test("/v1/chat/completions monitors with custom content filter method", async () => { + await db.insert(tables.apiKey).values({ + id: "token-id", + token: "real-token", + projectId: "project-id", + description: "Test API Key", + createdBy: "user-id", + }); + + await db.insert(tables.providerKey).values({ + id: "provider-key-id", + token: "sk-test-key", + provider: "llmgateway", + organizationId: "org-id", + baseUrl: mockServerUrl, + }); + + const previousContentFilterMode = process.env.LLM_CONTENT_FILTER_MODE; + const previousContentFilterMethod = process.env.LLM_CONTENT_FILTER_METHOD; + const previousContentFilterModels = process.env.LLM_CONTENT_FILTER_MODELS; + const previousCustomApiKey = process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY; + const previousCustomBaseUrl = + process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL; + const previousCustomModel = process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL; + const previousGatewayUrl = process.env.GATEWAY_URL; + const requestId = "chat-custom-content-filter-monitor-request-id"; + const debugSpy = vi.spyOn(logger, "debug").mockImplementation(() => {}); + const originalFetch = globalThis.fetch; + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockImplementation(async (input, init) => { + const url = + typeof input === "string" + ? input + : input instanceof URL + ? input.toString() + : input.url; + const headers = new Headers(init?.headers); + + if ( + url === `${mockServerUrl}/v1/chat/completions` && + headers.get("authorization") === "Bearer llmgateway-moderation-key" + ) { + return new Response( + JSON.stringify({ + id: "chatcmpl-moderation", + model: "openai/gpt-5-mini", + choices: [ + { + message: { + content: JSON.stringify({ + flagged: true, + categories: { + violence: true, + }, + category_scores: { + violence: 0.98, + }, + reason: "Explicit violent intent.", + }), + }, + }, + ], + }), + { + status: 200, + headers: { + "Content-Type": "application/json", + "x-request-id": "upstream-custom-request-id", + }, + }, + ); + } + + return await originalFetch(input as RequestInfo | URL, init); + }); + + try { + process.env.LLM_CONTENT_FILTER_MODE = "monitor"; + process.env.LLM_CONTENT_FILTER_METHOD = "custom"; + process.env.LLM_CONTENT_FILTER_MODELS = "custom"; + process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY = + "llmgateway-moderation-key"; + process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL = mockServerUrl; + process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL = "openai/gpt-5-mini"; + process.env.GATEWAY_URL = mockServerUrl; + + const res = await app.request("/v1/chat/completions", { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: "Bearer real-token", + "x-request-id": requestId, + }, + body: JSON.stringify({ + model: "llmgateway/custom", + messages: [ + { + role: "user", + content: "I want to attack someone.", + }, + ], + }), + }); + + expect(res.status).toBe(200); + + const json = await res.json(); + expect(json.choices[0].message.content).toContain( + "I want to attack someone.", + ); + + expect(debugSpy).toHaveBeenCalledWith( + "gateway_content_filter", + expect.objectContaining({ + durationMs: expect.any(Number), + mode: "custom", + requestId, + flagged: true, + }), + ); + + const logs = await waitForLogs(1); + const completedLog = logs.find((log) => log.requestId === requestId); + + expect(completedLog).toBeTruthy(); + expect(completedLog?.finishReason).toBe("stop"); + expect(completedLog?.internalContentFilter).toBe(true); + expect(completedLog?.gatewayContentFilterResponse).toEqual([ + { + id: "chatcmpl-moderation", + model: "openai/gpt-5-mini", + results: [ + { + flagged: true, + categories: { + violence: true, + }, + category_scores: { + violence: 0.98, + }, + reason: "Explicit violent intent.", + }, + ], + }, + ]); + } finally { + fetchSpy.mockRestore(); + debugSpy.mockRestore(); + if (previousContentFilterMode === undefined) { + delete process.env.LLM_CONTENT_FILTER_MODE; + } else { + process.env.LLM_CONTENT_FILTER_MODE = previousContentFilterMode; + } + if (previousContentFilterMethod === undefined) { + delete process.env.LLM_CONTENT_FILTER_METHOD; + } else { + process.env.LLM_CONTENT_FILTER_METHOD = previousContentFilterMethod; + } + if (previousContentFilterModels === undefined) { + delete process.env.LLM_CONTENT_FILTER_MODELS; + } else { + process.env.LLM_CONTENT_FILTER_MODELS = previousContentFilterModels; + } + if (previousCustomApiKey === undefined) { + delete process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY; + } else { + process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY = previousCustomApiKey; + } + if (previousCustomBaseUrl === undefined) { + delete process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL; + } else { + process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL = previousCustomBaseUrl; + } + if (previousCustomModel === undefined) { + delete process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL; + } else { + process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL = previousCustomModel; + } + if (previousGatewayUrl === undefined) { + delete process.env.GATEWAY_URL; + } else { + process.env.GATEWAY_URL = previousGatewayUrl; + } + } + }); + + test("/v1/chat/completions ignores missing custom moderation config", async () => { + await db.insert(tables.apiKey).values({ + id: "token-id", + token: "real-token", + projectId: "project-id", + description: "Test API Key", + createdBy: "user-id", + }); + + await db.insert(tables.providerKey).values({ + id: "provider-key-id", + token: "sk-test-key", + provider: "llmgateway", + organizationId: "org-id", + baseUrl: mockServerUrl, + }); + + const previousContentFilterMode = process.env.LLM_CONTENT_FILTER_MODE; + const previousContentFilterMethod = process.env.LLM_CONTENT_FILTER_METHOD; + const previousContentFilterModels = process.env.LLM_CONTENT_FILTER_MODELS; + const previousCustomApiKey = process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY; + const previousCustomBaseUrl = + process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL; + const previousCustomModel = process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL; + const previousGatewayUrl = process.env.GATEWAY_URL; + const requestId = "chat-custom-content-filter-missing-config-request-id"; + const errorSpy = vi.spyOn(logger, "error").mockImplementation(() => {}); + + try { + process.env.LLM_CONTENT_FILTER_MODE = "enabled"; + process.env.LLM_CONTENT_FILTER_METHOD = "custom"; + process.env.LLM_CONTENT_FILTER_MODELS = "custom"; + delete process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY; + process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL = mockServerUrl; + delete process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL; + process.env.GATEWAY_URL = mockServerUrl; + + const res = await app.request("/v1/chat/completions", { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: "Bearer real-token", + "x-request-id": requestId, + }, + body: JSON.stringify({ + model: "llmgateway/custom", + messages: [ + { + role: "user", + content: "Hello!", + }, + ], + }), + }); + + expect(res.status).toBe(200); + + const json = await res.json(); + expect(json.choices[0].message.content).toContain("Hello!"); + + expect(errorSpy).toHaveBeenCalledWith( + "gateway_content_filter_error", + expect.objectContaining({ + mode: "custom", + requestId, + organizationId: "org-id", + projectId: "project-id", + apiKeyId: "token-id", + error: + "LLM_CONTENT_FILTER_CUSTOM_API_KEY environment variable is required for custom content filter", + }), + expect.any(Error), + ); + + const logs = await waitForLogs(1); + const completedLog = logs.find((log) => log.requestId === requestId); + + expect(completedLog).toBeTruthy(); + expect(completedLog?.finishReason).toBe("stop"); + expect(completedLog?.gatewayContentFilterResponse).toBeNull(); + } finally { + errorSpy.mockRestore(); + if (previousContentFilterMode === undefined) { + delete process.env.LLM_CONTENT_FILTER_MODE; + } else { + process.env.LLM_CONTENT_FILTER_MODE = previousContentFilterMode; + } + if (previousContentFilterMethod === undefined) { + delete process.env.LLM_CONTENT_FILTER_METHOD; + } else { + process.env.LLM_CONTENT_FILTER_METHOD = previousContentFilterMethod; + } + if (previousContentFilterModels === undefined) { + delete process.env.LLM_CONTENT_FILTER_MODELS; + } else { + process.env.LLM_CONTENT_FILTER_MODELS = previousContentFilterModels; + } + if (previousCustomApiKey === undefined) { + delete process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY; + } else { + process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY = previousCustomApiKey; + } + if (previousCustomBaseUrl === undefined) { + delete process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL; + } else { + process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL = previousCustomBaseUrl; + } + if (previousCustomModel === undefined) { + delete process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL; + } else { + process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL = previousCustomModel; + } + if (previousGatewayUrl === undefined) { + delete process.env.GATEWAY_URL; + } else { + process.env.GATEWAY_URL = previousGatewayUrl; + } + } + }); + + test("/v1/chat/completions skips custom content filter for non-targeted models", async () => { + await db.insert(tables.apiKey).values({ + id: "token-id", + token: "real-token", + projectId: "project-id", + description: "Test API Key", + createdBy: "user-id", + }); + + await db.insert(tables.providerKey).values({ + id: "provider-key-id", + token: "sk-test-key", + provider: "llmgateway", + organizationId: "org-id", + baseUrl: mockServerUrl, + }); + + const previousContentFilterMode = process.env.LLM_CONTENT_FILTER_MODE; + const previousContentFilterMethod = process.env.LLM_CONTENT_FILTER_METHOD; + const previousContentFilterModels = process.env.LLM_CONTENT_FILTER_MODELS; + const previousCustomApiKey = process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY; + const previousCustomBaseUrl = + process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL; + const previousCustomModel = process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL; + const previousGatewayUrl = process.env.GATEWAY_URL; + const requestId = "chat-custom-content-filter-model-skip-request-id"; + const debugSpy = vi.spyOn(logger, "debug").mockImplementation(() => {}); + const originalFetch = globalThis.fetch; + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockImplementation(async (input, init) => { + const url = + typeof input === "string" + ? input + : input instanceof URL + ? input.toString() + : input.url; + const headers = new Headers(init?.headers); + + if ( + url === `${mockServerUrl}/v1/chat/completions` && + headers.get("authorization") === "Bearer llmgateway-moderation-key" + ) { + throw new Error("custom moderation should not be called"); + } + + return await originalFetch(input as RequestInfo | URL, init); + }); + + try { + process.env.LLM_CONTENT_FILTER_MODE = "monitor"; + process.env.LLM_CONTENT_FILTER_METHOD = "custom"; + process.env.LLM_CONTENT_FILTER_MODELS = "gpt-4o-mini"; + process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY = + "llmgateway-moderation-key"; + process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL = mockServerUrl; + process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL = "openai/gpt-5-mini"; + process.env.GATEWAY_URL = mockServerUrl; + + const res = await app.request("/v1/chat/completions", { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: "Bearer real-token", + "x-request-id": requestId, + }, + body: JSON.stringify({ + model: "llmgateway/custom", + messages: [ + { + role: "user", + content: "I want to attack someone.", + }, + ], + }), + }); + + expect(res.status).toBe(200); + + const json = await res.json(); + expect(json.choices[0].message.content).toContain( + "I want to attack someone.", + ); + expect( + fetchSpy.mock.calls.some(([input, init]) => { + const url = + typeof input === "string" + ? input + : input instanceof URL + ? input.toString() + : input.url; + const headers = new Headers(init?.headers); + return ( + url === `${mockServerUrl}/v1/chat/completions` && + headers.get("authorization") === "Bearer llmgateway-moderation-key" + ); + }), + ).toBe(false); + expect(debugSpy).not.toHaveBeenCalledWith( + "gateway_content_filter", + expect.anything(), + ); + + const logs = await waitForLogs(1); + const completedLog = logs.find((log) => log.requestId === requestId); + + expect(completedLog).toBeTruthy(); + expect(completedLog?.finishReason).toBe("stop"); + expect(completedLog?.internalContentFilter).toBeNull(); + expect(completedLog?.gatewayContentFilterResponse).toBeNull(); + } finally { + fetchSpy.mockRestore(); + debugSpy.mockRestore(); + if (previousContentFilterMode === undefined) { + delete process.env.LLM_CONTENT_FILTER_MODE; + } else { + process.env.LLM_CONTENT_FILTER_MODE = previousContentFilterMode; + } + if (previousContentFilterMethod === undefined) { + delete process.env.LLM_CONTENT_FILTER_METHOD; + } else { + process.env.LLM_CONTENT_FILTER_METHOD = previousContentFilterMethod; + } + if (previousContentFilterModels === undefined) { + delete process.env.LLM_CONTENT_FILTER_MODELS; + } else { + process.env.LLM_CONTENT_FILTER_MODELS = previousContentFilterModels; + } + if (previousCustomApiKey === undefined) { + delete process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY; + } else { + process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY = previousCustomApiKey; + } + if (previousCustomBaseUrl === undefined) { + delete process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL; + } else { + process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL = previousCustomBaseUrl; + } + if (previousCustomModel === undefined) { + delete process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL; + } else { + process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL = previousCustomModel; + } + if (previousGatewayUrl === undefined) { + delete process.env.GATEWAY_URL; + } else { + process.env.GATEWAY_URL = previousGatewayUrl; + } + } + }); + test("Reasoning effort error for unsupported model", async () => { await db.insert(tables.apiKey).values({ id: "token-id", diff --git a/apps/gateway/src/chat/chat.ts b/apps/gateway/src/chat/chat.ts index de20ea3a5d..984c619c6f 100644 --- a/apps/gateway/src/chat/chat.ts +++ b/apps/gateway/src/chat/chat.ts @@ -107,6 +107,7 @@ import { import { convertImagesToBase64 } from "./tools/convert-images-to-base64.js"; import { countInputImages } from "./tools/count-input-images.js"; import { createLogEntry } from "./tools/create-log-entry.js"; +import { checkCustomContentFilter } from "./tools/custom-content-filter.js"; import { estimateTokensFromContent } from "./tools/estimate-tokens-from-content.js"; import { estimateTokens } from "./tools/estimate-tokens.js"; import { @@ -1909,9 +1910,23 @@ chat.openapi(completions, async (c) => { c.req.raw.signal, ) : null; + const customContentFilterResult = + shouldApplyGatewayContentFilter && contentFilterMethod === "custom" + ? await checkCustomContentFilter( + messages as BaseMessage[], + { + requestId, + organizationId: project.organizationId, + projectId: project.id, + apiKeyId: apiKey.id, + }, + c.req.raw.signal, + ) + : null; const contentFilterMatched = keywordContentFilterMatch !== null || - openAIContentFilterResult?.flagged === true; + openAIContentFilterResult?.flagged === true || + customContentFilterResult?.flagged === true; const shouldRerouteContentFilter = contentFilterMode === "enabled" && contentFilterMatched; let contentFilterRoutingExcludedProviders: ProviderModelMapping[] = []; @@ -3003,9 +3018,12 @@ chat.openapi(completions, async (c) => { const shouldTagContentFilter = (contentFilterMode === "monitor" && contentFilterMatched) || contentFilterRoutingApplied; - const gatewayContentFilterResponse = openAIContentFilterResult?.responses - .length - ? openAIContentFilterResult.responses + const gatewayContentFilterResponses = [ + ...(openAIContentFilterResult?.responses ?? []), + ...(customContentFilterResult?.responses ?? []), + ]; + const gatewayContentFilterResponse = gatewayContentFilterResponses.length + ? gatewayContentFilterResponses : null; const insertLog = ( logData: Parameters[0], diff --git a/apps/gateway/src/chat/tools/check-content-filter.spec.ts b/apps/gateway/src/chat/tools/check-content-filter.spec.ts index 950fb2d3d5..8578d857c5 100644 --- a/apps/gateway/src/chat/tools/check-content-filter.spec.ts +++ b/apps/gateway/src/chat/tools/check-content-filter.spec.ts @@ -2,6 +2,7 @@ import { describe, it, expect, afterEach } from "vitest"; import { checkContentFilter, + getCustomContentFilterConfig, getContentFilterModels, getContentFilterMethod, getContentFilterMode, @@ -208,6 +209,111 @@ describe("getContentFilterMethod", () => { delete process.env.LLM_CONTENT_FILTER_METHOD; expect(getContentFilterMethod()).toBe("openai"); }); + + it("returns custom when method is set to custom", () => { + process.env.LLM_CONTENT_FILTER_METHOD = "custom"; + expect(getContentFilterMethod()).toBe("custom"); + }); +}); + +describe("getCustomContentFilterConfig", () => { + const originalApiKey = process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY; + const originalModel = process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL; + const originalCustomBaseUrl = process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL; + const originalIncludeImages = + process.env.LLM_CONTENT_FILTER_CUSTOM_INCLUDE_IMAGES; + const originalGatewayUrl = process.env.GATEWAY_URL; + + afterEach(() => { + if (originalApiKey === undefined) { + delete process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY; + } else { + process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY = originalApiKey; + } + + if (originalModel === undefined) { + delete process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL; + } else { + process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL = originalModel; + } + + if (originalCustomBaseUrl === undefined) { + delete process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL; + } else { + process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL = originalCustomBaseUrl; + } + + if (originalIncludeImages === undefined) { + delete process.env.LLM_CONTENT_FILTER_CUSTOM_INCLUDE_IMAGES; + } else { + process.env.LLM_CONTENT_FILTER_CUSTOM_INCLUDE_IMAGES = + originalIncludeImages; + } + + if (originalGatewayUrl === undefined) { + delete process.env.GATEWAY_URL; + } else { + process.env.GATEWAY_URL = originalGatewayUrl; + } + }); + + it("returns the configured api key and model", () => { + process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY = "custom-key"; + process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL = "openai/gpt-5-mini"; + process.env.GATEWAY_URL = "https://gateway.example.com/v1"; + + expect(getCustomContentFilterConfig()).toEqual({ + apiKey: "custom-key", + model: "openai/gpt-5-mini", + baseUrl: "https://gateway.example.com", + includeImages: true, + }); + }); + + it("uses the custom base url override when configured", () => { + process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY = "custom-key"; + process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL = "openai/gpt-5-mini"; + process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL = + "https://moderation.example.com/internal/v1"; + + expect(getCustomContentFilterConfig()).toEqual({ + apiKey: "custom-key", + model: "openai/gpt-5-mini", + baseUrl: "https://moderation.example.com/internal", + includeImages: true, + }); + }); + + it("disables image inclusion when configured", () => { + process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY = "custom-key"; + process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL = "openai/gpt-5-mini"; + process.env.LLM_CONTENT_FILTER_CUSTOM_INCLUDE_IMAGES = "false"; + + expect(getCustomContentFilterConfig()).toEqual({ + apiKey: "custom-key", + model: "openai/gpt-5-mini", + baseUrl: "http://localhost:4001", + includeImages: false, + }); + }); + + it("throws when the custom api key is missing", () => { + delete process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY; + process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL = "openai/gpt-5-mini"; + + expect(() => getCustomContentFilterConfig()).toThrow( + "LLM_CONTENT_FILTER_CUSTOM_API_KEY environment variable is required for custom content filter", + ); + }); + + it("throws when the custom model is missing", () => { + process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY = "custom-key"; + delete process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL; + + expect(() => getCustomContentFilterConfig()).toThrow( + "LLM_CONTENT_FILTER_CUSTOM_MODEL environment variable is required for custom content filter", + ); + }); }); describe("getContentFilterModels", () => { diff --git a/apps/gateway/src/chat/tools/check-content-filter.ts b/apps/gateway/src/chat/tools/check-content-filter.ts index 5d931aeb32..276fb7bf24 100644 --- a/apps/gateway/src/chat/tools/check-content-filter.ts +++ b/apps/gateway/src/chat/tools/check-content-filter.ts @@ -1,7 +1,16 @@ +import { getGatewayPublicBaseUrl } from "@llmgateway/shared"; + import type { BaseMessage, MessageContent } from "@llmgateway/models"; export type ContentFilterMode = "disabled" | "monitor" | "enabled"; -export type ContentFilterMethod = "keywords" | "openai"; +export type ContentFilterMethod = "keywords" | "openai" | "custom"; + +export interface CustomContentFilterConfig { + apiKey: string; + model: string; + baseUrl: string; + includeImages: boolean; +} /** * Returns the content filter mode from LLM_CONTENT_FILTER_MODE env var. @@ -30,6 +39,7 @@ export function getContentFilterMode(): ContentFilterMode { * Returns the content filter method from LLM_CONTENT_FILTER_METHOD env var. * - "keywords" (default): use the configured keyword list * - "openai": use OpenAI's moderation endpoint + * - "custom": use LLMGateway chat completions with a moderation prompt * * Legacy compatibility: * - LLM_CONTENT_FILTER_MODE=openai implies method=openai @@ -42,9 +52,50 @@ export function getContentFilterMethod(): ContentFilterMethod { return "openai"; } + if (envValue === "custom") { + return "custom"; + } + return "keywords"; } +function getRequiredEnvValue(envVarName: string): string { + const envValue = process.env[envVarName]?.trim(); + + if (envValue) { + return envValue; + } + + throw new Error( + `${envVarName} environment variable is required for custom content filter`, + ); +} + +function normalizeCustomContentFilterBaseUrl(url: string): string { + return url.replace(/\/v1\/?$/, "").replace(/\/$/, ""); +} + +export function getCustomContentFilterBaseUrl(): string { + const configuredBaseUrl = + process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL?.trim(); + + return configuredBaseUrl + ? normalizeCustomContentFilterBaseUrl(configuredBaseUrl) + : getGatewayPublicBaseUrl(); +} + +export function getCustomContentFilterConfig(): CustomContentFilterConfig { + const includeImagesEnvValue = + process.env.LLM_CONTENT_FILTER_CUSTOM_INCLUDE_IMAGES?.trim().toLowerCase(); + + return { + apiKey: getRequiredEnvValue("LLM_CONTENT_FILTER_CUSTOM_API_KEY"), + model: getRequiredEnvValue("LLM_CONTENT_FILTER_CUSTOM_MODEL"), + baseUrl: getCustomContentFilterBaseUrl(), + includeImages: includeImagesEnvValue !== "false", + }; +} + /** * Returns the list of canonical model names that should have content filtering * applied from LLM_CONTENT_FILTER_MODELS. diff --git a/apps/gateway/src/chat/tools/custom-content-filter.spec.ts b/apps/gateway/src/chat/tools/custom-content-filter.spec.ts new file mode 100644 index 0000000000..a5cd9eb31e --- /dev/null +++ b/apps/gateway/src/chat/tools/custom-content-filter.spec.ts @@ -0,0 +1,394 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; + +import { logger } from "@llmgateway/logger"; + +import { checkCustomContentFilter } from "./custom-content-filter.js"; + +describe("checkCustomContentFilter", () => { + const originalApiKey = process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY; + const originalModel = process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL; + const originalCustomBaseUrl = process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL; + const originalIncludeImages = + process.env.LLM_CONTENT_FILTER_CUSTOM_INCLUDE_IMAGES; + const originalGatewayUrl = process.env.GATEWAY_URL; + + afterEach(() => { + vi.restoreAllMocks(); + + if (originalApiKey === undefined) { + delete process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY; + } else { + process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY = originalApiKey; + } + + if (originalModel === undefined) { + delete process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL; + } else { + process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL = originalModel; + } + + if (originalCustomBaseUrl === undefined) { + delete process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL; + } else { + process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL = originalCustomBaseUrl; + } + + if (originalIncludeImages === undefined) { + delete process.env.LLM_CONTENT_FILTER_CUSTOM_INCLUDE_IMAGES; + } else { + process.env.LLM_CONTENT_FILTER_CUSTOM_INCLUDE_IMAGES = + originalIncludeImages; + } + + if (originalGatewayUrl === undefined) { + delete process.env.GATEWAY_URL; + } else { + process.env.GATEWAY_URL = originalGatewayUrl; + } + }); + + it("calls llmgateway chat completions with the configured key and model", async () => { + process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY = "custom-api-key"; + process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL = "anthropic/claude-sonnet-4-5"; + process.env.GATEWAY_URL = "https://gateway.example.com/v1"; + + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockImplementation(async (input, init) => { + const url = + typeof input === "string" + ? input + : input instanceof URL + ? input.toString() + : input.url; + expect(url).toBe("https://gateway.example.com/v1/chat/completions"); + + const headers = new Headers(init?.headers); + expect(headers.get("authorization")).toBe("Bearer custom-api-key"); + expect(headers.get("x-client-request-id")).toBe("request-id"); + + const body = JSON.parse(String(init?.body ?? "{}")); + expect(body.model).toBe("anthropic/claude-sonnet-4-5"); + expect(body.temperature).toBe(0); + expect(body.max_tokens).toBe(300); + expect(body.response_format).toEqual({ + type: "json_schema", + json_schema: expect.objectContaining({ + name: "gateway_content_filter", + strict: true, + }), + }); + expect(body.plugins).toEqual([{ id: "response-healing" }]); + expect(body.messages[0]?.role).toBe("system"); + expect(body.messages[1]?.content).toContain( + "I want to attack someone.", + ); + expect(body.messages[1]?.content).toContain( + "Image references:\nremote-image: https://example.com/image.png", + ); + + return new Response( + JSON.stringify({ + id: "chatcmpl-moderation", + model: "anthropic/claude-sonnet-4-5", + choices: [ + { + message: { + content: JSON.stringify({ + flagged: true, + categories: { + violence: true, + }, + category_scores: { + violence: 0.93, + }, + reason: "Threatening violence.", + }), + }, + }, + ], + }), + { + status: 200, + headers: { + "Content-Type": "application/json", + "x-request-id": "upstream-custom-request-id", + }, + }, + ); + }); + + const result = await checkCustomContentFilter( + [ + { + role: "user", + content: [ + { + type: "text", + text: "I want to attack someone.", + }, + { + type: "image_url", + image_url: { + url: "https://example.com/image.png", + }, + }, + ], + }, + ], + { + requestId: "request-id", + organizationId: "org-id", + projectId: "project-id", + apiKeyId: "api-key-id", + }, + ); + + expect(fetchSpy).toHaveBeenCalledOnce(); + expect(result.flagged).toBe(true); + expect(result.model).toBe("anthropic/claude-sonnet-4-5"); + expect(result.upstreamRequestId).toBe("upstream-custom-request-id"); + expect(result.responses).toEqual([ + { + id: "chatcmpl-moderation", + model: "anthropic/claude-sonnet-4-5", + results: [ + { + flagged: true, + categories: { + violence: true, + }, + category_scores: { + violence: 0.93, + }, + reason: "Threatening violence.", + }, + ], + }, + ]); + }); + + it("uses the custom base url override when configured", async () => { + process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY = "custom-api-key"; + process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL = "anthropic/claude-sonnet-4-5"; + process.env.LLM_CONTENT_FILTER_CUSTOM_BASE_URL = + "https://moderation.example.com/internal/v1"; + process.env.GATEWAY_URL = "https://gateway.example.com/v1"; + + const fetchSpy = vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response( + JSON.stringify({ + id: "chatcmpl-moderation", + model: "anthropic/claude-sonnet-4-5", + choices: [ + { + message: { + content: JSON.stringify({ + flagged: false, + categories: { + violence: false, + }, + category_scores: { + violence: 0.02, + }, + reason: "Safe.", + }), + }, + }, + ], + }), + { + status: 200, + headers: { + "Content-Type": "application/json", + "x-request-id": "upstream-custom-request-id", + }, + }, + ), + ); + + await checkCustomContentFilter( + [ + { + role: "user", + content: "Hello!", + }, + ], + { + requestId: "request-id", + organizationId: "org-id", + projectId: "project-id", + apiKeyId: "api-key-id", + }, + ); + + expect(fetchSpy).toHaveBeenCalledWith( + "https://moderation.example.com/internal/v1/chat/completions", + expect.any(Object), + ); + }); + + it("omits image references when image inclusion is disabled", async () => { + process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY = "custom-api-key"; + process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL = "anthropic/claude-sonnet-4-5"; + process.env.LLM_CONTENT_FILTER_CUSTOM_INCLUDE_IMAGES = "false"; + process.env.GATEWAY_URL = "https://gateway.example.com/v1"; + + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockImplementation(async (_input, init) => { + const body = JSON.parse(String(init?.body ?? "{}")); + expect(body.messages[1]?.content).toContain("Hello!"); + expect(body.messages[1]?.content).not.toContain("Image references:"); + expect(body.messages[1]?.content).not.toContain( + "https://example.com/image.png", + ); + + return new Response( + JSON.stringify({ + id: "chatcmpl-moderation", + model: "anthropic/claude-sonnet-4-5", + choices: [ + { + message: { + content: JSON.stringify({ + flagged: false, + categories: { + violence: false, + }, + category_scores: { + violence: 0.01, + }, + reason: "Safe.", + }), + }, + }, + ], + }), + { + status: 200, + headers: { + "Content-Type": "application/json", + "x-request-id": "upstream-custom-request-id", + }, + }, + ); + }); + + await checkCustomContentFilter( + [ + { + role: "user", + content: [ + { + type: "text", + text: "Hello!", + }, + { + type: "image_url", + image_url: { + url: "https://example.com/image.png", + }, + }, + ], + }, + ], + { + requestId: "request-id", + organizationId: "org-id", + projectId: "project-id", + apiKeyId: "api-key-id", + }, + ); + + expect(fetchSpy).toHaveBeenCalledOnce(); + }); + + it("parses JSON verdicts wrapped in code fences", async () => { + process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY = "custom-api-key"; + process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL = "openai/gpt-5-mini"; + process.env.GATEWAY_URL = "https://gateway.example.com"; + + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response( + JSON.stringify({ + id: "chatcmpl-moderation", + model: "openai/gpt-5-mini", + choices: [ + { + message: { + content: + '```json\n{"flagged":false,"categories":{"violence":false},"category_scores":{"violence":0.12},"reason":"Safe."}\n```', + }, + }, + ], + }), + { + status: 200, + headers: { + "Content-Type": "application/json", + "x-request-id": "upstream-custom-request-id", + }, + }, + ), + ); + + const result = await checkCustomContentFilter( + [ + { + role: "user", + content: "Hello!", + }, + ], + { + requestId: "request-id", + organizationId: "org-id", + projectId: "project-id", + apiKeyId: "api-key-id", + }, + ); + + expect(result.flagged).toBe(false); + expect(result.responses[0]?.results?.[0]?.category_scores).toEqual({ + violence: 0.12, + }); + }); + + it("fails open when custom moderation config is missing", async () => { + delete process.env.LLM_CONTENT_FILTER_CUSTOM_API_KEY; + delete process.env.LLM_CONTENT_FILTER_CUSTOM_MODEL; + process.env.GATEWAY_URL = "https://gateway.example.com"; + + const fetchSpy = vi.spyOn(globalThis, "fetch"); + const errorSpy = vi.spyOn(logger, "error").mockImplementation(() => {}); + + const result = await checkCustomContentFilter( + [ + { + role: "user", + content: "Hello!", + }, + ], + { + requestId: "request-id", + organizationId: "org-id", + projectId: "project-id", + apiKeyId: "api-key-id", + }, + ); + + expect(fetchSpy).not.toHaveBeenCalled(); + expect(result.flagged).toBe(false); + expect(result.responses).toEqual([]); + expect(errorSpy).toHaveBeenCalledWith( + "gateway_content_filter_error", + expect.objectContaining({ + mode: "custom", + requestId: "request-id", + error: + "LLM_CONTENT_FILTER_CUSTOM_API_KEY environment variable is required for custom content filter", + }), + expect.any(Error), + ); + }); +}); diff --git a/apps/gateway/src/chat/tools/custom-content-filter.ts b/apps/gateway/src/chat/tools/custom-content-filter.ts new file mode 100644 index 0000000000..e19de73f35 --- /dev/null +++ b/apps/gateway/src/chat/tools/custom-content-filter.ts @@ -0,0 +1,655 @@ +import { z } from "zod"; + +import { isCancellationError, isTimeoutError } from "@/lib/timeout-config.js"; + +import { logger } from "@llmgateway/logger"; + +import { + getCustomContentFilterBaseUrl, + getCustomContentFilterConfig, + type CustomContentFilterConfig, +} from "./check-content-filter.js"; +import { extractErrorCause } from "./extract-error-cause.js"; +import { buildOpenAIContentFilterTextInput } from "./openai-content-filter.js"; + +import type { ModerationApiPayload } from "@llmgateway/db"; +import type { BaseMessage, MessageContent } from "@llmgateway/models"; + +interface GatewayContentFilterContext { + requestId: string; + organizationId: string; + projectId: string; + apiKeyId: string; +} + +interface ErrorWithCode extends Error { + code?: string; + cause?: unknown; +} + +export interface CustomContentFilterCheckResult { + flagged: boolean; + model: string; + upstreamRequestId: string | null; + results: ModerationApiPayload["results"]; + responses: ModerationApiPayload[]; +} + +const CUSTOM_CONTENT_FILTER_TIMEOUT_MS = 60_000; +const CUSTOM_CONTENT_FILTER_MAX_TOKENS = 300; +const CUSTOM_CONTENT_FILTER_CATEGORIES = [ + "sexual", + "sexual/minors", + "harassment", + "harassment/threatening", + "hate", + "hate/threatening", + "illicit", + "illicit/violent", + "self-harm", + "self-harm/intent", + "self-harm/instructions", + "violence", + "violence/graphic", +] as const; +const CUSTOM_CONTENT_FILTER_SYSTEM_PROMPT = `You are the safety moderation layer for an LLM gateway. + +Classify the supplied conversation against these categories: +${CUSTOM_CONTENT_FILTER_CATEGORIES.map((category) => `- ${category}`).join("\n")} + +Return only valid JSON with this exact shape: +{ + "flagged": boolean, + "categories": { + "sexual": boolean, + "sexual/minors": boolean, + "harassment": boolean, + "harassment/threatening": boolean, + "hate": boolean, + "hate/threatening": boolean, + "illicit": boolean, + "illicit/violent": boolean, + "self-harm": boolean, + "self-harm/intent": boolean, + "self-harm/instructions": boolean, + "violence": boolean, + "violence/graphic": boolean + }, + "category_scores": { + "": number + }, + "reason": string +} + +Rules: +- Scores must be numbers from 0 to 1. +- Set flagged to true when the content is unsafe enough to block. +- If evidence is insufficient, set flagged to false and keep scores low. +- Do not wrap the JSON in markdown.`; +const CUSTOM_CONTENT_FILTER_JSON_SCHEMA = { + name: "gateway_content_filter", + strict: true, + schema: { + type: "object", + additionalProperties: false, + required: ["flagged", "categories", "category_scores", "reason"], + properties: { + flagged: { + type: "boolean", + }, + categories: { + type: "object", + additionalProperties: { + type: "boolean", + }, + }, + category_scores: { + type: "object", + additionalProperties: { + type: "number", + }, + }, + reason: { + type: "string", + }, + }, + }, +} as const; + +const customContentFilterVerdictSchema = z + .object({ + flagged: z.boolean(), + categories: z.record(z.boolean()).optional(), + category_scores: z.record(z.number()).optional(), + reason: z.string().optional(), + }) + .passthrough(); + +const gatewayChatCompletionResponseSchema = z + .object({ + id: z.string().optional(), + model: z.string().optional(), + choices: z + .array( + z.object({ + message: z + .object({ + content: z + .union([ + z.string(), + z + .array( + z + .object({ + text: z.string().optional(), + }) + .passthrough(), + ) + .optional(), + ]) + .nullable() + .optional(), + }) + .passthrough() + .optional(), + }), + ) + .min(1), + }) + .passthrough(); + +function getCustomContentFilterUrl(baseUrl: string): string { + return new URL("v1/chat/completions", `${baseUrl}/`).toString(); +} + +function getImageReference(part: MessageContent): string | null { + if (part.type === "image_url") { + return `remote-image: ${part.image_url.url}`; + } + + if (part.type === "image") { + return `inline-image: media_type=${part.source.media_type}, bytes=${part.source.data.length}`; + } + + return null; +} + +function buildCustomContentFilterInput( + messages: BaseMessage[], + includeImages: boolean, +): string { + const textSummary = buildOpenAIContentFilterTextInput(messages); + const imageReferences: string[] = []; + + if (includeImages) { + for (const message of messages) { + if (!Array.isArray(message.content)) { + continue; + } + + for (const part of message.content) { + const imageReference = getImageReference(part); + if (imageReference) { + imageReferences.push(imageReference); + } + } + } + } + + const sections = [ + `Conversation:\n${textSummary.length > 0 ? textSummary : "[no text content]"}`, + ]; + + if (imageReferences.length > 0) { + sections.push(`Image references:\n${imageReferences.join("\n")}`); + } + + return sections.join("\n\n"); +} + +function extractAssistantText(content: unknown): string { + if (typeof content === "string") { + return content; + } + + if (!Array.isArray(content)) { + return ""; + } + + return content + .map((part) => { + if ( + typeof part === "object" && + part !== null && + "text" in part && + typeof part.text === "string" + ) { + return part.text; + } + + return ""; + }) + .join(""); +} + +function extractJsonObject(text: string): string | null { + const trimmed = text.trim(); + if (trimmed.length === 0) { + return null; + } + + const fencedMatch = trimmed.match(/^```(?:json)?\s*([\s\S]*?)\s*```$/i); + const unwrapped = fencedMatch?.[1]?.trim() ?? trimmed; + const firstBraceIndex = unwrapped.indexOf("{"); + const lastBraceIndex = unwrapped.lastIndexOf("}"); + + if ( + firstBraceIndex === -1 || + lastBraceIndex === -1 || + lastBraceIndex <= firstBraceIndex + ) { + return null; + } + + return unwrapped.slice(firstBraceIndex, lastBraceIndex + 1); +} + +function getErrorCode(error: unknown): string | undefined { + if (!(error instanceof Error)) { + return undefined; + } + + const directCode = + typeof (error as ErrorWithCode).code === "string" + ? (error as ErrorWithCode).code + : undefined; + if (directCode) { + return directCode; + } + + const visited = new Set([error]); + let current = (error as ErrorWithCode).cause; + for (let depth = 0; depth < 5; depth++) { + if (!(current instanceof Error) || visited.has(current)) { + return undefined; + } + visited.add(current); + + if (typeof (current as ErrorWithCode).code === "string") { + return (current as ErrorWithCode).code; + } + + current = (current as ErrorWithCode).cause; + } + + return undefined; +} + +function buildModerationErrorDetails(error: unknown): Record { + if (!(error instanceof Error)) { + return { + error: String(error), + errorName: + error === null + ? "NullThrownValue" + : typeof error === "undefined" + ? "UndefinedThrownValue" + : typeof error, + }; + } + + const errorCause = extractErrorCause(error); + const errorCode = getErrorCode(error); + + return { + error: error.message, + errorName: error.name, + ...(errorCause ? { errorCause } : {}), + ...(errorCode ? { errorCode } : {}), + }; +} + +function getFlaggedCategories(payload: ModerationApiPayload): string[] { + const result = payload.results?.[0]; + if (!result) { + return []; + } + + const categories = new Set(); + + for (const [category, isFlagged] of Object.entries(result.categories ?? {})) { + if (isFlagged) { + categories.add(category); + } + } + + for (const [category, score] of Object.entries( + result.category_scores ?? {}, + )) { + if (score > 0.5) { + categories.add(category); + } + } + + return [...categories]; +} + +function createFailedCustomContentFilterResult( + model: string, + upstreamRequestId: string | null = null, +): CustomContentFilterCheckResult { + return { + flagged: false, + model, + upstreamRequestId, + results: [], + responses: [], + }; +} + +function logModerationResult( + context: GatewayContentFilterContext, + payload: Record, +) { + logger.debug("gateway_content_filter", { + provider: "llmgateway", + mode: "custom", + requestId: context.requestId, + organizationId: context.organizationId, + projectId: context.projectId, + apiKeyId: context.apiKeyId, + ...payload, + }); +} + +function logModerationError( + context: GatewayContentFilterContext, + payload: Record, + error?: unknown, +) { + const logPayload = { + provider: "llmgateway", + mode: "custom", + requestId: context.requestId, + organizationId: context.organizationId, + projectId: context.projectId, + apiKeyId: context.apiKeyId, + ...payload, + ...buildModerationErrorDetails(error), + }; + + if (error instanceof Error) { + logger.error("gateway_content_filter_error", logPayload, error); + return; + } + + logger.error("gateway_content_filter_error", logPayload); +} + +function buildCustomModerationPayload( + responseId: string | undefined, + responseModel: string | undefined, + verdict: z.infer, +): ModerationApiPayload { + const categories = verdict.categories ?? {}; + const categoryScores = verdict.category_scores ?? {}; + const inferredFlagged = + verdict.flagged || + Object.values(categories).some(Boolean) || + Object.values(categoryScores).some((score) => score > 0.5); + + return { + ...(responseId ? { id: responseId } : {}), + ...(responseModel ? { model: responseModel } : {}), + results: [ + { + flagged: inferredFlagged, + categories, + category_scores: categoryScores, + ...(verdict.reason ? { reason: verdict.reason } : {}), + }, + ], + }; +} + +async function runCustomContentFilterRequest( + messages: BaseMessage[], + context: GatewayContentFilterContext, + config: CustomContentFilterConfig, + signal: AbortSignal, +): Promise { + const startTime = Date.now(); + let upstreamResponse: Response; + let upstreamText: string; + + try { + upstreamResponse = await fetch(getCustomContentFilterUrl(config.baseUrl), { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${config.apiKey}`, + "X-Client-Request-Id": context.requestId, + }, + body: JSON.stringify({ + model: config.model, + temperature: 0, + max_tokens: CUSTOM_CONTENT_FILTER_MAX_TOKENS, + response_format: { + type: "json_schema", + json_schema: CUSTOM_CONTENT_FILTER_JSON_SCHEMA, + }, + plugins: [ + { + id: "response-healing", + }, + ], + messages: [ + { + role: "system", + content: CUSTOM_CONTENT_FILTER_SYSTEM_PROMPT, + }, + { + role: "user", + content: buildCustomContentFilterInput( + messages, + config.includeImages, + ), + }, + ], + }), + signal, + }); + upstreamText = await upstreamResponse.text(); + } catch (error) { + if (signal.aborted || isCancellationError(error)) { + throw error; + } + + logModerationError( + context, + { + durationMs: Date.now() - startTime, + url: getCustomContentFilterUrl(config.baseUrl), + timeout: isTimeoutError(error), + }, + error, + ); + + return createFailedCustomContentFilterResult(config.model); + } + + let responseJson: unknown = null; + if (upstreamText.length > 0) { + try { + responseJson = JSON.parse(upstreamText); + } catch { + responseJson = upstreamText; + } + } + + const upstreamRequestId = upstreamResponse.headers.get("x-request-id"); + if (!upstreamResponse.ok) { + logModerationError(context, { + durationMs: Date.now() - startTime, + status: upstreamResponse.status, + statusText: upstreamResponse.statusText, + upstreamRequestId, + response: responseJson, + }); + + return createFailedCustomContentFilterResult( + config.model, + upstreamRequestId, + ); + } + + const parsedResponse = + gatewayChatCompletionResponseSchema.safeParse(responseJson); + if (!parsedResponse.success) { + logModerationError(context, { + durationMs: Date.now() - startTime, + upstreamRequestId, + response: responseJson, + }); + + return createFailedCustomContentFilterResult( + config.model, + upstreamRequestId, + ); + } + + const assistantText = extractAssistantText( + parsedResponse.data.choices[0]?.message?.content, + ); + const verdictJson = extractJsonObject(assistantText); + if (!verdictJson) { + logModerationError(context, { + durationMs: Date.now() - startTime, + upstreamRequestId, + response: responseJson, + assistantText, + }); + + return createFailedCustomContentFilterResult( + config.model, + upstreamRequestId, + ); + } + + let verdictJsonValue: unknown; + try { + verdictJsonValue = JSON.parse(verdictJson); + } catch (error) { + logModerationError( + context, + { + durationMs: Date.now() - startTime, + upstreamRequestId, + response: responseJson, + assistantText, + }, + error, + ); + + return createFailedCustomContentFilterResult( + config.model, + upstreamRequestId, + ); + } + + const parsedVerdict = + customContentFilterVerdictSchema.safeParse(verdictJsonValue); + if (!parsedVerdict.success) { + logModerationError(context, { + durationMs: Date.now() - startTime, + upstreamRequestId, + response: responseJson, + assistantText, + }); + + return createFailedCustomContentFilterResult( + config.model, + upstreamRequestId, + ); + } + + const moderationPayload = buildCustomModerationPayload( + parsedResponse.data.id, + parsedResponse.data.model ?? config.model, + parsedVerdict.data, + ); + + logModerationResult(context, { + durationMs: Date.now() - startTime, + flagged: moderationPayload.results?.[0]?.flagged === true, + model: moderationPayload.model ?? config.model, + upstreamRequestId, + flaggedCategories: getFlaggedCategories(moderationPayload), + results: moderationPayload.results ?? [], + }); + + return { + flagged: moderationPayload.results?.[0]?.flagged === true, + model: moderationPayload.model ?? config.model, + upstreamRequestId, + results: moderationPayload.results ?? [], + responses: [moderationPayload], + }; +} + +export async function checkCustomContentFilter( + messages: BaseMessage[], + context: GatewayContentFilterContext, + requestSignal?: AbortSignal, +): Promise { + const startTime = Date.now(); + let config: CustomContentFilterConfig; + + try { + config = getCustomContentFilterConfig(); + } catch (error) { + logModerationError( + context, + { + durationMs: Date.now() - startTime, + url: getCustomContentFilterUrl(getCustomContentFilterBaseUrl()), + timeout: false, + }, + error, + ); + + return createFailedCustomContentFilterResult("custom-content-filter"); + } + + const signal = requestSignal + ? AbortSignal.any([ + AbortSignal.timeout(CUSTOM_CONTENT_FILTER_TIMEOUT_MS), + requestSignal, + ]) + : AbortSignal.timeout(CUSTOM_CONTENT_FILTER_TIMEOUT_MS); + + try { + return await runCustomContentFilterRequest( + messages, + context, + config, + signal, + ); + } catch (error) { + if (requestSignal?.aborted || isCancellationError(error)) { + throw error; + } + + logModerationError( + context, + { + durationMs: Date.now() - startTime, + url: getCustomContentFilterUrl(getCustomContentFilterBaseUrl()), + timeout: isTimeoutError(error), + }, + error, + ); + + return createFailedCustomContentFilterResult(config.model); + } +}