diff --git a/.gitignore b/.gitignore index 3b6cc969..4c51d5b1 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,6 @@ coverage *.tsbuildinfo .turbo + +# AppKit type generator caches +.databricks diff --git a/apps/dev-playground/.env.dist b/apps/dev-playground/.env.dist index 23c3265a..e1d0b207 100644 --- a/apps/dev-playground/.env.dist +++ b/apps/dev-playground/.env.dist @@ -9,6 +9,7 @@ OTEL_SERVICE_NAME='dev-playground' DATABRICKS_VOLUME_PLAYGROUND= DATABRICKS_VOLUME_OTHER= DATABRICKS_GENIE_SPACE_ID= +DATABRICKS_SERVING_ENDPOINT_NAME= LAKEBASE_ENDPOINT='' # Run: databricks postgres list-endpoints projects/{project-id}/branches/{branch-id} — use the `name` field from the output PGHOST= PGUSER= diff --git a/apps/dev-playground/client/.gitignore b/apps/dev-playground/client/.gitignore index a547bf36..267b28f3 100644 --- a/apps/dev-playground/client/.gitignore +++ b/apps/dev-playground/client/.gitignore @@ -12,6 +12,9 @@ dist dist-ssr *.local +# Auto-generated types (endpoint-specific, varies per developer) +src/appKitServingTypes.d.ts + # Editor directories and files .vscode/* !.vscode/extensions.json diff --git a/apps/dev-playground/client/src/appKitServingTypes.d.ts b/apps/dev-playground/client/src/appKitServingTypes.d.ts new file mode 100644 index 00000000..28f610b4 --- /dev/null +++ b/apps/dev-playground/client/src/appKitServingTypes.d.ts @@ -0,0 +1,114 @@ +// Auto-generated by AppKit - DO NOT EDIT +// Generated from serving endpoint OpenAPI schemas +import "@databricks/appkit"; +import "@databricks/appkit-ui/react"; + +declare module "@databricks/appkit" { + interface ServingEndpointRegistry { + default: { + request: { + messages?: { + role?: "user" | "assistant"; + content?: string; + }[]; + /** @openapi integer, nullable */ + n?: number | null; + max_tokens?: number; + /** @openapi double, nullable */ + top_p?: number | null; + reasoning_effort?: "low" | "medium" | "high" | null; + /** @openapi double, nullable */ + temperature?: number | null; + stop?: string | string[] | null; + }; + response: { + model?: string; + choices?: { + index?: number; + message?: { + role?: "user" | "assistant"; + content?: string; + }; + finish_reason?: string; + }[]; + usage?: { + prompt_tokens?: number; + completion_tokens?: number; + total_tokens?: number; + } | null; + object?: string; + id?: string; + created?: number; + }; + chunk: { + model?: string; + choices?: { + index?: number; + delta?: { + role?: "user" | "assistant"; + content?: string; + }; + finish_reason?: string | null; + }[]; + object?: string; + id?: string; + created?: number; + }; + }; + } +} + +declare module "@databricks/appkit-ui/react" { + interface ServingEndpointRegistry { + default: { + request: { + messages?: { + role?: "user" | "assistant"; + content?: string; + }[]; + /** @openapi integer, nullable */ + n?: number | null; + max_tokens?: number; + /** @openapi double, nullable */ + top_p?: number | null; + reasoning_effort?: "low" | "medium" | "high" | null; + /** @openapi double, nullable */ + temperature?: number | null; + stop?: string | string[] | null; + }; + response: { + model?: string; + choices?: { + index?: number; + message?: { + role?: "user" | "assistant"; + content?: string; + }; + finish_reason?: string; + }[]; + usage?: { + prompt_tokens?: number; + completion_tokens?: number; + total_tokens?: number; + } | null; + object?: string; + id?: string; + created?: number; + }; + chunk: { + model?: string; + choices?: { + index?: number; + delta?: { + role?: "user" | "assistant"; + content?: string; + }; + finish_reason?: string | null; + }[]; + object?: string; + id?: string; + created?: number; + }; + }; + } +} diff --git a/apps/dev-playground/client/src/routeTree.gen.ts b/apps/dev-playground/client/src/routeTree.gen.ts index c4c38d14..99ac75fc 100644 --- a/apps/dev-playground/client/src/routeTree.gen.ts +++ b/apps/dev-playground/client/src/routeTree.gen.ts @@ -12,6 +12,7 @@ import { Route as rootRouteImport } from './routes/__root' import { Route as TypeSafetyRouteRouteImport } from './routes/type-safety.route' import { Route as TelemetryRouteRouteImport } from './routes/telemetry.route' import { Route as SqlHelpersRouteRouteImport } from './routes/sql-helpers.route' +import { Route as ServingRouteRouteImport } from './routes/serving.route' import { Route as ReconnectRouteRouteImport } from './routes/reconnect.route' import { Route as LakebaseRouteRouteImport } from './routes/lakebase.route' import { Route as GenieRouteRouteImport } from './routes/genie.route' @@ -37,6 +38,11 @@ const SqlHelpersRouteRoute = SqlHelpersRouteRouteImport.update({ path: '/sql-helpers', getParentRoute: () => rootRouteImport, } as any) +const ServingRouteRoute = ServingRouteRouteImport.update({ + id: '/serving', + path: '/serving', + getParentRoute: () => rootRouteImport, +} as any) const ReconnectRouteRoute = ReconnectRouteRouteImport.update({ id: '/reconnect', path: '/reconnect', @@ -93,6 +99,7 @@ export interface FileRoutesByFullPath { '/genie': typeof GenieRouteRoute '/lakebase': typeof LakebaseRouteRoute '/reconnect': typeof ReconnectRouteRoute + '/serving': typeof ServingRouteRoute '/sql-helpers': typeof SqlHelpersRouteRoute '/telemetry': typeof TelemetryRouteRoute '/type-safety': typeof TypeSafetyRouteRoute @@ -107,6 +114,7 @@ export interface FileRoutesByTo { '/genie': typeof GenieRouteRoute '/lakebase': typeof LakebaseRouteRoute '/reconnect': typeof ReconnectRouteRoute + '/serving': typeof ServingRouteRoute '/sql-helpers': typeof SqlHelpersRouteRoute '/telemetry': typeof TelemetryRouteRoute '/type-safety': typeof TypeSafetyRouteRoute @@ -122,6 +130,7 @@ export interface FileRoutesById { '/genie': typeof GenieRouteRoute '/lakebase': typeof LakebaseRouteRoute '/reconnect': typeof ReconnectRouteRoute + '/serving': typeof ServingRouteRoute '/sql-helpers': typeof SqlHelpersRouteRoute '/telemetry': typeof TelemetryRouteRoute '/type-safety': typeof TypeSafetyRouteRoute @@ -138,6 +147,7 @@ export interface FileRouteTypes { | '/genie' | '/lakebase' | '/reconnect' + | '/serving' | '/sql-helpers' | '/telemetry' | '/type-safety' @@ -152,6 +162,7 @@ export interface FileRouteTypes { | '/genie' | '/lakebase' | '/reconnect' + | '/serving' | '/sql-helpers' | '/telemetry' | '/type-safety' @@ -166,6 +177,7 @@ export interface FileRouteTypes { | '/genie' | '/lakebase' | '/reconnect' + | '/serving' | '/sql-helpers' | '/telemetry' | '/type-safety' @@ -181,6 +193,7 @@ export interface RootRouteChildren { GenieRouteRoute: typeof GenieRouteRoute LakebaseRouteRoute: typeof LakebaseRouteRoute ReconnectRouteRoute: typeof ReconnectRouteRoute + ServingRouteRoute: typeof ServingRouteRoute SqlHelpersRouteRoute: typeof SqlHelpersRouteRoute TelemetryRouteRoute: typeof TelemetryRouteRoute TypeSafetyRouteRoute: typeof TypeSafetyRouteRoute @@ -209,6 +222,13 @@ declare module '@tanstack/react-router' { preLoaderRoute: typeof SqlHelpersRouteRouteImport parentRoute: typeof rootRouteImport } + '/serving': { + id: '/serving' + path: '/serving' + fullPath: '/serving' + preLoaderRoute: typeof ServingRouteRouteImport + parentRoute: typeof rootRouteImport + } '/reconnect': { id: '/reconnect' path: '/reconnect' @@ -285,6 +305,7 @@ const rootRouteChildren: RootRouteChildren = { GenieRouteRoute: GenieRouteRoute, LakebaseRouteRoute: LakebaseRouteRoute, ReconnectRouteRoute: ReconnectRouteRoute, + ServingRouteRoute: ServingRouteRoute, SqlHelpersRouteRoute: SqlHelpersRouteRoute, TelemetryRouteRoute: TelemetryRouteRoute, TypeSafetyRouteRoute: TypeSafetyRouteRoute, diff --git a/apps/dev-playground/client/src/routes/__root.tsx b/apps/dev-playground/client/src/routes/__root.tsx index 5cf74ce3..35a2282b 100644 --- a/apps/dev-playground/client/src/routes/__root.tsx +++ b/apps/dev-playground/client/src/routes/__root.tsx @@ -104,6 +104,14 @@ function RootComponent() { Files + + + diff --git a/apps/dev-playground/client/src/routes/index.tsx b/apps/dev-playground/client/src/routes/index.tsx index e331d93c..934b1467 100644 --- a/apps/dev-playground/client/src/routes/index.tsx +++ b/apps/dev-playground/client/src/routes/index.tsx @@ -218,6 +218,24 @@ function IndexRoute() { + + +
+

+ Model Serving +

+

+ Chat with a Databricks Model Serving endpoint using streaming + completions with real-time SSE responses. +

