From d5a48df43fcf85749cf0f766ba833f1011bace8f Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Fri, 3 Apr 2026 11:20:52 +0200 Subject: [PATCH 1/5] feat: add Model Serving connector and plugin Add the core Model Serving plugin that provides an authenticated proxy to Databricks Model Serving endpoints. Includes the connector layer (SDK client wrapper) and the plugin layer (Express routes for invoke/stream). Also adds UPSTREAM_ERROR SSE error code for propagating API errors. Signed-off-by: Pawel Kosiec --- .../api/appkit/Interface.EndpointConfig.md | 21 ++ .../appkit/Interface.ServingEndpointEntry.md | 27 ++ .../Interface.ServingEndpointRegistry.md | 5 + .../api/appkit/TypeAlias.ServingFactory.md | 19 + docs/docs/api/appkit/index.md | 4 + docs/docs/api/appkit/typedoc-sidebar.ts | 20 ++ docs/static/appkit-ui/styles.gen.css | 28 +- .../appkit/src/connectors/serving/client.ts | 223 ++++++++++++ .../connectors/serving/tests/client.test.ts | 303 ++++++++++++++++ .../appkit/src/connectors/serving/types.ts | 4 + packages/appkit/src/index.ts | 8 +- packages/appkit/src/plugins/index.ts | 1 + .../appkit/src/plugins/serving/defaults.ts | 26 ++ packages/appkit/src/plugins/serving/index.ts | 2 + .../appkit/src/plugins/serving/manifest.json | 54 +++ .../src/plugins/serving/schema-filter.ts | 127 +++++++ .../appkit/src/plugins/serving/serving.ts | 303 ++++++++++++++++ .../serving/tests/schema-filter.test.ts | 141 ++++++++ .../src/plugins/serving/tests/serving.test.ts | 339 ++++++++++++++++++ packages/appkit/src/plugins/serving/types.ts | 67 ++++ packages/appkit/src/stream/stream-manager.ts | 8 + packages/appkit/src/stream/types.ts | 1 + 22 files changed, 1727 insertions(+), 4 deletions(-) create mode 100644 docs/docs/api/appkit/Interface.EndpointConfig.md create mode 100644 docs/docs/api/appkit/Interface.ServingEndpointEntry.md create mode 100644 docs/docs/api/appkit/Interface.ServingEndpointRegistry.md create mode 100644 docs/docs/api/appkit/TypeAlias.ServingFactory.md create mode 100644 packages/appkit/src/connectors/serving/client.ts create mode 100644 packages/appkit/src/connectors/serving/tests/client.test.ts create mode 100644 packages/appkit/src/connectors/serving/types.ts create mode 100644 packages/appkit/src/plugins/serving/defaults.ts create mode 100644 packages/appkit/src/plugins/serving/index.ts create mode 100644 packages/appkit/src/plugins/serving/manifest.json create mode 100644 packages/appkit/src/plugins/serving/schema-filter.ts create mode 100644 packages/appkit/src/plugins/serving/serving.ts create mode 100644 packages/appkit/src/plugins/serving/tests/schema-filter.test.ts create mode 100644 packages/appkit/src/plugins/serving/tests/serving.test.ts create mode 100644 packages/appkit/src/plugins/serving/types.ts diff --git a/docs/docs/api/appkit/Interface.EndpointConfig.md b/docs/docs/api/appkit/Interface.EndpointConfig.md new file mode 100644 index 00000000..6ee94aa3 --- /dev/null +++ b/docs/docs/api/appkit/Interface.EndpointConfig.md @@ -0,0 +1,21 @@ +# Interface: EndpointConfig + +## Properties + +### env + +```ts +env: string; +``` + +Environment variable holding the endpoint name. + +*** + +### servedModel? + +```ts +optional servedModel: string; +``` + +Target a specific served model (bypasses traffic routing). diff --git a/docs/docs/api/appkit/Interface.ServingEndpointEntry.md b/docs/docs/api/appkit/Interface.ServingEndpointEntry.md new file mode 100644 index 00000000..fa054c3f --- /dev/null +++ b/docs/docs/api/appkit/Interface.ServingEndpointEntry.md @@ -0,0 +1,27 @@ +# Interface: ServingEndpointEntry + +Shape of a single registry entry. + +## Properties + +### chunk + +```ts +chunk: unknown; +``` + +*** + +### request + +```ts +request: Record; +``` + +*** + +### response + +```ts +response: unknown; +``` diff --git a/docs/docs/api/appkit/Interface.ServingEndpointRegistry.md b/docs/docs/api/appkit/Interface.ServingEndpointRegistry.md new file mode 100644 index 00000000..defe5270 --- /dev/null +++ b/docs/docs/api/appkit/Interface.ServingEndpointRegistry.md @@ -0,0 +1,5 @@ +# Interface: ServingEndpointRegistry + +Registry interface for serving endpoint type generation. +Empty by default — augmented by the Vite type generator's `.d.ts` output via module augmentation. +When populated, provides autocomplete for alias names and typed request/response/chunk per endpoint. diff --git a/docs/docs/api/appkit/TypeAlias.ServingFactory.md b/docs/docs/api/appkit/TypeAlias.ServingFactory.md new file mode 100644 index 00000000..9ccafef5 --- /dev/null +++ b/docs/docs/api/appkit/TypeAlias.ServingFactory.md @@ -0,0 +1,19 @@ +# Type Alias: ServingFactory + +```ts +type ServingFactory = keyof ServingEndpointRegistry extends never ? (alias?: string) => ServingEndpointMethods : (alias: K) => ServingEndpointMethods; +``` + +Factory function returned by `AppKit.serving`. + +This is a conditional type that adapts based on whether `ServingEndpointRegistry` +has been populated via module augmentation (generated by `appKitServingTypesPlugin()`): + +- **Registry empty (default):** `(alias?: string) => ServingEndpointMethods` — + accepts any alias string with untyped request/response/chunk. +- **Registry populated:** `(alias: K) => ServingEndpointMethods<...>` — + restricts `alias` to known endpoint keys and infers typed request/response/chunk + from the registry entry. + +Run `appKitServingTypesPlugin()` in your Vite config to generate the registry +augmentation and enable full type safety. diff --git a/docs/docs/api/appkit/index.md b/docs/docs/api/appkit/index.md index b5fb7ce0..f4685e04 100644 --- a/docs/docs/api/appkit/index.md +++ b/docs/docs/api/appkit/index.md @@ -33,6 +33,7 @@ plugin architecture, and React integration. | [BasePluginConfig](Interface.BasePluginConfig.md) | Base configuration interface for AppKit plugins | | [CacheConfig](Interface.CacheConfig.md) | Configuration for the CacheInterceptor. Controls TTL, size limits, storage backend, and probabilistic cleanup. | | [DatabaseCredential](Interface.DatabaseCredential.md) | Database credentials with OAuth token for Postgres connection | +| [EndpointConfig](Interface.EndpointConfig.md) | - | | [GenerateDatabaseCredentialRequest](Interface.GenerateDatabaseCredentialRequest.md) | Request parameters for generating database OAuth credentials | | [ITelemetry](Interface.ITelemetry.md) | Plugin-facing interface for OpenTelemetry instrumentation. Provides a thin abstraction over OpenTelemetry APIs for plugins. | | [LakebasePoolConfig](Interface.LakebasePoolConfig.md) | Configuration for creating a Lakebase connection pool | @@ -42,6 +43,8 @@ plugin architecture, and React integration. | [ResourceEntry](Interface.ResourceEntry.md) | Internal representation of a resource in the registry. Extends ResourceRequirement with resolution state and plugin ownership. | | [ResourceFieldEntry](Interface.ResourceFieldEntry.md) | Defines a single field for a resource. Each field has its own environment variable and optional description. Single-value types use one key (e.g. id); multi-value types (database, secret) use multiple (e.g. instance_name, database_name or scope, key). | | [ResourceRequirement](Interface.ResourceRequirement.md) | Declares a resource requirement for a plugin. Can be defined statically in a manifest or dynamically via getResourceRequirements(). Narrows the generated base: type → ResourceType enum, permission → ResourcePermission union. | +| [ServingEndpointEntry](Interface.ServingEndpointEntry.md) | Shape of a single registry entry. | +| [ServingEndpointRegistry](Interface.ServingEndpointRegistry.md) | Registry interface for serving endpoint type generation. Empty by default — augmented by the Vite type generator's `.d.ts` output via module augmentation. When populated, provides autocomplete for alias names and typed request/response/chunk per endpoint. | | [StreamExecutionSettings](Interface.StreamExecutionSettings.md) | Execution settings for streaming endpoints. Extends PluginExecutionSettings with SSE stream configuration. | | [TelemetryConfig](Interface.TelemetryConfig.md) | OpenTelemetry configuration for AppKit applications | | [ValidationResult](Interface.ValidationResult.md) | Result of validating all registered resources against the environment. | @@ -54,6 +57,7 @@ plugin architecture, and React integration. | [IAppRouter](TypeAlias.IAppRouter.md) | Express router type for plugin route registration | | [PluginData](TypeAlias.PluginData.md) | Tuple of plugin class, config, and name. Created by `toPlugin()` and passed to `createApp()`. | | [ResourcePermission](TypeAlias.ResourcePermission.md) | Union of all possible permission levels across all resource types. | +| [ServingFactory](TypeAlias.ServingFactory.md) | Factory function returned by `AppKit.serving`. | | [ToPlugin](TypeAlias.ToPlugin.md) | Factory function type returned by `toPlugin()`. Accepts optional config and returns a PluginData tuple. | ## Variables diff --git a/docs/docs/api/appkit/typedoc-sidebar.ts b/docs/docs/api/appkit/typedoc-sidebar.ts index 2f17b1d2..91815e3d 100644 --- a/docs/docs/api/appkit/typedoc-sidebar.ts +++ b/docs/docs/api/appkit/typedoc-sidebar.ts @@ -97,6 +97,11 @@ const typedocSidebar: SidebarsConfig = { id: "api/appkit/Interface.DatabaseCredential", label: "DatabaseCredential" }, + { + type: "doc", + id: "api/appkit/Interface.EndpointConfig", + label: "EndpointConfig" + }, { type: "doc", id: "api/appkit/Interface.GenerateDatabaseCredentialRequest", @@ -142,6 +147,16 @@ const typedocSidebar: SidebarsConfig = { id: "api/appkit/Interface.ResourceRequirement", label: "ResourceRequirement" }, + { + type: "doc", + id: "api/appkit/Interface.ServingEndpointEntry", + label: "ServingEndpointEntry" + }, + { + type: "doc", + id: "api/appkit/Interface.ServingEndpointRegistry", + label: "ServingEndpointRegistry" + }, { type: "doc", id: "api/appkit/Interface.StreamExecutionSettings", @@ -183,6 +198,11 @@ const typedocSidebar: SidebarsConfig = { id: "api/appkit/TypeAlias.ResourcePermission", label: "ResourcePermission" }, + { + type: "doc", + id: "api/appkit/TypeAlias.ServingFactory", + label: "ServingFactory" + }, { type: "doc", id: "api/appkit/TypeAlias.ToPlugin", diff --git a/docs/static/appkit-ui/styles.gen.css b/docs/static/appkit-ui/styles.gen.css index 9a9a38eb..a2192039 100644 --- a/docs/static/appkit-ui/styles.gen.css +++ b/docs/static/appkit-ui/styles.gen.css @@ -831,9 +831,6 @@ .max-w-\[calc\(100\%-2rem\)\] { max-width: calc(100% - 2rem); } - .max-w-full { - max-width: 100%; - } .max-w-max { max-width: max-content; } @@ -4514,6 +4511,11 @@ width: calc(var(--spacing) * 5); } } + .\[\&_\[data-slot\=scroll-area-viewport\]\>div\]\:\!block { + & [data-slot=scroll-area-viewport]>div { + display: block !important; + } + } .\[\&_a\]\:underline { & a { text-decoration-line: underline; @@ -4637,11 +4639,26 @@ color: var(--muted-foreground); } } + .\[\&_table\]\:block { + & table { + display: block; + } + } + .\[\&_table\]\:max-w-full { + & table { + max-width: 100%; + } + } .\[\&_table\]\:border-collapse { & table { border-collapse: collapse; } } + .\[\&_table\]\:overflow-x-auto { + & table { + overflow-x: auto; + } + } .\[\&_table\]\:text-xs { & table { font-size: var(--text-xs); @@ -4851,6 +4868,11 @@ width: 100%; } } + .\[\&\>\*\]\:min-w-0 { + &>* { + min-width: calc(var(--spacing) * 0); + } + } .\[\&\>\*\]\:focus-visible\:relative { &>* { &:focus-visible { diff --git a/packages/appkit/src/connectors/serving/client.ts b/packages/appkit/src/connectors/serving/client.ts new file mode 100644 index 00000000..6254426d --- /dev/null +++ b/packages/appkit/src/connectors/serving/client.ts @@ -0,0 +1,223 @@ +import { ApiError, type WorkspaceClient } from "@databricks/sdk-experimental"; +import { createLogger } from "../../logging/logger"; +import type { ServingInvokeOptions } from "./types"; + +const logger = createLogger("connectors:serving"); + +/** + * Builds the invocation URL for a serving endpoint. + * Uses `/served-models/{model}/invocations` when servedModel is specified, + * otherwise `/serving-endpoints/{name}/invocations`. + */ +function buildInvocationUrl( + host: string, + endpointName: string, + servedModel?: string, +): string { + const base = host.startsWith("http") ? host : `https://${host}`; + const encodedName = encodeURIComponent(endpointName); + const path = servedModel + ? `/serving-endpoints/${encodedName}/served-models/${encodeURIComponent(servedModel)}/invocations` + : `/serving-endpoints/${encodedName}/invocations`; + return new URL(path, base).toString(); +} + +/** + * Maps upstream Databricks error status codes to appropriate proxy responses. + */ +function mapUpstreamError( + status: number, + body: string, + headers: Headers, +): ApiError { + const safeMessage = body.length > 500 ? `${body.slice(0, 500)}...` : body; + + let parsed: { message?: string; error?: string } = {}; + try { + parsed = JSON.parse(body); + } catch { + // body is not JSON + } + + const message = parsed.message || parsed.error || safeMessage; + + switch (true) { + case status === 400: + return new ApiError(message, "BAD_REQUEST", 400, undefined, []); + case status === 401 || status === 403: + logger.warn("Authentication failure from serving endpoint: %s", message); + return new ApiError(message, "AUTH_FAILURE", status, undefined, []); + case status === 404: + return new ApiError(message, "NOT_FOUND", 404, undefined, []); + case status === 429: { + const retryAfter = headers.get("retry-after"); + const retryMessage = retryAfter + ? `${message} (retry-after: ${retryAfter})` + : message; + return new ApiError(retryMessage, "RATE_LIMITED", 429, undefined, []); + } + case status === 503: + return new ApiError( + "Endpoint loading, retry shortly", + "SERVICE_UNAVAILABLE", + 503, + undefined, + [], + ); + case status >= 500: + return new ApiError(message, "BAD_GATEWAY", 502, undefined, []); + default: + return new ApiError(message, "UNKNOWN", status, undefined, []); + } +} + +/** + * Invokes a serving endpoint and returns the parsed JSON response. + */ +export async function invoke( + client: WorkspaceClient, + endpointName: string, + body: Record, + options?: ServingInvokeOptions, +): Promise { + const host = client.config.host; + if (!host) { + throw new Error( + "Databricks host is not configured. Set DATABRICKS_HOST or configure client.config.host.", + ); + } + + const url = buildInvocationUrl(host, endpointName, options?.servedModel); + + // Always strip `stream` from the body — the connector controls this + const { stream: _stream, ...cleanBody } = body; + + const headers = new Headers({ + "Content-Type": "application/json", + Accept: "application/json", + }); + await client.config.authenticate(headers); + + logger.debug("Invoking endpoint %s at %s", endpointName, url); + + const res = await fetch(url, { + method: "POST", + headers, + body: JSON.stringify(cleanBody), + signal: options?.signal, + }); + + if (!res.ok) { + const text = await res.text(); + throw mapUpstreamError(res.status, text, res.headers); + } + + return res.json(); +} + +/** + * Invokes a serving endpoint with streaming enabled. + * Yields parsed JSON chunks from the NDJSON SSE response. + */ +export async function* stream( + client: WorkspaceClient, + endpointName: string, + body: Record, + options?: ServingInvokeOptions, +): AsyncGenerator { + const host = client.config.host; + if (!host) { + throw new Error( + "Databricks host is not configured. Set DATABRICKS_HOST or configure client.config.host.", + ); + } + + const url = buildInvocationUrl(host, endpointName, options?.servedModel); + + // Strip any user-provided `stream` and inject `stream: true` + const { stream: _stream, ...cleanBody } = body; + const streamBody = { ...cleanBody, stream: true }; + + const headers = new Headers({ + "Content-Type": "application/json", + Accept: "text/event-stream", + }); + await client.config.authenticate(headers); + + logger.debug("Streaming from endpoint %s at %s", endpointName, url); + + const res = await fetch(url, { + method: "POST", + headers, + body: JSON.stringify(streamBody), + signal: options?.signal, + }); + + if (!res.ok) { + const text = await res.text(); + throw mapUpstreamError(res.status, text, res.headers); + } + + if (!res.body) { + throw new Error("Response body is null — streaming not supported"); + } + + const reader = res.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + const MAX_BUFFER_SIZE = 1024 * 1024; // 1 MB + + try { + while (true) { + if (options?.signal?.aborted) break; + + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + + if (buffer.length > MAX_BUFFER_SIZE) { + logger.warn( + "Stream buffer exceeded %d bytes, discarding incomplete data", + MAX_BUFFER_SIZE, + ); + buffer = ""; + } + + // Process complete lines from the buffer + const lines = buffer.split("\n"); + // Keep the last (potentially incomplete) line in the buffer + buffer = lines.pop() ?? ""; + + for (const line of lines) { + const trimmed = line.trim(); + if (!trimmed || trimmed.startsWith(":")) continue; // skip empty lines and SSE comments + if (trimmed === "data: [DONE]") return; + + if (trimmed.startsWith("data: ")) { + const jsonStr = trimmed.slice(6); + try { + yield JSON.parse(jsonStr); + } catch { + logger.warn("Failed to parse streaming chunk: %s", jsonStr); + } + } + } + } + + // Process any remaining data in the buffer + if (buffer.trim() && !options?.signal?.aborted) { + const trimmed = buffer.trim(); + if (trimmed.startsWith("data: ") && trimmed !== "data: [DONE]") { + try { + yield JSON.parse(trimmed.slice(6)); + } catch { + logger.warn("Failed to parse final streaming chunk: %s", trimmed); + } + } + } + } finally { + reader.cancel().catch(() => {}); + reader.releaseLock(); + } +} diff --git a/packages/appkit/src/connectors/serving/tests/client.test.ts b/packages/appkit/src/connectors/serving/tests/client.test.ts new file mode 100644 index 00000000..6af859ae --- /dev/null +++ b/packages/appkit/src/connectors/serving/tests/client.test.ts @@ -0,0 +1,303 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { invoke, stream } from "../client"; + +const mockAuthenticate = vi.fn(); + +function createMockClient(host = "https://test.databricks.com") { + return { + config: { + host, + authenticate: mockAuthenticate, + }, + } as any; +} + +describe("Serving Connector", () => { + beforeEach(() => { + mockAuthenticate.mockResolvedValue(undefined); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe("invoke", () => { + test("constructs correct URL for endpoint invocation", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient(); + await invoke(client, "my-endpoint", { messages: [] }); + + expect(fetchSpy).toHaveBeenCalledWith( + "https://test.databricks.com/serving-endpoints/my-endpoint/invocations", + expect.objectContaining({ method: "POST" }), + ); + }); + + test("constructs correct URL with servedModel override", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient(); + await invoke( + client, + "my-endpoint", + { messages: [] }, + { servedModel: "llama-v2" }, + ); + + expect(fetchSpy).toHaveBeenCalledWith( + "https://test.databricks.com/serving-endpoints/my-endpoint/served-models/llama-v2/invocations", + expect.objectContaining({ method: "POST" }), + ); + }); + + test("authenticates request headers", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient(); + await invoke(client, "my-endpoint", { messages: [] }); + + expect(mockAuthenticate).toHaveBeenCalledWith(expect.any(Headers)); + }); + + test("strips stream property from body", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient(); + await invoke(client, "my-endpoint", { + messages: [], + stream: true, + temperature: 0.7, + }); + + const body = JSON.parse(fetchSpy.mock.calls[0][1]?.body as string); + expect(body).toEqual({ messages: [], temperature: 0.7 }); + expect(body.stream).toBeUndefined(); + }); + + test("returns parsed JSON response", async () => { + const responseData = { choices: [{ message: { content: "Hello" } }] }; + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(responseData), { status: 200 }), + ); + + const client = createMockClient(); + const result = await invoke(client, "my-endpoint", { messages: [] }); + + expect(result).toEqual(responseData); + }); + + test("throws ApiError on 400 response", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ message: "Invalid params" }), { + status: 400, + }), + ); + + const client = createMockClient(); + await expect( + invoke(client, "my-endpoint", { messages: [] }), + ).rejects.toThrow("Invalid params"); + }); + + test("throws ApiError on 404 response", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ message: "Endpoint not found" }), { + status: 404, + }), + ); + + const client = createMockClient(); + await expect( + invoke(client, "my-endpoint", { messages: [] }), + ).rejects.toThrow("Endpoint not found"); + }); + + test("maps 5xx to 502 bad gateway", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ message: "Internal error" }), { + status: 500, + }), + ); + + const client = createMockClient(); + try { + await invoke(client, "my-endpoint", { messages: [] }); + expect.unreachable("Should have thrown"); + } catch (err: any) { + expect(err.statusCode).toBe(502); + } + }); + + test("forwards AbortSignal", async () => { + const controller = new AbortController(); + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient(); + await invoke( + client, + "my-endpoint", + { messages: [] }, + { signal: controller.signal }, + ); + + expect(fetchSpy.mock.calls[0][1]?.signal).toBe(controller.signal); + }); + + test("throws when host is not configured", async () => { + const client = { + config: { + host: "", + authenticate: mockAuthenticate, + }, + } as any; + await expect( + invoke(client, "my-endpoint", { messages: [] }), + ).rejects.toThrow("Databricks host is not configured"); + }); + + test("prepends https:// to host without protocol", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient("test.databricks.com"); + await invoke(client, "my-endpoint", { messages: [] }); + + expect(fetchSpy.mock.calls[0][0]).toContain( + "https://test.databricks.com", + ); + }); + }); + + describe("stream", () => { + function createSSEResponse(chunks: string[]) { + const body = `${chunks.join("\n")}\n`; + return new Response(body, { + status: 200, + headers: { "Content-Type": "text/event-stream" }, + }); + } + + test("yields parsed NDJSON chunks", async () => { + const chunks = [ + 'data: {"choices":[{"delta":{"content":"Hello"}}]}', + 'data: {"choices":[{"delta":{"content":" world"}}]}', + "data: [DONE]", + ]; + + vi.spyOn(globalThis, "fetch").mockResolvedValue( + createSSEResponse(chunks), + ); + + const client = createMockClient(); + const results: unknown[] = []; + for await (const chunk of stream(client, "my-endpoint", { + messages: [], + })) { + results.push(chunk); + } + + expect(results).toEqual([ + { choices: [{ delta: { content: "Hello" } }] }, + { choices: [{ delta: { content: " world" } }] }, + ]); + }); + + test("injects stream: true into body", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue(createSSEResponse(["data: [DONE]"])); + + const client = createMockClient(); + // Consume the generator + for await (const _ of stream(client, "my-endpoint", { messages: [] })) { + // noop + } + + const body = JSON.parse(fetchSpy.mock.calls[0][1]?.body as string); + expect(body.stream).toBe(true); + }); + + test("strips user-provided stream and re-injects", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue(createSSEResponse(["data: [DONE]"])); + + const client = createMockClient(); + for await (const _ of stream(client, "my-endpoint", { + messages: [], + stream: false, + })) { + // noop + } + + const body = JSON.parse(fetchSpy.mock.calls[0][1]?.body as string); + expect(body.stream).toBe(true); + }); + + test("skips SSE comments and empty lines", async () => { + const chunks = [ + ": this is a comment", + "", + 'data: {"choices":[{"delta":{"content":"Hi"}}]}', + "", + "data: [DONE]", + ]; + + vi.spyOn(globalThis, "fetch").mockResolvedValue( + createSSEResponse(chunks), + ); + + const client = createMockClient(); + const results: unknown[] = []; + for await (const chunk of stream(client, "my-endpoint", { + messages: [], + })) { + results.push(chunk); + } + + expect(results).toHaveLength(1); + expect(results[0]).toEqual({ choices: [{ delta: { content: "Hi" } }] }); + }); + + test("throws on non-OK response", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ message: "Rate limited" }), { + status: 429, + headers: { "Retry-After": "5" }, + }), + ); + + const client = createMockClient(); + try { + for await (const _ of stream(client, "my-endpoint", { messages: [] })) { + // noop + } + expect.unreachable("Should have thrown"); + } catch (err: any) { + expect(err.statusCode).toBe(429); + } + }); + }); +}); diff --git a/packages/appkit/src/connectors/serving/types.ts b/packages/appkit/src/connectors/serving/types.ts new file mode 100644 index 00000000..6dd1acba --- /dev/null +++ b/packages/appkit/src/connectors/serving/types.ts @@ -0,0 +1,4 @@ +export interface ServingInvokeOptions { + servedModel?: string; + signal?: AbortSignal; +} diff --git a/packages/appkit/src/index.ts b/packages/appkit/src/index.ts index 8db7f1d7..662a9178 100644 --- a/packages/appkit/src/index.ts +++ b/packages/appkit/src/index.ts @@ -48,7 +48,13 @@ export { } from "./errors"; // Plugin authoring export { Plugin, type ToPlugin, toPlugin } from "./plugin"; -export { analytics, files, genie, lakebase, server } from "./plugins"; +export { analytics, files, genie, lakebase, server, serving } from "./plugins"; +export type { + EndpointConfig, + ServingEndpointEntry, + ServingEndpointRegistry, + ServingFactory, +} from "./plugins/serving/types"; // Registry types and utilities for plugin manifests export type { ConfigSchema, diff --git a/packages/appkit/src/plugins/index.ts b/packages/appkit/src/plugins/index.ts index 7caa040f..4d58082f 100644 --- a/packages/appkit/src/plugins/index.ts +++ b/packages/appkit/src/plugins/index.ts @@ -3,3 +3,4 @@ export * from "./files"; export * from "./genie"; export * from "./lakebase"; export * from "./server"; +export * from "./serving"; diff --git a/packages/appkit/src/plugins/serving/defaults.ts b/packages/appkit/src/plugins/serving/defaults.ts new file mode 100644 index 00000000..1fea64c2 --- /dev/null +++ b/packages/appkit/src/plugins/serving/defaults.ts @@ -0,0 +1,26 @@ +import type { StreamExecutionSettings } from "shared"; + +export const servingInvokeDefaults = { + cache: { + enabled: false, + }, + retry: { + enabled: false, + }, + timeout: 120_000, +}; + +export const servingStreamDefaults: StreamExecutionSettings = { + default: { + cache: { + enabled: false, + }, + retry: { + enabled: false, + }, + timeout: 120_000, + }, + stream: { + bufferSize: 200, + }, +}; diff --git a/packages/appkit/src/plugins/serving/index.ts b/packages/appkit/src/plugins/serving/index.ts new file mode 100644 index 00000000..85caf33b --- /dev/null +++ b/packages/appkit/src/plugins/serving/index.ts @@ -0,0 +1,2 @@ +export * from "./serving"; +export * from "./types"; diff --git a/packages/appkit/src/plugins/serving/manifest.json b/packages/appkit/src/plugins/serving/manifest.json new file mode 100644 index 00000000..9ac0845f --- /dev/null +++ b/packages/appkit/src/plugins/serving/manifest.json @@ -0,0 +1,54 @@ +{ + "$schema": "https://databricks.github.io/appkit/schemas/plugin-manifest.schema.json", + "name": "serving", + "displayName": "Model Serving Plugin", + "description": "Authenticated proxy to Databricks Model Serving endpoints", + "resources": { + "required": [ + { + "type": "serving_endpoint", + "alias": "Serving Endpoint", + "resourceKey": "serving-endpoint", + "description": "Model Serving endpoint for inference", + "permission": "CAN_QUERY", + "fields": { + "name": { + "env": "DATABRICKS_SERVING_ENDPOINT", + "description": "Serving endpoint name" + } + } + } + ], + "optional": [] + }, + "config": { + "schema": { + "type": "object", + "properties": { + "endpoints": { + "type": "object", + "description": "Map of alias names to endpoint configurations", + "additionalProperties": { + "type": "object", + "properties": { + "env": { + "type": "string", + "description": "Environment variable holding the endpoint name" + }, + "servedModel": { + "type": "string", + "description": "Target a specific served model (bypasses traffic routing)" + } + }, + "required": ["env"] + } + }, + "timeout": { + "type": "number", + "default": 120000, + "description": "Request timeout in ms. Default: 120000 (2 min)" + } + } + } + } +} diff --git a/packages/appkit/src/plugins/serving/schema-filter.ts b/packages/appkit/src/plugins/serving/schema-filter.ts new file mode 100644 index 00000000..6e52294a --- /dev/null +++ b/packages/appkit/src/plugins/serving/schema-filter.ts @@ -0,0 +1,127 @@ +import fs from "node:fs/promises"; +import { createLogger } from "../../logging/logger"; + +const CACHE_VERSION = "1"; + +interface ServingCacheEntry { + hash: string; + requestType: string; + responseType: string; + chunkType: string | null; +} + +interface ServingCache { + version: string; + endpoints: Record; +} + +const logger = createLogger("serving:schema-filter"); + +function isValidCache(data: unknown): data is ServingCache { + return ( + typeof data === "object" && + data !== null && + "version" in data && + (data as ServingCache).version === CACHE_VERSION && + "endpoints" in data && + typeof (data as ServingCache).endpoints === "object" + ); +} + +/** + * Loads endpoint schemas from the type generation cache file. + * Returns a map of alias → allowed parameter keys. + */ +export async function loadEndpointSchemas( + cacheFile: string, +): Promise>> { + const allowlists = new Map>(); + + try { + const raw = await fs.readFile(cacheFile, "utf8"); + const parsed: unknown = JSON.parse(raw); + if (!isValidCache(parsed)) { + logger.warn("Serving types cache has invalid structure, skipping"); + return allowlists; + } + const cache = parsed; + + for (const [alias, entry] of Object.entries(cache.endpoints)) { + // Extract property keys from the requestType string + // The requestType is a TypeScript object type like "{ messages: ...; temperature: ...; }" + const keys = extractPropertyKeys(entry.requestType); + if (keys.size > 0) { + allowlists.set(alias, keys); + } + } + } catch (err) { + if ((err as NodeJS.ErrnoException).code !== "ENOENT") { + logger.warn( + "Failed to load serving types cache: %s", + (err as Error).message, + ); + } + // No cache → no filtering, passthrough mode + } + + return allowlists; +} + +/** + * Extracts top-level property keys from a TypeScript object type string. + * Matches patterns like `key:` or `key?:` at the first nesting level. + */ +function extractPropertyKeys(typeStr: string): Set { + const keys = new Set(); + // Match property names at the top level of the object type + // Looking for patterns: ` propertyName:` or ` propertyName?:` + const propRegex = /^\s{2}(?:\/\*\*[^*]*\*\/\s*)?(\w+)\??:/gm; + for ( + let match = propRegex.exec(typeStr); + match !== null; + match = propRegex.exec(typeStr) + ) { + keys.add(match[1]); + } + return keys; +} + +/** + * Filters a request body against the allowed keys for an endpoint alias. + * Returns the filtered body and logs a warning for stripped params. + * + * If no allowlist exists for the alias, returns the body unchanged (passthrough). + */ +export function filterRequestBody( + body: Record, + allowlists: Map>, + alias: string, + filterMode: "strip" | "reject" = "strip", +): Record { + const allowed = allowlists.get(alias); + if (!allowed) return body; + + const stripped: string[] = []; + const filtered: Record = {}; + + for (const [key, value] of Object.entries(body)) { + if (allowed.has(key)) { + filtered[key] = value; + } else { + stripped.push(key); + } + } + + if (stripped.length > 0) { + if (filterMode === "reject") { + throw new Error(`Unknown request parameters: ${stripped.join(", ")}`); + } + logger.warn( + "Stripped unknown params from '%s': %s", + alias, + stripped.join(", "), + ); + } + + return filtered; +} diff --git a/packages/appkit/src/plugins/serving/serving.ts b/packages/appkit/src/plugins/serving/serving.ts new file mode 100644 index 00000000..e868cc02 --- /dev/null +++ b/packages/appkit/src/plugins/serving/serving.ts @@ -0,0 +1,303 @@ +import { randomUUID } from "node:crypto"; +import path from "node:path"; +import type express from "express"; +import type { IAppRouter, StreamExecutionSettings } from "shared"; +import * as servingConnector from "../../connectors/serving/client"; +import { getWorkspaceClient } from "../../context"; +import { createLogger } from "../../logging"; +import { Plugin, toPlugin } from "../../plugin"; +import type { PluginManifest, ResourceRequirement } from "../../registry"; +import { ResourceType } from "../../registry"; +import { servingInvokeDefaults, servingStreamDefaults } from "./defaults"; +import manifest from "./manifest.json"; +import { filterRequestBody, loadEndpointSchemas } from "./schema-filter"; +import type { EndpointConfig, IServingConfig, ServingFactory } from "./types"; + +const logger = createLogger("serving"); + +class EndpointNotFoundError extends Error { + constructor(alias: string) { + super(`Unknown endpoint alias: ${alias}`); + } +} + +class EndpointNotConfiguredError extends Error { + constructor(alias: string, envVar: string) { + super( + `Endpoint '${alias}' is not configured: env var '${envVar}' is not set`, + ); + } +} + +interface ResolvedEndpoint { + name: string; + servedModel?: string; +} + +export class ServingPlugin extends Plugin { + static manifest = manifest as PluginManifest<"serving">; + + protected static description = + "Authenticated proxy to Databricks Model Serving endpoints"; + protected declare config: IServingConfig; + + private readonly endpoints: Record; + private readonly isNamedMode: boolean; + private schemaAllowlists = new Map>(); + + constructor(config: IServingConfig) { + super(config); + this.config = config; + + if (config.endpoints) { + this.endpoints = config.endpoints; + this.isNamedMode = true; + } else { + this.endpoints = { + default: { env: "DATABRICKS_SERVING_ENDPOINT" }, + }; + this.isNamedMode = false; + } + } + + async setup(): Promise { + const cacheFile = path.join( + process.cwd(), + "node_modules", + ".databricks", + "appkit", + ".appkit-serving-types-cache.json", + ); + this.schemaAllowlists = await loadEndpointSchemas(cacheFile); + if (this.schemaAllowlists.size > 0) { + logger.debug( + "Loaded schema allowlists for %d endpoint(s)", + this.schemaAllowlists.size, + ); + } + } + + static getResourceRequirements( + config: IServingConfig, + ): ResourceRequirement[] { + const endpoints = config.endpoints ?? { + default: { env: "DATABRICKS_SERVING_ENDPOINT" }, + }; + + return Object.entries(endpoints).map(([alias, endpointConfig]) => ({ + type: ResourceType.SERVING_ENDPOINT, + alias: `serving-${alias}`, + resourceKey: `serving-${alias}`, + description: `Model Serving endpoint for "${alias}" inference`, + permission: "CAN_QUERY" as const, + fields: { + name: { + env: endpointConfig.env, + description: `Serving endpoint name for "${alias}"`, + }, + }, + required: true, + })); + } + + private resolveAndFilter( + alias: string, + body: Record, + ): { endpoint: ResolvedEndpoint; filteredBody: Record } { + const config = this.endpoints[alias]; + if (!config) { + throw new EndpointNotFoundError(alias); + } + + const name = process.env[config.env]; + if (!name) { + throw new EndpointNotConfiguredError(alias, config.env); + } + + const endpoint: ResolvedEndpoint = { + name, + servedModel: config.servedModel, + }; + const filteredBody = filterRequestBody( + body, + this.schemaAllowlists, + alias, + this.config.filterMode, + ); + return { endpoint, filteredBody }; + } + + injectRoutes(router: IAppRouter) { + if (this.isNamedMode) { + this.route(router, { + name: "invoke", + method: "post", + path: "/:alias/invoke", + handler: async (req: express.Request, res: express.Response) => { + await this.asUser(req)._handleInvoke(req, res); + }, + }); + + this.route(router, { + name: "stream", + method: "post", + path: "/:alias/stream", + handler: async (req: express.Request, res: express.Response) => { + await this.asUser(req)._handleStream(req, res); + }, + }); + } else { + this.route(router, { + name: "invoke", + method: "post", + path: "/invoke", + handler: async (req: express.Request, res: express.Response) => { + req.params.alias = "default"; + await this.asUser(req)._handleInvoke(req, res); + }, + }); + + this.route(router, { + name: "stream", + method: "post", + path: "/stream", + handler: async (req: express.Request, res: express.Response) => { + req.params.alias = "default"; + await this.asUser(req)._handleStream(req, res); + }, + }); + } + } + + async _handleInvoke( + req: express.Request, + res: express.Response, + ): Promise { + const { alias } = req.params; + const rawBody = req.body as Record; + + try { + const result = await this.invoke(alias, rawBody); + if (result === undefined) { + res.status(502).json({ error: "Invocation returned no result" }); + return; + } + res.json(result); + } catch (err) { + const message = err instanceof Error ? err.message : "Invocation failed"; + if (err instanceof EndpointNotFoundError) { + res.status(404).json({ error: message }); + } else if ( + err instanceof EndpointNotConfiguredError || + message.startsWith("Unknown request parameters:") + ) { + res.status(400).json({ error: message }); + } else { + res.status(502).json({ error: message }); + } + } + } + + async _handleStream( + req: express.Request, + res: express.Response, + ): Promise { + const { alias } = req.params; + const rawBody = req.body as Record; + + let endpoint: ResolvedEndpoint; + let filteredBody: Record; + try { + ({ endpoint, filteredBody } = this.resolveAndFilter(alias, rawBody)); + } catch (err) { + const message = err instanceof Error ? err.message : "Invalid request"; + const status = err instanceof EndpointNotFoundError ? 404 : 400; + res.status(status).json({ error: message }); + return; + } + + const timeout = this.config.timeout ?? 120_000; + const requestId = + (typeof req.query.requestId === "string" && req.query.requestId) || + randomUUID(); + + const streamSettings: StreamExecutionSettings = { + ...servingStreamDefaults, + default: { + ...servingStreamDefaults.default, + timeout, + }, + stream: { + ...servingStreamDefaults.stream, + streamId: requestId, + }, + }; + + const workspaceClient = getWorkspaceClient(); + if (!workspaceClient.config.host) { + res.status(500).json({ error: "Databricks host not configured" }); + return; + } + + await this.executeStream( + res, + () => + servingConnector.stream(workspaceClient, endpoint.name, filteredBody, { + servedModel: endpoint.servedModel, + }), + streamSettings, + ); + } + + async invoke(alias: string, body: Record): Promise { + const { endpoint, filteredBody } = this.resolveAndFilter(alias, body); + const workspaceClient = getWorkspaceClient(); + const timeout = this.config.timeout ?? 120_000; + + return this.execute( + () => + servingConnector.invoke(workspaceClient, endpoint.name, filteredBody, { + servedModel: endpoint.servedModel, + }), + { + default: { + ...servingInvokeDefaults, + timeout, + }, + }, + ); + } + + async *stream( + alias: string, + body: Record, + ): AsyncGenerator { + const { endpoint, filteredBody } = this.resolveAndFilter(alias, body); + const workspaceClient = getWorkspaceClient(); + + yield* servingConnector.stream( + workspaceClient, + endpoint.name, + filteredBody, + { servedModel: endpoint.servedModel }, + ); + } + + async shutdown(): Promise { + this.streamManager.abortAll(); + } + + exports(): ServingFactory { + return ((alias?: string) => ({ + invoke: (body: Record) => + this.invoke(alias ?? "default", body), + stream: (body: Record) => + this.stream(alias ?? "default", body), + })) as ServingFactory; + } +} + +/** + * @internal + */ +export const serving = toPlugin(ServingPlugin); diff --git a/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts b/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts new file mode 100644 index 00000000..948b47f9 --- /dev/null +++ b/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts @@ -0,0 +1,141 @@ +import { describe, expect, test, vi } from "vitest"; +import { filterRequestBody, loadEndpointSchemas } from "../schema-filter"; + +vi.mock("node:fs/promises", () => ({ + default: { + readFile: vi.fn(), + }, +})); + +describe("schema-filter", () => { + describe("filterRequestBody", () => { + test("strips unknown keys when allowlist exists", () => { + const allowlists = new Map([ + ["default", new Set(["messages", "temperature"])], + ]); + + const result = filterRequestBody( + { messages: [], temperature: 0.7, unknown_param: true }, + allowlists, + "default", + ); + + expect(result).toEqual({ messages: [], temperature: 0.7 }); + }); + + test("preserves all keys when no allowlist for alias", () => { + const allowlists = new Map>(); + + const body = { messages: [], custom: "value" }; + const result = filterRequestBody(body, allowlists, "default"); + + expect(result).toBe(body); // Same reference, no filtering + }); + + test("returns empty object when all keys are unknown", () => { + const allowlists = new Map([["default", new Set(["messages"])]]); + + const result = filterRequestBody( + { bad1: 1, bad2: 2 }, + allowlists, + "default", + ); + + expect(result).toEqual({}); + }); + + test("returns full body when all keys are allowed", () => { + const allowlists = new Map([["default", new Set(["a", "b", "c"])]]); + + const result = filterRequestBody( + { a: 1, b: 2, c: 3 }, + allowlists, + "default", + ); + + expect(result).toEqual({ a: 1, b: 2, c: 3 }); + }); + + test("throws in reject mode when unknown keys are present", () => { + const allowlists = new Map([["default", new Set(["messages"])]]); + + expect(() => + filterRequestBody( + { messages: [], unknown_param: true }, + allowlists, + "default", + "reject", + ), + ).toThrow("Unknown request parameters: unknown_param"); + }); + + test("does not throw in reject mode when all keys are allowed", () => { + const allowlists = new Map([ + ["default", new Set(["messages", "temperature"])], + ]); + + const result = filterRequestBody( + { messages: [], temperature: 0.7 }, + allowlists, + "default", + "reject", + ); + + expect(result).toEqual({ messages: [], temperature: 0.7 }); + }); + + test("strips in default mode (strip)", () => { + const allowlists = new Map([["default", new Set(["messages"])]]); + + const result = filterRequestBody( + { messages: [], extra: true }, + allowlists, + "default", + "strip", + ); + + expect(result).toEqual({ messages: [] }); + }); + }); + + describe("loadEndpointSchemas", () => { + test("returns empty map when cache file does not exist", async () => { + const fs = (await import("node:fs/promises")).default; + vi.mocked(fs.readFile).mockRejectedValue( + Object.assign(new Error("ENOENT"), { code: "ENOENT" }), + ); + + const result = await loadEndpointSchemas("/nonexistent/path"); + expect(result.size).toBe(0); + }); + + test("extracts property keys from cached types", async () => { + const fs = (await import("node:fs/promises")).default; + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + version: "1", + endpoints: { + default: { + hash: "abc", + requestType: `{ + messages: string[]; + temperature?: number | null; + max_tokens: number; +}`, + responseType: "{}", + chunkType: null, + }, + }, + }), + ); + + const result = await loadEndpointSchemas("/some/path"); + expect(result.size).toBe(1); + const keys = result.get("default"); + expect(keys).toBeDefined(); + expect(keys?.has("messages")).toBe(true); + expect(keys?.has("temperature")).toBe(true); + expect(keys?.has("max_tokens")).toBe(true); + }); + }); +}); diff --git a/packages/appkit/src/plugins/serving/tests/serving.test.ts b/packages/appkit/src/plugins/serving/tests/serving.test.ts new file mode 100644 index 00000000..1a953b77 --- /dev/null +++ b/packages/appkit/src/plugins/serving/tests/serving.test.ts @@ -0,0 +1,339 @@ +import { + createMockRequest, + createMockResponse, + createMockRouter, + mockServiceContext, + setupDatabricksEnv, +} from "@tools/test-helpers"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { ServiceContext } from "../../../context/service-context"; +import { ServingPlugin, serving } from "../serving"; +import type { IServingConfig } from "../types"; + +// Mock CacheManager singleton +const { mockCacheInstance } = vi.hoisted(() => { + const instance = { + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + getOrExecute: vi + .fn() + .mockImplementation( + async (_key: unknown[], fn: () => Promise) => { + return await fn(); + }, + ), + generateKey: vi.fn((...args: unknown[]) => JSON.stringify(args)), + }; + return { mockCacheInstance: instance }; +}); + +vi.mock("../../../cache", () => ({ + CacheManager: { + getInstanceSync: vi.fn(() => mockCacheInstance), + }, +})); + +// Mock the serving connector +const mockInvoke = vi.fn(); +const mockStream = vi.fn(); + +vi.mock("../../../connectors/serving/client", () => ({ + invoke: (...args: any[]) => mockInvoke(...args), + stream: (...args: any[]) => mockStream(...args), +})); + +describe("Serving Plugin", () => { + let serviceContextMock: Awaited>; + + beforeEach(async () => { + setupDatabricksEnv(); + process.env.DATABRICKS_SERVING_ENDPOINT = "test-endpoint"; + ServiceContext.reset(); + + serviceContextMock = await mockServiceContext(); + }); + + afterEach(() => { + serviceContextMock?.restore(); + delete process.env.DATABRICKS_SERVING_ENDPOINT; + vi.restoreAllMocks(); + }); + + test("serving factory should have correct name", () => { + const pluginData = serving(); + expect(pluginData.name).toBe("serving"); + }); + + test("serving factory with config should have correct name", () => { + const pluginData = serving({ + endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT" } }, + }); + expect(pluginData.name).toBe("serving"); + }); + + describe("default mode", () => { + test("reads DATABRICKS_SERVING_ENDPOINT", () => { + const plugin = new ServingPlugin({}); + const api = (plugin.exports() as any)(); + expect(api.invoke).toBeDefined(); + expect(api.stream).toBeDefined(); + }); + + test("injects /invoke and /stream routes", () => { + const plugin = new ServingPlugin({}); + const { router, handlers } = createMockRouter(); + + plugin.injectRoutes(router); + + expect(handlers["POST:/invoke"]).toBeDefined(); + expect(handlers["POST:/stream"]).toBeDefined(); + }); + + test("exports returns a factory that provides invoke and stream", () => { + const plugin = new ServingPlugin({}); + const factory = plugin.exports() as any; + const api = factory(); + + expect(typeof api.invoke).toBe("function"); + expect(typeof api.stream).toBe("function"); + }); + }); + + describe("named mode", () => { + const namedConfig: IServingConfig = { + endpoints: { + llm: { env: "DATABRICKS_SERVING_ENDPOINT" }, + embedder: { env: "DATABRICKS_SERVING_ENDPOINT_EMBEDDING" }, + }, + }; + + test("injects /:alias/invoke and /:alias/stream routes", () => { + const plugin = new ServingPlugin(namedConfig); + const { router, handlers } = createMockRouter(); + + plugin.injectRoutes(router); + + expect(handlers["POST:/:alias/invoke"]).toBeDefined(); + expect(handlers["POST:/:alias/stream"]).toBeDefined(); + }); + + test("exports factory returns invoke and stream for named aliases", () => { + const plugin = new ServingPlugin(namedConfig); + const factory = plugin.exports() as any; + + expect(typeof factory("llm").invoke).toBe("function"); + expect(typeof factory("llm").stream).toBe("function"); + expect(typeof factory("embedder").invoke).toBe("function"); + expect(typeof factory("embedder").stream).toBe("function"); + }); + }); + + describe("route handlers", () => { + test("_handleInvoke returns 404 for unknown alias", async () => { + const plugin = new ServingPlugin({ + endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT" } }, + }); + + const req = createMockRequest({ + params: { alias: "unknown" }, + body: { messages: [] }, + }); + const res = createMockResponse(); + + await plugin._handleInvoke(req as any, res as any); + + expect(res.status).toHaveBeenCalledWith(404); + expect(res.json).toHaveBeenCalledWith({ + error: "Unknown endpoint alias: unknown", + }); + }); + + test("_handleInvoke calls connector with correct endpoint", async () => { + mockInvoke.mockResolvedValue({ choices: [] }); + + const plugin = new ServingPlugin({}); + const req = createMockRequest({ + params: { alias: "default" }, + body: { messages: [{ role: "user", content: "Hello" }] }, + }); + const res = createMockResponse(); + + await plugin._handleInvoke(req as any, res as any); + + expect(mockInvoke).toHaveBeenCalledWith( + expect.anything(), + "test-endpoint", + { messages: [{ role: "user", content: "Hello" }] }, + { servedModel: undefined }, + ); + expect(res.json).toHaveBeenCalledWith({ choices: [] }); + }); + + test("_handleInvoke returns 400 with descriptive message when env var is not set", async () => { + delete process.env.DATABRICKS_SERVING_ENDPOINT; + + const plugin = new ServingPlugin({}); + const req = createMockRequest({ + params: { alias: "default" }, + body: { messages: [] }, + }); + const res = createMockResponse(); + + await plugin._handleInvoke(req as any, res as any); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.json).toHaveBeenCalledWith({ + error: + "Endpoint 'default' is not configured: env var 'DATABRICKS_SERVING_ENDPOINT' is not set", + }); + }); + + test("_handleInvoke does not throw when connector fails", async () => { + mockInvoke.mockRejectedValue(new Error("Connection refused")); + + const plugin = new ServingPlugin({}); + const req = createMockRequest({ + params: { alias: "default" }, + body: { messages: [] }, + }); + const res = createMockResponse(); + + // Should not throw — execute() handles the error internally + await expect( + plugin._handleInvoke(req as any, res as any), + ).resolves.not.toThrow(); + }); + + test("_handleStream returns 404 for unknown alias", async () => { + const plugin = new ServingPlugin({ + endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT" } }, + }); + + const req = createMockRequest({ + params: { alias: "unknown" }, + body: { messages: [] }, + query: {}, + }); + const res = createMockResponse(); + + await plugin._handleStream(req as any, res as any); + + expect(res.status).toHaveBeenCalledWith(404); + expect(res.json).toHaveBeenCalledWith({ + error: "Unknown endpoint alias: unknown", + }); + }); + + test("_handleStream returns 400 when env var is not set", async () => { + delete process.env.DATABRICKS_SERVING_ENDPOINT; + + const plugin = new ServingPlugin({}); + const req = createMockRequest({ + params: { alias: "default" }, + body: { messages: [] }, + query: {}, + }); + const res = createMockResponse(); + + await plugin._handleStream(req as any, res as any); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.json).toHaveBeenCalledWith({ + error: + "Endpoint 'default' is not configured: env var 'DATABRICKS_SERVING_ENDPOINT' is not set", + }); + }); + }); + + describe("getResourceRequirements", () => { + test("generates requirements for default mode", () => { + const reqs = ServingPlugin.getResourceRequirements({}); + expect(reqs).toHaveLength(1); + expect(reqs[0]).toMatchObject({ + type: "serving_endpoint", + alias: "serving-default", + permission: "CAN_QUERY", + fields: { + name: { + env: "DATABRICKS_SERVING_ENDPOINT", + }, + }, + }); + }); + + test("generates requirements for named mode", () => { + const reqs = ServingPlugin.getResourceRequirements({ + endpoints: { + llm: { env: "LLM_ENDPOINT" }, + embedder: { env: "EMBED_ENDPOINT" }, + }, + }); + expect(reqs).toHaveLength(2); + expect(reqs[0].fields.name.env).toBe("LLM_ENDPOINT"); + expect(reqs[1].fields.name.env).toBe("EMBED_ENDPOINT"); + }); + }); + + describe("programmatic API", () => { + test("invoke calls connector correctly", async () => { + mockInvoke.mockResolvedValue({ + choices: [{ message: { content: "Hi" } }], + }); + + const plugin = new ServingPlugin({}); + const result = await plugin.invoke("default", { messages: [] }); + + expect(mockInvoke).toHaveBeenCalledWith( + expect.anything(), + "test-endpoint", + { messages: [] }, + { servedModel: undefined }, + ); + expect(result).toEqual({ choices: [{ message: { content: "Hi" } }] }); + }); + + test("invoke throws for unknown alias", async () => { + const plugin = new ServingPlugin({ + endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT" } }, + }); + + await expect(plugin.invoke("unknown", { messages: [] })).rejects.toThrow( + "Unknown endpoint alias: unknown", + ); + }); + + test("stream yields chunks from connector", async () => { + const chunks = [ + { choices: [{ delta: { content: "Hello" } }] }, + { choices: [{ delta: { content: " world" } }] }, + ]; + + mockStream.mockImplementation(async function* () { + for (const chunk of chunks) { + yield chunk; + } + }); + + const plugin = new ServingPlugin({}); + const results: unknown[] = []; + for await (const chunk of plugin.stream("default", { messages: [] })) { + results.push(chunk); + } + + expect(results).toEqual(chunks); + }); + }); + + describe("shutdown", () => { + test("calls streamManager.abortAll", async () => { + const plugin = new ServingPlugin({}); + // Accessing the protected streamManager through the plugin + const abortSpy = vi.spyOn((plugin as any).streamManager, "abortAll"); + + await plugin.shutdown(); + + expect(abortSpy).toHaveBeenCalled(); + }); + }); +}); diff --git a/packages/appkit/src/plugins/serving/types.ts b/packages/appkit/src/plugins/serving/types.ts new file mode 100644 index 00000000..9a2dd230 --- /dev/null +++ b/packages/appkit/src/plugins/serving/types.ts @@ -0,0 +1,67 @@ +import type { BasePluginConfig } from "shared"; + +export interface EndpointConfig { + /** Environment variable holding the endpoint name. */ + env: string; + /** Target a specific served model (bypasses traffic routing). */ + servedModel?: string; +} + +export interface IServingConfig extends BasePluginConfig { + /** Map of alias → endpoint config. Defaults to { default: { env: "DATABRICKS_SERVING_ENDPOINT" } } if omitted. */ + endpoints?: Record; + /** Request timeout in ms. Default: 120000 (2 min) */ + timeout?: number; + /** How to handle unknown request parameters. 'strip' silently removes them (default). 'reject' returns 400. */ + filterMode?: "strip" | "reject"; +} + +/** + * Registry interface for serving endpoint type generation. + * Empty by default — augmented by the Vite type generator's `.d.ts` output via module augmentation. + * When populated, provides autocomplete for alias names and typed request/response/chunk per endpoint. + */ +// biome-ignore lint/suspicious/noEmptyInterface: intentionally empty — populated via module augmentation +export interface ServingEndpointRegistry {} + +/** Shape of a single registry entry. */ +export interface ServingEndpointEntry { + request: Record; + response: unknown; + chunk: unknown; +} + +/** Typed invoke/stream methods for a serving endpoint. */ +export interface ServingEndpointMethods< + TRequest extends Record = Record, + TResponse = unknown, + TChunk = unknown, +> { + invoke: (body: TRequest) => Promise; + stream: (body: TRequest) => AsyncGenerator; +} + +/** + * Factory function returned by `AppKit.serving`. + * + * This is a conditional type that adapts based on whether `ServingEndpointRegistry` + * has been populated via module augmentation (generated by `appKitServingTypesPlugin()`): + * + * - **Registry empty (default):** `(alias?: string) => ServingEndpointMethods` — + * accepts any alias string with untyped request/response/chunk. + * - **Registry populated:** `(alias: K) => ServingEndpointMethods<...>` — + * restricts `alias` to known endpoint keys and infers typed request/response/chunk + * from the registry entry. + * + * Run `appKitServingTypesPlugin()` in your Vite config to generate the registry + * augmentation and enable full type safety. + */ +export type ServingFactory = keyof ServingEndpointRegistry extends never + ? (alias?: string) => ServingEndpointMethods + : ( + alias: K, + ) => ServingEndpointMethods< + ServingEndpointRegistry[K]["request"], + ServingEndpointRegistry[K]["response"], + ServingEndpointRegistry[K]["chunk"] + >; diff --git a/packages/appkit/src/stream/stream-manager.ts b/packages/appkit/src/stream/stream-manager.ts index 41764772..8b511fac 100644 --- a/packages/appkit/src/stream/stream-manager.ts +++ b/packages/appkit/src/stream/stream-manager.ts @@ -374,6 +374,14 @@ export class StreamManager { if (error.name === "AbortError") { return SSEErrorCode.STREAM_ABORTED; } + + // Detect upstream API errors (e.g., from Databricks SDK ApiError) + if ( + "statusCode" in error && + typeof (error as any).statusCode === "number" + ) { + return SSEErrorCode.UPSTREAM_ERROR; + } } return SSEErrorCode.INTERNAL_ERROR; diff --git a/packages/appkit/src/stream/types.ts b/packages/appkit/src/stream/types.ts index 0fd862ba..3841bfd1 100644 --- a/packages/appkit/src/stream/types.ts +++ b/packages/appkit/src/stream/types.ts @@ -16,6 +16,7 @@ export const SSEErrorCode = { INVALID_REQUEST: "INVALID_REQUEST", STREAM_ABORTED: "STREAM_ABORTED", STREAM_EVICTED: "STREAM_EVICTED", + UPSTREAM_ERROR: "UPSTREAM_ERROR", } as const satisfies Record; export type SSEErrorCode = (typeof SSEErrorCode)[keyof typeof SSEErrorCode]; From 3df1517898fb060165b8d5db1400bec6aa49a0d9 Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Fri, 3 Apr 2026 14:20:34 +0200 Subject: [PATCH 2/5] fix: pass abort signal to serving connector in stream handler The serving plugin was not forwarding the abort signal to the serving connector, unlike the genie plugin. Without the signal, the connector's fetch request cannot be cancelled and the abort-check loop never triggers. Signed-off-by: Pawel Kosiec --- packages/appkit/src/plugins/serving/serving.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/appkit/src/plugins/serving/serving.ts b/packages/appkit/src/plugins/serving/serving.ts index e868cc02..e3547bcf 100644 --- a/packages/appkit/src/plugins/serving/serving.ts +++ b/packages/appkit/src/plugins/serving/serving.ts @@ -241,9 +241,10 @@ export class ServingPlugin extends Plugin { await this.executeStream( res, - () => + (signal) => servingConnector.stream(workspaceClient, endpoint.name, filteredBody, { servedModel: endpoint.servedModel, + signal, }), streamSettings, ); From 5b1ca1e55125790f10f1fa6203d0488ee85bcdbb Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Thu, 9 Apr 2026 16:19:28 +0200 Subject: [PATCH 3/5] fix: address PR review feedback for serving connector and plugin - Use SDK servingEndpoints.query() for invoke instead of raw fetch - Use SDK apiClient.request({ raw: true }) for streaming SSE - Fix exports() to support asUser via files plugin pattern - Rename DATABRICKS_SERVING_ENDPOINT to DATABRICKS_SERVING_ENDPOINT_NAME - Throw error on SSE buffer overflow instead of silent discard - Add OBO rationale comment in injectRoutes - Add SSE spec comments for empty line handling - Add ServingEndpointHandle type with asUser support Signed-off-by: Pawel Kosiec --- .../appkit/src/connectors/serving/client.ts | 140 +++------ .../connectors/serving/tests/client.test.ts | 285 +++++++----------- .../appkit/src/connectors/serving/types.ts | 3 +- .../appkit/src/plugins/serving/manifest.json | 2 +- .../appkit/src/plugins/serving/serving.ts | 54 ++-- .../src/plugins/serving/tests/serving.test.ts | 30 +- packages/appkit/src/plugins/serving/types.ts | 21 +- template/appkit.plugins.json | 24 ++ 8 files changed, 227 insertions(+), 332 deletions(-) diff --git a/packages/appkit/src/connectors/serving/client.ts b/packages/appkit/src/connectors/serving/client.ts index 6254426d..5773db17 100644 --- a/packages/appkit/src/connectors/serving/client.ts +++ b/packages/appkit/src/connectors/serving/client.ts @@ -1,35 +1,15 @@ +import type { serving } from "@databricks/sdk-experimental"; import { ApiError, type WorkspaceClient } from "@databricks/sdk-experimental"; import { createLogger } from "../../logging/logger"; -import type { ServingInvokeOptions } from "./types"; +import type { ServingStreamOptions } from "./types"; const logger = createLogger("connectors:serving"); -/** - * Builds the invocation URL for a serving endpoint. - * Uses `/served-models/{model}/invocations` when servedModel is specified, - * otherwise `/serving-endpoints/{name}/invocations`. - */ -function buildInvocationUrl( - host: string, - endpointName: string, - servedModel?: string, -): string { - const base = host.startsWith("http") ? host : `https://${host}`; - const encodedName = encodeURIComponent(endpointName); - const path = servedModel - ? `/serving-endpoints/${encodedName}/served-models/${encodeURIComponent(servedModel)}/invocations` - : `/serving-endpoints/${encodedName}/invocations`; - return new URL(path, base).toString(); -} - /** * Maps upstream Databricks error status codes to appropriate proxy responses. + * Used for raw API responses where the SDK doesn't handle errors automatically. */ -function mapUpstreamError( - status: number, - body: string, - headers: Headers, -): ApiError { +function mapUpstreamError(status: number, body: string): ApiError { const safeMessage = body.length > 500 ? `${body.slice(0, 500)}...` : body; let parsed: { message?: string; error?: string } = {}; @@ -49,13 +29,8 @@ function mapUpstreamError( return new ApiError(message, "AUTH_FAILURE", status, undefined, []); case status === 404: return new ApiError(message, "NOT_FOUND", 404, undefined, []); - case status === 429: { - const retryAfter = headers.get("retry-after"); - const retryMessage = retryAfter - ? `${message} (retry-after: ${retryAfter})` - : message; - return new ApiError(retryMessage, "RATE_LIMITED", 429, undefined, []); - } + case status === 429: + return new ApiError(message, "RATE_LIMITED", 429, undefined, []); case status === 503: return new ApiError( "Endpoint loading, retry shortly", @@ -72,97 +47,60 @@ function mapUpstreamError( } /** - * Invokes a serving endpoint and returns the parsed JSON response. + * Invokes a serving endpoint using the SDK's high-level query API. + * Returns a typed QueryEndpointResponse. */ export async function invoke( client: WorkspaceClient, endpointName: string, body: Record, - options?: ServingInvokeOptions, -): Promise { - const host = client.config.host; - if (!host) { - throw new Error( - "Databricks host is not configured. Set DATABRICKS_HOST or configure client.config.host.", - ); - } - - const url = buildInvocationUrl(host, endpointName, options?.servedModel); - - // Always strip `stream` from the body — the connector controls this +): Promise { + // Strip `stream` from the body — the connector controls this const { stream: _stream, ...cleanBody } = body; - const headers = new Headers({ - "Content-Type": "application/json", - Accept: "application/json", - }); - await client.config.authenticate(headers); - - logger.debug("Invoking endpoint %s at %s", endpointName, url); + logger.debug("Invoking endpoint %s", endpointName); - const res = await fetch(url, { - method: "POST", - headers, - body: JSON.stringify(cleanBody), - signal: options?.signal, - }); - - if (!res.ok) { - const text = await res.text(); - throw mapUpstreamError(res.status, text, res.headers); - } - - return res.json(); + return client.servingEndpoints.query({ + name: endpointName, + ...cleanBody, + } as serving.QueryEndpointInput); } /** * Invokes a serving endpoint with streaming enabled. - * Yields parsed JSON chunks from the NDJSON SSE response. + * Yields parsed JSON chunks from the SSE response. + * + * Uses the SDK's low-level `apiClient.request({ raw: true })` because + * the high-level `servingEndpoints.query()` returns `Promise` + * and does not support SSE streaming. */ export async function* stream( client: WorkspaceClient, endpointName: string, body: Record, - options?: ServingInvokeOptions, + options?: ServingStreamOptions, ): AsyncGenerator { - const host = client.config.host; - if (!host) { - throw new Error( - "Databricks host is not configured. Set DATABRICKS_HOST or configure client.config.host.", - ); - } - - const url = buildInvocationUrl(host, endpointName, options?.servedModel); - // Strip any user-provided `stream` and inject `stream: true` const { stream: _stream, ...cleanBody } = body; - const streamBody = { ...cleanBody, stream: true }; - - const headers = new Headers({ - "Content-Type": "application/json", - Accept: "text/event-stream", - }); - await client.config.authenticate(headers); - logger.debug("Streaming from endpoint %s at %s", endpointName, url); + logger.debug("Streaming from endpoint %s", endpointName); - const res = await fetch(url, { + const response = (await client.apiClient.request({ + path: `/serving-endpoints/${encodeURIComponent(endpointName)}/invocations`, method: "POST", - headers, - body: JSON.stringify(streamBody), - signal: options?.signal, - }); - - if (!res.ok) { - const text = await res.text(); - throw mapUpstreamError(res.status, text, res.headers); - } - - if (!res.body) { + headers: new Headers({ + "Content-Type": "application/json", + Accept: "text/event-stream", + }), + payload: { ...cleanBody, stream: true }, + raw: true, + })) as { contents: ReadableStream }; + + if (!response.contents) { throw new Error("Response body is null — streaming not supported"); } - const reader = res.body.getReader(); + const reader = response.contents.getReader(); const decoder = new TextDecoder(); let buffer = ""; const MAX_BUFFER_SIZE = 1024 * 1024; // 1 MB @@ -177,11 +115,9 @@ export async function* stream( buffer += decoder.decode(value, { stream: true }); if (buffer.length > MAX_BUFFER_SIZE) { - logger.warn( - "Stream buffer exceeded %d bytes, discarding incomplete data", - MAX_BUFFER_SIZE, + throw new Error( + `Stream buffer exceeded ${MAX_BUFFER_SIZE} bytes — possible non-SSE response`, ); - buffer = ""; } // Process complete lines from the buffer @@ -191,7 +127,9 @@ export async function* stream( for (const line of lines) { const trimmed = line.trim(); - if (!trimmed || trimmed.startsWith(":")) continue; // skip empty lines and SSE comments + // Per SSE spec: empty lines are event delimiters, + // lines starting with ":" are comments (often used as heartbeats). + if (!trimmed || trimmed.startsWith(":")) continue; if (trimmed === "data: [DONE]") return; if (trimmed.startsWith("data: ")) { diff --git a/packages/appkit/src/connectors/serving/tests/client.test.ts b/packages/appkit/src/connectors/serving/tests/client.test.ts index 6af859ae..b87ba427 100644 --- a/packages/appkit/src/connectors/serving/tests/client.test.ts +++ b/packages/appkit/src/connectors/serving/tests/client.test.ts @@ -1,201 +1,90 @@ -import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { afterEach, describe, expect, test, vi } from "vitest"; import { invoke, stream } from "../client"; -const mockAuthenticate = vi.fn(); - function createMockClient(host = "https://test.databricks.com") { return { - config: { - host, - authenticate: mockAuthenticate, + config: { host }, + servingEndpoints: { + query: vi.fn(), + }, + apiClient: { + request: vi.fn(), }, } as any; } describe("Serving Connector", () => { - beforeEach(() => { - mockAuthenticate.mockResolvedValue(undefined); - }); - afterEach(() => { vi.restoreAllMocks(); }); describe("invoke", () => { - test("constructs correct URL for endpoint invocation", async () => { - const fetchSpy = vi - .spyOn(globalThis, "fetch") - .mockResolvedValue( - new Response(JSON.stringify({ result: "ok" }), { status: 200 }), - ); - - const client = createMockClient(); - await invoke(client, "my-endpoint", { messages: [] }); - - expect(fetchSpy).toHaveBeenCalledWith( - "https://test.databricks.com/serving-endpoints/my-endpoint/invocations", - expect.objectContaining({ method: "POST" }), - ); - }); - - test("constructs correct URL with servedModel override", async () => { - const fetchSpy = vi - .spyOn(globalThis, "fetch") - .mockResolvedValue( - new Response(JSON.stringify({ result: "ok" }), { status: 200 }), - ); - + test("calls servingEndpoints.query with endpoint name and body", async () => { const client = createMockClient(); - await invoke( - client, - "my-endpoint", - { messages: [] }, - { servedModel: "llama-v2" }, - ); + const mockResponse = { choices: [{ message: { content: "Hello" } }] }; + client.servingEndpoints.query.mockResolvedValue(mockResponse); - expect(fetchSpy).toHaveBeenCalledWith( - "https://test.databricks.com/serving-endpoints/my-endpoint/served-models/llama-v2/invocations", - expect.objectContaining({ method: "POST" }), - ); - }); - - test("authenticates request headers", async () => { - vi.spyOn(globalThis, "fetch").mockResolvedValue( - new Response(JSON.stringify({ result: "ok" }), { status: 200 }), - ); - - const client = createMockClient(); - await invoke(client, "my-endpoint", { messages: [] }); + const result = await invoke(client, "my-endpoint", { + messages: [{ role: "user", content: "Hi" }], + temperature: 0.7, + }); - expect(mockAuthenticate).toHaveBeenCalledWith(expect.any(Headers)); + expect(client.servingEndpoints.query).toHaveBeenCalledWith({ + name: "my-endpoint", + messages: [{ role: "user", content: "Hi" }], + temperature: 0.7, + }); + expect(result).toEqual(mockResponse); }); test("strips stream property from body", async () => { - const fetchSpy = vi - .spyOn(globalThis, "fetch") - .mockResolvedValue( - new Response(JSON.stringify({ result: "ok" }), { status: 200 }), - ); - const client = createMockClient(); + client.servingEndpoints.query.mockResolvedValue({}); + await invoke(client, "my-endpoint", { messages: [], stream: true, temperature: 0.7, }); - const body = JSON.parse(fetchSpy.mock.calls[0][1]?.body as string); - expect(body).toEqual({ messages: [], temperature: 0.7 }); - expect(body.stream).toBeUndefined(); + const queryArg = client.servingEndpoints.query.mock.calls[0][0]; + expect(queryArg.stream).toBeUndefined(); + expect(queryArg.temperature).toBe(0.7); }); - test("returns parsed JSON response", async () => { - const responseData = { choices: [{ message: { content: "Hello" } }] }; - vi.spyOn(globalThis, "fetch").mockResolvedValue( - new Response(JSON.stringify(responseData), { status: 200 }), - ); - + test("returns typed QueryEndpointResponse", async () => { const client = createMockClient(); - const result = await invoke(client, "my-endpoint", { messages: [] }); + const responseData = { + choices: [{ message: { content: "Hello" } }], + model: "test-model", + }; + client.servingEndpoints.query.mockResolvedValue(responseData); + const result = await invoke(client, "my-endpoint", { messages: [] }); expect(result).toEqual(responseData); }); - test("throws ApiError on 400 response", async () => { - vi.spyOn(globalThis, "fetch").mockResolvedValue( - new Response(JSON.stringify({ message: "Invalid params" }), { - status: 400, - }), - ); - + test("propagates SDK errors", async () => { const client = createMockClient(); - await expect( - invoke(client, "my-endpoint", { messages: [] }), - ).rejects.toThrow("Invalid params"); - }); - - test("throws ApiError on 404 response", async () => { - vi.spyOn(globalThis, "fetch").mockResolvedValue( - new Response(JSON.stringify({ message: "Endpoint not found" }), { - status: 404, - }), + client.servingEndpoints.query.mockRejectedValue( + new Error("Endpoint not found"), ); - const client = createMockClient(); await expect( invoke(client, "my-endpoint", { messages: [] }), ).rejects.toThrow("Endpoint not found"); }); - - test("maps 5xx to 502 bad gateway", async () => { - vi.spyOn(globalThis, "fetch").mockResolvedValue( - new Response(JSON.stringify({ message: "Internal error" }), { - status: 500, - }), - ); - - const client = createMockClient(); - try { - await invoke(client, "my-endpoint", { messages: [] }); - expect.unreachable("Should have thrown"); - } catch (err: any) { - expect(err.statusCode).toBe(502); - } - }); - - test("forwards AbortSignal", async () => { - const controller = new AbortController(); - const fetchSpy = vi - .spyOn(globalThis, "fetch") - .mockResolvedValue( - new Response(JSON.stringify({ result: "ok" }), { status: 200 }), - ); - - const client = createMockClient(); - await invoke( - client, - "my-endpoint", - { messages: [] }, - { signal: controller.signal }, - ); - - expect(fetchSpy.mock.calls[0][1]?.signal).toBe(controller.signal); - }); - - test("throws when host is not configured", async () => { - const client = { - config: { - host: "", - authenticate: mockAuthenticate, - }, - } as any; - await expect( - invoke(client, "my-endpoint", { messages: [] }), - ).rejects.toThrow("Databricks host is not configured"); - }); - - test("prepends https:// to host without protocol", async () => { - const fetchSpy = vi - .spyOn(globalThis, "fetch") - .mockResolvedValue( - new Response(JSON.stringify({ result: "ok" }), { status: 200 }), - ); - - const client = createMockClient("test.databricks.com"); - await invoke(client, "my-endpoint", { messages: [] }); - - expect(fetchSpy.mock.calls[0][0]).toContain( - "https://test.databricks.com", - ); - }); }); describe("stream", () => { - function createSSEResponse(chunks: string[]) { + function createSSEStream(chunks: string[]) { const body = `${chunks.join("\n")}\n`; - return new Response(body, { - status: 200, - headers: { "Content-Type": "text/event-stream" }, + const encoder = new TextEncoder(); + return new ReadableStream({ + start(controller) { + controller.enqueue(encoder.encode(body)); + controller.close(); + }, }); } @@ -206,11 +95,11 @@ describe("Serving Connector", () => { "data: [DONE]", ]; - vi.spyOn(globalThis, "fetch").mockResolvedValue( - createSSEResponse(chunks), - ); - const client = createMockClient(); + client.apiClient.request.mockResolvedValue({ + contents: createSSEStream(chunks), + }); + const results: unknown[] = []; for await (const chunk of stream(client, "my-endpoint", { messages: [], @@ -224,27 +113,32 @@ describe("Serving Connector", () => { ]); }); - test("injects stream: true into body", async () => { - const fetchSpy = vi - .spyOn(globalThis, "fetch") - .mockResolvedValue(createSSEResponse(["data: [DONE]"])); - + test("sends stream: true in payload via apiClient.request", async () => { const client = createMockClient(); - // Consume the generator + client.apiClient.request.mockResolvedValue({ + contents: createSSEStream(["data: [DONE]"]), + }); + for await (const _ of stream(client, "my-endpoint", { messages: [] })) { // noop } - const body = JSON.parse(fetchSpy.mock.calls[0][1]?.body as string); - expect(body.stream).toBe(true); + expect(client.apiClient.request).toHaveBeenCalledWith( + expect.objectContaining({ + path: "/serving-endpoints/my-endpoint/invocations", + method: "POST", + raw: true, + payload: expect.objectContaining({ stream: true }), + }), + ); }); test("strips user-provided stream and re-injects", async () => { - const fetchSpy = vi - .spyOn(globalThis, "fetch") - .mockResolvedValue(createSSEResponse(["data: [DONE]"])); - const client = createMockClient(); + client.apiClient.request.mockResolvedValue({ + contents: createSSEStream(["data: [DONE]"]), + }); + for await (const _ of stream(client, "my-endpoint", { messages: [], stream: false, @@ -252,8 +146,8 @@ describe("Serving Connector", () => { // noop } - const body = JSON.parse(fetchSpy.mock.calls[0][1]?.body as string); - expect(body.stream).toBe(true); + const payload = client.apiClient.request.mock.calls[0][0].payload; + expect(payload.stream).toBe(true); }); test("skips SSE comments and empty lines", async () => { @@ -265,11 +159,11 @@ describe("Serving Connector", () => { "data: [DONE]", ]; - vi.spyOn(globalThis, "fetch").mockResolvedValue( - createSSEResponse(chunks), - ); - const client = createMockClient(); + client.apiClient.request.mockResolvedValue({ + contents: createSSEStream(chunks), + }); + const results: unknown[] = []; for await (const chunk of stream(client, "my-endpoint", { messages: [], @@ -281,22 +175,45 @@ describe("Serving Connector", () => { expect(results[0]).toEqual({ choices: [{ delta: { content: "Hi" } }] }); }); - test("throws on non-OK response", async () => { - vi.spyOn(globalThis, "fetch").mockResolvedValue( - new Response(JSON.stringify({ message: "Rate limited" }), { - status: 429, - headers: { "Retry-After": "5" }, - }), - ); + test("throws when response has no contents", async () => { + const client = createMockClient(); + client.apiClient.request.mockResolvedValue({ contents: null }); + try { + for await (const _ of stream(client, "my-endpoint", { + messages: [], + })) { + // noop + } + expect.unreachable("Should have thrown"); + } catch (err: any) { + expect(err.message).toContain("streaming not supported"); + } + }); + + test("throws when buffer exceeds max size", async () => { const client = createMockClient(); + const largeData = "x".repeat(1024 * 1024 + 1); + const encoder = new TextEncoder(); + const largeStream = new ReadableStream({ + start(controller) { + controller.enqueue(encoder.encode(largeData)); + controller.close(); + }, + }); + client.apiClient.request.mockResolvedValue({ + contents: largeStream, + }); + try { - for await (const _ of stream(client, "my-endpoint", { messages: [] })) { + for await (const _ of stream(client, "my-endpoint", { + messages: [], + })) { // noop } expect.unreachable("Should have thrown"); } catch (err: any) { - expect(err.statusCode).toBe(429); + expect(err.message).toContain("Stream buffer exceeded"); } }); }); diff --git a/packages/appkit/src/connectors/serving/types.ts b/packages/appkit/src/connectors/serving/types.ts index 6dd1acba..8c6c7f74 100644 --- a/packages/appkit/src/connectors/serving/types.ts +++ b/packages/appkit/src/connectors/serving/types.ts @@ -1,4 +1,3 @@ -export interface ServingInvokeOptions { - servedModel?: string; +export interface ServingStreamOptions { signal?: AbortSignal; } diff --git a/packages/appkit/src/plugins/serving/manifest.json b/packages/appkit/src/plugins/serving/manifest.json index 9ac0845f..7fcacd37 100644 --- a/packages/appkit/src/plugins/serving/manifest.json +++ b/packages/appkit/src/plugins/serving/manifest.json @@ -13,7 +13,7 @@ "permission": "CAN_QUERY", "fields": { "name": { - "env": "DATABRICKS_SERVING_ENDPOINT", + "env": "DATABRICKS_SERVING_ENDPOINT_NAME", "description": "Serving endpoint name" } } diff --git a/packages/appkit/src/plugins/serving/serving.ts b/packages/appkit/src/plugins/serving/serving.ts index e3547bcf..981a0da5 100644 --- a/packages/appkit/src/plugins/serving/serving.ts +++ b/packages/appkit/src/plugins/serving/serving.ts @@ -11,7 +11,12 @@ import { ResourceType } from "../../registry"; import { servingInvokeDefaults, servingStreamDefaults } from "./defaults"; import manifest from "./manifest.json"; import { filterRequestBody, loadEndpointSchemas } from "./schema-filter"; -import type { EndpointConfig, IServingConfig, ServingFactory } from "./types"; +import type { + EndpointConfig, + IServingConfig, + ServingEndpointMethods, + ServingFactory, +} from "./types"; const logger = createLogger("serving"); @@ -31,7 +36,6 @@ class EndpointNotConfiguredError extends Error { interface ResolvedEndpoint { name: string; - servedModel?: string; } export class ServingPlugin extends Plugin { @@ -54,7 +58,7 @@ export class ServingPlugin extends Plugin { this.isNamedMode = true; } else { this.endpoints = { - default: { env: "DATABRICKS_SERVING_ENDPOINT" }, + default: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" }, }; this.isNamedMode = false; } @@ -81,7 +85,7 @@ export class ServingPlugin extends Plugin { config: IServingConfig, ): ResourceRequirement[] { const endpoints = config.endpoints ?? { - default: { env: "DATABRICKS_SERVING_ENDPOINT" }, + default: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" }, }; return Object.entries(endpoints).map(([alias, endpointConfig]) => ({ @@ -114,10 +118,7 @@ export class ServingPlugin extends Plugin { throw new EndpointNotConfiguredError(alias, config.env); } - const endpoint: ResolvedEndpoint = { - name, - servedModel: config.servedModel, - }; + const endpoint: ResolvedEndpoint = { name }; const filteredBody = filterRequestBody( body, this.schemaAllowlists, @@ -127,6 +128,8 @@ export class ServingPlugin extends Plugin { return { endpoint, filteredBody }; } + // All serving routes use OBO (On-Behalf-Of) by default, consistent with the + // Genie and Files plugins. This ensures per-user CAN_QUERY permissions are enforced. injectRoutes(router: IAppRouter) { if (this.isNamedMode) { this.route(router, { @@ -234,16 +237,11 @@ export class ServingPlugin extends Plugin { }; const workspaceClient = getWorkspaceClient(); - if (!workspaceClient.config.host) { - res.status(500).json({ error: "Databricks host not configured" }); - return; - } await this.executeStream( res, (signal) => servingConnector.stream(workspaceClient, endpoint.name, filteredBody, { - servedModel: endpoint.servedModel, signal, }), streamSettings, @@ -257,9 +255,7 @@ export class ServingPlugin extends Plugin { return this.execute( () => - servingConnector.invoke(workspaceClient, endpoint.name, filteredBody, { - servedModel: endpoint.servedModel, - }), + servingConnector.invoke(workspaceClient, endpoint.name, filteredBody), { default: { ...servingInvokeDefaults, @@ -280,7 +276,6 @@ export class ServingPlugin extends Plugin { workspaceClient, endpoint.name, filteredBody, - { servedModel: endpoint.servedModel }, ); } @@ -288,13 +283,26 @@ export class ServingPlugin extends Plugin { this.streamManager.abortAll(); } + protected createEndpointAPI(alias: string): ServingEndpointMethods { + return { + invoke: (body: Record) => this.invoke(alias, body), + stream: (body: Record) => this.stream(alias, body), + }; + } + exports(): ServingFactory { - return ((alias?: string) => ({ - invoke: (body: Record) => - this.invoke(alias ?? "default", body), - stream: (body: Record) => - this.stream(alias ?? "default", body), - })) as ServingFactory; + const resolveEndpoint = (alias?: string) => { + const resolved = alias ?? "default"; + const spApi = this.createEndpointAPI(resolved); + return { + ...spApi, + asUser: (req: express.Request) => { + const userPlugin = this.asUser(req) as ServingPlugin; + return userPlugin.createEndpointAPI(resolved); + }, + }; + }; + return resolveEndpoint as ServingFactory; } } diff --git a/packages/appkit/src/plugins/serving/tests/serving.test.ts b/packages/appkit/src/plugins/serving/tests/serving.test.ts index 1a953b77..bb7f89ae 100644 --- a/packages/appkit/src/plugins/serving/tests/serving.test.ts +++ b/packages/appkit/src/plugins/serving/tests/serving.test.ts @@ -48,7 +48,7 @@ describe("Serving Plugin", () => { beforeEach(async () => { setupDatabricksEnv(); - process.env.DATABRICKS_SERVING_ENDPOINT = "test-endpoint"; + process.env.DATABRICKS_SERVING_ENDPOINT_NAME = "test-endpoint"; ServiceContext.reset(); serviceContextMock = await mockServiceContext(); @@ -56,7 +56,7 @@ describe("Serving Plugin", () => { afterEach(() => { serviceContextMock?.restore(); - delete process.env.DATABRICKS_SERVING_ENDPOINT; + delete process.env.DATABRICKS_SERVING_ENDPOINT_NAME; vi.restoreAllMocks(); }); @@ -67,13 +67,13 @@ describe("Serving Plugin", () => { test("serving factory with config should have correct name", () => { const pluginData = serving({ - endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT" } }, + endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" } }, }); expect(pluginData.name).toBe("serving"); }); describe("default mode", () => { - test("reads DATABRICKS_SERVING_ENDPOINT", () => { + test("reads DATABRICKS_SERVING_ENDPOINT_NAME", () => { const plugin = new ServingPlugin({}); const api = (plugin.exports() as any)(); expect(api.invoke).toBeDefined(); @@ -103,8 +103,8 @@ describe("Serving Plugin", () => { describe("named mode", () => { const namedConfig: IServingConfig = { endpoints: { - llm: { env: "DATABRICKS_SERVING_ENDPOINT" }, - embedder: { env: "DATABRICKS_SERVING_ENDPOINT_EMBEDDING" }, + llm: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" }, + embedder: { env: "DATABRICKS_SERVING_ENDPOINT_NAME_EMBEDDING" }, }, }; @@ -132,7 +132,7 @@ describe("Serving Plugin", () => { describe("route handlers", () => { test("_handleInvoke returns 404 for unknown alias", async () => { const plugin = new ServingPlugin({ - endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT" } }, + endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" } }, }); const req = createMockRequest({ @@ -165,13 +165,12 @@ describe("Serving Plugin", () => { expect.anything(), "test-endpoint", { messages: [{ role: "user", content: "Hello" }] }, - { servedModel: undefined }, ); expect(res.json).toHaveBeenCalledWith({ choices: [] }); }); test("_handleInvoke returns 400 with descriptive message when env var is not set", async () => { - delete process.env.DATABRICKS_SERVING_ENDPOINT; + delete process.env.DATABRICKS_SERVING_ENDPOINT_NAME; const plugin = new ServingPlugin({}); const req = createMockRequest({ @@ -185,7 +184,7 @@ describe("Serving Plugin", () => { expect(res.status).toHaveBeenCalledWith(400); expect(res.json).toHaveBeenCalledWith({ error: - "Endpoint 'default' is not configured: env var 'DATABRICKS_SERVING_ENDPOINT' is not set", + "Endpoint 'default' is not configured: env var 'DATABRICKS_SERVING_ENDPOINT_NAME' is not set", }); }); @@ -207,7 +206,7 @@ describe("Serving Plugin", () => { test("_handleStream returns 404 for unknown alias", async () => { const plugin = new ServingPlugin({ - endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT" } }, + endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" } }, }); const req = createMockRequest({ @@ -226,7 +225,7 @@ describe("Serving Plugin", () => { }); test("_handleStream returns 400 when env var is not set", async () => { - delete process.env.DATABRICKS_SERVING_ENDPOINT; + delete process.env.DATABRICKS_SERVING_ENDPOINT_NAME; const plugin = new ServingPlugin({}); const req = createMockRequest({ @@ -241,7 +240,7 @@ describe("Serving Plugin", () => { expect(res.status).toHaveBeenCalledWith(400); expect(res.json).toHaveBeenCalledWith({ error: - "Endpoint 'default' is not configured: env var 'DATABRICKS_SERVING_ENDPOINT' is not set", + "Endpoint 'default' is not configured: env var 'DATABRICKS_SERVING_ENDPOINT_NAME' is not set", }); }); }); @@ -256,7 +255,7 @@ describe("Serving Plugin", () => { permission: "CAN_QUERY", fields: { name: { - env: "DATABRICKS_SERVING_ENDPOINT", + env: "DATABRICKS_SERVING_ENDPOINT_NAME", }, }, }); @@ -288,14 +287,13 @@ describe("Serving Plugin", () => { expect.anything(), "test-endpoint", { messages: [] }, - { servedModel: undefined }, ); expect(result).toEqual({ choices: [{ message: { content: "Hi" } }] }); }); test("invoke throws for unknown alias", async () => { const plugin = new ServingPlugin({ - endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT" } }, + endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" } }, }); await expect(plugin.invoke("unknown", { messages: [] })).rejects.toThrow( diff --git a/packages/appkit/src/plugins/serving/types.ts b/packages/appkit/src/plugins/serving/types.ts index 9a2dd230..fd728819 100644 --- a/packages/appkit/src/plugins/serving/types.ts +++ b/packages/appkit/src/plugins/serving/types.ts @@ -8,7 +8,7 @@ export interface EndpointConfig { } export interface IServingConfig extends BasePluginConfig { - /** Map of alias → endpoint config. Defaults to { default: { env: "DATABRICKS_SERVING_ENDPOINT" } } if omitted. */ + /** Map of alias → endpoint config. Defaults to { default: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" } } if omitted. */ endpoints?: Record; /** Request timeout in ms. Default: 120000 (2 min) */ timeout?: number; @@ -41,15 +41,26 @@ export interface ServingEndpointMethods< stream: (body: TRequest) => AsyncGenerator; } +/** Endpoint handle with asUser support, returned by the exports factory. */ +export type ServingEndpointHandle< + TRequest extends Record = Record, + TResponse = unknown, + TChunk = unknown, +> = ServingEndpointMethods & { + asUser: ( + req: import("express").Request, + ) => ServingEndpointMethods; +}; + /** * Factory function returned by `AppKit.serving`. * * This is a conditional type that adapts based on whether `ServingEndpointRegistry` * has been populated via module augmentation (generated by `appKitServingTypesPlugin()`): * - * - **Registry empty (default):** `(alias?: string) => ServingEndpointMethods` — + * - **Registry empty (default):** `(alias?: string) => ServingEndpointHandle` — * accepts any alias string with untyped request/response/chunk. - * - **Registry populated:** `(alias: K) => ServingEndpointMethods<...>` — + * - **Registry populated:** `(alias: K) => ServingEndpointHandle<...>` — * restricts `alias` to known endpoint keys and infers typed request/response/chunk * from the registry entry. * @@ -57,10 +68,10 @@ export interface ServingEndpointMethods< * augmentation and enable full type safety. */ export type ServingFactory = keyof ServingEndpointRegistry extends never - ? (alias?: string) => ServingEndpointMethods + ? (alias?: string) => ServingEndpointHandle : ( alias: K, - ) => ServingEndpointMethods< + ) => ServingEndpointHandle< ServingEndpointRegistry[K]["request"], ServingEndpointRegistry[K]["response"], ServingEndpointRegistry[K]["chunk"] diff --git a/template/appkit.plugins.json b/template/appkit.plugins.json index cf60a8af..d1420d2e 100644 --- a/template/appkit.plugins.json +++ b/template/appkit.plugins.json @@ -149,6 +149,30 @@ "optional": [] }, "requiredByTemplate": true + }, + "serving": { + "name": "serving", + "displayName": "Model Serving Plugin", + "description": "Authenticated proxy to Databricks Model Serving endpoints", + "package": "@databricks/appkit", + "resources": { + "required": [ + { + "type": "serving_endpoint", + "alias": "Serving Endpoint", + "resourceKey": "serving-endpoint", + "description": "Model Serving endpoint for inference", + "permission": "CAN_QUERY", + "fields": { + "name": { + "env": "DATABRICKS_SERVING_ENDPOINT_NAME", + "description": "Serving endpoint name" + } + } + } + ], + "optional": [] + } } } } From 2ed0e11fe447fc5ce2878cca18c109b4c59e6a3a Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Thu, 9 Apr 2026 16:45:52 +0200 Subject: [PATCH 4/5] fix: only cancel stream reader on abort to prevent double-close The SDK's readableToWeb() closes the controller on "end" event. Calling reader.cancel() unconditionally in the finally block causes a "Controller is already closed" error on subsequent requests. Only cancel when the signal was actually aborted (early termination). Signed-off-by: Pawel Kosiec --- packages/appkit/src/connectors/serving/client.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/appkit/src/connectors/serving/client.ts b/packages/appkit/src/connectors/serving/client.ts index 5773db17..25947677 100644 --- a/packages/appkit/src/connectors/serving/client.ts +++ b/packages/appkit/src/connectors/serving/client.ts @@ -155,7 +155,9 @@ export async function* stream( } } } finally { - reader.cancel().catch(() => {}); + if (options?.signal?.aborted) { + reader.cancel().catch(() => {}); + } reader.releaseLock(); } } From 1996ad60a295199337cc0f12187a308162107105 Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Thu, 9 Apr 2026 19:39:53 +0200 Subject: [PATCH 5/5] docs: update generated ServingFactory type docs for ServingEndpointHandle Signed-off-by: Pawel Kosiec --- docs/docs/api/appkit/TypeAlias.ServingFactory.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/docs/api/appkit/TypeAlias.ServingFactory.md b/docs/docs/api/appkit/TypeAlias.ServingFactory.md index 9ccafef5..a7022895 100644 --- a/docs/docs/api/appkit/TypeAlias.ServingFactory.md +++ b/docs/docs/api/appkit/TypeAlias.ServingFactory.md @@ -1,7 +1,7 @@ # Type Alias: ServingFactory ```ts -type ServingFactory = keyof ServingEndpointRegistry extends never ? (alias?: string) => ServingEndpointMethods : (alias: K) => ServingEndpointMethods; +type ServingFactory = keyof ServingEndpointRegistry extends never ? (alias?: string) => ServingEndpointHandle : (alias: K) => ServingEndpointHandle; ``` Factory function returned by `AppKit.serving`. @@ -9,9 +9,9 @@ Factory function returned by `AppKit.serving`. This is a conditional type that adapts based on whether `ServingEndpointRegistry` has been populated via module augmentation (generated by `appKitServingTypesPlugin()`): -- **Registry empty (default):** `(alias?: string) => ServingEndpointMethods` — +- **Registry empty (default):** `(alias?: string) => ServingEndpointHandle` — accepts any alias string with untyped request/response/chunk. -- **Registry populated:** `(alias: K) => ServingEndpointMethods<...>` — +- **Registry populated:** `(alias: K) => ServingEndpointHandle<...>` — restricts `alias` to known endpoint keys and infers typed request/response/chunk from the registry entry.