+ +
+
diff --git a/apps/dev-playground/client/src/routes/serving.route.tsx b/apps/dev-playground/client/src/routes/serving.route.tsx new file mode 100644 index 00000000..ab980b51 --- /dev/null +++ b/apps/dev-playground/client/src/routes/serving.route.tsx @@ -0,0 +1,148 @@ +import { useServingStream } from "@databricks/appkit-ui/react"; +import { createFileRoute } from "@tanstack/react-router"; +import { useEffect, useRef, useState } from "react"; + +export const Route = createFileRoute("/serving")({ + component: ServingRoute, +}); + +interface Message { + id: string; + role: "user" | "assistant"; + content: string; +} + +function extractContent(chunk: unknown): string { + return ( + (chunk as { choices?: { delta?: { content?: string } }[] })?.choices?.[0] + ?.delta?.content ?? "" + ); +} + +function ServingRoute() { + const [input, setInput] = useState(""); + const [messages, setMessages] = useState([]); + + const { stream, chunks, streaming, error, reset } = useServingStream({ + messages: [], + }); + + const streamingContent = chunks.map(extractContent).join(""); + + // Commit assistant message when streaming transitions from true → false + const prevStreamingRef = useRef(false); + useEffect(() => { + if (prevStreamingRef.current && !streaming && streamingContent) { + setMessages((prev) => [ + ...prev, + { + id: crypto.randomUUID(), + role: "assistant", + content: streamingContent, + }, + ]); + reset(); + } + prevStreamingRef.current = streaming; + }, [streaming, streamingContent, reset]); + + function handleSubmit(e: React.FormEvent) { + e.preventDefault(); + if (!input.trim() || streaming) return; + + const userMessage: Message = { + id: crypto.randomUUID(), + role: "user", + content: input.trim(), + }; + + const fullMessages = [ + ...messages.map(({ role, content }) => ({ role, content })), + { role: "user" as const, content: userMessage.content }, + ]; + + setMessages((prev) => [...prev, userMessage]); + setInput(""); + reset(); + stream({ messages: fullMessages }); + } + + return ( +
+
+
+
+

+ Model Serving +

+

+ Chat with a Databricks Model Serving endpoint. Set{" "} + + DATABRICKS_SERVING_ENDPOINT_NAME + {" "} + to enable. +

+
+ +
+ {/* Messages area */} +
+ {messages.map((msg) => ( +
+
+

{msg.content}

+
+
+ ))} + + {/* Streaming response */} + {streaming && ( +
+
+

+ {streamingContent || "..."} +

+
+
+ )} + + {error && ( +
+ Error: {error} +
+ )} +
+ + {/* Input area */} +
+ setInput(e.target.value)} + placeholder="Send a message..." + className="flex-1 rounded-md border px-3 py-2 text-sm bg-background" + disabled={streaming} + /> + +
+
+
+
+
+ ); +} diff --git a/apps/dev-playground/client/vite.config.ts b/apps/dev-playground/client/vite.config.ts index f892c62f..5f37880b 100644 --- a/apps/dev-playground/client/vite.config.ts +++ b/apps/dev-playground/client/vite.config.ts @@ -1,4 +1,5 @@ import path from "node:path"; +import { appKitServingTypesPlugin } from "@databricks/appkit"; import { tanstackRouter } from "@tanstack/router-plugin/vite"; import react from "@vitejs/plugin-react"; import { defineConfig } from "vite"; @@ -11,6 +12,7 @@ export default defineConfig({ target: "react", autoCodeSplitting: process.env.NODE_ENV !== "development", }), + appKitServingTypesPlugin(), ], server: { hmr: { diff --git a/apps/dev-playground/server/index.ts b/apps/dev-playground/server/index.ts index a4b6a2c6..af05b11f 100644 --- a/apps/dev-playground/server/index.ts +++ b/apps/dev-playground/server/index.ts @@ -1,5 +1,12 @@ import "reflect-metadata"; -import { analytics, createApp, files, genie, server } from "@databricks/appkit"; +import { + analytics, + createApp, + files, + genie, + server, + serving, +} from "@databricks/appkit"; import { WorkspaceClient } from "@databricks/sdk-experimental"; import { lakebaseExamples } from "./lakebase-examples-plugin"; import { reconnect } from "./reconnect-plugin"; @@ -26,6 +33,7 @@ createApp({ }), lakebaseExamples(), files(), + serving(), ], ...(process.env.APPKIT_E2E_TEST && { client: createMockClient() }), }).then((appkit) => { diff --git a/docs/docs/api/appkit/Function.appKitServingTypesPlugin.md b/docs/docs/api/appkit/Function.appKitServingTypesPlugin.md new file mode 100644 index 00000000..e53e5bd3 --- /dev/null +++ b/docs/docs/api/appkit/Function.appKitServingTypesPlugin.md @@ -0,0 +1,24 @@ +# Function: appKitServingTypesPlugin() + +```ts +function appKitServingTypesPlugin(options?: AppKitServingTypesPluginOptions): Plugin$1; +``` + +Vite plugin to generate TypeScript types for AppKit serving endpoints. +Fetches OpenAPI schemas from Databricks and generates a .d.ts with +ServingEndpointRegistry module augmentation. + +Endpoint discovery order: +1. Explicit `endpoints` option (override) +2. AST extraction from server file (server/index.ts or server/server.ts) +3. DATABRICKS_SERVING_ENDPOINT_NAME env var (single default endpoint) + +## Parameters + +| Parameter | Type | +| ------ | ------ | +| `options?` | `AppKitServingTypesPluginOptions` | + +## Returns + +`Plugin$1` diff --git a/docs/docs/api/appkit/Function.extractServingEndpoints.md b/docs/docs/api/appkit/Function.extractServingEndpoints.md new file mode 100644 index 00000000..24a5b00d --- /dev/null +++ b/docs/docs/api/appkit/Function.extractServingEndpoints.md @@ -0,0 +1,24 @@ +# Function: extractServingEndpoints() + +```ts +function extractServingEndpoints(serverFilePath: string): + | Record + | null; +``` + +Extract serving endpoint config from a server file by AST-parsing it. +Looks for `serving({ endpoints: { alias: { env: "..." }, ... } })` calls +and extracts the endpoint alias names and their environment variable mappings. + +## Parameters + +| Parameter | Type | Description | +| ------ | ------ | ------ | +| `serverFilePath` | `string` | Absolute path to the server entry file | + +## Returns + + \| `Record`\<`string`, [`EndpointConfig`](Interface.EndpointConfig.md)\> + \| `null` + +Extracted endpoint config, or null if not found or not extractable diff --git a/docs/docs/api/appkit/Function.findServerFile.md b/docs/docs/api/appkit/Function.findServerFile.md new file mode 100644 index 00000000..2ed4e268 --- /dev/null +++ b/docs/docs/api/appkit/Function.findServerFile.md @@ -0,0 +1,19 @@ +# Function: findServerFile() + +```ts +function findServerFile(basePath: string): string | null; +``` + +Find the server entry file by checking candidate paths in order. + +## Parameters + +| Parameter | Type | Description | +| ------ | ------ | ------ | +| `basePath` | `string` | Project root directory to search from | + +## Returns + +`string` \| `null` + +Absolute path to the server file, or null if none found diff --git a/docs/docs/api/appkit/index.md b/docs/docs/api/appkit/index.md index f4685e04..faadf237 100644 --- a/docs/docs/api/appkit/index.md +++ b/docs/docs/api/appkit/index.md @@ -70,9 +70,12 @@ plugin architecture, and React integration. | Function | Description | | ------ | ------ | +| [appKitServingTypesPlugin](Function.appKitServingTypesPlugin.md) | Vite plugin to generate TypeScript types for AppKit serving endpoints. Fetches OpenAPI schemas from Databricks and generates a .d.ts with ServingEndpointRegistry module augmentation. | | [appKitTypesPlugin](Function.appKitTypesPlugin.md) | Vite plugin to generate types for AppKit queries. Calls generateFromEntryPoint under the hood. | | [createApp](Function.createApp.md) | Bootstraps AppKit with the provided configuration. | | [createLakebasePool](Function.createLakebasePool.md) | Create a Lakebase pool with appkit's logger integration. Telemetry automatically uses appkit's OpenTelemetry configuration via global registry. | +| [extractServingEndpoints](Function.extractServingEndpoints.md) | Extract serving endpoint config from a server file by AST-parsing it. Looks for `serving({ endpoints: { alias: { env: "..." }, ... } })` calls and extracts the endpoint alias names and their environment variable mappings. | +| [findServerFile](Function.findServerFile.md) | Find the server entry file by checking candidate paths in order. | | [generateDatabaseCredential](Function.generateDatabaseCredential.md) | Generate OAuth credentials for Postgres database connection using the proper Postgres API. | | [getExecutionContext](Function.getExecutionContext.md) | Get the current execution context. | | [getLakebaseOrmConfig](Function.getLakebaseOrmConfig.md) | Get Lakebase connection configuration for ORMs that don't accept pg.Pool directly. | diff --git a/docs/docs/api/appkit/typedoc-sidebar.ts b/docs/docs/api/appkit/typedoc-sidebar.ts index 91815e3d..1d498d1a 100644 --- a/docs/docs/api/appkit/typedoc-sidebar.ts +++ b/docs/docs/api/appkit/typedoc-sidebar.ts @@ -225,6 +225,11 @@ const typedocSidebar: SidebarsConfig = { type: "category", label: "Functions", items: [ + { + type: "doc", + id: "api/appkit/Function.appKitServingTypesPlugin", + label: "appKitServingTypesPlugin" + }, { type: "doc", id: "api/appkit/Function.appKitTypesPlugin", @@ -240,6 +245,16 @@ const typedocSidebar: SidebarsConfig = { id: "api/appkit/Function.createLakebasePool", label: "createLakebasePool" }, + { + type: "doc", + id: "api/appkit/Function.extractServingEndpoints", + label: "extractServingEndpoints" + }, + { + type: "doc", + id: "api/appkit/Function.findServerFile", + label: "findServerFile" + }, { type: "doc", id: "api/appkit/Function.generateDatabaseCredential", diff --git a/docs/docs/plugins/serving.md b/docs/docs/plugins/serving.md new file mode 100644 index 00000000..82eac240 --- /dev/null +++ b/docs/docs/plugins/serving.md @@ -0,0 +1,227 @@ +--- +sidebar_position: 7 +--- + +# Serving plugin + +Provides an authenticated proxy to [Databricks Model Serving](https://docs.databricks.com/aws/en/machine-learning/model-serving) endpoints, with invoke and streaming support. + +**Key features:** +- Named endpoint aliases for multiple serving endpoints +- Non-streaming (`invoke`) and SSE streaming (`stream`) invocation +- Automatic OpenAPI type generation for request/response schemas +- Request body filtering based on endpoint schema +- On-behalf-of (OBO) user execution + +## Basic usage + +```ts +import { createApp, server, serving } from "@databricks/appkit"; + +await createApp({ + plugins: [ + server(), + serving(), + ], +}); +``` + +With no configuration, the plugin reads `DATABRICKS_SERVING_ENDPOINT_NAME` from the environment and registers it under the `default` alias. + +## Configuration options + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `endpoints` | `Record` | `{ default: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" } }` | Map of alias names to endpoint configs | +| `timeout` | `number` | `120000` | Request timeout in ms | + +### Endpoint aliases + +Endpoint aliases let you reference multiple serving endpoints by name: + +```ts +serving({ + endpoints: { + llm: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" }, + classifier: { env: "DATABRICKS_SERVING_ENDPOINT_CLASSIFIER" }, + }, +}) +``` + +Each alias maps to an environment variable holding the actual endpoint name. If an endpoint serves multiple models, you can use `servedModel` to bypass traffic routing and target a specific model directly: + +```ts +serving({ + endpoints: { + llm: { env: "DATABRICKS_SERVING_ENDPOINT_NAME", servedModel: "llama-v2" }, + }, +}) +``` + +## Type generation + +The `appKitServingTypesPlugin()` Vite plugin generates TypeScript types from your serving endpoints' OpenAPI schemas. Add it to your `vite.config.ts`: + +```ts +import { appKitServingTypesPlugin } from "@databricks/appkit"; + +export default defineConfig({ + plugins: [ + appKitServingTypesPlugin(), + ], +}); +``` + +The plugin auto-discovers endpoint configuration from your server file (`server/index.ts` or `server/server.ts`) — no manual config passing needed. + +Generated types provide: +- **Alias autocomplete** in both backend (`AppKit.serving("alias")`) and frontend hooks (`useServingStream`, `useServingInvoke`) +- **Typed request/response/chunk** per endpoint based on OpenAPI schemas + +If an endpoint's OpenAPI schema is unavailable (not deployed, env var not set), the plugin generates generic fallback types. The endpoint is still usable — just without typed request/response. + +:::note +Endpoints that don't define a streaming response schema in their OpenAPI spec will have `chunk: unknown`. For these endpoints, use `useServingInvoke` instead of `useServingStream` — the `response` type will still be properly typed. +::: + +## Environment variables + +| Variable | Description | +|----------|-------------| +| `DATABRICKS_SERVING_ENDPOINT_NAME` | Default endpoint name (used when `endpoints` config is omitted) | + +When using named endpoints, define a custom environment variable per alias (e.g. `DATABRICKS_SERVING_ENDPOINT_CLASSIFIER`). + +## Execution context + +All serving routes execute on behalf of the authenticated user (OBO) by default, consistent with the Genie and Files plugins. This ensures per-user `CAN_QUERY` permissions are enforced on the serving endpoint. + +For programmatic access via `exports()`, use `.asUser(req)` to run in user context: + +```ts +// Service principal context (default) +const result = await AppKit.serving("llm").invoke({ messages }); + +// User context (recommended in route handlers) +const result = await AppKit.serving("llm").asUser(req).invoke({ messages }); +``` + +## HTTP endpoints + +### Named mode (with `endpoints` config) + +- `POST /api/serving/:alias/invoke` — Non-streaming invocation +- `POST /api/serving/:alias/stream` — SSE streaming invocation + +### Default mode (no `endpoints` config) + +- `POST /api/serving/invoke` — Non-streaming invocation +- `POST /api/serving/stream` — SSE streaming invocation + +### Request format + +``` +POST /api/serving/:alias/invoke +Content-Type: application/json + +{ + "messages": [ + { "role": "user", "content": "Hello" } + ] +} +``` + +## Programmatic access + +The plugin exports `invoke` and `stream` methods for server-side use: + +```ts +const AppKit = await createApp({ + plugins: [ + server(), + serving({ + endpoints: { + llm: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" }, + }, + }), + ], +}); + +// Non-streaming +const result = await AppKit.serving("llm").invoke({ + messages: [{ role: "user", content: "Hello" }], +}); + +// Streaming +for await (const chunk of AppKit.serving("llm").stream({ + messages: [{ role: "user", content: "Hello" }], +})) { + console.log(chunk); +} +``` + +## Frontend hooks + +The `@databricks/appkit-ui` package provides React hooks for serving endpoints: + +### useServingStream + +Streaming invocation via SSE: + +```tsx +import { useServingStream } from "@databricks/appkit-ui/react"; + +function ChatStream() { + const { stream, chunks, streaming, error, reset } = useServingStream( + { messages: [{ role: "user", content: "Hello" }] }, + { + alias: "llm", + onComplete: (finalChunks) => { + // Called with all accumulated chunks when the stream finishes + console.log("Stream done, got", finalChunks.length, "chunks"); + }, + }, + ); + + return ( + <> + + + {chunks.map((chunk, i) =>
{JSON.stringify(chunk)}
)} + {error &&

{error}

} + + ); +} +``` + +### useServingInvoke + +Non-streaming invocation. `invoke()` returns a promise with the response data (or `null` on error): + +```tsx +import { useServingInvoke } from "@databricks/appkit-ui/react"; + +function Classify() { + const { invoke, data, loading, error } = useServingInvoke( + { inputs: ["sample text"] }, + { alias: "classifier" }, + ); + + async function handleClick() { + const result = await invoke(); + if (result) { + console.log("Classification result:", result); + } + } + + return ( + <> + + {data &&
{JSON.stringify(data)}
} + {error &&

{error}

} + + ); +} +``` + +Both hooks accept `autoStart: true` to invoke automatically on mount. diff --git a/packages/appkit-ui/src/react/hooks/__tests__/use-serving-invoke.test.ts b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-invoke.test.ts new file mode 100644 index 00000000..6d5f159f --- /dev/null +++ b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-invoke.test.ts @@ -0,0 +1,117 @@ +import { act, renderHook, waitFor } from "@testing-library/react"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { useServingInvoke } from "../use-serving-invoke"; + +describe("useServingInvoke", () => { + beforeEach(() => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ choices: [] }), { status: 200 }), + ); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + test("initial state is idle", () => { + const { result } = renderHook(() => useServingInvoke({ messages: [] })); + + expect(result.current.data).toBeNull(); + expect(result.current.loading).toBe(false); + expect(result.current.error).toBeNull(); + expect(typeof result.current.invoke).toBe("function"); + }); + + test("calls fetch to correct URL on invoke", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + const { result } = renderHook(() => + useServingInvoke({ messages: [{ role: "user", content: "Hello" }] }), + ); + + act(() => { + result.current.invoke(); + }); + + await waitFor(() => { + expect(fetchSpy).toHaveBeenCalledWith( + "/api/serving/invoke", + expect.objectContaining({ + method: "POST", + body: JSON.stringify({ + messages: [{ role: "user", content: "Hello" }], + }), + }), + ); + }); + }); + + test("uses alias in URL when provided", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + const { result } = renderHook(() => + useServingInvoke({ messages: [] }, { alias: "llm" }), + ); + + act(() => { + result.current.invoke(); + }); + + await waitFor(() => { + expect(fetchSpy).toHaveBeenCalledWith( + "/api/serving/llm/invoke", + expect.any(Object), + ); + }); + }); + + test("sets data on successful response", async () => { + const responseData = { + choices: [{ message: { content: "Hi" } }], + }; + + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(responseData), { status: 200 }), + ); + + const { result } = renderHook(() => useServingInvoke({ messages: [] })); + + act(() => { + result.current.invoke(); + }); + + await waitFor(() => { + expect(result.current.data).toEqual(responseData); + expect(result.current.loading).toBe(false); + }); + }); + + test("sets error on failed response", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ error: "Not found" }), { status: 404 }), + ); + + const { result } = renderHook(() => useServingInvoke({ messages: [] })); + + await act(async () => { + result.current.invoke(); + // Wait for the fetch promise chain to resolve + await new Promise((r) => setTimeout(r, 10)); + }); + + await waitFor(() => { + expect(result.current.error).toBe("Not found"); + expect(result.current.loading).toBe(false); + }); + }); + + test("auto starts when autoStart is true", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + renderHook(() => useServingInvoke({ messages: [] }, { autoStart: true })); + + await waitFor(() => { + expect(fetchSpy).toHaveBeenCalled(); + }); + }); +}); diff --git a/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts new file mode 100644 index 00000000..1ab0bf44 --- /dev/null +++ b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts @@ -0,0 +1,291 @@ +import { act, renderHook, waitFor } from "@testing-library/react"; +import { afterEach, describe, expect, test, vi } from "vitest"; + +// Mock connectSSE — capture callbacks so we can simulate SSE events +let capturedCallbacks: { + onMessage?: (msg: { data: string }) => void; + onError?: (err: Error) => void; + signal?: AbortSignal; +} = {}; + +let resolveStream: (() => void) | null = null; + +const mockConnectSSE = vi.fn().mockImplementation((opts: any) => { + capturedCallbacks = { + onMessage: opts.onMessage, + onError: opts.onError, + signal: opts.signal, + }; + return new Promise((resolve) => { + resolveStream = resolve; + // Also resolve after a tick as fallback for tests that don't manually resolve + setTimeout(resolve, 0); + }); +}); + +vi.mock("@/js", () => ({ + connectSSE: (...args: unknown[]) => mockConnectSSE(...args), +})); + +import { useServingStream } from "../use-serving-stream"; + +describe("useServingStream", () => { + afterEach(() => { + capturedCallbacks = {}; + resolveStream = null; + vi.clearAllMocks(); + }); + + test("initial state is idle", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + expect(result.current.chunks).toEqual([]); + expect(result.current.streaming).toBe(false); + expect(result.current.error).toBeNull(); + expect(typeof result.current.stream).toBe("function"); + expect(typeof result.current.reset).toBe("function"); + }); + + test("calls connectSSE with correct URL on stream", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + expect(mockConnectSSE).toHaveBeenCalledWith( + expect.objectContaining({ + url: "/api/serving/stream", + payload: JSON.stringify({ messages: [] }), + }), + ); + }); + + test("uses override body when passed to stream()", () => { + const { result } = renderHook(() => + useServingStream({ messages: [{ role: "user", content: "old" }] }), + ); + + const overrideBody = { + messages: [{ role: "user" as const, content: "new" }], + }; + + act(() => { + result.current.stream(overrideBody); + }); + + expect(mockConnectSSE).toHaveBeenCalledWith( + expect.objectContaining({ + payload: JSON.stringify(overrideBody), + }), + ); + }); + + test("uses alias in URL when provided", () => { + const { result } = renderHook(() => + useServingStream({ messages: [] }, { alias: "embedder" }), + ); + + act(() => { + result.current.stream(); + }); + + expect(mockConnectSSE).toHaveBeenCalledWith( + expect.objectContaining({ + url: "/api/serving/embedder/stream", + }), + ); + }); + + test("sets streaming to true when stream() is called", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + expect(result.current.streaming).toBe(true); + }); + + test("accumulates chunks from onMessage", async () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 1 }) }); + }); + + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 2 }) }); + }); + + expect(result.current.chunks).toEqual([{ id: 1 }, { id: 2 }]); + }); + + test("accumulates chunks with error field as normal data", async () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onMessage?.({ + data: JSON.stringify({ error: "Model overloaded" }), + }); + }); + + // Chunks with an `error` field are treated as data, not stream errors. + // Transport-level errors are delivered via onError callback instead. + expect(result.current.chunks).toEqual([{ error: "Model overloaded" }]); + expect(result.current.error).toBeNull(); + expect(result.current.streaming).toBe(true); + }); + + test("sets error from onError callback", async () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onError?.(new Error("Connection lost")); + }); + + expect(result.current.error).toBe("Connection lost"); + expect(result.current.streaming).toBe(false); + }); + + test("silently skips malformed JSON messages", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onMessage?.({ data: "not valid json{" }); + }); + + // No chunks added, no error set + expect(result.current.chunks).toEqual([]); + expect(result.current.error).toBeNull(); + }); + + test("reset() clears state and aborts active stream", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 1 }) }); + }); + + expect(result.current.chunks).toHaveLength(1); + expect(result.current.streaming).toBe(true); + + act(() => { + result.current.reset(); + }); + + expect(result.current.chunks).toEqual([]); + expect(result.current.streaming).toBe(false); + expect(result.current.error).toBeNull(); + }); + + test("autoStart triggers stream on mount", async () => { + renderHook(() => useServingStream({ messages: [] }, { autoStart: true })); + + await waitFor(() => { + expect(mockConnectSSE).toHaveBeenCalled(); + }); + }); + + test("passes abort signal to connectSSE", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + expect(capturedCallbacks.signal).toBeDefined(); + expect(capturedCallbacks.signal?.aborted).toBe(false); + }); + + test("aborts stream on unmount", () => { + const { result, unmount } = renderHook(() => + useServingStream({ messages: [] }), + ); + + act(() => { + result.current.stream(); + }); + + const signal = capturedCallbacks.signal; + expect(signal?.aborted).toBe(false); + + unmount(); + + expect(signal?.aborted).toBe(true); + }); + + test("sets streaming to false when connectSSE resolves", async () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + await waitFor(() => { + expect(result.current.streaming).toBe(false); + }); + }); + + test("calls onComplete with accumulated chunks when stream finishes", async () => { + const onComplete = vi.fn(); + + // Use a controllable mock so stream doesn't auto-resolve + mockConnectSSE.mockImplementationOnce((opts: any) => { + capturedCallbacks = { + onMessage: opts.onMessage, + onError: opts.onError, + signal: opts.signal, + }; + return new Promise((resolve) => { + resolveStream = resolve; + }); + }); + + const { result } = renderHook(() => + useServingStream({ messages: [] }, { onComplete }), + ); + + act(() => { + result.current.stream(); + }); + + // Send two chunks + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 1 }) }); + }); + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 2 }) }); + }); + + expect(onComplete).not.toHaveBeenCalled(); + + // Complete the stream + await act(async () => { + resolveStream?.(); + await new Promise((r) => setTimeout(r, 0)); + }); + + expect(onComplete).toHaveBeenCalledWith([{ id: 1 }, { id: 2 }]); + }); +}); diff --git a/packages/appkit-ui/src/react/hooks/index.ts b/packages/appkit-ui/src/react/hooks/index.ts index 84d51b53..a425b010 100644 --- a/packages/appkit-ui/src/react/hooks/index.ts +++ b/packages/appkit-ui/src/react/hooks/index.ts @@ -2,8 +2,13 @@ export type { AnalyticsFormat, InferResultByFormat, InferRowType, + InferServingChunk, + InferServingRequest, + InferServingResponse, PluginRegistry, QueryRegistry, + ServingAlias, + ServingEndpointRegistry, TypedArrowTable, UseAnalyticsQueryOptions, UseAnalyticsQueryResult, @@ -15,3 +20,13 @@ export { useChartData, } from "./use-chart-data"; export { usePluginClientConfig } from "./use-plugin-config"; +export { + type UseServingInvokeOptions, + type UseServingInvokeResult, + useServingInvoke, +} from "./use-serving-invoke"; +export { + type UseServingStreamOptions, + type UseServingStreamResult, + useServingStream, +} from "./use-serving-stream"; diff --git a/packages/appkit-ui/src/react/hooks/types.ts b/packages/appkit-ui/src/react/hooks/types.ts index 5db725fc..19ce1fac 100644 --- a/packages/appkit-ui/src/react/hooks/types.ts +++ b/packages/appkit-ui/src/react/hooks/types.ts @@ -134,3 +134,54 @@ export type InferParams = K extends AugmentedRegistry export interface PluginRegistry { [key: string]: Record; } + +// ============================================================================ +// Serving Endpoint Registry +// ============================================================================ + +/** + * Serving endpoint registry for type-safe alias names. + * Extend this interface via module augmentation to get alias autocomplete: + * + * @example + * ```typescript + * // Auto-generated by appKitServingTypesPlugin() + * declare module "@databricks/appkit-ui/react" { + * interface ServingEndpointRegistry { + * llm: { request: {...}; response: {...}; chunk: {...} }; + * } + * } + * ``` + */ +// biome-ignore lint/suspicious/noEmptyInterface: intentionally empty — populated via module augmentation +export interface ServingEndpointRegistry {} + +/** Resolves to registry keys if populated, otherwise string */ +export type ServingAlias = + AugmentedRegistry extends never + ? string + : AugmentedRegistry; + +/** Infers chunk type from registry when alias is a known key */ +export type InferServingChunk = + K extends AugmentedRegistry + ? ServingEndpointRegistry[K] extends { chunk: infer C } + ? C + : unknown + : unknown; + +/** Infers response type from registry when alias is a known key */ +export type InferServingResponse = + K extends AugmentedRegistry + ? ServingEndpointRegistry[K] extends { response: infer R } + ? R + : unknown + : unknown; + +/** Infers request type from registry when alias is a known key */ +export type InferServingRequest = + K extends AugmentedRegistry + ? ServingEndpointRegistry[K] extends { request: infer Req } + ? Req + : Record + : Record; diff --git a/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts b/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts new file mode 100644 index 00000000..8e80e82e --- /dev/null +++ b/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts @@ -0,0 +1,111 @@ +import { useCallback, useEffect, useRef, useState } from "react"; +import type { + InferServingRequest, + InferServingResponse, + ServingAlias, +} from "./types"; + +export interface UseServingInvokeOptions< + K extends ServingAlias = ServingAlias, +> { + /** Endpoint alias for named mode. Omit for default mode. */ + alias?: K; + /** If false, does not invoke automatically on mount. Default: false */ + autoStart?: boolean; +} + +export interface UseServingInvokeResult< + T = unknown, + TBody = Record, +> { + /** Trigger the invocation. Pass an optional body override for this invocation. */ + invoke: (overrideBody?: TBody) => Promise; + /** Response data, null until loaded. */ + data: T | null; + /** Whether a request is in progress. */ + loading: boolean; + /** Error message, if any. */ + error: string | null; +} + +/** + * Hook for non-streaming invocation of a serving endpoint. + * Calls `POST /api/serving/invoke` (default) or `POST /api/serving/{alias}/invoke` (named). + * + * When the type generator has populated `ServingEndpointRegistry`, the response type + * is automatically inferred from the endpoint's OpenAPI schema. + */ +export function useServingInvoke( + body: InferServingRequest, + options: UseServingInvokeOptions = {} as UseServingInvokeOptions, +): UseServingInvokeResult, InferServingRequest> { + type TResponse = InferServingResponse; + const { alias, autoStart = false } = options; + + const [data, setData] = useState(null); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(null); + const abortControllerRef = useRef(null); + + const urlSuffix = alias + ? `/api/serving/${encodeURIComponent(String(alias))}/invoke` + : "/api/serving/invoke"; + + const bodyJson = JSON.stringify(body); + + const invoke = useCallback( + (overrideBody?: InferServingRequest): Promise => { + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + } + + setLoading(true); + setError(null); + setData(null); + + const abortController = new AbortController(); + abortControllerRef.current = abortController; + + const payload = overrideBody ? JSON.stringify(overrideBody) : bodyJson; + + return fetch(urlSuffix, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: payload, + signal: abortController.signal, + }) + .then(async (res) => { + if (!res.ok) { + const errorBody = await res.json().catch(() => null); + throw new Error(errorBody?.error || `HTTP ${res.status}`); + } + return res.json(); + }) + .then((result: TResponse) => { + if (abortController.signal.aborted) return null; + setData(result); + setLoading(false); + return result; + }) + .catch((err: Error) => { + if (abortController.signal.aborted) return null; + setError(err.message || "Request failed"); + setLoading(false); + return null; + }); + }, + [urlSuffix, bodyJson], + ); + + useEffect(() => { + if (autoStart) { + invoke(); + } + + return () => { + abortControllerRef.current?.abort(); + }; + }, [invoke, autoStart]); + + return { invoke, data, loading, error }; +} diff --git a/packages/appkit-ui/src/react/hooks/use-serving-stream.ts b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts new file mode 100644 index 00000000..f0bb7bf2 --- /dev/null +++ b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts @@ -0,0 +1,137 @@ +import { useCallback, useEffect, useRef, useState } from "react"; +import { connectSSE } from "@/js"; +import type { + InferServingChunk, + InferServingRequest, + ServingAlias, +} from "./types"; + +export interface UseServingStreamOptions< + K extends ServingAlias = ServingAlias, + T = InferServingChunk, +> { + /** Endpoint alias for named mode. Omit for default mode. */ + alias?: K; + /** If true, starts streaming automatically on mount. Default: false */ + autoStart?: boolean; + /** Called with accumulated chunks when the stream completes successfully. */ + onComplete?: (chunks: T[]) => void; +} + +export interface UseServingStreamResult< + T = unknown, + TBody = Record, +> { + /** Trigger the streaming invocation. Pass an optional body override for this invocation. */ + stream: (overrideBody?: TBody) => void; + /** Accumulated chunks received so far. */ + chunks: T[]; + /** Whether streaming is in progress. */ + streaming: boolean; + /** Error message, if any. */ + error: string | null; + /** Reset chunks and abort any active stream. */ + reset: () => void; +} + +/** + * Hook for streaming invocation of a serving endpoint via SSE. + * Calls `POST /api/serving/stream` (default) or `POST /api/serving/{alias}/stream` (named). + * Accumulates parsed chunks in state. + * + * When the type generator has populated `ServingEndpointRegistry`, the chunk type + * is automatically inferred from the endpoint's OpenAPI schema. + */ +export function useServingStream( + body: InferServingRequest, + options: UseServingStreamOptions = {} as UseServingStreamOptions, +): UseServingStreamResult, InferServingRequest> { + type TChunk = InferServingChunk; + const { alias, autoStart = false, onComplete } = options; + + const [chunks, setChunks] = useState([]); + const [streaming, setStreaming] = useState(false); + const [error, setError] = useState(null); + const abortControllerRef = useRef(null); + const chunksRef = useRef([]); + const onCompleteRef = useRef(onComplete); + onCompleteRef.current = onComplete; + + const urlSuffix = alias + ? `/api/serving/${encodeURIComponent(String(alias))}/stream` + : "/api/serving/stream"; + + const reset = useCallback(() => { + abortControllerRef.current?.abort(); + abortControllerRef.current = null; + chunksRef.current = []; + setChunks([]); + setStreaming(false); + setError(null); + }, []); + + const bodyJson = JSON.stringify(body); + + const stream = useCallback( + (overrideBody?: InferServingRequest) => { + // Abort any existing stream + abortControllerRef.current?.abort(); + + setStreaming(true); + setError(null); + setChunks([]); + chunksRef.current = []; + + const abortController = new AbortController(); + abortControllerRef.current = abortController; + + const payload = overrideBody ? JSON.stringify(overrideBody) : bodyJson; + + connectSSE({ + url: urlSuffix, + payload, + signal: abortController.signal, + onMessage: async (message) => { + if (abortController.signal.aborted) return; + try { + const parsed = JSON.parse(message.data); + + chunksRef.current = [...chunksRef.current, parsed as TChunk]; + setChunks(chunksRef.current); + } catch { + // Skip malformed messages + } + }, + onError: (err) => { + if (abortController.signal.aborted) return; + setStreaming(false); + setError(err instanceof Error ? err.message : "Streaming failed"); + }, + }) + .then(() => { + if (abortController.signal.aborted) return; + // Stream completed + setStreaming(false); + onCompleteRef.current?.(chunksRef.current); + }) + .catch(() => { + if (abortController.signal.aborted) return; + setStreaming(false); + setError("Connection error"); + }); + }, + [urlSuffix, bodyJson], + ); + + useEffect(() => { + if (autoStart) { + stream(); + } + + return () => { + abortControllerRef.current?.abort(); + }; + }, [stream, autoStart]); + + return { stream, chunks, streaming, error, reset }; +} diff --git a/packages/appkit/package.json b/packages/appkit/package.json index 0613ec51..64166c4c 100644 --- a/packages/appkit/package.json +++ b/packages/appkit/package.json @@ -50,6 +50,7 @@ "typecheck": "tsc --noEmit" }, "dependencies": { + "@ast-grep/napi": "0.37.0", "@databricks/lakebase": "workspace:*", "@databricks/sdk-experimental": "0.16.0", "@opentelemetry/api": "1.9.0", diff --git a/packages/appkit/src/index.ts b/packages/appkit/src/index.ts index 662a9178..3df5572b 100644 --- a/packages/appkit/src/index.ts +++ b/packages/appkit/src/index.ts @@ -81,6 +81,10 @@ export { SpanStatusCode, type TelemetryConfig, } from "./telemetry"; - +export { + extractServingEndpoints, + findServerFile, +} from "./type-generator/serving/server-file-extractor"; +export { appKitServingTypesPlugin } from "./type-generator/serving/vite-plugin"; // Vite plugin and type generation export { appKitTypesPlugin } from "./type-generator/vite-plugin"; diff --git a/packages/appkit/src/plugins/serving/schema-filter.ts b/packages/appkit/src/plugins/serving/schema-filter.ts index 6e52294a..92a25c69 100644 --- a/packages/appkit/src/plugins/serving/schema-filter.ts +++ b/packages/appkit/src/plugins/serving/schema-filter.ts @@ -1,19 +1,9 @@ import fs from "node:fs/promises"; import { createLogger } from "../../logging/logger"; - -const CACHE_VERSION = "1"; - -interface ServingCacheEntry { - hash: string; - requestType: string; - responseType: string; - chunkType: string | null; -} - -interface ServingCache { - version: string; - endpoints: Record; -} +import { + CACHE_VERSION, + type ServingCache, +} from "../../type-generator/serving/cache"; const logger = createLogger("serving:schema-filter"); @@ -47,11 +37,8 @@ export async function loadEndpointSchemas( const cache = parsed; for (const [alias, entry] of Object.entries(cache.endpoints)) { - // Extract property keys from the requestType string - // The requestType is a TypeScript object type like "{ messages: ...; temperature: ...; }" - const keys = extractPropertyKeys(entry.requestType); - if (keys.size > 0) { - allowlists.set(alias, keys); + if (entry.requestKeys && entry.requestKeys.length > 0) { + allowlists.set(alias, new Set(entry.requestKeys)); } } } catch (err) { @@ -67,25 +54,6 @@ export async function loadEndpointSchemas( return allowlists; } -/** - * Extracts top-level property keys from a TypeScript object type string. - * Matches patterns like `key:` or `key?:` at the first nesting level. - */ -function extractPropertyKeys(typeStr: string): Set { - const keys = new Set(); - // Match property names at the top level of the object type - // Looking for patterns: ` propertyName:` or ` propertyName?:` - const propRegex = /^\s{2}(?:\/\*\*[^*]*\*\/\s*)?(\w+)\??:/gm; - for ( - let match = propRegex.exec(typeStr); - match !== null; - match = propRegex.exec(typeStr) - ) { - keys.add(match[1]); - } - return keys; -} - /** * Filters a request body against the allowed keys for an endpoint alias. * Returns the filtered body and logs a warning for stripped params. diff --git a/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts b/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts index 948b47f9..4fc030d8 100644 --- a/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts +++ b/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts @@ -109,7 +109,7 @@ describe("schema-filter", () => { expect(result.size).toBe(0); }); - test("extracts property keys from cached types", async () => { + test("reads requestKeys from cache entries", async () => { const fs = (await import("node:fs/promises")).default; vi.mocked(fs.readFile).mockResolvedValue( JSON.stringify({ @@ -117,13 +117,10 @@ describe("schema-filter", () => { endpoints: { default: { hash: "abc", - requestType: `{ - messages: string[]; - temperature?: number | null; - max_tokens: number; -}`, + requestType: "{}", responseType: "{}", chunkType: null, + requestKeys: ["messages", "temperature", "max_tokens"], }, }, }), @@ -137,5 +134,26 @@ describe("schema-filter", () => { expect(keys?.has("temperature")).toBe(true); expect(keys?.has("max_tokens")).toBe(true); }); + + test("skips entries without requestKeys (backwards compat)", async () => { + const fs = (await import("node:fs/promises")).default; + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + version: "1", + endpoints: { + default: { + hash: "abc", + requestType: "{ messages: string[] }", + responseType: "{}", + chunkType: null, + }, + }, + }), + ); + + const result = await loadEndpointSchemas("/some/path"); + // No requestKeys → passthrough mode (no allowlist) + expect(result.size).toBe(0); + }); }); }); diff --git a/packages/appkit/src/type-generator/serving/cache.ts b/packages/appkit/src/type-generator/serving/cache.ts new file mode 100644 index 00000000..dc9bf7e2 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/cache.ts @@ -0,0 +1,56 @@ +import crypto from "node:crypto"; +import fs from "node:fs/promises"; +import path from "node:path"; +import { createLogger } from "../../logging/logger"; + +const logger = createLogger("type-generator:serving:cache"); + +export const CACHE_VERSION = "1"; +const CACHE_FILE = ".appkit-serving-types-cache.json"; +const CACHE_DIR = path.join( + process.cwd(), + "node_modules", + ".databricks", + "appkit", +); + +export interface ServingCacheEntry { + hash: string; + requestType: string; + responseType: string; + chunkType: string | null; + requestKeys: string[]; +} + +export interface ServingCache { + version: string; + endpoints: Record; +} + +export function hashSchema(schemaJson: string): string { + return crypto.createHash("sha256").update(schemaJson).digest("hex"); +} + +export async function loadServingCache(): Promise { + const cachePath = path.join(CACHE_DIR, CACHE_FILE); + try { + await fs.mkdir(CACHE_DIR, { recursive: true }); + const raw = await fs.readFile(cachePath, "utf8"); + const cache = JSON.parse(raw) as ServingCache; + if (cache.version === CACHE_VERSION) { + return cache; + } + logger.debug("Cache version mismatch, starting fresh"); + } catch (err) { + if ((err as NodeJS.ErrnoException).code !== "ENOENT") { + logger.warn("Cache file is corrupted, flushing cache completely."); + } + } + return { version: CACHE_VERSION, endpoints: {} }; +} + +export async function saveServingCache(cache: ServingCache): Promise { + const cachePath = path.join(CACHE_DIR, CACHE_FILE); + await fs.mkdir(CACHE_DIR, { recursive: true }); + await fs.writeFile(cachePath, JSON.stringify(cache, null, 2), "utf8"); +} diff --git a/packages/appkit/src/type-generator/serving/converter.ts b/packages/appkit/src/type-generator/serving/converter.ts new file mode 100644 index 00000000..b56b0460 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/converter.ts @@ -0,0 +1,159 @@ +import type { OpenApiOperation, OpenApiSchema } from "./fetcher"; + +/** + * Converts an OpenAPI schema to a TypeScript type string. + */ +function schemaToTypeString(schema: OpenApiSchema, indent = 0): string { + const pad = " ".repeat(indent); + + if (schema.oneOf) { + return schema.oneOf.map((s) => schemaToTypeString(s, indent)).join(" | "); + } + + if (schema.enum) { + return schema.enum.map((v) => JSON.stringify(v)).join(" | "); + } + + switch (schema.type) { + case "string": + return "string"; + case "integer": + case "number": + return "number"; + case "boolean": + return "boolean"; + case "array": { + if (!schema.items) return "unknown[]"; + const itemType = schemaToTypeString(schema.items, indent); + // Wrap union types in parens for array + if (itemType.includes(" | ") && !itemType.startsWith("{")) { + return `(${itemType})[]`; + } + return `${itemType}[]`; + } + case "object": { + if (!schema.properties) return "Record"; + const required = new Set(schema.required ?? []); + const entries = Object.entries(schema.properties).map(([key, prop]) => { + const optional = !required.has(key) ? "?" : ""; + const nullable = prop.nullable ? " | null" : ""; + const typeStr = schemaToTypeString(prop, indent + 1); + const formatComment = + prop.format && (prop.type === "number" || prop.type === "integer") + ? `/** @openapi ${prop.format}${prop.nullable ? ", nullable" : ""} */\n${pad} ` + : prop.nullable && prop.type === "integer" + ? `/** @openapi integer, nullable */\n${pad} ` + : ""; + return `${pad} ${formatComment}${key}${optional}: ${typeStr}${nullable};`; + }); + return `{\n${entries.join("\n")}\n${pad}}`; + } + default: + return "unknown"; + } +} + +/** + * Extracts the top-level property keys from the request schema. + * Strips the `stream` property (plugin-controlled). + */ +export function extractRequestKeys(operation: OpenApiOperation): string[] { + const schema = operation.requestBody?.content?.["application/json"]?.schema; + if (!schema?.properties) return []; + return Object.keys(schema.properties).filter((k) => k !== "stream"); +} + +/** + * Extracts and converts the request schema from an OpenAPI path operation. + * Strips the `stream` property from the request type. + */ +export function convertRequestSchema(operation: OpenApiOperation): string { + const schema = operation.requestBody?.content?.["application/json"]?.schema; + if (!schema || !schema.properties) return "Record"; + + // Strip `stream` property — the plugin controls this + const { stream: _stream, ...filteredProps } = schema.properties; + const filteredRequired = (schema.required ?? []).filter( + (r) => r !== "stream", + ); + + const filteredSchema: OpenApiSchema = { + ...schema, + properties: filteredProps, + required: filteredRequired.length > 0 ? filteredRequired : undefined, + }; + + return schemaToTypeString(filteredSchema); +} + +/** + * Extracts and converts the response schema from an OpenAPI path operation. + */ +export function convertResponseSchema(operation: OpenApiOperation): string { + const response = operation.responses?.["200"]; + const schema = response?.content?.["application/json"]?.schema; + if (!schema) return "unknown"; + return schemaToTypeString(schema); +} + +/** + * Derives a streaming chunk type from the response schema. + * Returns null if the response doesn't follow OpenAI-compatible format. + * + * OpenAI-compatible heuristic: response has `choices` array where items + * have a `message` object property. + */ +export function deriveChunkType(operation: OpenApiOperation): string | null { + const response = operation.responses?.["200"]; + const schema = response?.content?.["application/json"]?.schema; + if (!schema?.properties) return null; + + const choicesProp = schema.properties.choices; + if (!choicesProp || choicesProp.type !== "array" || !choicesProp.items) + return null; + + const choiceItemProps = choicesProp.items.properties; + if (!choiceItemProps?.message) return null; + + // It's OpenAI-compatible. Build the chunk type by transforming. + const messageSchema = choiceItemProps.message; + + // Build chunk schema: replace message with delta (Partial), make finish_reason nullable, drop usage + const chunkProperties: Record = {}; + + for (const [key, prop] of Object.entries(schema.properties)) { + if (key === "usage") continue; // Drop usage from chunks + if (key === "choices") { + // Transform choices items + const chunkChoiceProps: Record = {}; + for (const [ck, cp] of Object.entries(choiceItemProps)) { + if (ck === "message") { + // Replace message with delta: Partial + chunkChoiceProps.delta = { ...messageSchema }; + } else if (ck === "finish_reason") { + chunkChoiceProps[ck] = { ...cp, nullable: true }; + } else { + chunkChoiceProps[ck] = cp; + } + } + chunkProperties[key] = { + type: "array", + items: { + type: "object", + properties: chunkChoiceProps, + }, + }; + } else { + chunkProperties[key] = prop; + } + } + + const chunkSchema: OpenApiSchema = { + type: "object", + properties: chunkProperties, + }; + + // Delta properties are already optional (no `required` array in the schema), + // so schemaToTypeString renders them with `?:` — no Partial<> wrapper needed. + return schemaToTypeString(chunkSchema); +} diff --git a/packages/appkit/src/type-generator/serving/fetcher.ts b/packages/appkit/src/type-generator/serving/fetcher.ts new file mode 100644 index 00000000..bf733d7b --- /dev/null +++ b/packages/appkit/src/type-generator/serving/fetcher.ts @@ -0,0 +1,158 @@ +import type { WorkspaceClient } from "@databricks/sdk-experimental"; +import { createLogger } from "../../logging/logger"; + +const logger = createLogger("type-generator:serving:fetcher"); + +interface OpenApiSpec { + openapi: string; + info: { title: string; version: string }; + paths: Record>; +} + +export interface OpenApiOperation { + requestBody?: { + content: { + "application/json": { + schema: OpenApiSchema; + }; + }; + }; + responses?: Record< + string, + { + content?: { + "application/json": { + schema: OpenApiSchema; + }; + }; + } + >; +} + +export interface OpenApiSchema { + type?: string; + properties?: Record; + required?: string[]; + items?: OpenApiSchema; + enum?: string[]; + nullable?: boolean; + oneOf?: OpenApiSchema[]; + format?: string; +} + +/** + * Fetches the OpenAPI schema for a serving endpoint. + * Returns null if the endpoint is not found or access is denied. + */ +export async function fetchOpenApiSchema( + client: WorkspaceClient, + endpointName: string, + servedModel?: string, +): Promise<{ spec: OpenApiSpec; pathKey: string } | null> { + const headers = new Headers({ Accept: "application/json" }); + await client.config.authenticate(headers); + + const host = client.config.host; + if (!host) { + logger.warn("Databricks host not configured, skipping schema fetch"); + return null; + } + + const base = host.startsWith("http") ? host : `https://${host}`; + const url = new URL( + `/api/2.0/serving-endpoints/${encodeURIComponent(endpointName)}/openapi`, + base, + ); + + const controller = new AbortController(); + const timeout = setTimeout(() => controller.abort(), 5000); + + try { + const res = await fetch(url.toString(), { + headers, + signal: controller.signal, + }); + + if (!res.ok) { + const body = await res.text().catch(() => ""); + if (res.status === 404) { + logger.warn( + "Endpoint '%s' not found, skipping type generation%s", + endpointName, + body ? `: ${body}` : "", + ); + } else if (res.status === 403) { + logger.warn( + "Access denied to endpoint '%s' schema, skipping type generation%s", + endpointName, + body ? `: ${body}` : "", + ); + } else { + logger.warn( + "Failed to fetch schema for '%s' (HTTP %d), skipping%s", + endpointName, + res.status, + body ? `: ${body}` : "", + ); + } + return null; + } + + const rawSpec: unknown = await res.json(); + if ( + typeof rawSpec !== "object" || + rawSpec === null || + !("paths" in rawSpec) || + typeof (rawSpec as OpenApiSpec).paths !== "object" + ) { + logger.warn( + "Invalid OpenAPI schema structure for '%s', skipping", + endpointName, + ); + return null; + } + const spec = rawSpec as OpenApiSpec; + + // Find the right path key + const pathKeys = Object.keys(spec.paths ?? {}); + if (pathKeys.length === 0) { + logger.warn("No paths in OpenAPI schema for '%s'", endpointName); + return null; + } + + let pathKey: string; + if (servedModel) { + const match = pathKeys.find((k) => k.includes(`/${servedModel}/`)); + if (!match) { + logger.warn( + "Served model '%s' not found in schema for '%s', using first path", + servedModel, + endpointName, + ); + pathKey = pathKeys[0]; + } else { + pathKey = match; + } + } else { + pathKey = pathKeys[0]; + } + + return { spec, pathKey }; + } catch (err) { + if ((err as Error).name === "AbortError") { + logger.warn( + "Timeout fetching schema for '%s', skipping type generation", + endpointName, + ); + } else { + logger.warn( + "Error fetching schema for '%s': %s", + endpointName, + (err as Error).message, + ); + } + return null; + } finally { + clearTimeout(timeout); + } +} diff --git a/packages/appkit/src/type-generator/serving/generator.ts b/packages/appkit/src/type-generator/serving/generator.ts new file mode 100644 index 00000000..85ed7237 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/generator.ts @@ -0,0 +1,276 @@ +import fs from "node:fs/promises"; +import { WorkspaceClient } from "@databricks/sdk-experimental"; +import pc from "picocolors"; +import { createLogger } from "../../logging/logger"; +import type { EndpointConfig } from "../../plugins/serving/types"; +import { + CACHE_VERSION, + hashSchema, + loadServingCache, + type ServingCache, + saveServingCache, +} from "./cache"; +import { + convertRequestSchema, + convertResponseSchema, + deriveChunkType, + extractRequestKeys, +} from "./converter"; +import { fetchOpenApiSchema } from "./fetcher"; + +const logger = createLogger("type-generator:serving"); + +const GENERIC_REQUEST = "Record"; +const GENERIC_RESPONSE = "unknown"; +const GENERIC_CHUNK = "unknown"; + +interface GenerateServingTypesOptions { + outFile: string; + endpoints?: Record; + noCache?: boolean; +} + +/** + * Generates TypeScript type declarations for serving endpoints + * by fetching their OpenAPI schemas and converting to TypeScript. + */ +export async function generateServingTypes( + options: GenerateServingTypesOptions, +): Promise { + const { outFile, noCache } = options; + + // Resolve endpoints from config or env + const endpoints = options.endpoints ?? resolveDefaultEndpoints(); + if (Object.keys(endpoints).length === 0) { + logger.debug("No serving endpoints configured, skipping type generation"); + return; + } + + const startTime = performance.now(); + + const cache = noCache + ? { version: CACHE_VERSION, endpoints: {} } + : await loadServingCache(); + + let client: WorkspaceClient | undefined; + let updated = false; + + const registryEntries: string[] = []; + const logEntries: Array<{ + alias: string; + status: "HIT" | "MISS"; + error?: string; + }> = []; + + for (const [alias, config] of Object.entries(endpoints)) { + const endpointName = process.env[config.env]; + if (!endpointName) { + registryEntries.push( + buildRegistryEntry( + alias, + GENERIC_REQUEST, + GENERIC_RESPONSE, + GENERIC_CHUNK, + ), + ); + logEntries.push({ + alias, + status: "MISS", + error: `env ${config.env} not set`, + }); + continue; + } + + client ??= new WorkspaceClient({}); + const result = await fetchOpenApiSchema( + client, + endpointName, + config.servedModel, + ); + if (!result) { + registryEntries.push( + buildRegistryEntry( + alias, + GENERIC_REQUEST, + GENERIC_RESPONSE, + GENERIC_CHUNK, + ), + ); + logEntries.push({ + alias, + status: "MISS", + error: "schema fetch failed", + }); + continue; + } + + const { spec, pathKey } = result; + const schemaJson = JSON.stringify(spec); + const hash = hashSchema(schemaJson); + + // Check cache + const cached = cache.endpoints[alias]; + if (cached && cached.hash === hash) { + registryEntries.push( + buildRegistryEntry( + alias, + cached.requestType, + cached.responseType, + cached.chunkType, + ), + ); + logEntries.push({ alias, status: "HIT" }); + continue; + } + + // Cache miss — convert + const operation = spec.paths[pathKey]?.post; + if (!operation) { + logEntries.push({ + alias, + status: "MISS", + error: "no POST operation", + }); + continue; + } + + let requestType: string; + let responseType: string; + let chunkType: string | null; + let requestKeys: string[]; + try { + requestType = convertRequestSchema(operation); + responseType = convertResponseSchema(operation); + chunkType = deriveChunkType(operation); + requestKeys = extractRequestKeys(operation); + } catch (convErr) { + logger.warn( + "Schema conversion failed for '%s': %s", + alias, + (convErr as Error).message, + ); + registryEntries.push( + buildRegistryEntry( + alias, + GENERIC_REQUEST, + GENERIC_RESPONSE, + GENERIC_CHUNK, + ), + ); + logEntries.push({ + alias, + status: "MISS", + error: "schema conversion failed", + }); + continue; + } + + cache.endpoints[alias] = { + hash, + requestType, + responseType, + chunkType, + requestKeys, + }; + updated = true; + + registryEntries.push( + buildRegistryEntry(alias, requestType, responseType, chunkType), + ); + logEntries.push({ alias, status: "MISS" }); + } + + // Print formatted table (matching analytics typegen output) + if (logEntries.length > 0) { + const maxNameLen = Math.max(...logEntries.map((e) => e.alias.length)); + const separator = pc.dim("─".repeat(50)); + console.log(""); + console.log( + ` ${pc.bold("Typegen Serving")} ${pc.dim(`(${logEntries.length})`)}`, + ); + console.log(` ${separator}`); + for (const entry of logEntries) { + const tag = + entry.status === "HIT" + ? `cache ${pc.bold(pc.green("HIT "))}` + : `cache ${pc.bold(pc.yellow("MISS "))}`; + const rawName = entry.alias.padEnd(maxNameLen); + const reason = entry.error ? ` ${pc.dim(entry.error)}` : ""; + console.log(` ${tag} ${rawName}${reason}`); + } + const elapsed = ((performance.now() - startTime) / 1000).toFixed(2); + const newCount = logEntries.filter((e) => e.status === "MISS").length; + const cacheCount = logEntries.filter((e) => e.status === "HIT").length; + console.log(` ${separator}`); + console.log( + ` ${newCount} new, ${cacheCount} from cache. ${pc.dim(`${elapsed}s`)}`, + ); + console.log(""); + } + + const output = generateTypeDeclarations(registryEntries); + await fs.writeFile(outFile, output, "utf-8"); + + if (registryEntries.length === 0) { + logger.debug( + "Wrote empty serving types to %s (no endpoints resolved)", + outFile, + ); + } else { + logger.debug("Wrote serving types to %s", outFile); + } + + if (updated) { + await saveServingCache(cache as ServingCache); + } +} + +function resolveDefaultEndpoints(): Record { + if (process.env.DATABRICKS_SERVING_ENDPOINT_NAME) { + return { default: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" } }; + } + return {}; +} + +function buildRegistryEntry( + alias: string, + requestType: string, + responseType: string, + chunkType: string | null, +): string { + const indent = " "; + const chunkEntry = chunkType ? chunkType : "unknown"; + return ` ${alias}: { +${indent}request: ${indentType(requestType, indent)}; +${indent}response: ${indentType(responseType, indent)}; +${indent}chunk: ${indentType(chunkEntry, indent)}; + };`; +} + +function indentType(typeStr: string, baseIndent: string): string { + if (!typeStr.includes("\n")) return typeStr; + return typeStr + .split("\n") + .map((line, i) => (i === 0 ? line : `${baseIndent}${line}`)) + .join("\n"); +} + +function generateTypeDeclarations(entries: string[]): string { + return `// Auto-generated by AppKit - DO NOT EDIT +// Generated from serving endpoint OpenAPI schemas +import "@databricks/appkit"; +import "@databricks/appkit-ui/react"; + +declare module "@databricks/appkit" { + interface ServingEndpointRegistry { +${entries.join("\n")} + } +} + +declare module "@databricks/appkit-ui/react" { + interface ServingEndpointRegistry { +${entries.join("\n")} + } +} +`; +} diff --git a/packages/appkit/src/type-generator/serving/server-file-extractor.ts b/packages/appkit/src/type-generator/serving/server-file-extractor.ts new file mode 100644 index 00000000..b26b0bf1 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/server-file-extractor.ts @@ -0,0 +1,221 @@ +import fs from "node:fs"; +import path from "node:path"; +import { Lang, parse, type SgNode } from "@ast-grep/napi"; +import { createLogger } from "../../logging/logger"; +import type { EndpointConfig } from "../../plugins/serving/types"; + +const logger = createLogger("type-generator:serving:extractor"); + +/** + * Candidate paths for the server entry file, relative to the project root. + * Checked in order; the first that exists is used. + * Same convention as plugin sync (sync.ts SERVER_FILE_CANDIDATES). + */ +const SERVER_FILE_CANDIDATES = ["server/index.ts", "server/server.ts"]; + +/** + * Find the server entry file by checking candidate paths in order. + * + * @param basePath - Project root directory to search from + * @returns Absolute path to the server file, or null if none found + */ +export function findServerFile(basePath: string): string | null { + for (const candidate of SERVER_FILE_CANDIDATES) { + const fullPath = path.join(basePath, candidate); + if (fs.existsSync(fullPath)) { + return fullPath; + } + } + return null; +} + +/** + * Extract serving endpoint config from a server file by AST-parsing it. + * Looks for `serving({ endpoints: { alias: { env: "..." }, ... } })` calls + * and extracts the endpoint alias names and their environment variable mappings. + * + * @param serverFilePath - Absolute path to the server entry file + * @returns Extracted endpoint config, or null if not found or not extractable + */ +export function extractServingEndpoints( + serverFilePath: string, +): Record | null { + let content: string; + try { + content = fs.readFileSync(serverFilePath, "utf-8"); + } catch { + logger.debug("Could not read server file: %s", serverFilePath); + return null; + } + + const lang = serverFilePath.endsWith(".tsx") ? Lang.Tsx : Lang.TypeScript; + const ast = parse(lang, content); + const root = ast.root(); + + // Find serving(...) call expressions + const servingCall = findServingCall(root); + if (!servingCall) { + logger.debug("No serving() call found in %s", serverFilePath); + return null; + } + + // Get the first argument (the config object) + const args = servingCall.field("arguments"); + if (!args) { + return null; + } + + const configArg = args.children().find((child) => child.kind() === "object"); + if (!configArg) { + // serving() called with no args or non-object arg + return null; + } + + // Find the "endpoints" property in the config object + const endpointsPair = findProperty(configArg, "endpoints"); + if (!endpointsPair) { + // Config object has no "endpoints" property (e.g. serving({ timeout: 5000 })) + return null; + } + + // Get the value of the endpoints property + const endpointsValue = getPropertyValue(endpointsPair); + if (!endpointsValue || endpointsValue.kind() !== "object") { + // endpoints is a variable reference, not an inline object + logger.debug( + "serving() endpoints is not an inline object literal in %s. " + + "Pass endpoints explicitly via appKitServingTypesPlugin({ endpoints }) in vite.config.ts.", + serverFilePath, + ); + return null; + } + + // Extract each endpoint entry + const endpoints: Record = {}; + const pairs = endpointsValue + .children() + .filter((child) => child.kind() === "pair"); + + for (const pair of pairs) { + const entry = extractEndpointEntry(pair); + if (entry) { + endpoints[entry.alias] = entry.config; + } + } + + if (Object.keys(endpoints).length === 0) { + return null; + } + + logger.debug( + "Extracted %d endpoint(s) from %s: %s", + Object.keys(endpoints).length, + serverFilePath, + Object.keys(endpoints).join(", "), + ); + + return endpoints; +} + +/** + * Find the serving() call expression in the AST. + * Looks for call expressions where the callee identifier is "serving". + */ +function findServingCall(root: SgNode): SgNode | null { + const callExpressions = root.findAll({ + rule: { kind: "call_expression" }, + }); + + for (const call of callExpressions) { + const callee = call.children()[0]; + if (callee?.kind() === "identifier" && callee.text() === "serving") { + return call; + } + } + + return null; +} + +/** + * Find a property (pair node) with the given key name in an object expression. + */ +function findProperty(objectNode: SgNode, propertyName: string): SgNode | null { + const pairs = objectNode + .children() + .filter((child) => child.kind() === "pair"); + + for (const pair of pairs) { + const key = pair.children()[0]; + if (!key) continue; + + const keyText = + key.kind() === "property_identifier" + ? key.text() + : key.kind() === "string" + ? key.text().replace(/^['"]|['"]$/g, "") + : null; + + if (keyText === propertyName) { + return pair; + } + } + + return null; +} + +/** + * Get the value node from a pair (property: value). + * The value is typically the last meaningful child after the colon. + */ +function getPropertyValue(pairNode: SgNode): SgNode | null { + const children = pairNode.children(); + // pair children: [key, ":", value] + return children.length >= 3 ? children[children.length - 1] : null; +} + +/** + * Extract a single endpoint entry from a pair node like: + * `demo: { env: "DATABRICKS_SERVING_ENDPOINT_NAME", servedModel: "my-model" }` + */ +function extractEndpointEntry( + pair: SgNode, +): { alias: string; config: EndpointConfig } | null { + const children = pair.children(); + if (children.length < 3) return null; + + // Get alias name (the key) + const keyNode = children[0]; + const alias = + keyNode.kind() === "property_identifier" + ? keyNode.text() + : keyNode.kind() === "string" + ? keyNode.text().replace(/^['"]|['"]$/g, "") + : null; + + if (!alias) return null; + + // Get the value (should be an object like { env: "..." }) + const valueNode = children[children.length - 1]; + if (valueNode.kind() !== "object") return null; + + // Extract env field + const envPair = findProperty(valueNode, "env"); + if (!envPair) return null; + + const envValue = getPropertyValue(envPair); + if (!envValue || envValue.kind() !== "string") return null; + + const env = envValue.text().replace(/^['"]|['"]$/g, ""); + + // Extract optional servedModel field + const config: EndpointConfig = { env }; + const servedModelPair = findProperty(valueNode, "servedModel"); + if (servedModelPair) { + const servedModelValue = getPropertyValue(servedModelPair); + if (servedModelValue?.kind() === "string") { + config.servedModel = servedModelValue.text().replace(/^['"]|['"]$/g, ""); + } + } + + return { alias, config }; +} diff --git a/packages/appkit/src/type-generator/serving/tests/cache.test.ts b/packages/appkit/src/type-generator/serving/tests/cache.test.ts new file mode 100644 index 00000000..0c99c997 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/cache.test.ts @@ -0,0 +1,109 @@ +import fs from "node:fs/promises"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { + CACHE_VERSION, + hashSchema, + loadServingCache, + type ServingCache, + saveServingCache, +} from "../cache"; + +vi.mock("node:fs/promises"); + +describe("serving cache", () => { + beforeEach(() => { + vi.mocked(fs.mkdir).mockResolvedValue(undefined); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe("hashSchema", () => { + test("returns consistent SHA256 hash", () => { + const hash1 = hashSchema('{"openapi": "3.1.0"}'); + const hash2 = hashSchema('{"openapi": "3.1.0"}'); + expect(hash1).toBe(hash2); + expect(hash1).toHaveLength(64); // SHA256 hex + }); + + test("different inputs produce different hashes", () => { + const hash1 = hashSchema('{"a": 1}'); + const hash2 = hashSchema('{"a": 2}'); + expect(hash1).not.toBe(hash2); + }); + }); + + describe("loadServingCache", () => { + test("returns empty cache when file does not exist", async () => { + vi.mocked(fs.readFile).mockRejectedValue( + Object.assign(new Error("ENOENT"), { code: "ENOENT" }), + ); + + const cache = await loadServingCache(); + expect(cache).toEqual({ version: CACHE_VERSION, endpoints: {} }); + }); + + test("returns parsed cache when file exists with correct version", async () => { + const cached: ServingCache = { + version: CACHE_VERSION, + endpoints: { + llm: { + hash: "abc", + requestType: "{ messages: string[] }", + responseType: "{ model: string }", + chunkType: null, + requestKeys: ["messages"], + }, + }, + }; + vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(cached)); + + const cache = await loadServingCache(); + expect(cache).toEqual(cached); + }); + + test("flushes cache when version mismatches", async () => { + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ version: "0", endpoints: { old: {} } }), + ); + + const cache = await loadServingCache(); + expect(cache).toEqual({ version: CACHE_VERSION, endpoints: {} }); + }); + + test("flushes cache when file is corrupted", async () => { + vi.mocked(fs.readFile).mockResolvedValue("not json"); + + const cache = await loadServingCache(); + expect(cache).toEqual({ version: CACHE_VERSION, endpoints: {} }); + }); + }); + + describe("saveServingCache", () => { + test("writes cache to file", async () => { + vi.mocked(fs.writeFile).mockResolvedValue(); + + const cache: ServingCache = { + version: CACHE_VERSION, + endpoints: { + test: { + hash: "xyz", + requestType: "{}", + responseType: "{}", + chunkType: null, + requestKeys: [], + }, + }, + }; + + await saveServingCache(cache); + + expect(fs.writeFile).toHaveBeenCalledWith( + expect.stringContaining(".appkit-serving-types-cache.json"), + JSON.stringify(cache, null, 2), + "utf8", + ); + }); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/converter.test.ts b/packages/appkit/src/type-generator/serving/tests/converter.test.ts new file mode 100644 index 00000000..1be30738 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/converter.test.ts @@ -0,0 +1,308 @@ +import { describe, expect, test } from "vitest"; +import { + convertRequestSchema, + convertResponseSchema, + deriveChunkType, + extractRequestKeys, +} from "../converter"; +import type { OpenApiOperation, OpenApiSchema } from "../fetcher"; + +function makeOperation( + requestProps: Record, + responseProps?: Record, + required?: string[], +): OpenApiOperation { + return { + requestBody: { + content: { + "application/json": { + schema: { + type: "object", + properties: requestProps, + required, + }, + }, + }, + }, + responses: responseProps + ? { + "200": { + content: { + "application/json": { + schema: { + type: "object", + properties: responseProps, + }, + }, + }, + }, + } + : undefined, + }; +} + +describe("converter", () => { + describe("convertRequestSchema", () => { + test("converts string type", () => { + const op = makeOperation({ name: { type: "string" } }); + const result = convertRequestSchema(op); + expect(result).toContain("name?: string;"); + }); + + test("converts integer type to number", () => { + const op = makeOperation({ count: { type: "integer" } }); + expect(convertRequestSchema(op)).toContain("count?: number;"); + }); + + test("converts number type", () => { + const op = makeOperation({ + temp: { type: "number", format: "double" }, + }); + expect(convertRequestSchema(op)).toContain("temp?: number;"); + }); + + test("converts boolean type", () => { + const op = makeOperation({ flag: { type: "boolean" } }); + expect(convertRequestSchema(op)).toContain("flag?: boolean;"); + }); + + test("converts enum to string literal union", () => { + const op = makeOperation({ + role: { type: "string", enum: ["user", "assistant"] }, + }); + const result = convertRequestSchema(op); + expect(result).toContain('"user" | "assistant"'); + }); + + test("converts array type", () => { + const op = makeOperation({ + items: { type: "array", items: { type: "string" } }, + }); + expect(convertRequestSchema(op)).toContain("items?: string[];"); + }); + + test("converts nested object", () => { + const op = makeOperation({ + messages: { + type: "array", + items: { + type: "object", + properties: { + role: { type: "string" }, + content: { type: "string" }, + }, + }, + }, + }); + const result = convertRequestSchema(op); + expect(result).toContain("role?: string;"); + expect(result).toContain("content?: string;"); + }); + + test("handles nullable properties", () => { + const op = makeOperation({ + temp: { type: "number", nullable: true }, + }); + expect(convertRequestSchema(op)).toContain("temp?: number | null;"); + }); + + test("handles oneOf union types", () => { + const op = makeOperation({ + stop: { + oneOf: [ + { type: "string" }, + { type: "array", items: { type: "string" } }, + ], + }, + }); + const result = convertRequestSchema(op); + expect(result).toContain("string | string[]"); + }); + + test("strips stream property from request", () => { + const op = makeOperation({ + messages: { type: "array", items: { type: "string" } }, + stream: { type: "boolean", nullable: true }, + temperature: { type: "number" }, + }); + const result = convertRequestSchema(op); + expect(result).not.toContain("stream"); + expect(result).toContain("messages"); + expect(result).toContain("temperature"); + }); + + test("marks required properties without ?", () => { + const op = makeOperation( + { + messages: { type: "array", items: { type: "string" } }, + temperature: { type: "number" }, + }, + undefined, + ["messages"], + ); + const result = convertRequestSchema(op); + expect(result).toContain("messages: string[];"); + expect(result).toContain("temperature?: number;"); + }); + + test("returns Record for missing schema", () => { + const op: OpenApiOperation = {}; + expect(convertRequestSchema(op)).toBe("Record"); + }); + }); + + describe("convertResponseSchema", () => { + test("converts response schema", () => { + const op = makeOperation( + {}, + { + model: { type: "string" }, + id: { type: "string" }, + }, + ); + const result = convertResponseSchema(op); + expect(result).toContain("model?: string;"); + expect(result).toContain("id?: string;"); + }); + + test("returns unknown for missing response", () => { + const op: OpenApiOperation = {}; + expect(convertResponseSchema(op)).toBe("unknown"); + }); + }); + + describe("deriveChunkType", () => { + test("derives chunk type from OpenAI-compatible response", () => { + const op: OpenApiOperation = { + responses: { + "200": { + content: { + "application/json": { + schema: { + type: "object", + properties: { + model: { type: "string" }, + choices: { + type: "array", + items: { + type: "object", + properties: { + index: { type: "integer" }, + message: { + type: "object", + properties: { + role: { + type: "string", + enum: ["user", "assistant"], + }, + content: { type: "string" }, + }, + }, + finish_reason: { type: "string" }, + }, + }, + }, + usage: { + type: "object", + properties: { + prompt_tokens: { type: "integer" }, + }, + nullable: true, + }, + id: { type: "string" }, + }, + }, + }, + }, + }, + }, + }; + + const result = deriveChunkType(op); + expect(result).not.toBeNull(); + // Should have delta instead of message + expect(result).toContain("delta"); + expect(result).not.toContain("message"); + // Should make finish_reason nullable + expect(result).toContain("finish_reason"); + expect(result).toContain("| null"); + // Should drop usage + expect(result).not.toContain("usage"); + // Should keep model and id + expect(result).toContain("model"); + expect(result).toContain("id"); + }); + + test("returns null for non-OpenAI response (no choices)", () => { + const op = makeOperation( + {}, + { + predictions: { type: "array", items: { type: "number" } }, + }, + ); + expect(deriveChunkType(op)).toBeNull(); + }); + + test("returns null for choices without message", () => { + const op: OpenApiOperation = { + responses: { + "200": { + content: { + "application/json": { + schema: { + type: "object", + properties: { + choices: { + type: "array", + items: { + type: "object", + properties: { + score: { type: "number" }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }; + expect(deriveChunkType(op)).toBeNull(); + }); + + test("returns null for missing response", () => { + const op: OpenApiOperation = {}; + expect(deriveChunkType(op)).toBeNull(); + }); + }); + + describe("extractRequestKeys", () => { + test("extracts top-level property keys excluding stream", () => { + const op = makeOperation({ + messages: { type: "array", items: { type: "string" } }, + temperature: { type: "number" }, + stream: { type: "boolean", nullable: true }, + }); + expect(extractRequestKeys(op)).toEqual(["messages", "temperature"]); + }); + + test("returns empty array for missing schema", () => { + const op: OpenApiOperation = {}; + expect(extractRequestKeys(op)).toEqual([]); + }); + + test("returns empty array for schema without properties", () => { + const op: OpenApiOperation = { + requestBody: { + content: { + "application/json": { + schema: { type: "object" }, + }, + }, + }, + }; + expect(extractRequestKeys(op)).toEqual([]); + }); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/fetcher.test.ts b/packages/appkit/src/type-generator/serving/tests/fetcher.test.ts new file mode 100644 index 00000000..802540b0 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/fetcher.test.ts @@ -0,0 +1,209 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { fetchOpenApiSchema } from "../fetcher"; + +const mockAuthenticate = vi.fn(async () => {}); + +function createMockClient(host?: string) { + return { + config: { + host, + authenticate: mockAuthenticate, + }, + } as any; +} + +function makeValidSpec( + paths: Record = { "/invocations": { post: {} } }, +) { + return { + openapi: "3.0.0", + info: { title: "test", version: "1" }, + paths, + }; +} + +describe("fetchOpenApiSchema", () => { + beforeEach(() => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(makeValidSpec()), { status: 200 }), + ); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + test("returns null when host is not configured", async () => { + const result = await fetchOpenApiSchema(createMockClient(undefined), "ep"); + expect(result).toBeNull(); + }); + + test("returns null on HTTP 404", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response("Not found", { status: 404 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my-endpoint", + ); + expect(result).toBeNull(); + }); + + test("returns null on HTTP 403", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response("Forbidden", { status: 403 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my-endpoint", + ); + expect(result).toBeNull(); + }); + + test("returns null on generic error status", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response("Server error", { status: 500 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my-endpoint", + ); + expect(result).toBeNull(); + }); + + test("returns null on timeout (AbortError)", async () => { + vi.spyOn(globalThis, "fetch").mockRejectedValue( + Object.assign(new Error("The operation was aborted"), { + name: "AbortError", + }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my-endpoint", + ); + expect(result).toBeNull(); + }); + + test("returns null on network error", async () => { + vi.spyOn(globalThis, "fetch").mockRejectedValue(new Error("fetch failed")); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my-endpoint", + ); + expect(result).toBeNull(); + }); + + test("returns spec and pathKey for valid response", async () => { + const spec = makeValidSpec({ + "/serving-endpoints/ep/invocations": { post: { requestBody: {} } }, + }); + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(spec), { status: 200 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + ); + expect(result).not.toBeNull(); + expect(result?.pathKey).toBe("/serving-endpoints/ep/invocations"); + expect(result?.spec.openapi).toBe("3.0.0"); + }); + + test("matches servedModel path when provided", async () => { + const spec = makeValidSpec({ + "/serving-endpoints/ep/served-models/gpt4/invocations": { post: {} }, + "/serving-endpoints/ep/invocations": { post: {} }, + }); + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(spec), { status: 200 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + "gpt4", + ); + expect(result?.pathKey).toBe( + "/serving-endpoints/ep/served-models/gpt4/invocations", + ); + }); + + test("falls back to first path when servedModel not found", async () => { + const spec = makeValidSpec({ + "/serving-endpoints/ep/invocations": { post: {} }, + }); + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(spec), { status: 200 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + "nonexistent-model", + ); + expect(result?.pathKey).toBe("/serving-endpoints/ep/invocations"); + }); + + test("returns null for invalid spec structure (missing paths)", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ openapi: "3.0.0", info: {} }), { + status: 200, + }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + ); + expect(result).toBeNull(); + }); + + test("returns null when paths object is empty", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(makeValidSpec({})), { status: 200 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + ); + expect(result).toBeNull(); + }); + + test("authenticates request headers", async () => { + await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + ); + expect(mockAuthenticate).toHaveBeenCalledWith(expect.any(Headers)); + }); + + test("constructs correct URL with encoded endpoint name", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my endpoint", + ); + + expect(fetchSpy).toHaveBeenCalledWith( + expect.stringContaining("/serving-endpoints/my%20endpoint/openapi"), + expect.any(Object), + ); + }); + + test("prepends https when host lacks protocol", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + await fetchOpenApiSchema(createMockClient("host.databricks.com"), "ep"); + + const url = fetchSpy.mock.calls[0][0] as string; + expect(url.startsWith("https://")).toBe(true); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/generator.test.ts b/packages/appkit/src/type-generator/serving/tests/generator.test.ts new file mode 100644 index 00000000..8761519b --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/generator.test.ts @@ -0,0 +1,215 @@ +import fs from "node:fs/promises"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { generateServingTypes } from "../generator"; + +vi.mock("node:fs/promises"); + +// Mock cache module +vi.mock("../cache", () => ({ + CACHE_VERSION: "1", + hashSchema: vi.fn(() => "mock-hash"), + loadServingCache: vi.fn(async () => ({ version: "1", endpoints: {} })), + saveServingCache: vi.fn(async () => {}), +})); + +// Mock fetcher +const mockFetchOpenApiSchema = vi.fn(); +vi.mock("../fetcher", () => ({ + fetchOpenApiSchema: (...args: any[]) => mockFetchOpenApiSchema(...args), +})); + +// Mock WorkspaceClient +vi.mock("@databricks/sdk-experimental", () => ({ + WorkspaceClient: vi.fn(() => ({ config: {} })), +})); + +const CHAT_OPENAPI_SPEC = { + openapi: "3.1.0", + info: { title: "test", version: "1" }, + paths: { + "/served-models/llm/invocations": { + post: { + requestBody: { + content: { + "application/json": { + schema: { + type: "object", + properties: { + messages: { + type: "array", + items: { + type: "object", + properties: { + role: { type: "string" }, + content: { type: "string" }, + }, + }, + }, + temperature: { type: "number", nullable: true }, + stream: { type: "boolean", nullable: true }, + }, + }, + }, + }, + }, + responses: { + "200": { + content: { + "application/json": { + schema: { + type: "object", + properties: { + model: { type: "string" }, + choices: { + type: "array", + items: { + type: "object", + properties: { + message: { + type: "object", + properties: { + role: { type: "string" }, + content: { type: "string" }, + }, + }, + finish_reason: { type: "string" }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, +}; + +describe("generateServingTypes", () => { + const outFile = "/tmp/test-serving-types.d.ts"; + + beforeEach(() => { + vi.mocked(fs.writeFile).mockResolvedValue(); + process.env.TEST_SERVING_ENDPOINT = "my-endpoint"; + }); + + afterEach(() => { + delete process.env.TEST_SERVING_ENDPOINT; + delete process.env.DATABRICKS_SERVING_ENDPOINT_NAME; + vi.restoreAllMocks(); + }); + + test("generates .d.ts with module augmentation for a chat endpoint", async () => { + mockFetchOpenApiSchema.mockResolvedValue({ + spec: CHAT_OPENAPI_SPEC, + pathKey: "/served-models/llm/invocations", + }); + + await generateServingTypes({ + outFile, + endpoints: { llm: { env: "TEST_SERVING_ENDPOINT" } }, + noCache: true, + }); + + expect(fs.writeFile).toHaveBeenCalledWith( + outFile, + expect.any(String), + "utf-8", + ); + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + + // Verify module augmentation structure + expect(output).toContain("// Auto-generated by AppKit - DO NOT EDIT"); + expect(output).toContain('import "@databricks/appkit"'); + expect(output).toContain('import "@databricks/appkit-ui/react"'); + expect(output).toContain('declare module "@databricks/appkit"'); + expect(output).toContain('declare module "@databricks/appkit-ui/react"'); + expect(output).toContain("interface ServingEndpointRegistry"); + expect(output).toContain("llm:"); + expect(output).toContain("request:"); + expect(output).toContain("response:"); + expect(output).toContain("chunk:"); + }); + + test("strips stream property from generated request type", async () => { + mockFetchOpenApiSchema.mockResolvedValue({ + spec: CHAT_OPENAPI_SPEC, + pathKey: "/served-models/llm/invocations", + }); + + await generateServingTypes({ + outFile, + endpoints: { llm: { env: "TEST_SERVING_ENDPOINT" } }, + noCache: true, + }); + + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + // `stream` should be stripped from request type + expect(output).toContain("messages"); + expect(output).toContain("temperature"); + expect(output).not.toMatch(/\bstream\??\s*:/); + }); + + test("emits generic types when env var is not set", async () => { + delete process.env.TEST_SERVING_ENDPOINT; + + await generateServingTypes({ + outFile, + endpoints: { llm: { env: "TEST_SERVING_ENDPOINT" } }, + noCache: true, + }); + + expect(mockFetchOpenApiSchema).not.toHaveBeenCalled(); + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + expect(output).toContain("llm:"); + expect(output).toContain("Record"); + }); + + test("skips generation when no endpoints configured and no env var", async () => { + await generateServingTypes({ + outFile, + noCache: true, + }); + + expect(mockFetchOpenApiSchema).not.toHaveBeenCalled(); + expect(fs.writeFile).not.toHaveBeenCalled(); + }); + + test("emits generic types when schema fetch returns null", async () => { + mockFetchOpenApiSchema.mockResolvedValue(null); + + await generateServingTypes({ + outFile, + endpoints: { llm: { env: "TEST_SERVING_ENDPOINT" } }, + noCache: true, + }); + + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + expect(output).toContain("llm:"); + expect(output).toContain("Record"); + }); + + test("resolves default endpoint from DATABRICKS_SERVING_ENDPOINT_NAME", async () => { + process.env.DATABRICKS_SERVING_ENDPOINT_NAME = "my-default-endpoint"; + mockFetchOpenApiSchema.mockResolvedValue({ + spec: CHAT_OPENAPI_SPEC, + pathKey: "/served-models/llm/invocations", + }); + + await generateServingTypes({ + outFile, + noCache: true, + }); + + expect(mockFetchOpenApiSchema).toHaveBeenCalledWith( + expect.anything(), + "my-default-endpoint", + undefined, + ); + + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + expect(output).toContain("default:"); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/server-file-extractor.test.ts b/packages/appkit/src/type-generator/serving/tests/server-file-extractor.test.ts new file mode 100644 index 00000000..4d1a73c7 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/server-file-extractor.test.ts @@ -0,0 +1,216 @@ +import fs from "node:fs"; +import path from "node:path"; +import { afterEach, describe, expect, test, vi } from "vitest"; +import { + extractServingEndpoints, + findServerFile, +} from "../server-file-extractor"; + +describe("findServerFile", () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + test("returns server/index.ts when it exists", () => { + vi.spyOn(fs, "existsSync").mockImplementation((p) => + String(p).endsWith(path.join("server", "index.ts")), + ); + expect(findServerFile("/app")).toBe( + path.join("/app", "server", "index.ts"), + ); + }); + + test("returns server/server.ts when index.ts does not exist", () => { + vi.spyOn(fs, "existsSync").mockImplementation((p) => + String(p).endsWith(path.join("server", "server.ts")), + ); + expect(findServerFile("/app")).toBe( + path.join("/app", "server", "server.ts"), + ); + }); + + test("returns null when no server file exists", () => { + vi.spyOn(fs, "existsSync").mockReturnValue(false); + expect(findServerFile("/app")).toBeNull(); + }); +}); + +describe("extractServingEndpoints", () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + function mockServerFile(content: string) { + vi.spyOn(fs, "readFileSync").mockReturnValue(content); + } + + test("extracts inline endpoints from serving() call", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ + endpoints: { + demo: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" }, + second: { env: "DATABRICKS_SERVING_ENDPOINT_SECOND" }, + } + }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toEqual({ + demo: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" }, + second: { env: "DATABRICKS_SERVING_ENDPOINT_SECOND" }, + }); + }); + + test("extracts servedModel when present", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ + endpoints: { + demo: { env: "DATABRICKS_SERVING_ENDPOINT_NAME", servedModel: "my-model" }, + } + }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toEqual({ + demo: { + env: "DATABRICKS_SERVING_ENDPOINT_NAME", + servedModel: "my-model", + }, + }); + }); + + test("returns null when serving() has no arguments", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [serving()], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when serving() has config but no endpoints", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ timeout: 5000 }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when serving() has empty config object", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [serving({})], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when endpoints is a variable reference", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +const myEndpoints = { demo: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" } }; +createApp({ + plugins: [ + serving({ endpoints: myEndpoints }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when no serving() call exists", () => { + mockServerFile(` +import { createApp, analytics } from '@databricks/appkit'; + +createApp({ + plugins: [analytics({})], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when server file cannot be read", () => { + vi.spyOn(fs, "readFileSync").mockImplementation(() => { + throw new Error("ENOENT"); + }); + + const result = extractServingEndpoints("/app/server/nonexistent.ts"); + expect(result).toBeNull(); + }); + + test("handles single-quoted env values", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ + endpoints: { + demo: { env: 'DATABRICKS_SERVING_ENDPOINT_NAME' }, + } + }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toEqual({ + demo: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" }, + }); + }); + + test("handles endpoints with trailing commas", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ + endpoints: { + demo: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" }, + second: { env: "DATABRICKS_SERVING_ENDPOINT_SECOND" }, + }, + }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toEqual({ + demo: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" }, + second: { env: "DATABRICKS_SERVING_ENDPOINT_SECOND" }, + }); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/vite-plugin.test.ts b/packages/appkit/src/type-generator/serving/tests/vite-plugin.test.ts new file mode 100644 index 00000000..074d3d44 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/vite-plugin.test.ts @@ -0,0 +1,186 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; + +const mockGenerateServingTypes = vi.fn(async () => {}); +const mockFindServerFile = vi.fn((): string | null => null); +const mockExtractServingEndpoints = vi.fn( + (): Record | null => null, +); + +vi.mock("../generator", () => ({ + generateServingTypes: (...args: any[]) => mockGenerateServingTypes(...args), +})); + +vi.mock("../server-file-extractor", () => ({ + findServerFile: (...args: any[]) => mockFindServerFile(...args), + extractServingEndpoints: (...args: any[]) => + mockExtractServingEndpoints(...args), +})); + +import { appKitServingTypesPlugin } from "../vite-plugin"; + +describe("appKitServingTypesPlugin", () => { + const originalEnv = { ...process.env }; + + beforeEach(() => { + mockGenerateServingTypes.mockReset(); + mockFindServerFile.mockReset(); + mockExtractServingEndpoints.mockReset(); + }); + + afterEach(() => { + process.env = { ...originalEnv }; + vi.restoreAllMocks(); + }); + + describe("apply()", () => { + test("returns true when explicit endpoints provided", () => { + const plugin = appKitServingTypesPlugin({ + endpoints: { llm: { env: "LLM_ENDPOINT" } }, + }); + expect((plugin as any).apply()).toBe(true); + }); + + test("returns true when DATABRICKS_SERVING_ENDPOINT_NAME is set", () => { + process.env.DATABRICKS_SERVING_ENDPOINT_NAME = "my-endpoint"; + const plugin = appKitServingTypesPlugin(); + expect((plugin as any).apply()).toBe(true); + }); + + test("returns true when server file found in cwd", () => { + mockFindServerFile.mockReturnValueOnce("/app/server/index.ts"); + const plugin = appKitServingTypesPlugin(); + expect((plugin as any).apply()).toBe(true); + }); + + test("returns true when server file found in parent dir", () => { + mockFindServerFile + .mockReturnValueOnce(null) // cwd check + .mockReturnValueOnce("/app/server/index.ts"); // parent check + const plugin = appKitServingTypesPlugin(); + expect((plugin as any).apply()).toBe(true); + }); + + test("returns false when nothing configured", () => { + delete process.env.DATABRICKS_SERVING_ENDPOINT_NAME; + mockFindServerFile.mockReturnValue(null); + const plugin = appKitServingTypesPlugin(); + expect((plugin as any).apply()).toBe(false); + }); + }); + + describe("configResolved()", () => { + test("resolves outFile relative to config.root", async () => { + const plugin = appKitServingTypesPlugin({ + endpoints: { llm: { env: "LLM" } }, + }); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ + outFile: expect.stringContaining( + "/app/client/src/appKitServingTypes.d.ts", + ), + }), + ); + }); + + test("uses custom outFile when provided", async () => { + const plugin = appKitServingTypesPlugin({ + outFile: "types/serving.d.ts", + endpoints: { llm: { env: "LLM" } }, + }); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ + outFile: expect.stringContaining("types/serving.d.ts"), + }), + ); + }); + }); + + describe("buildStart()", () => { + test("calls generateServingTypes with explicit endpoints", async () => { + const endpoints = { llm: { env: "LLM_ENDPOINT" } }; + const plugin = appKitServingTypesPlugin({ endpoints }); + (plugin as any).configResolved({ root: "/app/client" }); + + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ + endpoints, + noCache: false, + }), + ); + }); + + test("extracts endpoints from server file when not explicit", async () => { + const extracted = { llm: { env: "LLM_EP" } }; + mockFindServerFile.mockReturnValue("/app/server/index.ts"); + mockExtractServingEndpoints.mockReturnValue(extracted); + + const plugin = appKitServingTypesPlugin(); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ endpoints: extracted }), + ); + }); + + test("passes undefined endpoints when no server file found", async () => { + mockFindServerFile.mockReturnValue(null); + + const plugin = appKitServingTypesPlugin(); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ endpoints: undefined }), + ); + }); + + test("passes undefined when AST extraction returns null", async () => { + mockFindServerFile.mockReturnValue("/app/server/index.ts"); + mockExtractServingEndpoints.mockReturnValue(null); + + const plugin = appKitServingTypesPlugin(); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ endpoints: undefined }), + ); + }); + + test("swallows errors in dev mode", async () => { + process.env.NODE_ENV = "development"; + mockGenerateServingTypes.mockRejectedValue(new Error("fetch failed")); + + const plugin = appKitServingTypesPlugin({ + endpoints: { llm: { env: "LLM" } }, + }); + (plugin as any).configResolved({ root: "/app/client" }); + + // Should not throw + await expect((plugin as any).buildStart()).resolves.toBeUndefined(); + }); + + test("rethrows errors in production mode", async () => { + process.env.NODE_ENV = "production"; + mockGenerateServingTypes.mockRejectedValue(new Error("fetch failed")); + + const plugin = appKitServingTypesPlugin({ + endpoints: { llm: { env: "LLM" } }, + }); + (plugin as any).configResolved({ root: "/app/client" }); + + await expect((plugin as any).buildStart()).rejects.toThrow( + "fetch failed", + ); + }); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/vite-plugin.ts b/packages/appkit/src/type-generator/serving/vite-plugin.ts new file mode 100644 index 00000000..accde210 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/vite-plugin.ts @@ -0,0 +1,109 @@ +import path from "node:path"; +import type { Plugin } from "vite"; +import { createLogger } from "../../logging/logger"; +import type { EndpointConfig } from "../../plugins/serving/types"; +import { generateServingTypes } from "./generator"; +import { + extractServingEndpoints, + findServerFile, +} from "./server-file-extractor"; + +const logger = createLogger("type-generator:serving:vite-plugin"); + +interface AppKitServingTypesPluginOptions { + /** Path to the output .d.ts file (relative to client root). Default: "src/appKitServingTypes.d.ts" */ + outFile?: string; + /** Endpoint config override. If omitted, auto-discovers from the server file or falls back to DATABRICKS_SERVING_ENDPOINT_NAME env var. */ + endpoints?: Record; +} + +/** + * Vite plugin to generate TypeScript types for AppKit serving endpoints. + * Fetches OpenAPI schemas from Databricks and generates a .d.ts with + * ServingEndpointRegistry module augmentation. + * + * Endpoint discovery order: + * 1. Explicit `endpoints` option (override) + * 2. AST extraction from server file (server/index.ts or server/server.ts) + * 3. DATABRICKS_SERVING_ENDPOINT_NAME env var (single default endpoint) + */ +export function appKitServingTypesPlugin( + options?: AppKitServingTypesPluginOptions, +): Plugin { + let outFile: string; + let projectRoot: string; + + async function generate() { + try { + // Resolve endpoints: explicit option > server file AST > env var fallback (handled by generator) + let endpoints = options?.endpoints; + if (!endpoints) { + const serverFile = findServerFile(projectRoot); + if (serverFile) { + endpoints = extractServingEndpoints(serverFile) ?? undefined; + } + } + + await generateServingTypes({ + outFile, + endpoints, + noCache: false, + }); + } catch (error) { + if (process.env.NODE_ENV === "production") { + throw error; + } + logger.error("Error generating serving types: %O", error); + } + } + + return { + name: "appkit-serving-types", + + apply() { + // Fast checks — no AST parsing here + if (options?.endpoints && Object.keys(options.endpoints).length > 0) { + return true; + } + + if (process.env.DATABRICKS_SERVING_ENDPOINT_NAME) { + return true; + } + + // Check if a server file exists (may contain serving() config) + // Use process.cwd() for apply() since configResolved hasn't run yet + if (findServerFile(process.cwd())) { + return true; + } + + // Also check parent dir (for when cwd is client/) + const parentDir = path.resolve(process.cwd(), ".."); + if (findServerFile(parentDir)) { + return true; + } + + logger.debug( + "No serving endpoints configured. Skipping type generation.", + ); + return false; + }, + + configResolved(config) { + // Resolve project root: go up one level from Vite root (client dir) + // This handles both: + // - pnpm dev: process.cwd() is app root, config.root is client/ + // - pnpm build: process.cwd() is client/ (cd client && vite build), config.root is client/ + projectRoot = path.resolve(config.root, ".."); + outFile = path.resolve( + config.root, + options?.outFile ?? "src/appKitServingTypes.d.ts", + ); + }, + + async buildStart() { + await generate(); + }, + + // No configureServer / watcher — schemas change on endpoint redeploy, not on file edit + }; +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 199fcfb8..9ca11b81 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -242,6 +242,9 @@ importers: packages/appkit: dependencies: + '@ast-grep/napi': + specifier: 0.37.0 + version: 0.37.0 '@databricks/lakebase': specifier: workspace:* version: link:../lakebase diff --git a/template/client/src/App.tsx b/template/client/src/App.tsx index fb4c28e6..a94bb5bc 100644 --- a/template/client/src/App.tsx +++ b/template/client/src/App.tsx @@ -17,6 +17,9 @@ import { GeniePage } from './pages/genie/GeniePage'; {{- if .plugins.files}} import { FilesPage } from './pages/files/FilesPage'; {{- end}} +{{- if .plugins.serving}} +import { ServingPage } from './pages/serving/ServingPage'; +{{- end}} const navLinkClass = ({ isActive }: { isActive: boolean }) => `px-3 py-1.5 rounded-md text-sm font-medium transition-colors ${ @@ -53,6 +56,11 @@ function Layout() { Files +{{- end}} +{{- if .plugins.serving}} + + Serving + {{- end}} @@ -80,6 +88,9 @@ const router = createBrowserRouter([ {{- end}} {{- if .plugins.files}} { path: '/files', element: }, +{{- end}} +{{- if .plugins.serving}} + { path: '/serving', element: }, {{- end}} ], }, diff --git a/template/client/src/pages/serving/ServingPage.tsx b/template/client/src/pages/serving/ServingPage.tsx new file mode 100644 index 00000000..b80934ba --- /dev/null +++ b/template/client/src/pages/serving/ServingPage.tsx @@ -0,0 +1,127 @@ +{{if .plugins.serving -}} +import { useServingInvoke } from '@databricks/appkit-ui/react'; +// For streaming endpoints (e.g. chat models), use useServingStream instead: +// import { useServingStream } from '@databricks/appkit-ui/react'; +import { useState } from 'react'; + +interface ChatChoice { + message?: { content?: string }; +} + +interface ChatResponse { + choices?: ChatChoice[]; +} + +function extractContent(data: unknown): string { + const resp = data as ChatResponse; + return resp?.choices?.[0]?.message?.content ?? JSON.stringify(data); +} + +interface Message { + id: string; + role: 'user' | 'assistant'; + content: string; +} + +export function ServingPage() { + const [input, setInput] = useState(''); + const [messages, setMessages] = useState([]); + + const { invoke, loading, error } = useServingInvoke({ messages: [] }); + // For streaming endpoints (e.g. chat models), use useServingStream instead: + // const { stream, chunks, streaming, error, reset } = useServingStream({ messages: [] }); + // Then accumulate chunks: chunks.map(c => c?.choices?.[0]?.delta?.content ?? '').join('') + + function handleSubmit(e: React.FormEvent) { + e.preventDefault(); + if (!input.trim() || loading) return; + + const userMessage: Message = { + id: crypto.randomUUID(), + role: 'user', + content: input.trim(), + }; + + const fullMessages = [ + ...messages.map(({ role, content }) => ({ role, content })), + { role: 'user' as const, content: userMessage.content }, + ]; + + setMessages((prev) => [...prev, userMessage]); + setInput(''); + + void invoke({ messages: fullMessages }).then((result) => { + if (result) { + setMessages((prev) => [ + ...prev, + { id: crypto.randomUUID(), role: 'assistant', content: extractContent(result) }, + ]); + } + }); + } + + return ( +
+
+

Model Serving

+

+ Chat with a Databricks Model Serving endpoint. +

+
+ +
+
+ {messages.map((msg) => ( +
+
+

{msg.content}

+
+
+ ))} + + {loading && ( +
+
+

...

+
+
+ )} + + {error && ( +
+ Error: {error} +
+ )} +
+ +
+ setInput(e.target.value)} + placeholder="Send a message..." + className="flex-1 rounded-md border px-3 py-2 text-sm bg-background" + disabled={loading} + /> + +
+
+
+ ); +} +{{- end}} diff --git a/template/client/vite.config.ts b/template/client/vite.config.ts index b49d4055..12c1d864 100644 --- a/template/client/vite.config.ts +++ b/template/client/vite.config.ts @@ -2,11 +2,20 @@ import { defineConfig } from 'vite'; import react from '@vitejs/plugin-react'; import tailwindcss from '@tailwindcss/vite'; import path from 'node:path'; +{{- if .plugins.serving}} +import { appKitServingTypesPlugin } from '@databricks/appkit'; +{{- end}} // https://vite.dev/config/ export default defineConfig({ root: __dirname, - plugins: [react(), tailwindcss()], + plugins: [ + react(), + tailwindcss(), +{{- if .plugins.serving}} + appKitServingTypesPlugin(), +{{- end}} + ], server: { middlewareMode: true, }, diff --git a/template/databricks.yml.tmpl b/template/databricks.yml.tmpl index accf7709..77997d31 100644 --- a/template/databricks.yml.tmpl +++ b/template/databricks.yml.tmpl @@ -13,7 +13,7 @@ resources: description: "{{.appDescription}}" source_code_path: ./ -{{- if or .plugins.genie .plugins.files}} +{{- if or .plugins.genie .plugins.files .plugins.serving}} user_api_scopes: {{- if .plugins.genie}} - dashboards.genie @@ -21,8 +21,11 @@ resources: {{- if .plugins.files}} - files.files {{- end}} +{{- if .plugins.serving}} + - serving.serving-endpoints +{{- end}} {{- else}} - # Uncomment to enable on behalf of user API scopes. Available scopes: sql, dashboards.genie, files.files + # Uncomment to enable on behalf of user API scopes. Available scopes: sql, dashboards.genie, files.files, serving.serving-endpoints # user_api_scopes: # - sql {{- end}} diff --git a/tools/generate-app-templates.ts b/tools/generate-app-templates.ts index 4b029121..1eff9357 100644 --- a/tools/generate-app-templates.ts +++ b/tools/generate-app-templates.ts @@ -55,21 +55,23 @@ const FEATURE_DEPENDENCIES: Record = { files: "Volume", genie: "Genie Space", lakebase: "Database", + serving: "Serving Endpoint", }; const APP_TEMPLATES: AppTemplate[] = [ { name: "appkit-all-in-one", - features: ["analytics", "files", "genie", "lakebase"], + features: ["analytics", "files", "genie", "lakebase", "serving"], set: { "analytics.sql-warehouse.id": "placeholder", "files.files.path": "placeholder", "genie.genie-space.id": "placeholder", "lakebase.postgres.branch": "placeholder", "lakebase.postgres.database": "placeholder", + "serving.serving-endpoint.name": "placeholder", }, description: - "Full-stack Node.js app with SQL analytics dashboards, file browser, Genie AI conversations, and Lakebase Autoscaling (Postgres) CRUD", + "Full-stack Node.js app with SQL analytics dashboards, file browser, Genie AI conversations, Lakebase Autoscaling (Postgres) CRUD, and Model Serving", }, { name: "appkit-analytics", @@ -96,6 +98,15 @@ const APP_TEMPLATES: AppTemplate[] = [ }, description: "Node.js app with file browser for Databricks Volumes", }, + { + name: "appkit-serving", + features: ["serving"], + set: { + "serving.serving-endpoint.name": "placeholder", + }, + description: + "Node.js app with Databricks Model Serving endpoint integration", + }, { name: "appkit-lakebase", features: ["lakebase"],