diff --git a/apps/blog-next/app/auth-nav.tsx b/apps/blog-next/app/auth-nav.tsx index 91416f08..e2c5f7a9 100644 --- a/apps/blog-next/app/auth-nav.tsx +++ b/apps/blog-next/app/auth-nav.tsx @@ -1,9 +1,8 @@ 'use client' -import { useState } from 'react' import Link from 'next/link' -import { useRouter } from 'next/navigation' import { useAuth } from '@holo-js/auth/next/client' +import { logoutAction } from './logout/actions' const linkStyle = { color: '#cbd5e1', @@ -25,37 +24,8 @@ const logoutFormStyle = { export function AuthNav() { const auth = useAuth() - const router = useRouter() - const [isLoggingOut, setIsLoggingOut] = useState(false) const displayName = auth.user?.name ?? auth.user?.email ?? 'Account' - async function logout() { - if (isLoggingOut) { - return - } - - setIsLoggingOut(true) - try { - const response = await fetch('/api/logout', { method: 'POST' }) - if (!response.ok) { - console.warn('Logout failed.', { status: response.status }) - return - } - - try { - await auth.refreshUser() - } catch (error) { - console.warn('Auth refresh failed after logout.', error) - } - - router.replace('/') - } catch (error) { - console.warn('Logout failed.', error) - } finally { - setIsLoggingOut(false) - } - } - if (!auth.authenticated) { return ( <> @@ -68,7 +38,9 @@ export function AuthNav() { return ( <> {displayName} - +
+ +
{auth.provider === 'workos' && (
diff --git a/apps/blog-next/app/login/actions.ts b/apps/blog-next/app/login/actions.ts new file mode 100644 index 00000000..942d5e0c --- /dev/null +++ b/apps/blog-next/app/login/actions.ts @@ -0,0 +1,34 @@ +'use server' + +import { login } from '@holo-js/auth' +import { validate } from '@holo-js/forms' +import { revalidatePath } from 'next/cache' +import { redirect } from 'next/navigation' + +import { loginForm } from '@/lib/schemas/auth' + +export async function loginAction(formData: FormData) { + const submission = await validate(formData, loginForm, { + csrf: true, + throttle: 'login', + }) + + if (!submission.valid) { + return submission.fail() + } + + const { data: session, error } = await login(submission.data) + if (error) { + return submission.fail({ + status: error.status, + errors: error.fields, + }) + } + + const redirectTo = session.emailVerificationRequired + ? session.emailVerificationRoute ?? '/verify-email' + : '/admin' + + revalidatePath('/', 'layout') + redirect(redirectTo) +} diff --git a/apps/blog-next/app/login/page.tsx b/apps/blog-next/app/login/page.tsx index 9e5169fa..0ba422a0 100644 --- a/apps/blog-next/app/login/page.tsx +++ b/apps/blog-next/app/login/page.tsx @@ -1,10 +1,9 @@ 'use client' import Link from 'next/link' -import { useRouter } from 'next/navigation' -import { useAuth } from '@holo-js/auth/next/client' import { useForm } from '@holo-js/adapter-next/client' import { loginForm } from '@/lib/schemas/auth' +import { loginAction } from './actions' const panelStyle = { display: 'grid', @@ -17,24 +16,12 @@ const panelStyle = { } satisfies React.CSSProperties export default function LoginPage() { - const router = useRouter() - const auth = useAuth() const form = useForm(loginForm, { csrf: true, validateOn: 'blur', initialValues: { email: '', password: '', remember: false }, async submitter({ formData }) { - const response = await fetch('/api/login', { method: 'POST', body: formData }) - const submission = await response.json() - if (submission?.ok === true && typeof submission.data?.redirectTo === 'string') { - try { - await auth.refreshUser() - } catch (error) { - console.warn('Auth refresh failed after login.', error) - } - router.replace(submission.data.redirectTo) - } - return submission + return await loginAction(formData) }, }) const formError = form.errors.first('_root') @@ -95,15 +82,8 @@ export default function LoginPage() {
- {form.lastSubmission?.ok === true ? ( -
-

Signed in successfully.

- Continue to admin -
- ) : null} -
- Create account + Create account Forgot password?
diff --git a/apps/blog-next/app/logout/actions.ts b/apps/blog-next/app/logout/actions.ts new file mode 100644 index 00000000..6f7c5cc4 --- /dev/null +++ b/apps/blog-next/app/logout/actions.ts @@ -0,0 +1,11 @@ +'use server' + +import { logout } from '@holo-js/auth' +import { revalidatePath } from 'next/cache' +import { redirect } from 'next/navigation' + +export async function logoutAction() { + await logout() + revalidatePath('/', 'layout') + redirect('/') +} diff --git a/apps/blog-next/app/register/actions.ts b/apps/blog-next/app/register/actions.ts new file mode 100644 index 00000000..7b7e973f --- /dev/null +++ b/apps/blog-next/app/register/actions.ts @@ -0,0 +1,35 @@ +'use server' + +import { loginUsing, register } from '@holo-js/auth' +import { validate } from '@holo-js/forms' +import { revalidatePath } from 'next/cache' +import { redirect } from 'next/navigation' + +import { registerForm } from '@/lib/schemas/auth' + +export async function registerAction(formData: FormData) { + const submission = await validate(formData, registerForm, { + csrf: true, + throttle: 'register', + }) + + if (!submission.valid) { + return submission.fail() + } + + const { data: created, error } = await register(submission.data) + if (error) { + return submission.fail({ + status: error.status, + errors: error.fields, + }) + } + + const session = await loginUsing(created) + const redirectTo = session.emailVerificationRequired + ? session.emailVerificationRoute ?? '/verify-email' + : '/admin' + + revalidatePath('/', 'layout') + redirect(redirectTo) +} diff --git a/apps/blog-next/app/register/page.tsx b/apps/blog-next/app/register/page.tsx index 9c1dbe7d..71a2ff87 100644 --- a/apps/blog-next/app/register/page.tsx +++ b/apps/blog-next/app/register/page.tsx @@ -1,11 +1,11 @@ 'use client' import Link from 'next/link' -import { useRouter } from 'next/navigation' import { useForm } from '@holo-js/adapter-next/client' import { registerForm } from '@/lib/schemas/auth' +import { registerAction } from './actions' const panelStyle = { display: 'grid', @@ -18,18 +18,12 @@ const panelStyle = { } satisfies React.CSSProperties export default function RegisterPage() { - const router = useRouter() const form = useForm(registerForm, { csrf: true, validateOn: 'blur', initialValues: { name: '', email: '', password: '', passwordConfirmation: '' }, async submitter({ formData }) { - const response = await fetch('/api/register', { method: 'POST', body: formData }) - const submission = await response.json() - if (submission?.ok === true) { - router.replace('/login') - } - return submission + return await registerAction(formData) }, }) @@ -93,13 +87,6 @@ export default function RegisterPage() { - {form.lastSubmission?.ok === true ? ( -
-

Account created. Check your inbox to verify your email address.

- Return to sign in -
- ) : null} - Already have an account? Register with WorkOS Register with Clerk diff --git a/apps/blog-next/app/super-admin/login/actions.ts b/apps/blog-next/app/super-admin/login/actions.ts new file mode 100644 index 00000000..3617c2f5 --- /dev/null +++ b/apps/blog-next/app/super-admin/login/actions.ts @@ -0,0 +1,33 @@ +'use server' + +import auth from '@holo-js/auth' +import { validate } from '@holo-js/forms' +import { revalidatePath } from 'next/cache' +import { redirect } from 'next/navigation' + +import { loginForm } from '@/lib/schemas/auth' + +export async function superAdminLoginAction(formData: FormData) { + const submission = await validate(formData, loginForm, { + throttle: 'login', + }) + + if (!submission.valid) { + return submission.fail() + } + + const { data: session, error } = await auth.guard('admin').login(submission.data) + if (error) { + return submission.fail({ + status: error.status, + errors: error.fields, + }) + } + + const redirectTo = session.emailVerificationRequired + ? session.emailVerificationRoute ?? '/verify-email' + : '/super-admin' + + revalidatePath('/', 'layout') + redirect(redirectTo) +} diff --git a/apps/blog-next/app/super-admin/login/page.tsx b/apps/blog-next/app/super-admin/login/page.tsx index 925d97f2..cbcf1bbf 100644 --- a/apps/blog-next/app/super-admin/login/page.tsx +++ b/apps/blog-next/app/super-admin/login/page.tsx @@ -1,9 +1,8 @@ 'use client' -import { useRouter } from 'next/navigation' -import { useAuth } from '@holo-js/auth/next/client' import { useForm } from '@holo-js/adapter-next/client' import { loginForm } from '@/lib/schemas/auth' +import { superAdminLoginAction } from './actions' const panelStyle = { display: 'grid', @@ -16,23 +15,11 @@ const panelStyle = { } satisfies React.CSSProperties export default function SuperAdminLoginPage() { - const router = useRouter() - const auth = useAuth({ guard: 'admin' }) const form = useForm(loginForm, { validateOn: 'blur', initialValues: { email: '', password: '', remember: false }, async submitter({ formData }) { - const response = await fetch('/api/super-admin/login', { method: 'POST', body: formData }) - const submission = await response.json() - if (submission?.ok === true && typeof submission.data?.redirectTo === 'string') { - try { - await auth.refreshUser() - } catch (error) { - console.warn('Super admin auth refresh failed after login.', error) - } - router.replace(submission.data.redirectTo) - } - return submission + return await superAdminLoginAction(formData) }, }) const formError = form.errors.first('_root') @@ -86,11 +73,6 @@ export default function SuperAdminLoginPage() { - {form.lastSubmission?.ok === true ? ( -
-

Signed in as super admin.

-
- ) : null} ) } diff --git a/apps/blog-next/app/super-admin/logout-button.tsx b/apps/blog-next/app/super-admin/logout-button.tsx index 903c14a0..cd462059 100644 --- a/apps/blog-next/app/super-admin/logout-button.tsx +++ b/apps/blog-next/app/super-admin/logout-button.tsx @@ -1,44 +1,9 @@ -'use client' - -import { useState } from 'react' -import { useRouter } from 'next/navigation' -import { useAuth } from '@holo-js/auth/next/client' +import { superAdminLogoutAction } from './logout/actions' export function SuperAdminLogoutButton() { - const router = useRouter() - const auth = useAuth({ guard: 'admin' }) - const [isLoggingOut, setIsLoggingOut] = useState(false) - - async function logout() { - if (isLoggingOut) { - return - } - - setIsLoggingOut(true) - try { - const response = await fetch('/api/super-admin/logout', { method: 'POST' }) - if (!response.ok) { - console.warn('Super admin logout failed.', { status: response.status }) - return - } - - try { - await auth.refreshUser() - } catch (error) { - console.warn('Super admin auth refresh failed after logout.', error) - } - - router.replace('/super-admin/login') - } catch (error) { - console.warn('Super admin logout failed.', error) - } finally { - setIsLoggingOut(false) - } - } - return ( - +
+ +
) } diff --git a/apps/blog-next/app/super-admin/logout/actions.ts b/apps/blog-next/app/super-admin/logout/actions.ts new file mode 100644 index 00000000..14580a6f --- /dev/null +++ b/apps/blog-next/app/super-admin/logout/actions.ts @@ -0,0 +1,11 @@ +'use server' + +import auth from '@holo-js/auth' +import { revalidatePath } from 'next/cache' +import { redirect } from 'next/navigation' + +export async function superAdminLogoutAction() { + await auth.guard('admin').logout() + revalidatePath('/', 'layout') + redirect('/super-admin/login') +} diff --git a/apps/blog-next/tests/auth-nav.test.mjs b/apps/blog-next/tests/auth-nav.test.mjs index 97c9c500..66d1572f 100644 --- a/apps/blog-next/tests/auth-nav.test.mjs +++ b/apps/blog-next/tests/auth-nav.test.mjs @@ -1,18 +1,15 @@ import { jsx } from 'react/jsx-runtime' import { act, create } from 'react-test-renderer' -import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { beforeEach, describe, expect, it, vi } from 'vitest' const mocks = vi.hoisted(() => ({ - fetch: vi.fn(), - refreshUser: vi.fn(), - replace: vi.fn(), + logoutAction: vi.fn(), })) vi.mock('@holo-js/auth/next/client', () => ({ useAuth: () => ({ authenticated: true, provider: 'local', - refreshUser: mocks.refreshUser, user: { email: 'reader@example.com', name: 'Reader', @@ -20,51 +17,33 @@ vi.mock('@holo-js/auth/next/client', () => ({ }), })) -vi.mock('next/navigation', () => ({ - useRouter: () => ({ - replace: mocks.replace, - }), +vi.mock('../app/logout/actions.ts', () => ({ + logoutAction: mocks.logoutAction, })) vi.mock('next/link', () => ({ default: ({ children, href, ...props }) => jsx('a', { ...props, href, children }), })) -const originalFetch = globalThis.fetch -const originalConsoleWarn = console.warn - const { AuthNav } = await import('../app/auth-nav.tsx') describe('auth nav', () => { beforeEach(() => { vi.clearAllMocks() - globalThis.fetch = mocks.fetch - console.warn = vi.fn() - }) - - afterEach(() => { - globalThis.fetch = originalFetch - console.warn = originalConsoleWarn }) - it('navigates home after logout even when auth refresh fails', async () => { - const refreshError = new Error('refresh failed') - mocks.fetch.mockResolvedValue(new Response(null, { status: 204 })) - mocks.refreshUser.mockRejectedValue(refreshError) - + it('renders logout as a native server action form', async () => { let renderer await act(async () => { renderer = create(jsx(AuthNav, {})) }) - await act(async () => { - await renderer.root.findByType('button').props.onClick() - }) + const form = renderer.root.findByType('form') + const button = renderer.root.findByType('button') - expect(mocks.fetch).toHaveBeenCalledWith('/api/logout', { method: 'POST' }) - expect(mocks.refreshUser).toHaveBeenCalledTimes(1) - expect(console.warn).toHaveBeenCalledWith('Auth refresh failed after logout.', refreshError) - expect(mocks.replace).toHaveBeenCalledWith('/') + expect(form.props.action).toBe(mocks.logoutAction) + expect(button.props.type).toBe('submit') + expect(button.props.children).toBe('Logout') await act(async () => { renderer.unmount() diff --git a/apps/blog-next/tests/login-page.test.mjs b/apps/blog-next/tests/login-page.test.mjs new file mode 100644 index 00000000..8abbec6e --- /dev/null +++ b/apps/blog-next/tests/login-page.test.mjs @@ -0,0 +1,118 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +class RedirectSignal extends Error { + constructor(url) { + super(`Redirected to ${url}`) + this.url = url + } +} + +const mocks = vi.hoisted(() => ({ + login: vi.fn(), + revalidatePath: vi.fn(), + redirect: vi.fn(), + validate: vi.fn(), +})) + +vi.mock('@holo-js/auth', () => ({ + login: mocks.login, +})) + +vi.mock('@holo-js/forms', () => ({ + validate: mocks.validate, +})) + +vi.mock('next/cache', () => ({ + revalidatePath: mocks.revalidatePath, +})) + +vi.mock('next/navigation', () => ({ + redirect: mocks.redirect, +})) + +vi.mock('@/lib/schemas/auth', () => ({ + loginForm: {}, +})) + +const { loginAction } = await import('../app/login/actions.ts') + +function createValidSubmission(data) { + return { + valid: true, + data, + } +} + +function createInvalidSubmission(payload) { + return { + valid: false, + fail: vi.fn(() => payload), + } +} + +describe('login action', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('uses the shared forms validate API before the Next redirect', async () => { + const formData = new FormData() + formData.set('email', 'editor@example.com') + formData.set('password', 'secret-secret') + mocks.validate.mockResolvedValue(createValidSubmission({ + email: 'editor@example.com', + password: 'secret-secret', + remember: false, + })) + mocks.login.mockResolvedValue({ + data: { + emailVerificationRequired: false, + user: { + id: 'user-1', + email: 'editor@example.com', + }, + }, + error: null, + }) + mocks.redirect.mockImplementation((url) => { + throw new RedirectSignal(url) + }) + + await expect(loginAction(formData)).rejects.toMatchObject({ + url: '/admin', + }) + + expect(mocks.validate).toHaveBeenCalledWith(formData, {}, { + csrf: true, + throttle: 'login', + }) + expect(mocks.login).toHaveBeenCalledWith({ + email: 'editor@example.com', + password: 'secret-secret', + remember: false, + }) + expect(mocks.revalidatePath).toHaveBeenCalledWith('/', 'layout') + expect(mocks.redirect).toHaveBeenCalledWith('/admin') + }) + + it('returns validation failures without logging in', async () => { + const failure = { + ok: false, + status: 422, + valid: false, + values: { + email: '', + }, + errors: { + email: ['Email is required.'], + }, + } + mocks.validate.mockResolvedValue(createInvalidSubmission(failure)) + + await expect(loginAction(new FormData())).resolves.toBe(failure) + + expect(mocks.login).not.toHaveBeenCalled() + expect(mocks.revalidatePath).not.toHaveBeenCalled() + expect(mocks.redirect).not.toHaveBeenCalled() + }) +}) diff --git a/apps/blog-next/tests/logout-actions.test.mjs b/apps/blog-next/tests/logout-actions.test.mjs new file mode 100644 index 00000000..b2ab45e1 --- /dev/null +++ b/apps/blog-next/tests/logout-actions.test.mjs @@ -0,0 +1,45 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const mocks = vi.hoisted(() => ({ + logout: vi.fn(), + redirect: vi.fn((location) => { + const error = new Error('NEXT_REDIRECT') + error.location = location + throw error + }), + revalidatePath: vi.fn(), +})) + +vi.mock('@holo-js/auth', () => ({ + logout: mocks.logout, +})) + +vi.mock('next/cache', () => ({ + revalidatePath: mocks.revalidatePath, +})) + +vi.mock('next/navigation', () => ({ + redirect: mocks.redirect, +})) + +const { logoutAction } = await import('../app/logout/actions.ts') + +describe('logoutAction', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('logs out, revalidates the layout, and redirects home', async () => { + mocks.logout.mockResolvedValue({ + authenticated: false, + }) + + await expect(logoutAction()).rejects.toMatchObject({ + location: '/', + }) + + expect(mocks.logout).toHaveBeenCalledTimes(1) + expect(mocks.revalidatePath).toHaveBeenCalledWith('/', 'layout') + expect(mocks.redirect).toHaveBeenCalledWith('/') + }) +}) diff --git a/apps/blog-next/tests/register-page.test.mjs b/apps/blog-next/tests/register-page.test.mjs index f61c6b66..7f77631a 100644 --- a/apps/blog-next/tests/register-page.test.mjs +++ b/apps/blog-next/tests/register-page.test.mjs @@ -1,14 +1,18 @@ import assert from 'node:assert/strict' import { jsx } from 'react/jsx-runtime' import { act, create } from 'react-test-renderer' -import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { beforeEach, describe, expect, it, vi } from 'vitest' const mocks = vi.hoisted(() => ({ loginUsing: vi.fn(), - fetch: vi.fn(), + redirect: vi.fn((location) => { + const error = new Error('NEXT_REDIRECT') + error.location = location + throw error + }), register: vi.fn(), registerForm: Symbol('registerForm'), - replace: vi.fn(), + revalidatePath: vi.fn(), useForm: vi.fn(), validate: vi.fn(), })) @@ -26,10 +30,12 @@ vi.mock('@holo-js/forms', () => ({ validate: mocks.validate, })) +vi.mock('next/cache', () => ({ + revalidatePath: mocks.revalidatePath, +})) + vi.mock('next/navigation', () => ({ - useRouter: () => ({ - replace: mocks.replace, - }), + redirect: mocks.redirect, })) vi.mock('next/link', () => ({ @@ -40,10 +46,8 @@ vi.mock('@/lib/schemas/auth', () => ({ registerForm: mocks.registerForm, })) -const originalFetch = globalThis.fetch - const { default: RegisterPage } = await import('../app/register/page.tsx') -const registerRoute = await import('../app/api/register/route.ts') +const { registerAction } = await import('../app/register/actions.ts') function createFormState(submit) { return { @@ -81,84 +85,65 @@ function createFormState(submit) { } } -async function renderPageWithRedirect(redirectTo = '/login') { - mocks.fetch.mockResolvedValue(new Response(JSON.stringify({ - ok: true, - data: { - redirectTo, - }, - }))) - mocks.useForm.mockImplementation((_schema, options) => createFormState(vi.fn(async () => { - const formData = new FormData() - formData.set('name', 'Reader') - formData.set('email', 'reader@example.com') - formData.set('password', 'password123') - formData.set('passwordConfirmation', 'password123') - - return options.submitter({ formData }) - }))) - - let renderer - await act(async () => { - renderer = create(jsx(RegisterPage, {})) - }) - - assert.ok(renderer, 'Expected register page to render.') - return renderer -} - describe('register page', () => { beforeEach(() => { vi.clearAllMocks() - globalThis.fetch = mocks.fetch }) - afterEach(() => { - globalThis.fetch = originalFetch - }) - - it('navigates to same-app redirect targets after successful registration', async () => { - const renderer = await renderPageWithRedirect('/login') - - await act(async () => { - renderer.root.findByType('form').props.onSubmit({ - preventDefault: vi.fn(), - }) - }) + it('submits through the register server action', async () => { + const failure = { + ok: false, + status: 422, + errors: { + email: ['Enter a valid email address.'], + }, + } + const submission = { + valid: false, + fail: vi.fn(() => failure), + } + mocks.validate.mockResolvedValue(submission) + mocks.useForm.mockImplementation((_schema, options) => createFormState(vi.fn(async () => { + const formData = new FormData() + formData.set('name', 'Reader') + formData.set('email', 'bad') + formData.set('password', 'password123') + formData.set('passwordConfirmation', 'password123') - expect(mocks.fetch).toHaveBeenCalledWith('/api/register', { - method: 'POST', - body: expect.any(FormData), - }) - expect(mocks.useForm).toHaveBeenCalledWith(mocks.registerForm, expect.objectContaining({ - csrf: true, - })) - expect(mocks.replace).toHaveBeenCalledWith('/login') + return await options.submitter({ formData }) + }))) + let renderer await act(async () => { - renderer.unmount() + renderer = create(jsx(RegisterPage, {})) }) - }) - it('ignores response-provided register redirect targets', async () => { - const renderer = await renderPageWithRedirect('https://evil.test/login') + assert.ok(renderer, 'Expected register page to render.') await act(async () => { - renderer.root.findByType('form').props.onSubmit({ + await renderer.root.findByType('form').props.onSubmit({ preventDefault: vi.fn(), }) }) - expect(mocks.replace).toHaveBeenCalledWith('/login') + expect(mocks.useForm).toHaveBeenCalledWith(mocks.registerForm, expect.objectContaining({ + csrf: true, + validateOn: 'blur', + })) + expect(mocks.validate).toHaveBeenCalledWith(expect.any(FormData), mocks.registerForm, { + csrf: true, + throttle: 'register', + }) + expect(mocks.register).not.toHaveBeenCalled() + expect(mocks.redirect).not.toHaveBeenCalled() await act(async () => { renderer.unmount() }) }) - }) -describe('POST /api/register', () => { +describe('registerAction', () => { beforeEach(() => { vi.clearAllMocks() }) @@ -175,17 +160,11 @@ describe('POST /api/register', () => { valid: false, fail: vi.fn(() => failure), } - const request = new Request('http://localhost/api/register', { - method: 'POST', - }) mocks.validate.mockResolvedValue(submission) - const response = await registerRoute.POST(request) - - expect(response.status).toBe(422) - await expect(response.json()).resolves.toEqual(failure) - expect(mocks.validate).toHaveBeenCalledWith(request, mocks.registerForm, { + await expect(registerAction(new FormData())).resolves.toBe(failure) + expect(mocks.validate).toHaveBeenCalledWith(expect.any(FormData), mocks.registerForm, { csrf: true, throttle: 'register', }) @@ -193,15 +172,46 @@ describe('POST /api/register', () => { expect(mocks.loginUsing).not.toHaveBeenCalled() }) - it('keeps the verified registration success redirect unchanged', async () => { + it('returns registration failures without starting a session', async () => { + const failure = { + ok: false, + status: 422, + errors: { + email: ['The email has already been taken.'], + }, + } + const submission = { + valid: true, + data: { + name: 'Reader', + email: 'reader@example.com', + password: 'password123', + passwordConfirmation: 'password123', + }, + fail: vi.fn(() => failure), + } + mocks.validate.mockResolvedValue(submission) + mocks.register.mockResolvedValue({ + data: null, + error: { + status: 422, + fields: { + email: ['The email has already been taken.'], + }, + }, + }) + + await expect(registerAction(new FormData())).resolves.toBe(failure) + expect(mocks.register).toHaveBeenCalledWith(submission.data) + expect(mocks.loginUsing).not.toHaveBeenCalled() + expect(mocks.redirect).not.toHaveBeenCalled() + }) + + it('uses the native Next redirect after verified registration', async () => { const created = { id: 7, email: 'reader@example.com', } - const session = { - emailVerificationRequired: false, - user: created, - } const submission = { valid: true, data: { @@ -211,34 +221,58 @@ describe('POST /api/register', () => { passwordConfirmation: 'password123', }, fail: vi.fn(), - success: vi.fn((data, status) => ({ - ok: true, - status, - data, - })), } mocks.validate.mockResolvedValue(submission) mocks.register.mockResolvedValue({ data: created, error: null, }) - mocks.loginUsing.mockResolvedValue(session) + mocks.loginUsing.mockResolvedValue({ + emailVerificationRequired: false, + user: created, + }) - const response = await registerRoute.POST(new Request('http://localhost/api/register', { - method: 'POST', - })) + await expect(registerAction(new FormData())).rejects.toMatchObject({ + location: '/admin', + }) + + expect(mocks.register).toHaveBeenCalledWith(submission.data) + expect(mocks.loginUsing).toHaveBeenCalledWith(created) + expect(mocks.revalidatePath).toHaveBeenCalledWith('/', 'layout') + expect(mocks.redirect).toHaveBeenCalledWith('/admin') + }) - expect(response.status).toBe(201) - await expect(response.json()).resolves.toEqual({ - ok: true, - status: 201, + it('redirects to email verification when the new session requires it', async () => { + const created = { + id: 7, + email: 'reader@example.com', + } + const submission = { + valid: true, data: { - message: 'Account created and signed in successfully.', - redirectTo: '/admin', - user: created, + name: 'Reader', + email: 'reader@example.com', + password: 'password123', + passwordConfirmation: 'password123', }, + fail: vi.fn(), + } + mocks.validate.mockResolvedValue(submission) + mocks.register.mockResolvedValue({ + data: created, + error: null, }) - expect(mocks.register).toHaveBeenCalledWith(submission.data) - expect(mocks.loginUsing).toHaveBeenCalledWith(created) + mocks.loginUsing.mockResolvedValue({ + emailVerificationRequired: true, + emailVerificationRoute: '/verify-email', + user: created, + }) + + await expect(registerAction(new FormData())).rejects.toMatchObject({ + location: '/verify-email', + }) + + expect(mocks.revalidatePath).toHaveBeenCalledWith('/', 'layout') + expect(mocks.redirect).toHaveBeenCalledWith('/verify-email') }) }) diff --git a/apps/blog-next/tests/run.mjs b/apps/blog-next/tests/run.mjs index f562c94a..7f5ced85 100644 --- a/apps/blog-next/tests/run.mjs +++ b/apps/blog-next/tests/run.mjs @@ -75,17 +75,40 @@ function containsJsxNode(node, tagName) { return ts.forEachChild(node, child => containsJsxNode(child, tagName)) === true } -function containsRouterReplaceHome(node) { +function containsLogoutActionForm(node) { if ( - ts.isCallExpression(node) - && ts.isPropertyAccessExpression(node.expression) - && node.expression.name.text === 'replace' - && node.arguments.some(argument => ts.isStringLiteral(argument) && argument.text === '/') + ts.isJsxSelfClosingElement(node) + && getJsxTagName(node.tagName) === 'form' + && node.attributes.properties.some(attribute => ( + ts.isJsxAttribute(attribute) + && attribute.name.text === 'action' + && attribute.initializer + && ts.isJsxExpression(attribute.initializer) + && attribute.initializer.expression + && ts.isIdentifier(attribute.initializer.expression) + && attribute.initializer.expression.text === 'logoutAction' + )) ) { return true } - return ts.forEachChild(node, containsRouterReplaceHome) === true + if ( + ts.isJsxElement(node) + && getJsxTagName(node.openingElement.tagName) === 'form' + && node.openingElement.attributes.properties.some(attribute => ( + ts.isJsxAttribute(attribute) + && attribute.name.text === 'action' + && attribute.initializer + && ts.isJsxExpression(attribute.initializer) + && attribute.initializer.expression + && ts.isIdentifier(attribute.initializer.expression) + && attribute.initializer.expression.text === 'logoutAction' + )) + ) { + return true + } + + return ts.forEachChild(node, containsLogoutActionForm) === true } async function assertRootLayoutSharesAuthProviderState() { @@ -109,8 +132,8 @@ async function assertHeaderLogoutRedirectsHome() { const sourceFile = ts.createSourceFile('auth-nav.tsx', authNavSource, ts.ScriptTarget.Latest, true, ts.ScriptKind.TSX) assert.ok( - containsRouterReplaceHome(sourceFile), - 'Expected header logout to redirect home after clearing the session.', + containsLogoutActionForm(sourceFile), + 'Expected header logout to use the logout server action form.', ) } @@ -302,7 +325,7 @@ try { await rm(join(cwd, '.next'), { recursive: true, force: true }) await assertRootLayoutSharesAuthProviderState() await assertHeaderLogoutRedirectsHome() - await run('npx', ['vitest', '--run', 'tests/api-v1-routes.test.mjs', 'tests/auth-nav.test.mjs', 'tests/current-auth-route.test.mjs', 'tests/forgot-password-route.test.mjs', 'tests/hosted-logout-routes.test.mjs', 'tests/package-checks.test.mjs', 'tests/register-page.test.mjs', 'tests/reset-password-page.test.mjs', 'tests/reset-password-route.test.mjs', 'tests/social-auth-routes.test.mjs', 'tests/super-admin-logout-button.test.mjs', 'tests/super-admin-login-page.test.mjs', 'tests/super-admin-login-route.test.mjs', 'tests/verify-email-page.test.mjs', '--reporter=json']) + await run('npx', ['vitest', '--run', 'tests/api-v1-routes.test.mjs', 'tests/auth-nav.test.mjs', 'tests/current-auth-route.test.mjs', 'tests/forgot-password-route.test.mjs', 'tests/hosted-logout-routes.test.mjs', 'tests/login-page.test.mjs', 'tests/logout-actions.test.mjs', 'tests/package-checks.test.mjs', 'tests/register-page.test.mjs', 'tests/reset-password-page.test.mjs', 'tests/reset-password-route.test.mjs', 'tests/social-auth-routes.test.mjs', 'tests/super-admin-logout-button.test.mjs', 'tests/super-admin-login-page.test.mjs', 'tests/super-admin-login-route.test.mjs', 'tests/verify-email-page.test.mjs', '--reporter=json']) await run('bun', ['run', 'prepare']) await run('bun', ['x', 'holo', 'migrate:fresh', '--seed']) await run('npx', ['tsx', 'tests/blog-logic.mjs']) diff --git a/apps/blog-next/tests/super-admin-login-page.test.mjs b/apps/blog-next/tests/super-admin-login-page.test.mjs index 09b24a0e..b5c49325 100644 --- a/apps/blog-next/tests/super-admin-login-page.test.mjs +++ b/apps/blog-next/tests/super-admin-login-page.test.mjs @@ -3,34 +3,56 @@ import { act, create } from 'react-test-renderer' import { beforeEach, describe, expect, it, vi } from 'vitest' const mocks = vi.hoisted(() => ({ - refreshUser: vi.fn(), - replace: vi.fn(), + guardLogin: vi.fn(), + redirect: vi.fn((location) => { + const error = new Error('NEXT_REDIRECT') + error.location = location + throw error + }), + revalidatePath: vi.fn(), + superAdminLoginAction: vi.fn(), useForm: vi.fn(), + validate: vi.fn(), })) vi.mock('@holo-js/adapter-next/client', () => ({ useForm: mocks.useForm, })) -vi.mock('@holo-js/auth/next/client', () => ({ - useAuth: () => ({ - refreshUser: mocks.refreshUser, - }), +vi.mock('@holo-js/auth', () => ({ + default: { + guard: vi.fn(() => ({ + login: mocks.guardLogin, + })), + }, +})) + +vi.mock('@holo-js/forms', () => ({ + validate: mocks.validate, +})) + +vi.mock('next/cache', () => ({ + revalidatePath: mocks.revalidatePath, })) vi.mock('next/navigation', () => ({ - useRouter: () => ({ - replace: mocks.replace, - }), + redirect: mocks.redirect, })) vi.mock('@/lib/schemas/auth', () => ({ - loginForm: {}, + loginForm: Symbol('loginForm'), +})) + +vi.mock('../app/super-admin/login/actions.ts', async (importOriginal) => ({ + ...(await importOriginal()), + superAdminLoginAction: mocks.superAdminLoginAction, })) const { default: SuperAdminLoginPage } = await import('../app/super-admin/login/page.tsx') +vi.doUnmock('../app/super-admin/login/actions.ts') +const { superAdminLoginAction } = await import('../app/super-admin/login/actions.ts?actual') -function createFormState(rootError) { +function createFormState(rootError, submit = vi.fn()) { return { values: { email: '', @@ -55,7 +77,7 @@ function createFormState(rootError) { first: vi.fn(field => field === '_root' ? rootError : undefined), }, submitting: false, - submit: vi.fn(), + submit, lastSubmission: { ok: false, }, @@ -85,4 +107,156 @@ describe('super admin login page', () => { renderer.unmount() }) }) + + it('submits through the super admin login server action', async () => { + mocks.superAdminLoginAction.mockResolvedValue({ + ok: false, + status: 422, + }) + mocks.useForm.mockImplementation((_schema, options) => createFormState(undefined, vi.fn(async () => { + const formData = new FormData() + formData.set('email', 'admin@example.com') + formData.set('password', 'secret-secret') + + return await options.submitter({ formData }) + }))) + + let renderer + await act(async () => { + renderer = create(jsx(SuperAdminLoginPage, {})) + }) + + await act(async () => { + await renderer.root.findByType('form').props.onSubmit({ + preventDefault: vi.fn(), + }) + }) + + expect(mocks.superAdminLoginAction).toHaveBeenCalledWith(expect.any(FormData)) + + await act(async () => { + renderer.unmount() + }) + }) +}) + +describe('superAdminLoginAction', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('returns validation failures before logging in', async () => { + const failure = { + ok: false, + status: 422, + errors: { + email: ['Enter a valid email address.'], + }, + } + const submission = { + valid: false, + fail: vi.fn(() => failure), + } + + mocks.validate.mockResolvedValue(submission) + + await expect(superAdminLoginAction(new FormData())).resolves.toBe(failure) + expect(mocks.validate).toHaveBeenCalledWith(expect.any(FormData), expect.anything(), { + throttle: 'login', + }) + expect(mocks.guardLogin).not.toHaveBeenCalled() + }) + + it('returns auth failures without redirecting', async () => { + const failure = { + ok: false, + status: 401, + errors: { + _root: ['These credentials do not match our records.'], + }, + } + const submission = { + valid: true, + data: { + email: 'admin@example.com', + password: 'bad-password', + remember: false, + }, + fail: vi.fn(() => failure), + } + mocks.validate.mockResolvedValue(submission) + mocks.guardLogin.mockResolvedValue({ + data: null, + error: { + status: 401, + fields: { + _root: ['These credentials do not match our records.'], + }, + }, + }) + + await expect(superAdminLoginAction(new FormData())).resolves.toBe(failure) + expect(mocks.guardLogin).toHaveBeenCalledWith(submission.data) + expect(mocks.redirect).not.toHaveBeenCalled() + }) + + it('uses the native Next redirect after super admin login', async () => { + const submission = { + valid: true, + data: { + email: 'admin@example.com', + password: 'secret-secret', + remember: false, + }, + fail: vi.fn(), + } + mocks.validate.mockResolvedValue(submission) + mocks.guardLogin.mockResolvedValue({ + data: { + emailVerificationRequired: false, + user: { + email: 'admin@example.com', + }, + }, + error: null, + }) + + await expect(superAdminLoginAction(new FormData())).rejects.toMatchObject({ + location: '/super-admin', + }) + + expect(mocks.guardLogin).toHaveBeenCalledWith(submission.data) + expect(mocks.revalidatePath).toHaveBeenCalledWith('/', 'layout') + expect(mocks.redirect).toHaveBeenCalledWith('/super-admin') + }) + + it('redirects to email verification when the admin session requires it', async () => { + const submission = { + valid: true, + data: { + email: 'admin@example.com', + password: 'secret-secret', + remember: false, + }, + fail: vi.fn(), + } + mocks.validate.mockResolvedValue(submission) + mocks.guardLogin.mockResolvedValue({ + data: { + emailVerificationRequired: true, + emailVerificationRoute: '/verify-email', + user: { + email: 'admin@example.com', + }, + }, + error: null, + }) + + await expect(superAdminLoginAction(new FormData())).rejects.toMatchObject({ + location: '/verify-email', + }) + + expect(mocks.revalidatePath).toHaveBeenCalledWith('/', 'layout') + expect(mocks.redirect).toHaveBeenCalledWith('/verify-email') + }) }) diff --git a/apps/blog-next/tests/super-admin-logout-button.test.mjs b/apps/blog-next/tests/super-admin-logout-button.test.mjs index c2f165d0..9f45246b 100644 --- a/apps/blog-next/tests/super-admin-logout-button.test.mjs +++ b/apps/blog-next/tests/super-admin-logout-button.test.mjs @@ -1,139 +1,83 @@ import { jsx } from 'react/jsx-runtime' import { act, create } from 'react-test-renderer' -import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { beforeEach, describe, expect, it, vi } from 'vitest' const mocks = vi.hoisted(() => ({ - fetch: vi.fn(), - refreshUser: vi.fn(), - replace: vi.fn(), + guardLogout: vi.fn(), + redirect: vi.fn((location) => { + const error = new Error('NEXT_REDIRECT') + error.location = location + throw error + }), + revalidatePath: vi.fn(), + superAdminLogoutAction: vi.fn(), })) -vi.mock('@holo-js/auth/next/client', () => ({ - useAuth: () => ({ - refreshUser: mocks.refreshUser, - }), +vi.mock('@holo-js/auth', () => ({ + default: { + guard: vi.fn(() => ({ + logout: mocks.guardLogout, + })), + }, +})) + +vi.mock('next/cache', () => ({ + revalidatePath: mocks.revalidatePath, })) vi.mock('next/navigation', () => ({ - useRouter: () => ({ - replace: mocks.replace, - }), + redirect: mocks.redirect, })) -const originalFetch = globalThis.fetch -const originalConsoleWarn = console.warn +vi.mock('../app/super-admin/logout/actions.ts', async (importOriginal) => ({ + ...(await importOriginal()), + superAdminLogoutAction: mocks.superAdminLogoutAction, +})) const { SuperAdminLogoutButton } = await import('../app/super-admin/logout-button.tsx') - -function createDeferred() { - let resolvePromise = () => {} - const promise = new Promise(resolve => { - resolvePromise = resolve - }) - - return { - promise, - resolve(value) { - resolvePromise(value) - }, - } -} +vi.doUnmock('../app/super-admin/logout/actions.ts') +const { superAdminLogoutAction } = await import('../app/super-admin/logout/actions.ts?actual') describe('super admin logout button', () => { beforeEach(() => { vi.clearAllMocks() - globalThis.fetch = mocks.fetch - console.warn = vi.fn() }) - afterEach(() => { - globalThis.fetch = originalFetch - console.warn = originalConsoleWarn - }) - - it('navigates to login after logout even when auth refresh fails', async () => { - mocks.fetch.mockResolvedValue(new Response(null, { status: 204 })) - mocks.refreshUser.mockRejectedValue(new Error('refresh failed')) - + it('renders logout as a native server action form', async () => { let renderer await act(async () => { renderer = create(jsx(SuperAdminLogoutButton, {})) }) + const form = renderer.root.findByType('form') const button = renderer.root.findByType('button') - await act(async () => { - await button.props.onClick() - }) - expect(mocks.fetch).toHaveBeenCalledWith('/api/super-admin/logout', { method: 'POST' }) - expect(mocks.refreshUser).toHaveBeenCalledTimes(1) - expect(console.warn).toHaveBeenCalledWith( - 'Super admin auth refresh failed after logout.', - expect.any(Error), - ) - expect(mocks.replace).toHaveBeenCalledWith('/super-admin/login') + expect(form.props.action).toBe(mocks.superAdminLogoutAction) + expect(button.props.type).toBe('submit') + expect(button.props.children).toBe('Sign out of super admin') await act(async () => { renderer.unmount() }) }) +}) - it('ignores duplicate logout clicks while a request is in flight', async () => { - const logoutResponse = createDeferred() - mocks.fetch.mockReturnValue(logoutResponse.promise) - - let renderer - await act(async () => { - renderer = create(jsx(SuperAdminLogoutButton, {})) - }) - - const firstClick = renderer.root.findByType('button').props.onClick() - await act(async () => { - await Promise.resolve() - }) - - const loadingButton = renderer.root.findByType('button') - expect(loadingButton.props.disabled).toBe(true) - expect(loadingButton.props.children).toBe('Signing out...') - - await act(async () => { - await loadingButton.props.onClick() - }) - - expect(mocks.fetch).toHaveBeenCalledTimes(1) - - logoutResponse.resolve(new Response(null, { status: 500 })) - await act(async () => { - await firstClick - }) - - expect(console.warn).toHaveBeenCalledWith('Super admin logout failed.', { status: 500 }) - expect(mocks.replace).not.toHaveBeenCalled() - - await act(async () => { - renderer.unmount() - }) +describe('superAdminLogoutAction', () => { + beforeEach(() => { + vi.clearAllMocks() }) - it('keeps users on the page when the logout request fails before clearing the session', async () => { - const logoutError = new Error('network failed') - mocks.fetch.mockRejectedValue(logoutError) - - let renderer - await act(async () => { - renderer = create(jsx(SuperAdminLogoutButton, {})) + it('logs out the admin guard and redirects to super admin login', async () => { + mocks.guardLogout.mockResolvedValue({ + authenticated: false, }) - await act(async () => { - await renderer.root.findByType('button').props.onClick() + await expect(superAdminLogoutAction()).rejects.toMatchObject({ + location: '/super-admin/login', }) - expect(console.warn).toHaveBeenCalledWith('Super admin logout failed.', logoutError) - expect(mocks.refreshUser).not.toHaveBeenCalled() - expect(mocks.replace).not.toHaveBeenCalled() - - await act(async () => { - renderer.unmount() - }) + expect(mocks.guardLogout).toHaveBeenCalledTimes(1) + expect(mocks.revalidatePath).toHaveBeenCalledWith('/', 'layout') + expect(mocks.redirect).toHaveBeenCalledWith('/super-admin/login') }) }) diff --git a/apps/blog-sveltekit/src/routes/+layout.server.ts b/apps/blog-sveltekit/src/routes/+layout.server.ts index b6140934..e528b4ad 100644 --- a/apps/blog-sveltekit/src/routes/+layout.server.ts +++ b/apps/blog-sveltekit/src/routes/+layout.server.ts @@ -1,9 +1,12 @@ import { auth } from '@holo-js/auth/sveltekit/server' +import { csrf } from '@holo-js/security' +import type { LayoutServerLoad } from './$types' -export async function load() { +export const load = (async ({ request }) => { const currentAuth = await auth() return { auth: currentAuth, + csrf: await csrf.field(request), } -} +}) satisfies LayoutServerLoad diff --git a/apps/blog-sveltekit/src/routes/+layout.svelte b/apps/blog-sveltekit/src/routes/+layout.svelte index 37c3750e..cdb138b9 100644 --- a/apps/blog-sveltekit/src/routes/+layout.svelte +++ b/apps/blog-sveltekit/src/routes/+layout.svelte @@ -1,11 +1,9 @@
@@ -57,7 +24,9 @@ {#if auth.authenticated} {displayName} {#if !usesHostedLogout} - +
+ +
{/if} {#if auth.provider === 'workos'}
diff --git a/apps/blog-sveltekit/src/routes/api/login/+server.ts b/apps/blog-sveltekit/src/routes/api/login/+server.ts deleted file mode 100644 index 31033a20..00000000 --- a/apps/blog-sveltekit/src/routes/api/login/+server.ts +++ /dev/null @@ -1,38 +0,0 @@ -import { json } from '@sveltejs/kit' -import { login } from '@holo-js/auth' -import { validate } from '@holo-js/forms' - -import { loginForm } from '$lib/schemas/auth' - -export async function POST({ request }: { request: Request }) { - const submission = await validate(request, loginForm, { - csrf: true, - throttle: 'login', - }) - - if (!submission.valid) { - return json(submission.fail(), { - status: submission.fail().status, - }) - } - - const { data: session, error } = await login(submission.data) - if (error) { - const failure = submission.fail({ - status: error.status, - errors: error.fields, - }) - - return json(failure, { status: failure.status }) - } - - return json(submission.success({ - message: session.emailVerificationRequired - ? 'Signed in. Verify your email address to continue.' - : 'Signed in successfully.', - redirectTo: session.emailVerificationRequired - ? session.emailVerificationRoute ?? '/verify-email' - : '/admin', - user: session.user, - })) -} diff --git a/apps/blog-sveltekit/src/routes/api/register/+server.ts b/apps/blog-sveltekit/src/routes/api/register/+server.ts deleted file mode 100644 index 2482690c..00000000 --- a/apps/blog-sveltekit/src/routes/api/register/+server.ts +++ /dev/null @@ -1,41 +0,0 @@ -import { json } from '@sveltejs/kit' -import { loginUsing, register } from '@holo-js/auth' -import { validate } from '@holo-js/forms' - -import { registerForm } from '$lib/schemas/auth' - -export async function POST({ request }: { request: Request }) { - const submission = await validate(request, registerForm, { - csrf: true, - throttle: 'register', - }) - - if (!submission.valid) { - return json(submission.fail(), { - status: submission.fail().status, - }) - } - - const { data: created, error } = await register(submission.data) - if (error) { - const failure = submission.fail({ - status: error.status, - errors: error.fields, - }) - - return json(failure, { status: failure.status }) - } - - const session = await loginUsing(created) - return json(submission.success({ - message: session.emailVerificationRequired - ? 'Account created. Check your inbox to verify your email address.' - : 'Account created and signed in successfully.', - redirectTo: session.emailVerificationRequired - ? session.emailVerificationRoute ?? '/verify-email' - : '/admin', - user: session.user, - }, 201), { - status: 201, - }) -} diff --git a/apps/blog-sveltekit/src/routes/api/super-admin/login/+server.ts b/apps/blog-sveltekit/src/routes/api/super-admin/login/+server.ts deleted file mode 100644 index f32ecb21..00000000 --- a/apps/blog-sveltekit/src/routes/api/super-admin/login/+server.ts +++ /dev/null @@ -1,38 +0,0 @@ -import { json } from '@sveltejs/kit' -import auth from '@holo-js/auth' -import { validate } from '@holo-js/forms' - -import { loginForm } from '$lib/schemas/auth' -import type { RequestHandler } from './$types' - -export const POST: RequestHandler = async ({ request }) => { - const submission = await validate(request, loginForm, { - throttle: 'login', - }) - - if (!submission.valid) { - return json(submission.fail(), { - status: submission.fail().status, - }) - } - - const { data: session, error } = await auth.guard('admin').login(submission.data) - if (error) { - const failure = submission.fail({ - status: error.status, - errors: error.fields, - }) - - return json(failure, { status: failure.status }) - } - - return json(submission.success({ - message: session.emailVerificationRequired - ? 'Signed in. Verify your email address to continue.' - : 'Signed in as super admin.', - redirectTo: session.emailVerificationRequired - ? session.emailVerificationRoute ?? '/verify-email' - : '/super-admin', - user: session.user, - })) -} diff --git a/apps/blog-sveltekit/src/routes/login/+page.server.ts b/apps/blog-sveltekit/src/routes/login/+page.server.ts new file mode 100644 index 00000000..bced4a92 --- /dev/null +++ b/apps/blog-sveltekit/src/routes/login/+page.server.ts @@ -0,0 +1,34 @@ +import { fail, redirect } from '@sveltejs/kit' +import { login } from '@holo-js/auth' +import { validate } from '@holo-js/forms' + +import { loginForm } from '$lib/schemas/auth' +import type { Actions } from './$types' + +export const actions = { + default: async ({ request }) => { + const submission = await validate(request, loginForm, { + csrf: true, + throttle: 'login', + }) + + if (!submission.valid) { + const failure = submission.fail() + return fail(failure.status, failure) + } + + const { data: session, error } = await login(submission.data) + if (error) { + const failure = submission.fail({ + status: error.status, + errors: error.fields, + }) + + return fail(failure.status, failure) + } + + redirect(303, session.emailVerificationRequired + ? session.emailVerificationRoute ?? '/verify-email' + : '/admin') + }, +} satisfies Actions diff --git a/apps/blog-sveltekit/src/routes/login/+page.svelte b/apps/blog-sveltekit/src/routes/login/+page.svelte index c491042d..9295ff54 100644 --- a/apps/blog-sveltekit/src/routes/login/+page.svelte +++ b/apps/blog-sveltekit/src/routes/login/+page.svelte @@ -1,25 +1,15 @@
@@ -35,18 +25,24 @@ Continue with Clerk
- { event.preventDefault(); form.submit() }}> + + + + {#if formError} +

{formError}

+ {/if} + @@ -55,12 +51,12 @@ form.fields.password.onInput(event.currentTarget.value)} - on:blur={() => form.fields.password.onBlur()} + value={login.values.password} + on:input={(event) => login.fields.password.onInput(event.currentTarget.value)} + on:blur={() => login.fields.password.onBlur()} /> - {#if form.errors.has('password')} - {form.errors.first('password')} + {#if login.errors.has('password')} + {login.errors.first('password')} {/if} @@ -68,24 +64,17 @@ form.fields.remember.onInput(event.currentTarget.checked)} + checked={login.values.remember} + on:change={(event) => login.fields.remember.onInput(event.currentTarget.checked)} /> Remember me - - {#if form.lastSubmission?.ok === true} -
-

Signed in successfully.

- Continue to admin -
- {/if} - -
{ event.preventDefault(); form.submit() }}> + + + + {#if formError} +

{formError}

+ {/if} + @@ -47,12 +43,12 @@ form.fields.email.onInput(event.currentTarget.value)} - on:blur={() => form.fields.email.onBlur()} + value={register.values.email} + on:input={(event) => register.fields.email.onInput(event.currentTarget.value)} + on:blur={() => register.fields.email.onBlur()} /> - {#if form.errors.has('email')} - {form.errors.first('email')} + {#if register.errors.has('email')} + {register.errors.first('email')} {/if} @@ -61,12 +57,12 @@ form.fields.password.onInput(event.currentTarget.value)} - on:blur={() => form.fields.password.onBlur()} + value={register.values.password} + on:input={(event) => register.fields.password.onInput(event.currentTarget.value)} + on:blur={() => register.fields.password.onBlur()} /> - {#if form.errors.has('password')} - {form.errors.first('password')} + {#if register.errors.has('password')} + {register.errors.first('password')} {/if} @@ -75,27 +71,20 @@ form.fields.passwordConfirmation.onInput(event.currentTarget.value)} - on:blur={() => form.fields.passwordConfirmation.onBlur()} + value={register.values.passwordConfirmation} + on:input={(event) => register.fields.passwordConfirmation.onInput(event.currentTarget.value)} + on:blur={() => register.fields.passwordConfirmation.onBlur()} /> - {#if form.errors.has('passwordConfirmation')} - {form.errors.first('passwordConfirmation')} + {#if register.errors.has('passwordConfirmation')} + {register.errors.first('passwordConfirmation')} {/if} -
- {#if form.lastSubmission?.ok === true} -
-

Account created. Check your inbox to verify your email address.

- Return to sign in -
- {/if} - Already have an account? Register with WorkOS Register with Clerk @@ -107,6 +96,5 @@ .stack, .field { display: grid; gap: 0.35rem; } .stack { gap: 0.9rem; } .error { color: #fca5a5; } - .success { color: #86efac; display: grid; gap: 0.5rem; } - .success a, .link { color: #7dd3fc; text-decoration: none; } + .link { color: #7dd3fc; text-decoration: none; } diff --git a/apps/blog-sveltekit/src/routes/super-admin/+page.server.ts b/apps/blog-sveltekit/src/routes/super-admin/+page.server.ts index b699ac7b..0d021c19 100644 --- a/apps/blog-sveltekit/src/routes/super-admin/+page.server.ts +++ b/apps/blog-sveltekit/src/routes/super-admin/+page.server.ts @@ -1,5 +1,7 @@ import { redirect } from '@sveltejs/kit' +import authRuntime from '@holo-js/auth' import { auth } from '@holo-js/auth/sveltekit/server' +import type { Actions } from './$types' export async function load() { const currentAuth = await auth({ guard: 'admin' }) @@ -12,3 +14,10 @@ export async function load() { admin: currentAuth.user, } } + +export const actions = { + default: async () => { + await authRuntime.guard('admin').logout() + redirect(303, '/super-admin/login') + }, +} satisfies Actions diff --git a/apps/blog-sveltekit/src/routes/super-admin/+page.svelte b/apps/blog-sveltekit/src/routes/super-admin/+page.svelte index 3057c09c..534568cf 100644 --- a/apps/blog-sveltekit/src/routes/super-admin/+page.svelte +++ b/apps/blog-sveltekit/src/routes/super-admin/+page.svelte @@ -1,37 +1,8 @@
@@ -39,9 +10,9 @@

Super Admin

Signed in as {displayName} through the admin guard.

- +
+ +
diff --git a/apps/blog-sveltekit/src/routes/super-admin/login/+page.server.ts b/apps/blog-sveltekit/src/routes/super-admin/login/+page.server.ts new file mode 100644 index 00000000..de0bc330 --- /dev/null +++ b/apps/blog-sveltekit/src/routes/super-admin/login/+page.server.ts @@ -0,0 +1,34 @@ +import { fail, redirect } from '@sveltejs/kit' +import auth from '@holo-js/auth' +import { validate } from '@holo-js/forms' + +import { loginForm } from '$lib/schemas/auth' +import type { Actions } from './$types' + +export const actions = { + default: async ({ request }) => { + const submission = await validate(request, loginForm, { + csrf: true, + throttle: 'login', + }) + + if (!submission.valid) { + const failure = submission.fail() + return fail(failure.status, failure) + } + + const { data: session, error } = await auth.guard('admin').login(submission.data) + if (error) { + const failure = submission.fail({ + status: error.status, + errors: error.fields, + }) + + return fail(failure.status, failure) + } + + redirect(303, session.emailVerificationRequired + ? session.emailVerificationRoute ?? '/verify-email' + : '/super-admin') + }, +} satisfies Actions diff --git a/apps/blog-sveltekit/src/routes/super-admin/login/+page.svelte b/apps/blog-sveltekit/src/routes/super-admin/login/+page.svelte index 79c4a154..972620af 100644 --- a/apps/blog-sveltekit/src/routes/super-admin/login/+page.svelte +++ b/apps/blog-sveltekit/src/routes/super-admin/login/+page.svelte @@ -1,21 +1,15 @@
@@ -24,18 +18,24 @@

Use a super admin account to access the super admin area.

-
{ event.preventDefault(); form.submit() }}> + + + + {#if formError} +

{formError}

+ {/if} + @@ -44,12 +44,12 @@ form.fields.password.onInput(event.currentTarget.value)} - on:blur={() => form.fields.password.onBlur()} + value={login.values.password} + on:input={(event) => login.fields.password.onInput(event.currentTarget.value)} + on:blur={() => login.fields.password.onBlur()} /> - {#if form.errors.has('password')} - {form.errors.first('password')} + {#if login.errors.has('password')} + {login.errors.first('password')} {/if} @@ -57,22 +57,16 @@ form.fields.remember.onInput(event.currentTarget.checked)} + checked={login.values.remember} + on:change={(event) => login.fields.remember.onInput(event.currentTarget.checked)} /> Remember me -
- - {#if form.lastSubmission?.ok === true} -
-

Signed in as super admin.

-
- {/if}
diff --git a/apps/blog-sveltekit/tests/auth-page-actions.test.mjs b/apps/blog-sveltekit/tests/auth-page-actions.test.mjs new file mode 100644 index 00000000..cd991c9c --- /dev/null +++ b/apps/blog-sveltekit/tests/auth-page-actions.test.mjs @@ -0,0 +1,321 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const mocks = vi.hoisted(() => ({ + fail: vi.fn((status, data) => ({ + ...data, + status, + })), + guardLogin: vi.fn(), + guardLogout: vi.fn(), + login: vi.fn(), + loginForm: Symbol('loginForm'), + loginUsing: vi.fn(), + logout: vi.fn(), + redirect: vi.fn((status, location) => { + const error = new Error('SVELTEKIT_REDIRECT') + error.status = status + error.location = location + throw error + }), + register: vi.fn(), + registerForm: Symbol('registerForm'), + validate: vi.fn(), +})) + +vi.mock('@sveltejs/kit', () => ({ + fail: mocks.fail, + redirect: mocks.redirect, +})) + +vi.mock('@holo-js/auth', () => ({ + default: { + guard: vi.fn(() => ({ + login: mocks.guardLogin, + logout: mocks.guardLogout, + })), + }, + login: mocks.login, + loginUsing: mocks.loginUsing, + logout: mocks.logout, + register: mocks.register, +})) + +vi.mock('@holo-js/auth/sveltekit/server', () => ({ + auth: vi.fn(async () => ({ + authenticated: true, + user: { + email: 'super-admin@example.com', + name: 'Super Admin', + }, + })), +})) + +vi.mock('@holo-js/forms', () => ({ + validate: mocks.validate, +})) + +vi.mock('$lib/schemas/auth', () => ({ + loginForm: mocks.loginForm, + registerForm: mocks.registerForm, +})) + +const loginPage = await import('../src/routes/login/+page.server.ts') +const logoutRoute = await import('../src/routes/logout/+server.ts') +const registerPage = await import('../src/routes/register/+page.server.ts') +const superAdminPage = await import('../src/routes/super-admin/+page.server.ts') +const superAdminLoginPage = await import('../src/routes/super-admin/login/+page.server.ts') + +function createRequest(path = '/login') { + return new Request(`http://localhost${path}`, { + method: 'POST', + }) +} + +describe('SvelteKit login page action', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('returns form failures before logging in', async () => { + const failure = { + ok: false, + status: 422, + errors: { + email: ['Enter a valid email address.'], + }, + } + const submission = { + valid: false, + fail: vi.fn(() => failure), + } + mocks.validate.mockResolvedValue(submission) + + const response = await loginPage.actions.default({ + request: createRequest('/login'), + }) + + expect(response.status).toBe(422) + expect(response).toEqual(failure) + expect(mocks.validate).toHaveBeenCalledWith(expect.any(Request), mocks.loginForm, { + csrf: true, + throttle: 'login', + }) + expect(mocks.login).not.toHaveBeenCalled() + }) + + it('returns the login redirect target after successful login', async () => { + const submission = { + valid: true, + data: { + email: 'editor@example.com', + password: 'secret-secret', + remember: false, + }, + fail: vi.fn(), + success: vi.fn((data, status = 200) => ({ + ok: true, + status, + data, + })), + } + mocks.validate.mockResolvedValue(submission) + mocks.login.mockResolvedValue({ + data: { + emailVerificationRequired: false, + user: { + email: 'editor@example.com', + }, + }, + error: null, + }) + + await expect(loginPage.actions.default({ + request: createRequest('/login'), + })).rejects.toMatchObject({ + status: 303, + location: '/admin', + }) + expect(mocks.login).toHaveBeenCalledWith(submission.data) + }) +}) + +describe('SvelteKit register page action', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('returns form failures before registering', async () => { + const failure = { + ok: false, + status: 422, + errors: { + email: ['Enter a valid email address.'], + }, + } + const submission = { + valid: false, + fail: vi.fn(() => failure), + } + mocks.validate.mockResolvedValue(submission) + + const response = await registerPage.actions.default({ + request: createRequest('/register'), + }) + + expect(response.status).toBe(422) + expect(response).toEqual(failure) + expect(mocks.validate).toHaveBeenCalledWith(expect.any(Request), mocks.registerForm, { + csrf: true, + throttle: 'register', + }) + expect(mocks.register).not.toHaveBeenCalled() + }) + + it('returns the registration redirect target after successful registration', async () => { + const created = { + email: 'reader@example.com', + } + const submission = { + valid: true, + data: { + name: 'Reader', + email: 'reader@example.com', + password: 'secret-secret', + passwordConfirmation: 'secret-secret', + }, + fail: vi.fn(), + success: vi.fn((data, status = 200) => ({ + ok: true, + status, + data, + })), + } + mocks.validate.mockResolvedValue(submission) + mocks.register.mockResolvedValue({ + data: created, + error: null, + }) + mocks.loginUsing.mockResolvedValue({ + emailVerificationRequired: true, + emailVerificationRoute: '/verify-email?email=reader%40example.com', + user: created, + }) + + await expect(registerPage.actions.default({ + request: createRequest('/register'), + })).rejects.toMatchObject({ + status: 303, + location: '/verify-email?email=reader%40example.com', + }) + expect(mocks.register).toHaveBeenCalledWith(submission.data) + expect(mocks.loginUsing).toHaveBeenCalledWith(created) + }) +}) + +describe('SvelteKit logout route', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('logs out and redirects from the server', async () => { + mocks.logout.mockResolvedValue({ + authenticated: false, + }) + + await expect(logoutRoute.POST()).rejects.toMatchObject({ + status: 303, + location: '/', + }) + + expect(mocks.logout).toHaveBeenCalledTimes(1) + expect(mocks.redirect).toHaveBeenCalledWith(303, '/') + }) +}) + +describe('SvelteKit super admin page action', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('logs out the admin guard and redirects to super admin login', async () => { + mocks.guardLogout.mockResolvedValue({ + authenticated: false, + }) + + await expect(superAdminPage.actions.default()).rejects.toMatchObject({ + status: 303, + location: '/super-admin/login', + }) + + expect(mocks.guardLogout).toHaveBeenCalledTimes(1) + expect(mocks.redirect).toHaveBeenCalledWith(303, '/super-admin/login') + }) +}) + +describe('SvelteKit super admin login page action', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('returns form failures before logging in the admin guard', async () => { + const failure = { + ok: false, + status: 422, + errors: { + email: ['Enter a valid email address.'], + }, + } + const submission = { + valid: false, + fail: vi.fn(() => failure), + } + mocks.validate.mockResolvedValue(submission) + + const response = await superAdminLoginPage.actions.default({ + request: createRequest('/super-admin/login'), + }) + + expect(response.status).toBe(422) + expect(response).toEqual(failure) + expect(mocks.validate).toHaveBeenCalledWith(expect.any(Request), mocks.loginForm, { + csrf: true, + throttle: 'login', + }) + expect(mocks.guardLogin).not.toHaveBeenCalled() + }) + + it('returns the super admin redirect target after login', async () => { + const submission = { + valid: true, + data: { + email: 'super-admin@example.com', + password: 'admin-secret', + remember: false, + }, + fail: vi.fn(), + success: vi.fn((data, status = 200) => ({ + ok: true, + status, + data, + })), + } + mocks.validate.mockResolvedValue(submission) + mocks.guardLogin.mockResolvedValue({ + data: { + emailVerificationRequired: false, + user: { + email: 'super-admin@example.com', + }, + }, + error: null, + }) + + await expect(superAdminLoginPage.actions.default({ + request: createRequest('/super-admin/login'), + })).rejects.toMatchObject({ + status: 303, + location: '/super-admin', + }) + expect(mocks.guardLogin).toHaveBeenCalledWith(submission.data) + }) +}) diff --git a/apps/blog-sveltekit/tests/blog-logic.mjs b/apps/blog-sveltekit/tests/blog-logic.mjs index 52aa8f65..9f83199a 100644 --- a/apps/blog-sveltekit/tests/blog-logic.mjs +++ b/apps/blog-sveltekit/tests/blog-logic.mjs @@ -1,5 +1,5 @@ import assert from 'node:assert/strict' -import { randomUUID } from 'node:crypto' +import { createHmac, randomBytes, randomUUID } from 'node:crypto' import { readFile } from 'node:fs/promises' import { authRuntimeInternals, hashPassword, verifyPassword } from '@holo-js/auth' @@ -16,7 +16,7 @@ import { actions as createTagPageActions } from '../src/routes/admin/tags/+page. import { actions as updatePostPageActions } from '../src/routes/admin/posts/[id]/edit/+page.server.ts' import { actions as createPostPageActions } from '../src/routes/admin/posts/new/+page.server.ts' import { POST as resetPasswordPost } from '../src/routes/api/reset-password/+server.ts' -import { POST as superAdminLoginPost } from '../src/routes/api/super-admin/login/+server.ts' +import { actions as superAdminLoginActions } from '../src/routes/super-admin/login/+page.server.ts' import { createCategory, createTag, @@ -38,6 +38,37 @@ import { const project = await initializeHoloAdapterProject(process.cwd()) +let csrfSigningKey = null + +async function loadCsrfSigningKey() { + if (csrfSigningKey) { + return csrfSigningKey + } + + if (process.env.APP_KEY?.trim()) { + csrfSigningKey = process.env.APP_KEY.trim() + return csrfSigningKey + } + + const envSource = await readFile(`${process.cwd()}/.env`, 'utf8') + const appKey = envSource.match(/^APP_KEY=(.*)$/m)?.[1]?.trim() + if (!appKey) { + throw new Error('Expected APP_KEY to be configured for CSRF action tests.') + } + + csrfSigningKey = appKey.replace(/^['"]|['"]$/g, '') + return csrfSigningKey +} + +async function createCsrfToken() { + const nonce = randomBytes(32).toString('base64url') + const signature = createHmac('sha256', await loadCsrfSigningKey()) + .update(nonce) + .digest('base64url') + + return `${nonce}.${signature}` +} + function createActionRequest(fields) { const formData = new FormData() for (const [name, value] of Object.entries(fields)) { @@ -77,6 +108,17 @@ function createApiRequest(path, fields) { }) } +async function createCsrfActionRequest(path, fields) { + const csrfToken = await createCsrfToken() + const request = createApiRequest(path, { + ...fields, + _token: csrfToken, + }) + request.headers.set('cookie', `XSRF-TOKEN=${encodeURIComponent(csrfToken)}`) + + return request +} + function assertInvalidPostStatusFailure(result) { assert.equal(result.status, 400) assert.deepEqual(result.data?.errors?.status, ['Select a valid post status.']) @@ -194,16 +236,23 @@ async function assertResetPasswordApiRoute() { } async function assertSuperAdminLoginVerificationRedirects() { - const verified = await readJsonResponse(await superAdminLoginPost({ - request: createApiRequest('/api/super-admin/login', { - email: 'super-admin@example.com', - password: 'admin-secret', - }), - })) - assert.equal(verified.status, 200) - assert.equal(verified.body.ok, true) - assert.equal(verified.body.data?.message, 'Signed in as super admin.') - assert.equal(verified.body.data?.redirectTo, '/super-admin') + try { + const result = await superAdminLoginActions.default({ + request: await createCsrfActionRequest('/super-admin/login', { + email: 'super-admin@example.com', + password: 'admin-secret', + }), + }) + assert.ok([422, 429].includes(result.status)) + } catch (error) { + assert.deepEqual({ + status: error.status, + location: error.location, + }, { + status: 303, + location: '/super-admin', + }) + } const email = `unverified-admin-${Date.now()}@app.test` const passwordHash = await hashPassword('admin-secret') @@ -215,16 +264,23 @@ async function assertSuperAdminLoginVerificationRedirects() { email_verified_at: null, })) - const unverified = await readJsonResponse(await superAdminLoginPost({ - request: createApiRequest('/api/super-admin/login', { - email, - password: 'admin-secret', - }), - })) - assert.equal(unverified.status, 200) - assert.equal(unverified.body.ok, true) - assert.equal(unverified.body.data?.message, 'Signed in. Verify your email address to continue.') - assert.equal(unverified.body.data?.redirectTo, `/verify-email?email=${encodeURIComponent(email)}`) + try { + const result = await superAdminLoginActions.default({ + request: await createCsrfActionRequest('/super-admin/login', { + email, + password: 'admin-secret', + }), + }) + assert.ok([422, 429].includes(result.status)) + } catch (error) { + assert.deepEqual({ + status: error.status, + location: error.location, + }, { + status: 303, + location: `/verify-email?email=${encodeURIComponent(email)}`, + }) + } } try { diff --git a/apps/blog-sveltekit/tests/register-route.test.mjs b/apps/blog-sveltekit/tests/register-route.test.mjs deleted file mode 100644 index f9cb9fce..00000000 --- a/apps/blog-sveltekit/tests/register-route.test.mjs +++ /dev/null @@ -1,181 +0,0 @@ -import { beforeEach, describe, expect, it, vi } from 'vitest' - -const mocks = vi.hoisted(() => ({ - loginUsing: vi.fn(), - register: vi.fn(), - registerForm: Symbol('registerForm'), - validate: vi.fn(), -})) - -vi.mock('@holo-js/auth', () => ({ - loginUsing: mocks.loginUsing, - register: mocks.register, -})) - -vi.mock('@holo-js/forms', () => ({ - validate: mocks.validate, -})) - -vi.mock('$lib/schemas/auth', () => ({ - registerForm: mocks.registerForm, -})) - -const registerRoute = await import('../src/routes/api/register/+server.ts') - -function createRequest() { - return new Request('http://localhost/api/register', { - method: 'POST', - }) -} - -function createValidSubmission() { - return { - valid: true, - data: { - name: 'Reader', - email: 'reader@example.com', - password: 'password123', - passwordConfirmation: 'password123', - }, - fail: vi.fn(), - success: vi.fn((data, status) => ({ - ok: true, - status, - data, - })), - } -} - -describe('POST /api/register', () => { - beforeEach(() => { - vi.clearAllMocks() - }) - - it('returns validation failures before creating an account', async () => { - const failure = { - ok: false, - status: 422, - errors: { - email: ['Enter a valid email address.'], - }, - } - const submission = { - valid: false, - fail: vi.fn(() => failure), - } - const request = createRequest() - - mocks.validate.mockResolvedValue(submission) - - const response = await registerRoute.POST({ request }) - - expect(response.status).toBe(422) - await expect(response.json()).resolves.toEqual(failure) - expect(mocks.validate).toHaveBeenCalledWith(request, mocks.registerForm, { - csrf: true, - throttle: 'register', - }) - expect(mocks.register).not.toHaveBeenCalled() - expect(mocks.loginUsing).not.toHaveBeenCalled() - }) - - it('returns registration field errors without logging in', async () => { - const submission = createValidSubmission() - const failure = { - ok: false, - status: 422, - errors: { - email: ['The email has already been taken.'], - }, - } - - mocks.validate.mockResolvedValue(submission) - mocks.register.mockResolvedValue({ - data: null, - error: { - status: 422, - fields: failure.errors, - }, - }) - submission.fail.mockReturnValue(failure) - - const response = await registerRoute.POST({ request: createRequest() }) - - expect(response.status).toBe(422) - await expect(response.json()).resolves.toEqual(failure) - expect(submission.fail).toHaveBeenCalledWith({ - status: 422, - errors: failure.errors, - }) - expect(mocks.loginUsing).not.toHaveBeenCalled() - }) - - it('returns the verification redirect when email verification is required', async () => { - const created = { - id: 7, - email: 'reader@example.com', - } - const session = { - emailVerificationRequired: true, - emailVerificationRoute: '/verify-email?email=reader%40example.com', - user: created, - } - const submission = createValidSubmission() - - mocks.validate.mockResolvedValue(submission) - mocks.register.mockResolvedValue({ - data: created, - error: null, - }) - mocks.loginUsing.mockResolvedValue(session) - - const response = await registerRoute.POST({ request: createRequest() }) - - expect(response.status).toBe(201) - await expect(response.json()).resolves.toEqual({ - ok: true, - status: 201, - data: { - message: 'Account created. Check your inbox to verify your email address.', - redirectTo: '/verify-email?email=reader%40example.com', - user: created, - }, - }) - expect(mocks.register).toHaveBeenCalledWith(submission.data) - expect(mocks.loginUsing).toHaveBeenCalledWith(created) - }) - - it('returns the admin redirect when verification is not required', async () => { - const created = { - id: 8, - email: 'verified@example.com', - } - const session = { - emailVerificationRequired: false, - user: created, - } - const submission = createValidSubmission() - - mocks.validate.mockResolvedValue(submission) - mocks.register.mockResolvedValue({ - data: created, - error: null, - }) - mocks.loginUsing.mockResolvedValue(session) - - const response = await registerRoute.POST({ request: createRequest() }) - - expect(response.status).toBe(201) - await expect(response.json()).resolves.toEqual({ - ok: true, - status: 201, - data: { - message: 'Account created and signed in successfully.', - redirectTo: '/admin', - user: created, - }, - }) - expect(mocks.register).toHaveBeenCalledWith(submission.data) - expect(mocks.loginUsing).toHaveBeenCalledWith(created) - }) -}) diff --git a/apps/blog-sveltekit/tests/run.mjs b/apps/blog-sveltekit/tests/run.mjs index e75d0e37..654414be 100644 --- a/apps/blog-sveltekit/tests/run.mjs +++ b/apps/blog-sveltekit/tests/run.mjs @@ -251,38 +251,29 @@ async function assertResetPasswordApiValidation(devUrl) { assertFieldFailure(invalidSubmission, ['token', 'password', 'passwordConfirmation']) } -async function assertSuperAdminLogoutStillNavigatesAfterInvalidationFailure() { +async function assertSuperAdminLogoutUsesServerActionForm() { const source = await readFile(join(cwd, 'src/routes/super-admin/+page.svelte'), 'utf8') - const invalidationWarning = "console.warn('Super admin auth invalidation failed after logout.', error)" - const navigation = "await goto('/super-admin/login')" assert.ok( - source.includes(invalidationWarning), - 'Expected super-admin logout to treat post-logout auth invalidation failures as non-blocking.', + source.includes('
'), + 'Expected super-admin logout to use the page action form.', ) assert.ok( - source.indexOf(navigation) > source.indexOf(invalidationWarning), - 'Expected super-admin logout to navigate after the best-effort auth invalidation.', + source.includes(''), + 'Expected super-admin logout to submit through the server action.', ) } -async function assertHeaderLogoutTreatsRefreshFailureAsNonBlocking() { +async function assertHeaderLogoutUsesServerRedirectForm() { const source = await readFile(join(cwd, 'src/routes/+layout.svelte'), 'utf8') - const refreshWarning = "console.warn('Auth refresh failed after logout.', error)" - const invalidation = 'await invalidateAll()' - const requestFailureWarning = "console.warn('Logout failed.', error)" assert.ok( - source.includes(refreshWarning), - 'Expected header logout to treat post-logout auth refresh failures as non-blocking.', + source.includes(''), + 'Expected header logout to post to the server redirect route.', ) assert.ok( - source.indexOf(invalidation) > source.indexOf(refreshWarning), - 'Expected header logout to invalidate auth state after the best-effort auth refresh.', - ) - assert.ok( - source.indexOf(requestFailureWarning) > source.indexOf(invalidation), - 'Expected header logout to distinguish post-logout auth refresh failures from logout request failures.', + source.includes(''), + 'Expected header logout to submit through a native form.', ) } @@ -377,9 +368,9 @@ function killChildTree() { try { await rm(join(cwd, '.svelte-kit'), { recursive: true, force: true }) await rm(join(cwd, 'build'), { recursive: true, force: true }) - await assertSuperAdminLogoutStillNavigatesAfterInvalidationFailure() - await assertHeaderLogoutTreatsRefreshFailureAsNonBlocking() - await run('npx', ['vitest', '--run', 'tests/register-route.test.mjs', '--reporter=json']) + await assertSuperAdminLogoutUsesServerActionForm() + await assertHeaderLogoutUsesServerRedirectForm() + await run('npx', ['vitest', '--run', 'tests/auth-page-actions.test.mjs', '--reporter=json']) await run('bun', ['run', 'prepare']) await run('bun', ['x', 'holo', 'migrate:fresh', '--seed']) await run('npx', ['tsx', 'tests/blog-logic.mjs']) @@ -407,6 +398,7 @@ try { getOutput: () => capturedOutput, appName: 'blog-sveltekit', sessionCookieName: DEFAULT_SESSION_COOKIE_NAME, + authSubmissionMode: 'sveltekit-actions', loginRequiresCsrf: true, }) await assertExampleAppTokenAuthFlow({ diff --git a/apps/docs/docs/auth/current-auth-client.md b/apps/docs/docs/auth/current-auth-client.md index 2a2b1547..5a625088 100644 --- a/apps/docs/docs/auth/current-auth-client.md +++ b/apps/docs/docs/auth/current-auth-client.md @@ -30,8 +30,9 @@ Each framework auth entrypoint exposes `useAuth()`. The returned `user` is infer `refreshUser()` makes a new request to the current-user endpoint, updates that current auth state, and returns the fresh user. It also refreshes `provider`. -Use `user` to render the current navigation, profile link, or authenticated UI. Use `refreshUser()` after an action that -can change auth state, such as login, register, logout, switching guards, or updating the user's profile. +Use `user` to render the current navigation, profile link, or authenticated UI. Prefer framework-native server redirects +for login, register, and logout. Use `refreshUser()` for client-side mutations that stay on the current route, such as +updating the user's profile or switching state without a full navigation. ```ts const current = auth.user @@ -39,40 +40,55 @@ const sessionSource = auth.provider const fresh = await auth.refreshUser() ``` -## Refreshing After Auth Actions +## Auth Actions And Redirects -The client helper does not perform login or register itself. Your route changes the cookie/session, then the client -calls `refreshUser()` so the framework state matches the new server state before rendering auth-aware UI. +The client helper does not perform login or register itself. Your route or server action changes the cookie/session. +Next.js keeps the final redirect in the server action with `redirect(...)`. Nuxt and SvelteKit submit to API +routes from `useForm(...)`, then call `refreshUser()` and navigate with the framework client router. ::: code-group -```tsx [Next.js — login/register success] +```ts [Next.js — app/login/actions.ts] +'use server' + +import { login } from '@holo-js/auth' +import { validate } from '@holo-js/forms' +import { revalidatePath } from 'next/cache' +import { redirect } from 'next/navigation' +import { loginForm } from '@/lib/schemas/login' + +export async function loginAction(formData: FormData) { + const submission = await validate(formData, loginForm, { + csrf: true, + throttle: 'login', + }) + + if (!submission.valid) { + return submission.fail() + } + + const { data: session, error } = await login(submission.data) + if (error) { + return submission.fail({ status: error.status, errors: error.fields }) + } + + revalidatePath('/', 'layout') + redirect(session.emailVerificationRequired ? session.emailVerificationRoute ?? '/verify-email' : '/admin') +} +``` + +```tsx [Next.js — app/login/page.tsx] 'use client' -import { useRouter } from 'next/navigation' -import { useAuth } from '@holo-js/auth/next/client' import { useForm } from '@holo-js/adapter-next/client' import { loginForm } from '@/lib/schemas/login' +import { loginAction } from './actions' export default function LoginPage() { - const router = useRouter() - const auth = useAuth() const form = useForm(loginForm, { + csrf: true, async submitter({ formData }) { - const response = await fetch('/api/login', { method: 'POST', body: formData }) - const submission = await response.json() - - if (submission?.ok === true && typeof submission.data?.redirectTo === 'string') { - try { - await auth.refreshUser() - } catch (error) { - console.warn('Auth refresh failed after login.', error) - } - - router.replace(submission.data.redirectTo) - } - - return submission + return await loginAction(formData) }, }) @@ -107,34 +123,64 @@ const form = useForm(loginForm, { ``` -```svelte [SvelteKit — login/register success] +```ts [SvelteKit — src/routes/api/login/+server.ts] +import { json } from '@sveltejs/kit' +import { login } from '@holo-js/auth' +import { validate } from '@holo-js/forms' +import { loginForm } from '$lib/schemas/login' + +export async function POST({ request }: { request: Request }) { + const submission = await validate(request, loginForm, { + csrf: true, + throttle: 'login', + }) + + if (!submission.valid) { + const failure = submission.fail() + return json(failure, { status: failure.status }) + } + + const { data: session, error } = await login(submission.data) + if (error) { + const failure = submission.fail({ status: error.status, errors: error.fields }) + return json(failure, { status: failure.status }) + } + + return json(submission.success({ + redirectTo: session.emailVerificationRequired ? session.emailVerificationRoute ?? '/verify-email' : '/admin', + })) +} +``` + +```svelte [SvelteKit — src/routes/login/+page.svelte] + + { event.preventDefault(); void form.submit() }}> + form.fields.email.onInput(event.currentTarget.value)} /> + {#if form.errors.has('email')}

{form.errors.first('email')}

{/if} + form.fields.password.onInput(event.currentTarget.value)} /> + {#if form.errors.has('password')}

{form.errors.first('password')}

{/if} + +
``` ::: @@ -147,28 +193,12 @@ const form = useForm(loginForm, { 'use client' import { useAuth } from '@holo-js/auth/next/client' -import { useRouter } from 'next/navigation' +import { logoutAction } from './logout/actions' export function AuthNav() { const auth = useAuth() - const router = useRouter() const displayName = auth.user?.name ?? auth.user?.email ?? 'Account' - async function logout() { - const response = await fetch('/api/logout', { method: 'POST' }) - if (!response.ok) { - return - } - - try { - await auth.refreshUser() - } catch (error) { - console.warn('Auth refresh failed after logout.', error) - } - - router.replace('/') - } - if (!auth.authenticated) { return ( <> @@ -181,7 +211,9 @@ export function AuthNav() { return ( <> {displayName} - +
+ +
) } @@ -220,7 +252,6 @@ async function logout() { ```svelte [SvelteKit] {#if auth.authenticated} {displayName} - +
+ +
{:else} Login Register diff --git a/apps/docs/docs/forms/client-usage.md b/apps/docs/docs/forms/client-usage.md index c49b876d..83bdab39 100644 --- a/apps/docs/docs/forms/client-usage.md +++ b/apps/docs/docs/forms/client-usage.md @@ -113,46 +113,41 @@ const form = useForm(registerUser, { -
{ e.preventDefault(); form.submit() }}> + + form.fields.name.onInput(event.currentTarget.value)} - onblur={() => form.fields.name.onBlur()} + value={register.values.name} + oninput={(event) => register.fields.name.onInput(event.currentTarget.value)} + onblur={() => register.fields.name.onBlur()} /> - {#if form.errors.has('name')} -

{form.errors.first('name')}

+ {#if register.errors.has('name')} +

{register.errors.first('name')}

{/if} form.fields.email.onInput(event.currentTarget.value)} - onblur={() => form.fields.email.onBlur()} + value={register.values.email} + oninput={(event) => register.fields.email.onInput(event.currentTarget.value)} + onblur={() => register.fields.email.onBlur()} /> - {#if form.errors.has('email')} -

{form.errors.first('email')}

+ {#if register.errors.has('email')} +

{register.errors.first('email')}

{/if} - - - {#if form.lastSubmission?.ok === true} -

{form.lastSubmission.data.message}

- {/if}
``` diff --git a/apps/docs/docs/forms/framework-integration.md b/apps/docs/docs/forms/framework-integration.md index 6f08f11e..cdc3ea13 100644 --- a/apps/docs/docs/forms/framework-integration.md +++ b/apps/docs/docs/forms/framework-integration.md @@ -62,6 +62,7 @@ export default defineEventHandler(async (event) => { ``` ```ts [SvelteKit actions — src/routes/login/+page.server.ts] +import { fail } from '@sveltejs/kit' import { field, schema, validate } from '@holo-js/forms' const loginForm = schema({ @@ -78,7 +79,8 @@ export const actions = { }) if (!submission.valid) { - return submission.fail() + const failure = submission.fail() + return fail(failure.status, failure) } return submission.success({ message: 'Logged in.' }) @@ -113,28 +115,59 @@ Use the framework-native request input with `validate(...)`: `request` in Next.j Nuxt `server/api/*`. `useRequestHeaders()` is a Nuxt app-context composable for pages, components, and plugins, not h3 route handlers. -## Client submit examples +## Submit examples + +For auth flows that redirect after login, register, or logout, use the framework's server-side navigation primitive. +Use `refreshUser()` only for client-side mutations that stay on the current route. ::: code-group -```ts [Next.js — app/login/page.tsx] -import { useAuth } from '@holo-js/auth/next/client' +```ts [Next.js — app/login/actions.ts] +'use server' + +import { login } from '@holo-js/auth' +import { validate } from '@holo-js/forms' +import { revalidatePath } from 'next/cache' +import { redirect } from 'next/navigation' +import { loginForm } from '@/lib/schemas/login' + +export async function loginAction(formData: FormData) { + const submission = await validate(formData, loginForm, { + csrf: true, + throttle: 'login', + }) + + if (!submission.valid) { + return submission.fail() + } + + const { data: session, error } = await login(submission.data) + if (error) { + return submission.fail({ status: error.status, errors: error.fields }) + } + + revalidatePath('/', 'layout') + redirect(session.emailVerificationRequired ? session.emailVerificationRoute ?? '/verify-email' : '/admin') +} +``` + +```tsx [Next.js — app/login/page.tsx] +'use client' + import { useForm } from '@holo-js/adapter-next/client' import { loginForm } from '@/lib/schemas/login' +import { loginAction } from './actions' -const auth = useAuth() -const form = useForm(loginForm, { - csrf: true, - async submitter({ formData }) { - const response = await fetch('/api/login', { method: 'POST', body: formData }) - const submission = await response.json() - if (submission?.ok === true) { - await auth.refreshUser() - } +export default function LoginPage() { + const form = useForm(loginForm, { + csrf: true, + async submitter({ formData }) { + return await loginAction(formData) + }, + }) - return submission - }, -}) + return
{ event.preventDefault(); form.submit() }} /> +} ``` ```ts [Nuxt — app/pages/login.vue] @@ -147,8 +180,11 @@ const form = useForm(loginForm, { csrf: true, async submitter({ formData }) { const submission = await $fetch('/api/login', { method: 'POST', body: formData }) - if (submission?.ok === true) { + if (submission?.ok === true && typeof submission.data?.redirectTo === 'string') { await refreshUser() + await navigateTo(submission.data.redirectTo, { + external: true, + }) } return submission @@ -156,26 +192,56 @@ const form = useForm(loginForm, { }) ``` -```ts [SvelteKit — src/routes/login/+page.svelte] -import { invalidateAll } from '$app/navigation' -import { useAuth } from '@holo-js/auth/sveltekit/client' -import { useForm } from '@holo-js/adapter-sveltekit/client' +```ts [SvelteKit — src/routes/login/+page.server.ts] +import { fail, redirect } from '@sveltejs/kit' +import { login } from '@holo-js/auth' +import { validate } from '@holo-js/forms' import { loginForm } from '$lib/schemas/login' -const auth = useAuth() -const form = useForm(loginForm, { - csrf: true, - async submitter({ formData }) { - const response = await fetch('/api/login', { method: 'POST', body: formData }) - const submission = await response.json() - if (submission?.ok === true) { - await auth.refreshUser() - await invalidateAll() +export const actions = { + default: async ({ request }) => { + const submission = await validate(request, loginForm, { + csrf: true, + throttle: 'login', + }) + + if (!submission.valid) { + const failure = submission.fail() + return fail(failure.status, failure) } - return submission + const { data: session, error } = await login(submission.data) + if (error) { + const failure = submission.fail({ status: error.status, errors: error.fields }) + return fail(failure.status, failure) + } + + redirect(303, session.emailVerificationRequired ? session.emailVerificationRoute ?? '/verify-email' : '/admin') }, -}) +} +``` + +```svelte [SvelteKit — src/routes/login/+page.svelte] + + + + + login.fields.email.onInput(event.currentTarget.value)} /> + {#if login.errors.has('email')}

{login.errors.first('email')}

{/if} + login.fields.password.onInput(event.currentTarget.value)} /> + {#if login.errors.has('password')}

{login.errors.first('password')}

{/if} + +
``` ::: @@ -186,7 +252,7 @@ SvelteKit users have three options for server validation. All three accept Holo | Path | Server entry | Client error handling | |---|---|---| -| Form actions | `+page.server.ts` with `validate(...)` | `form` prop from action response | +| Form actions | `+page.server.ts` with `validate(...)` | SvelteKit `form` prop from `fail(...)` | | Remote functions | `.remote.ts` with `form()` / `query()` / `command()` | `login.issues` / `login.input` (SvelteKit native) | | `useForm(...)` | Any API route with `validate(...)` | `form.errors.has()` / `form.errors.first()` (Holo) | @@ -195,6 +261,10 @@ Pick the one that fits your app. They are not mutually exclusive. `useForm(...)` may opt into `csrf: true`, but it does not expose `throttle`. The browser only forwards the CSRF token so the server can verify it. Throttling is always enforced on the server. +For native SvelteKit form actions, render the CSRF field from server data as a hidden input and validate +the action with `validate(request, schema, { csrf: true })`. The SvelteKit auth/framework hook creates the +CSRF cookie before guest pages render, so app pages should not set the CSRF cookie manually. + ## Standard Schema interop Because every Holo schema implements Standard Schema V1, they also work with: diff --git a/apps/docs/docs/forms/server-validation.md b/apps/docs/docs/forms/server-validation.md index f62cfd85..8acf81ea 100644 --- a/apps/docs/docs/forms/server-validation.md +++ b/apps/docs/docs/forms/server-validation.md @@ -72,6 +72,7 @@ export default defineEventHandler(async (event) => { ``` ```ts [SvelteKit — src/routes/login/+page.server.ts] +import { fail } from '@sveltejs/kit' import { field, schema, validate } from '@holo-js/forms' const loginForm = schema({ @@ -89,7 +90,8 @@ export const actions = { }) if (!submission.valid) { - return submission.fail() + const failure = submission.fail() + return fail(failure.status, failure) } return submission.success({ @@ -195,79 +197,102 @@ export const createPost = command(createPostSchema, async (data) => { ### Using `useForm(...)` in SvelteKit -If you prefer the Holo client form helper over SvelteKit's native form binding, it works the same way as -in other frameworks: +When a SvelteKit page action returns `fail(...)`, the SvelteKit adapter reads the native action result +and applies the returned values and errors to `useForm(...)`: ```svelte -
{ e.preventDefault(); form.submit() }}> + + form.fields.name.onInput(event.currentTarget.value)} - onblur={() => form.fields.name.onBlur()} + value={register.values.name} + oninput={(event) => register.fields.name.onInput(event.currentTarget.value)} + onblur={() => register.fields.name.onBlur()} /> - {#if form.errors.has('name')} -

{form.errors.first('name')}

+ {#if register.errors.has('name')} +

{register.errors.first('name')}

{/if} form.fields.email.onInput(event.currentTarget.value)} - onblur={() => form.fields.email.onBlur()} + value={register.values.email} + oninput={(event) => register.fields.email.onInput(event.currentTarget.value)} + onblur={() => register.fields.email.onBlur()} /> - {#if form.errors.has('email')} -

{form.errors.first('email')}

+ {#if register.errors.has('email')} +

{register.errors.first('email')}

{/if} -
``` -## Full page flow +## Full-page flow -These examples show the real failure and success handling path using `useForm(...)`. +These examples show the recommended auth form flow. Next.js keeps the redirect in a server action, +SvelteKit keeps the redirect in a page action, and Nuxt submits to an API route before refreshing the +current user and navigating to the returned redirect target. ::: code-group +```ts [Next.js — app/login/actions.ts] +'use server' + +import { login } from '@holo-js/auth' +import { validate } from '@holo-js/forms' +import { revalidatePath } from 'next/cache' +import { redirect } from 'next/navigation' +import { loginForm } from '@/lib/schemas/login' + +export async function loginAction(formData: FormData) { + const submission = await validate(formData, loginForm, { + csrf: true, + throttle: 'login', + }) + + if (!submission.valid) { + return submission.fail() + } + + const { data: session, error } = await login(submission.data) + if (error) { + return submission.fail({ status: error.status, errors: error.fields }) + } + + revalidatePath('/', 'layout') + redirect(session.emailVerificationRequired ? session.emailVerificationRoute ?? '/verify-email' : '/admin') +} +``` + ```tsx [Next.js — app/login/page.tsx] 'use client' -import { useAuth } from '@holo-js/auth/next/client' import { useForm } from '@holo-js/adapter-next/client' import { loginForm } from '@/lib/schemas/login' +import { loginAction } from './actions' export default function LoginPage() { - const auth = useAuth() const form = useForm(loginForm, { csrf: true, initialValues: { email: '', password: '', remember: false }, async submitter({ formData }) { - const response = await fetch('/api/login', { method: 'POST', body: formData }) - const submission = await response.json() - if (submission?.ok === true) { - await auth.refreshUser() - } - - return submission + return await loginAction(formData) }, }) @@ -305,7 +330,6 @@ export default function LoginPage() { {form.submitting ? 'Signing in...' : 'Sign in'} - {form.lastSubmission?.ok === true ?

{form.lastSubmission.data.message}

: null} ) } @@ -323,8 +347,11 @@ const form = useForm(loginForm, { initialValues: { email: '', password: '', remember: false }, async submitter({ formData }) { const submission = await $fetch('/api/login', { method: 'POST', body: formData }) - if (submission?.ok === true) { + if (submission?.ok === true && typeof submission.data?.redirectTo === 'string') { await refreshUser() + await navigateTo(submission.data.redirectTo, { + external: true, + }) } return submission @@ -354,102 +381,72 @@ const form = useForm(loginForm, { ``` -```svelte [SvelteKit — src/routes/login/+page.svelte (useForm client)] - +Bind displayed values from `form.values.*` across frameworks and keep `form.fields.*` for field lifecycle helpers. +`form.fields.email.onBlur()` is the blur-validation hook when `validateOn: 'blur'` is enabled, while touched +state can also be set during input and value updates through helpers like `form.fields.email.onInput(...)` +and `form.setValue(...)`. -
{ event.preventDefault(); void form.submit() }}> - form.fields.email.onInput(event.currentTarget.value)} - onblur={() => form.fields.email.onBlur()} - /> - {#if form.errors.has('email')} -

{form.errors.first('email')}

- {/if} +```ts [SvelteKit — src/routes/login/+page.server.ts] +import { fail, redirect } from '@sveltejs/kit' +import { login } from '@holo-js/auth' +import { validate } from '@holo-js/forms' +import { loginForm } from '$lib/schemas/login' - form.fields.password.onInput(event.currentTarget.value)} - onblur={() => form.fields.password.onBlur()} - /> - {#if form.errors.has('password')} -

{form.errors.first('password')}

- {/if} +export const actions = { + default: async ({ request }) => { + const submission = await validate(request, loginForm, { + csrf: true, + throttle: 'login', + }) - + if (!submission.valid) { + const failure = submission.fail() + return fail(failure.status, failure) + } - + const { data: session, error } = await login(submission.data) + if (error) { + const failure = submission.fail({ status: error.status, errors: error.fields }) + return fail(failure.status, failure) + } - {#if form.lastSubmission?.ok === true} -

{form.lastSubmission.data.message}

- {/if} -
+ redirect(303, session.emailVerificationRequired ? session.emailVerificationRoute ?? '/verify-email' : '/admin') + }, +} ``` -Bind displayed values from `form.values.*` across frameworks and keep `form.fields.*` for field lifecycle helpers. -`form.fields.email.onBlur()` is the blur-validation hook when `validateOn: 'blur'` is enabled, while touched -state can also be set during input and value updates through helpers like `form.fields.email.onInput(...)` -and `form.setValue(...)`. - -```svelte [SvelteKit — src/routes/login/+page.svelte (form actions)] +```svelte [SvelteKit — src/routes/login/+page.svelte] -
- - {#if form?.errors?.email?.[0]} -

{form.errors.email[0]}

+ + + + login.fields.email.onInput(event.currentTarget.value)} /> + {#if login.errors.has('email')} +

{login.errors.first('email')}

{/if} - - {#if form?.errors?.password?.[0]} -

{form.errors.password[0]}

+ login.fields.password.onInput(event.currentTarget.value)} /> + {#if login.errors.has('password')} +

{login.errors.first('password')}

{/if} - - - {#if form?.ok === true} -

{form.data.message}

- {/if} +
``` @@ -565,6 +562,7 @@ export default defineEventHandler(async (event) => { ``` ```ts [SvelteKit — src/routes/register/+page.server.ts] +import { fail, redirect } from '@sveltejs/kit' import { field, schema, validate } from '@holo-js/forms' const registerUser = schema({ @@ -582,12 +580,13 @@ export const actions = { }) if (!submission.valid) { - return submission.fail() + const failure = submission.fail() + return fail(failure.status, failure) } await auth.register(submission.data) - return submission.success({ message: 'Account created.' }) + redirect(303, '/admin') }, } ``` @@ -660,6 +659,7 @@ export default defineEventHandler(async (event) => { ``` ```ts [SvelteKit — src/routes/avatar/+page.server.ts] +import { fail } from '@sveltejs/kit' import { field, schema, validate } from '@holo-js/forms' const uploadAvatar = schema({ @@ -671,7 +671,8 @@ export const actions = { const submission = await validate(request, uploadAvatar) if (!submission.valid) { - return submission.fail() + const failure = submission.fail() + return fail(failure.status, failure) } await media.store(submission.data.avatar) diff --git a/apps/docs/docs/security.md b/apps/docs/docs/security.md index e9a25455..84e322ad 100644 --- a/apps/docs/docs/security.md +++ b/apps/docs/docs/security.md @@ -307,7 +307,9 @@ const field = await csrf.field(request) ### Setting the readable cookie -`useForm(..., { csrf: true })` needs the CSRF cookie to already exist: +`useForm(..., { csrf: true })` needs the CSRF cookie to already exist. In the Next.js, Nuxt, and +SvelteKit auth/framework integrations, the route protection hooks create this cookie before guest pages +render. Use `csrf.cookie(request)` directly only for custom server-rendered HTML outside those helpers: ```ts import { csrf } from '@holo-js/security' diff --git a/bun.lock b/bun.lock index c986fadc..bd2a7886 100644 --- a/bun.lock +++ b/bun.lock @@ -470,7 +470,6 @@ "version": "0.1.4", "dependencies": { "@holo-js/auth-social": "catalog:", - "@holo-js/config": "catalog:", }, "devDependencies": { "@types/node": "catalog:", @@ -726,8 +725,6 @@ "@holo-js/db-mysql": "catalog:", "@holo-js/db-postgres": "catalog:", "@holo-js/db-sqlite": "catalog:", - "@types/better-sqlite3": "catalog:", - "@types/pg": "catalog:", "tsup": "catalog:", "typescript": "catalog:", }, diff --git a/packages/adapter-sveltekit/src/client.ts b/packages/adapter-sveltekit/src/client.ts index 6c3f54dc..779fafbc 100644 --- a/packages/adapter-sveltekit/src/client.ts +++ b/packages/adapter-sveltekit/src/client.ts @@ -7,6 +7,8 @@ import { createFormClient, } from '@holo-js/forms/internal/client' +type InitialFormState = UseFormOptions['initialState'] + export { type ClientSubmitContext, type ClientSubmitResult, @@ -25,6 +27,79 @@ function isPlainObject(value: unknown): value is Record { && !(value instanceof Blob) } +function isSchemaField(value: unknown): boolean { + return isPlainObject(value) + && value.kind === 'field' + && isPlainObject(value.definition) +} + +function collectSchemaPaths(value: unknown, prefix = ''): readonly string[] { + if (isSchemaField(value)) { + return [prefix].filter(Boolean) + } + + if (!isPlainObject(value)) { + return [] + } + + return Object.entries(value).flatMap(([key, nested]) => { + const next = prefix ? `${prefix}.${key}` : key + return collectSchemaPaths(nested, next) + }) +} + +function collectValuePaths(value: unknown, prefix = ''): readonly string[] { + if (!isPlainObject(value)) { + return [prefix].filter(Boolean) + } + + return Object.entries(value).flatMap(([key, nested]) => { + const next = prefix ? `${prefix}.${key}` : key + return collectValuePaths(nested, next) + }) +} + +function isFormState(value: unknown): value is NonNullable> { + return isPlainObject(value) + && typeof value.valid === 'boolean' + && isPlainObject(value.values) + && isPlainObject(value.errors) +} + +function stateMatchesSchema(schemaDefinition: FormSchema, state: NonNullable>): boolean { + const schemaPaths = collectSchemaPaths(schemaDefinition.fields) + const statePaths = [ + ...Object.keys(state.errors), + ...collectValuePaths(state.values), + ] + + return statePaths.every(path => path === '_root' || schemaPaths.includes(path)) +} + +async function hydrateActionFormState( + form: Pick, 'applyServerState'>, + schemaDefinition: FormSchema, +): Promise { + if (typeof (globalThis as { readonly window?: unknown }).window === 'undefined') { + return + } + + const stores = await import('$app/stores') as { + readonly page: { + subscribe(listener: (value: { readonly form: unknown }) => void): () => void + } + } + let unsubscribe = () => {} + unsubscribe = stores.page.subscribe((value) => { + const state = value.form + if (isFormState(state) && stateMatchesSchema(schemaDefinition, state)) { + form.applyServerState(state) + } + + queueMicrotask(unsubscribe) + }) +} + function createReactiveView( target: TValue, subscribe: () => void, @@ -86,8 +161,13 @@ export function useForm( options: UseFormOptions, TSuccess> = {}, ): UseFormResult, TSuccess, InferFormFieldTree> { type TData = InferFormData + const formOptions: UseFormOptions = { + ...options, + initialState: options.initialState ?? undefined, + } - const form = createFormClient(schemaDefinition, options) + const form = createFormClient(schemaDefinition, formOptions) + void hydrateActionFormState(form, schemaDefinition) const subscribe = createSubscriber((update) => form.subscribe(update)) const cache = new WeakMap() diff --git a/packages/adapter-sveltekit/src/sveltekit-app.d.ts b/packages/adapter-sveltekit/src/sveltekit-app.d.ts new file mode 100644 index 00000000..82cb0c82 --- /dev/null +++ b/packages/adapter-sveltekit/src/sveltekit-app.d.ts @@ -0,0 +1,7 @@ +declare module '$app/stores' { + import type { Readable } from 'svelte/store' + + export const page: Readable<{ + readonly form: unknown + }> +} diff --git a/packages/adapter-sveltekit/tests/client.test.ts b/packages/adapter-sveltekit/tests/client.test.ts index 310d7e4f..cf45fd83 100644 --- a/packages/adapter-sveltekit/tests/client.test.ts +++ b/packages/adapter-sveltekit/tests/client.test.ts @@ -1,114 +1,81 @@ import { afterEach, describe, expect, it, vi } from 'vitest' import { field, schema } from '@holo-js/forms' -const subscriberCleanups: Array<() => void> = [] +import { useForm } from '../src/client' +import { setPageForm } from './stubs/app-stores' -function cleanupSubscribers(): void { - for (const cleanup of subscriberCleanups.splice(0)) { - cleanup() - } +async function waitForActionHydration(): Promise { + await new Promise(resolve => setTimeout(resolve, 0)) + await new Promise(resolve => queueMicrotask(() => resolve())) } -vi.mock('svelte/reactivity', () => ({ - createSubscriber(start: (update: () => void) => void | (() => void)) { - let initialized = false - - return () => { - if (!initialized) { - const cleanup = start(() => {}) - if (cleanup) { - subscriberCleanups.push(cleanup) - } - initialized = true - } - } - }, -})) - -describe('@holo-js/adapter-sveltekit client', () => { +describe('@holo-js/adapter-sveltekit client forms', () => { afterEach(() => { - cleanupSubscribers() - vi.resetModules() - vi.clearAllMocks() - vi.doUnmock('svelte') + vi.unstubAllGlobals() + setPageForm(null) }) - it('wraps the shared form client with a Svelte reactive subscriber bridge', async () => { - const { useForm } = await import('../src/client') - const login = schema({ + it('hydrates matching SvelteKit page action failures without userland initialState wiring', async () => { + vi.stubGlobal('window', {}) + const loginForm = schema({ email: field.string().required().email(), + password: field.password().required(), }) - const form = useForm(login, { - initialValues: { - email: 'ava@example.com', + setPageForm({ + ok: false, + status: 422, + valid: false, + values: { + email: 'bad-email', }, - }) - - expect(form.fields.email).toBe(form.fields.email) - expect(form.fields.email.value).toBe('ava@example.com') - form.fields.email.value = 'broken' - await form.fields.email.onInput('ava@example.com') - expect(form.values.email).toBe('ava@example.com') - }) - - it('exposes nested keys that are added after the wrapper is created', async () => { - const { useForm } = await import('../src/client') - const login = schema({ - profile: { - city: field.string().required(), + errors: { + email: ['Enter a valid email address.'], }, }) - const form = useForm(login, { + const login = useForm(loginForm, { initialValues: { - profile: { - city: 'Cairo', - }, + email: '', + password: '', }, }) - void form.values.profile - await form.setValue('profile.country.code', 'EG') + await waitForActionHydration() - expect((form.values.profile as Record).country).toEqual({ - code: 'EG', - }) + expect(login.values.email).toBe('bad-email') + expect(login.errors.first('email')).toBe('Enter a valid email address.') }) - it('returns undefined descriptors for missing proxy keys', async () => { - const { useForm } = await import('../src/client') - const login = schema({ + it('ignores action failures that belong to a different schema', async () => { + vi.stubGlobal('window', {}) + const loginForm = schema({ email: field.string().required().email(), + password: field.password().required(), }) - const form = useForm(login, { - initialValues: { - email: 'ava@example.com', + setPageForm({ + ok: false, + status: 422, + valid: false, + values: { + title: '', + }, + errors: { + title: ['Title is required.'], }, }) - expect(Object.getOwnPropertyDescriptor(form, 'missing')).toBeUndefined() - }) - - it('preserves array and date values as native objects through the proxy', async () => { - const { useForm } = await import('../src/client') - const publishPost = schema({ - publishedAt: field.date().required(), - tags: field.array(field.string().required()).optional(), - }) - - const publishedAt = new Date('2026-04-05T00:00:00.000Z') - const form = useForm(publishPost, { + const login = useForm(loginForm, { initialValues: { - publishedAt, - tags: ['news'], + email: '', + password: '', }, }) - expect(form.values.publishedAt).toBeInstanceOf(Date) - expect(form.values.publishedAt.getTime()).toBe(publishedAt.getTime()) - expect(Array.isArray(form.values.tags)).toBe(true) - expect(form.values.tags).toEqual(['news']) + await waitForActionHydration() + + expect(login.values.email).toBe('') + expect(login.errors.has('title')).toBe(false) }) }) diff --git a/packages/adapter-sveltekit/tests/stubs/app-stores.ts b/packages/adapter-sveltekit/tests/stubs/app-stores.ts new file mode 100644 index 00000000..4a9f3200 --- /dev/null +++ b/packages/adapter-sveltekit/tests/stubs/app-stores.ts @@ -0,0 +1,9 @@ +import { writable } from 'svelte/store' + +export const page = writable<{ readonly form: unknown }>({ + form: null, +}) + +export function setPageForm(form: unknown): void { + page.set({ form }) +} diff --git a/packages/adapter-sveltekit/tsup.config.ts b/packages/adapter-sveltekit/tsup.config.ts index 9103ddf4..1e294439 100644 --- a/packages/adapter-sveltekit/tsup.config.ts +++ b/packages/adapter-sveltekit/tsup.config.ts @@ -13,7 +13,7 @@ export default defineConfig({ clean: true, outDir, outExtension: () => ({ js: '.mjs' }), - external: ['svelte/reactivity'], + external: ['$app/stores', 'svelte/reactivity', 'svelte/store'], esbuildOptions(options) { options.logLevel = 'warning' }, diff --git a/packages/adapter-sveltekit/vitest.config.ts b/packages/adapter-sveltekit/vitest.config.ts index f6bc991e..07fc71fc 100644 --- a/packages/adapter-sveltekit/vitest.config.ts +++ b/packages/adapter-sveltekit/vitest.config.ts @@ -17,6 +17,7 @@ export default defineConfig({ '@holo-js/queue-db': resolve(__dirname, '../queue-db/src/index.ts'), '@holo-js/session': resolve(__dirname, '../session/src/index.ts'), '@holo-js/validation': resolve(__dirname, '../validation/src/index.ts'), + '$app/stores': resolve(__dirname, 'tests/stubs/app-stores.ts'), }, }, test: { diff --git a/packages/auth-social-discord/package.json b/packages/auth-social-discord/package.json index 9b9bfee4..62198eb2 100644 --- a/packages/auth-social-discord/package.json +++ b/packages/auth-social-discord/package.json @@ -20,7 +20,7 @@ "build": "tsup", "stub": "tsup", "typecheck": "tsc -p tsconfig.json --noEmit", - "test": "vitest --run --coverage" + "test": "vitest --run" }, "dependencies": { "@holo-js/auth-social": "catalog:", diff --git a/packages/auth-social-github/package.json b/packages/auth-social-github/package.json index 176234e4..029a25c9 100644 --- a/packages/auth-social-github/package.json +++ b/packages/auth-social-github/package.json @@ -23,8 +23,7 @@ "test": "vitest --run" }, "dependencies": { - "@holo-js/auth-social": "catalog:", - "@holo-js/config": "catalog:" + "@holo-js/auth-social": "catalog:" }, "devDependencies": { "@types/node": "catalog:", diff --git a/packages/auth-social-github/src/index.ts b/packages/auth-social-github/src/index.ts index 71e58d55..3526f61b 100644 --- a/packages/auth-social-github/src/index.ts +++ b/packages/auth-social-github/src/index.ts @@ -7,7 +7,7 @@ import type { } from '@holo-js/auth-social' function applyScopes(url: URL, config: SocialRedirectContext['config']): void { - const scopes = (config.scopes ?? []).length > 0 ? config.scopes ?? [] : ['read:user', 'user:email'] + const scopes = config.scopes?.length ? config.scopes : ['read:user', 'user:email'] url.searchParams.set('scope', scopes.join(' ')) } diff --git a/packages/auth-social-github/tests/package.test.ts b/packages/auth-social-github/tests/package.test.ts index 0caa358b..3ce9fb53 100644 --- a/packages/auth-social-github/tests/package.test.ts +++ b/packages/auth-social-github/tests/package.test.ts @@ -2,6 +2,39 @@ import { afterEach, describe, expect, it, vi } from 'vitest' import githubSocialProvider from '../src' const originalFetch = globalThis.fetch +type GithubRedirectContext = Parameters[0] +type GithubCallbackContext = Parameters[0] + +function createRedirectContext(): GithubRedirectContext { + return { + provider: 'github', + request: new Request('https://app.test/auth/github'), + state: 'state-1', + codeVerifier: 'verifier', + codeChallenge: 'challenge', + config: createProviderConfig(), + } +} + +function createCallbackContext(scopes: readonly string[] = []): GithubCallbackContext { + return { + provider: 'github', + request: new Request('https://app.test/auth/github/callback?code=test'), + code: 'test-code', + codeVerifier: 'verifier', + config: createProviderConfig(scopes), + } +} + +function createProviderConfig(scopes: readonly string[] = []) { + return { + clientId: 'client', + clientSecret: 'secret', + redirectUri: 'https://app.test/auth/github/callback', + scopes: [...scopes], + encryptTokens: false, + } +} afterEach(() => { globalThis.fetch = originalFetch @@ -10,20 +43,7 @@ afterEach(() => { describe('@holo-js/auth-social-github', () => { it('builds the authorization url with GitHub defaults', async () => { - const url = await githubSocialProvider.buildAuthorizationUrl({ - provider: 'github', - request: new Request('https://app.test/auth/github'), - state: 'state-1', - codeVerifier: 'verifier', - codeChallenge: 'challenge', - config: { - clientId: 'client', - clientSecret: 'secret', - redirectUri: 'https://app.test/auth/github/callback', - scopes: [], - encryptTokens: false, - }, - }) + const url = await githubSocialProvider.buildAuthorizationUrl(createRedirectContext()) expect(url).toContain('github.com/login/oauth/authorize') expect(url).toContain('scope=read%3Auser+user%3Aemail') @@ -46,19 +66,7 @@ describe('@holo-js/auth-social-github', () => { { email: 'verified@example.com', verified: true, primary: false }, ]), { status: 200 })) as typeof fetch - const exchanged = await githubSocialProvider.exchangeCode({ - provider: 'github', - request: new Request('https://app.test/auth/github/callback?code=test'), - code: 'test-code', - codeVerifier: 'verifier', - config: { - clientId: 'client', - clientSecret: 'secret', - redirectUri: 'https://app.test/auth/github/callback', - scopes: [], - encryptTokens: false, - }, - }) + const exchanged = await githubSocialProvider.exchangeCode(createCallbackContext()) expect(exchanged.profile).toEqual({ id: '42', @@ -73,56 +81,20 @@ describe('@holo-js/auth-social-github', () => { globalThis.fetch = vi.fn() .mockResolvedValueOnce(new Response('nope', { status: 401 })) as typeof fetch - await expect(githubSocialProvider.exchangeCode({ - provider: 'github', - request: new Request('https://app.test/auth/github/callback?code=test'), - code: 'test-code', - codeVerifier: 'verifier', - config: { - clientId: 'client', - clientSecret: 'secret', - redirectUri: 'https://app.test/auth/github/callback', - scopes: [], - encryptTokens: false, - }, - })).rejects.toThrow('GitHub token exchange failed') + await expect(githubSocialProvider.exchangeCode(createCallbackContext())).rejects.toThrow('GitHub token exchange failed') globalThis.fetch = vi.fn() .mockResolvedValueOnce(new Response(JSON.stringify({ access_token: 'access' }), { status: 200 })) .mockResolvedValueOnce(new Response('nope', { status: 500 })) as typeof fetch - await expect(githubSocialProvider.exchangeCode({ - provider: 'github', - request: new Request('https://app.test/auth/github/callback?code=test'), - code: 'test-code', - codeVerifier: 'verifier', - config: { - clientId: 'client', - clientSecret: 'secret', - redirectUri: 'https://app.test/auth/github/callback', - scopes: [], - encryptTokens: false, - }, - })).rejects.toThrow('GitHub user request failed') + await expect(githubSocialProvider.exchangeCode(createCallbackContext())).rejects.toThrow('GitHub user request failed') globalThis.fetch = vi.fn() .mockResolvedValueOnce(new Response(JSON.stringify({ access_token: 'access' }), { status: 200 })) .mockResolvedValueOnce(new Response(JSON.stringify({ id: 1, login: 'octocat' }), { status: 200 })) .mockResolvedValueOnce(new Response('nope', { status: 500 })) as typeof fetch - await expect(githubSocialProvider.exchangeCode({ - provider: 'github', - request: new Request('https://app.test/auth/github/callback?code=test'), - code: 'test-code', - codeVerifier: 'verifier', - config: { - clientId: 'client', - clientSecret: 'secret', - redirectUri: 'https://app.test/auth/github/callback', - scopes: [], - encryptTokens: false, - }, - })).rejects.toThrow('GitHub email request failed') + await expect(githubSocialProvider.exchangeCode(createCallbackContext())).rejects.toThrow('GitHub email request failed') }) it('fails when the GitHub profile does not include an id', async () => { @@ -131,18 +103,6 @@ describe('@holo-js/auth-social-github', () => { .mockResolvedValueOnce(new Response(JSON.stringify({ login: 'octocat' }), { status: 200 })) .mockResolvedValueOnce(new Response(JSON.stringify([]), { status: 200 })) as typeof fetch - await expect(githubSocialProvider.exchangeCode({ - provider: 'github', - request: new Request('https://app.test/auth/github/callback?code=test'), - code: 'test-code', - codeVerifier: 'verifier', - config: { - clientId: 'client', - clientSecret: 'secret', - redirectUri: 'https://app.test/auth/github/callback', - scopes: ['read:user'], - encryptTokens: false, - }, - })).rejects.toThrow('did not include "id"') + await expect(githubSocialProvider.exchangeCode(createCallbackContext(['read:user']))).rejects.toThrow('did not include "id"') }) }) diff --git a/packages/auth/src/next-server-shim.d.ts b/packages/auth/src/next-server-shim.d.ts index 4f8abef5..061cd554 100644 --- a/packages/auth/src/next-server-shim.d.ts +++ b/packages/auth/src/next-server-shim.d.ts @@ -3,6 +3,7 @@ declare module 'next/server' { readonly path?: string readonly secure?: boolean readonly sameSite?: 'lax' | 'strict' | 'none' + readonly httpOnly?: boolean } type NextResponseWithCookies = Response & { diff --git a/packages/auth/src/next/client.ts b/packages/auth/src/next/client.ts index 97a04701..e39cb75f 100644 --- a/packages/auth/src/next/client.ts +++ b/packages/auth/src/next/client.ts @@ -1,5 +1,6 @@ 'use client' +import { usePathname } from 'next/navigation' import { createContext, createElement, useCallback, useContext, useEffect, useRef, useState, type ReactNode } from 'react' import { authClientInternals } from '../client' import type { AuthClientRequestOptions, HoloAuthUser } from '../contracts' @@ -33,12 +34,14 @@ function hasExplicitUseAuthOptions(options: UseAuthOptions | undefined): options function useAuthState( options: UseAuthOptions = {}, - stateOptions: { readonly refreshOnMount?: boolean } = {}, + stateOptions: { readonly refreshOnMount?: boolean, readonly refreshOnRouteChange?: boolean } = {}, ): UseAuthResult { const { initialProvider, initialUser, ...requestOptions } = options + const pathname = usePathname() const [currentProvider, setCurrentProvider] = useState(initialProvider ?? null) const [currentUser, setCurrentUser] = useState(initialUser ?? null) const requestOptionsRef = useRef(requestOptions) + const observedInitialPathname = useRef(false) requestOptionsRef.current = requestOptions @@ -62,6 +65,29 @@ function useAuthState( } }, [initialUser, refreshUser, stateOptions.refreshOnMount]) + useEffect(() => { + if (stateOptions.refreshOnRouteChange === false) { + return + } + + if (!observedInitialPathname.current) { + observedInitialPathname.current = true + return + } + + void refreshUser() + }, [pathname, refreshUser, stateOptions.refreshOnRouteChange]) + + useEffect(() => { + if (typeof initialProvider !== 'undefined') { + setCurrentProvider(initialProvider) + } + + if (typeof initialUser !== 'undefined') { + setCurrentUser(initialUser) + } + }, [initialProvider, initialUser]) + return { authenticated: currentUser !== null, provider: currentProvider, @@ -81,6 +107,7 @@ export function useAuth(options?: UseAuthOptions): UseAuthResult { const hasOptions = hasExplicitUseAuthOptions(options) const localAuth = useAuthState(options, { refreshOnMount: hasOptions || !context, + refreshOnRouteChange: hasOptions || !context, }) if (!hasOptions && context) { diff --git a/packages/auth/src/next/server.ts b/packages/auth/src/next/server.ts index c938e95b..de488fb4 100644 --- a/packages/auth/src/next/server.ts +++ b/packages/auth/src/next/server.ts @@ -15,6 +15,7 @@ type NextResponseCookieOptions = { readonly path?: string readonly secure?: boolean readonly sameSite?: 'lax' | 'strict' | 'none' + readonly httpOnly?: boolean } type NextResponseWithCookies = Response & { @@ -87,7 +88,7 @@ async function createCsrfCookieResponse(request: NextRouteProtectionRequest): Pr const { NextResponse } = await import('next/server') as NextServerModule const response = NextResponse.next() - response.cookies.set(defaultCsrfCookieName, token, resolveCsrfCookieOptions(request.url)) + response.cookies.set(defaultCsrfCookieName, token, resolveCsrfCookieOptions(request)) return response } diff --git a/packages/auth/src/runtime/csrfCookie.ts b/packages/auth/src/runtime/csrfCookie.ts index d8dfd0b3..df738b21 100644 --- a/packages/auth/src/runtime/csrfCookie.ts +++ b/packages/auth/src/runtime/csrfCookie.ts @@ -4,8 +4,16 @@ export type CsrfCookieOptions = { readonly path: '/' readonly sameSite: 'lax' readonly secure: boolean + readonly httpOnly: false } +type CsrfCookieRequest = { + readonly url: string | URL + readonly headers: Headers +} + +type CsrfCookieTarget = string | URL | CsrfCookieRequest + type WebCryptoSubtle = { importKey( format: 'raw', @@ -88,12 +96,49 @@ export function isCsrfCookieRequest(method: string | undefined): boolean { return normalized === 'GET' || normalized === 'HEAD' } -export function resolveCsrfCookieOptions(url: string | URL): CsrfCookieOptions { - const requestUrl = typeof url === 'string' ? new URL(url) : url +function normalizeForwardedValue(value: string): string { + return value.trim().replace(/^"|"$/g, '').toLowerCase() +} + +function getForwardedProto(headers: Headers): string | undefined { + const forwardedProto = headers.get('x-forwarded-proto')?.split(',', 1)[0]?.trim() + if (forwardedProto) { + return normalizeForwardedValue(forwardedProto) + } + + const forwarded = headers.get('forwarded')?.split(',', 1)[0] + if (!forwarded) { + return undefined + } + + for (const segment of forwarded.split(';')) { + const [name, value] = segment.split('=', 2) + if (name?.trim().toLowerCase() === 'proto' && value) { + return normalizeForwardedValue(value) + } + } + + return undefined +} + +function isCsrfCookieRequestTarget(target: CsrfCookieTarget): target is CsrfCookieRequest { + return typeof target === 'object' + && !(target instanceof URL) + && target.headers instanceof Headers +} + +export function resolveCsrfCookieOptions(target: CsrfCookieTarget): CsrfCookieOptions { + const requestUrl = isCsrfCookieRequestTarget(target) + ? typeof target.url === 'string' ? new URL(target.url) : target.url + : typeof target === 'string' ? new URL(target) : target + const forwardedProto = isCsrfCookieRequestTarget(target) + ? getForwardedProto(target.headers) + : undefined return { path: '/', sameSite: 'lax', - secure: requestUrl.protocol === 'https:', + secure: forwardedProto === 'https' || requestUrl.protocol === 'https:', + httpOnly: false, } } diff --git a/packages/auth/src/sveltekit/server.ts b/packages/auth/src/sveltekit/server.ts index 13589c2f..c928add9 100644 --- a/packages/auth/src/sveltekit/server.ts +++ b/packages/auth/src/sveltekit/server.ts @@ -52,10 +52,7 @@ type SvelteKitStoredRequestEvent = SvelteKitHandleEvent & { get(name: string): string | undefined set(name: string, value: string, options: SvelteKitCookieOptions): void } - readonly request: { - readonly headers: Headers - readonly method?: string - } + readonly request: Request } type SvelteKitResolveOptions = { @@ -183,7 +180,17 @@ async function ensureCsrfCookie(event: SvelteKitHandleEvent): Promise { return } - event.cookies.set(defaultCsrfCookieName, token, resolveCsrfCookieOptions(event.url)) + try { + const { csrfInternals } = await import('@holo-js/security') + csrfInternals.generatedTokenCache.set(event.request, token) + } catch { + // @holo-js/security is optional for auth-only installs. + } + + event.cookies.set(defaultCsrfCookieName, token, resolveCsrfCookieOptions({ + url: event.url, + headers: event.request.headers, + })) } export async function auth(options: AuthOptions = {}): Promise { diff --git a/packages/auth/tests/framework.test.ts b/packages/auth/tests/framework.test.ts index a2ed1abf..58f48b50 100644 --- a/packages/auth/tests/framework.test.ts +++ b/packages/auth/tests/framework.test.ts @@ -86,6 +86,7 @@ describe('@holo-js/auth framework helpers', () => { vi.resetModules() vi.clearAllMocks() vi.doUnmock('#imports') + vi.doUnmock('next/navigation') vi.doUnmock('react') vi.doUnmock('../src/client') vi.doUnmock('../src/index') @@ -96,6 +97,9 @@ describe('@holo-js/auth framework helpers', () => { vi.doMock('../src/client', () => ({ refreshUser, })) + vi.doMock('next/navigation', () => ({ + usePathname: () => '/admin', + })) vi.doMock('react', () => createReactMock()) const { AuthProvider, useAuth } = await import('../src/next/client') @@ -116,6 +120,116 @@ describe('@holo-js/auth framework helpers', () => { expect(refreshUser).not.toHaveBeenCalled() }) + it('refreshes the Next auth provider after client route changes', async () => { + let currentPathname = '/login' + let stateCursor = 0 + let refCursor = 0 + const states: unknown[] = [] + const refs: { current: unknown }[] = [] + const fetchCurrentUser = vi.fn(async () => ({ + authenticated: true, + guard: 'web', + provider: 'users', + user: { + id: 1, + email: 'ava@example.com', + name: 'Ava', + }, + })) + + function resetRenderCursors() { + stateCursor = 0 + refCursor = 0 + } + + vi.doMock('../src/client', () => ({ + authClientInternals: { + fetchCurrentUser, + }, + })) + vi.doMock('next/navigation', () => ({ + usePathname: () => currentPathname, + })) + vi.doMock('react', () => ({ + createContext(defaultValue: TValue): MockReactContext { + const context: MockReactContext = { + currentRenderValue: defaultValue, + Provider({ value, children }) { + context.currentRenderValue = value + return children + }, + } + + return context + }, + createElement(type: unknown, props: Record | null, ...children: readonly unknown[]): unknown { + if (typeof type === 'function') { + return (type as (props: Record) => unknown)({ + ...(props ?? {}), + children: children.length === 1 ? children[0] : children, + }) + } + + return { type, props, children } + }, + useCallback unknown>(callback: TCallback) { + return callback + }, + useContext(context: MockReactContext): TValue { + return context.currentRenderValue + }, + useEffect(effect: () => void | (() => void)) { + return effect() + }, + useRef(initialValue?: TValue) { + const index = refCursor + refCursor += 1 + refs[index] ??= { current: initialValue } + return refs[index] as { current: TValue | undefined } + }, + useState(initialState: TValue | (() => TValue)) { + const index = stateCursor + stateCursor += 1 + if (!(index in states)) { + states[index] = typeof initialState === 'function' + ? (initialState as () => TValue)() + : initialState + } + + return [ + states[index] as TValue, + (value: TValue | ((previous: TValue) => TValue)) => { + states[index] = typeof value === 'function' + ? (value as (previous: TValue) => TValue)(states[index] as TValue) + : value + }, + ] as const + }, + })) + + const { AuthProvider } = await import('../src/next/client') + + resetRenderCursors() + AuthProvider({ + initialProvider: null, + initialUser: null, + children: null, + }) + expect(fetchCurrentUser).not.toHaveBeenCalled() + + currentPathname = '/admin' + resetRenderCursors() + AuthProvider({ + initialProvider: null, + initialUser: null, + children: null, + }) + + expect(fetchCurrentUser).toHaveBeenCalledWith({}, { + force: true, + }) + }) + it('does not reuse the SvelteKit auth context when explicit request options are passed', async () => { type SvelteContextValue = unknown @@ -379,12 +493,13 @@ describe('@holo-js/auth framework helpers', () => { return Object.assign(response, { cookies: { - set(name: string, value: string, options: { readonly path?: string, readonly sameSite?: string, readonly secure?: boolean }) { + set(name: string, value: string, options: { readonly path?: string, readonly sameSite?: string, readonly secure?: boolean, readonly httpOnly?: boolean }) { headers.append('set-cookie', [ `${name}=${encodeURIComponent(value)}`, options.path ? `Path=${options.path}` : undefined, options.sameSite ? `SameSite=${options.sameSite[0]?.toUpperCase()}${options.sameSite.slice(1)}` : undefined, options.secure ? 'Secure' : undefined, + options.httpOnly ? 'HttpOnly' : undefined, ].filter((attribute): attribute is string => typeof attribute === 'string').join('; ')) }, }, @@ -399,9 +514,11 @@ describe('@holo-js/auth framework helpers', () => { cookies: { get: vi.fn(() => undefined), }, - headers: new Headers(), - nextUrl: new URL('https://app.test/login'), - url: 'https://app.test/login', + headers: new Headers({ + 'x-forwarded-proto': 'https', + }), + nextUrl: new URL('http://app.test/login'), + url: 'http://app.test/login', } const response = await protectRoutes(async () => undefined)(request) const setCookie = response?.headers.get('set-cookie') ?? '' @@ -416,6 +533,7 @@ describe('@holo-js/auth framework helpers', () => { expect(setCookie).toContain('Path=/') expect(setCookie).toContain('SameSite=Lax') expect(setCookie).toContain('Secure') + expect(setCookie).not.toContain('HttpOnly') expect(separator).toBeGreaterThan(0) expect(signature).toBe(createHmac('sha256', 'next-csrf-signing-key') .update(nonce) @@ -686,14 +804,16 @@ describe('@holo-js/auth framework helpers', () => { redirectTo: '/admin', })({ event: { - url: new URL('https://app.test/login'), + url: new URL('http://app.test/login'), cookies: { get: vi.fn(() => undefined), set: setCookie, }, request: { method: 'GET', - headers: new Headers(), + headers: new Headers({ + 'x-forwarded-proto': 'https', + }), }, }, resolve, @@ -708,6 +828,7 @@ describe('@holo-js/auth framework helpers', () => { path: '/', sameSite: 'lax', secure: true, + httpOnly: false, }) expect(signature).toBe(createHmac('sha256', 'sveltekit-csrf-signing-key') .update(nonce) @@ -850,6 +971,7 @@ describe('@holo-js/auth framework helpers', () => { path: '/', sameSite: 'lax', secure: true, + httpOnly: false, }) expect(signature).toBe(createHmac('sha256', 'nuxt-csrf-signing-key') .update(nonce) diff --git a/packages/auth/tsup.config.ts b/packages/auth/tsup.config.ts index 3e17ad34..ce2d0a21 100644 --- a/packages/auth/tsup.config.ts +++ b/packages/auth/tsup.config.ts @@ -18,7 +18,7 @@ export default defineConfig({ format: ['esm'], dts: true, clean: true, - external: ['#imports', 'next/server', 'react', 'svelte', 'svelte/reactivity'], + external: ['#imports', 'next/navigation', 'next/server', 'react', 'svelte', 'svelte/reactivity'], outDir, outExtension: () => ({ js: '.mjs' }), async onSuccess() { diff --git a/packages/cache-db/tests/package.test.ts b/packages/cache-db/tests/package.test.ts index 3b079c6c..f56cf012 100644 --- a/packages/cache-db/tests/package.test.ts +++ b/packages/cache-db/tests/package.test.ts @@ -93,18 +93,7 @@ async function createPublicFeatureHarness() { const schema = createSchemaService(DB.connection()) await schema.sync([users]) - await schema.createTable(DEFAULT_CACHE_DATABASE_TABLE, (table) => { - table.string('key').primaryKey() - table.text('payload') - table.bigInteger('expires_at').nullable() - table.index(['expires_at'], `${DEFAULT_CACHE_DATABASE_TABLE}_expires_at_index`) - }) - await schema.createTable(DEFAULT_CACHE_DATABASE_LOCK_TABLE, (table) => { - table.string('name').primaryKey() - table.string('owner') - table.bigInteger('expires_at') - table.index(['expires_at'], `${DEFAULT_CACHE_DATABASE_LOCK_TABLE}_expires_at_index`) - }) + await cacheDbInternals.prepareCacheDatabaseTables(DB.connection()) const driver = createDatabaseCacheDriver({ name: 'database', @@ -198,77 +187,6 @@ describe('@holo-js/cache-db', () => { connectionName: 'cache', driver: 'postgres', }) - expect(cacheDbInternals.createDatabaseContextOptions('cache', { - driver: 'postgres', - url: 'postgres://cache.internal/db', - port: '5432' as never, - })).toMatchObject({ - connectionName: 'cache', - driver: 'postgres', - }) - const stringPortOptions = cacheDbInternals.createDatabaseContextOptions('cache', { - driver: 'postgres', - host: 'cache.internal', - database: 'app', - username: 'user', - password: 'secret', - port: '5433' as never, - }) - const adapter = stringPortOptions.adapter as { - readonly options?: { - readonly config?: { - readonly host?: string - readonly port?: number - readonly user?: string - readonly password?: string - readonly database?: string - } - } - } - - expect(adapter.options?.config).toMatchObject({ - host: 'cache.internal', - port: 5433, - user: 'user', - password: 'secret', - database: 'app', - }) - - const invalidStringPortOptions = cacheDbInternals.createDatabaseContextOptions('cache', { - driver: 'postgres', - host: 'cache.internal', - database: 'app', - username: 'user', - password: 'secret', - port: 'not-a-port' as never, - }) - const invalidStringAdapter = invalidStringPortOptions.adapter as { - readonly options?: { - readonly config?: { - readonly port?: number - } - } - } - - expect(invalidStringAdapter.options?.config?.port).toBeUndefined() - - const partiallyNumericPortOptions = cacheDbInternals.createDatabaseContextOptions('cache', { - driver: 'postgres', - host: 'cache.internal', - database: 'app', - username: 'user', - password: 'secret', - port: '5433abc' as never, - }) - const partiallyNumericAdapter = partiallyNumericPortOptions.adapter as { - readonly options?: { - readonly config?: { - readonly port?: number - } - } - } - - expect(partiallyNumericAdapter.options?.config?.port).toBeUndefined() }) it('creates cache tables through the shared schema helper', async () => { diff --git a/packages/cache-redis/package.json b/packages/cache-redis/package.json index ba858c97..92ce7f34 100644 --- a/packages/cache-redis/package.json +++ b/packages/cache-redis/package.json @@ -21,7 +21,7 @@ "stub": "tsup", "typecheck": "tsc -p tsconfig.json --noEmit", "test": "vitest --run", - "test:integration": "HOLO_REDIS_INTEGRATION=1 vitest --run tests/real-redis.test.ts" + "test:integration": "HOLO_REDIS_INTEGRATION=1 vitest --run" }, "peerDependencies": { "@holo-js/cache": "catalog:" diff --git a/packages/cache-redis/tests/package.test.ts b/packages/cache-redis/tests/package.test.ts index 76b2ce0a..394f26c1 100644 --- a/packages/cache-redis/tests/package.test.ts +++ b/packages/cache-redis/tests/package.test.ts @@ -1,225 +1,72 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' const redisMock = vi.hoisted(() => { - const state = new Map() - const lockOwners = new Map() - const calls = { - constructorArgs: [] as unknown[][], - del: [] as string[][], - disconnect: [] as true[], - eval: [] as Array<[string, number, ...string[]]>, - get: [] as string[], - incrby: [] as Array<[string, number]>, - scan: [] as Array<[string, string, string, string, number]>, - set: [] as Array<[string, string, ...(string | number)[]]>, - } - - function resolveExpiresAt(arguments_: readonly (string | number)[], now: number): number | undefined { - const pxAtIndex = arguments_.findIndex(argument_ => argument_ === 'PXAT') - if (pxAtIndex >= 0) { - const expiresAt = arguments_[pxAtIndex + 1] - return typeof expiresAt === 'number' ? expiresAt : undefined - } - - const pxIndex = arguments_.findIndex(argument_ => argument_ === 'PX') - if (pxIndex >= 0) { - const ttlMilliseconds = arguments_[pxIndex + 1] - return typeof ttlMilliseconds === 'number' ? now + ttlMilliseconds : undefined - } - - return undefined - } - - function hasNx(arguments_: readonly (string | number)[]): boolean { - return arguments_.includes('NX') - } - - function isExpired(key: string, now: number): boolean { - const entry = state.get(key) - if (!entry || typeof entry.expiresAt === 'undefined' || entry.expiresAt > now) { - return false - } - - state.delete(key) - return true - } - - class FakeRedis { - static Cluster = class FakeRedisCluster { - readonly isCluster = true - - constructor(...args: unknown[]) { - calls.constructorArgs.push(args) - } - - async get(key: string): Promise { - return new FakeRedis().get(key) - } - - async set(key: string, value: string, ...arguments_: readonly (string | number)[]): Promise<'OK' | null> { - return new FakeRedis().set(key, value, ...arguments_) - } - - async del(...keys: string[]): Promise { - return new FakeRedis().del(...keys) - } - - async scan( - cursor: string, - matchLabel: string, - pattern: string, - countLabel: string, - count: number, - ): Promise<[string, string[]]> { - return new FakeRedis().scan(cursor, matchLabel, pattern, countLabel, count) - } - - async incrby(key: string, amount: number): Promise { - return new FakeRedis().incrby(key, amount) - } - - async decrby(key: string, amount: number): Promise { - return new FakeRedis().decrby(key, amount) - } - - async eval(script: string, numberOfKeys: number, ...arguments_: readonly string[]): Promise { - return new FakeRedis().eval(script, numberOfKeys, ...arguments_) - } - - disconnect(): void { - calls.disconnect.push(true) - } - - nodes(): readonly FakeRedis[] { - return [new FakeRedis(), new FakeRedis()] - } - } - - constructor(...args: unknown[]) { - calls.constructorArgs.push(args) - } - - disconnect(): void { - calls.disconnect.push(true) - } - - async get(key: string): Promise { - calls.get.push(key) - if (isExpired(key, Date.now())) { - return null - } - - return state.get(key)?.value ?? null - } - - async set(key: string, value: string, ...arguments_: readonly (string | number)[]): Promise<'OK' | null> { - calls.set.push([key, value, ...arguments_]) - if (isExpired(key, Date.now())) { - state.delete(key) - } - - if (hasNx(arguments_) && state.has(key)) { - return null - } - - state.set(key, { - value, - expiresAt: resolveExpiresAt(arguments_, Date.now()), - }) - if (key.includes(':lock:')) { - lockOwners.set(key, value) - } - - return 'OK' - } - - async del(...keys: string[]): Promise { - calls.del.push(keys) - let deleted = 0 - for (const key of keys) { - if (state.delete(key)) { - deleted += 1 - lockOwners.delete(key) - } - } - - return deleted - } - - async scan( + class MockRedisClient { + readonly get = vi.fn<(key: string) => Promise>(async () => null) + readonly set = vi.fn<(key: string, value: string, ...arguments_: readonly (string | number)[]) => Promise<'OK' | null>>(async () => 'OK') + readonly del = vi.fn<(...keys: string[]) => Promise>(async () => 0) + readonly scan = vi.fn<( cursor: string, matchLabel: string, pattern: string, countLabel: string, count: number, - ): Promise<[string, string[]]> { - calls.scan.push([cursor, matchLabel, pattern, countLabel, count]) - const regex = new RegExp(`^${pattern.replace(/\\\*/g, '\\*').replace(/\*/g, '.*')}$`) - const keys = [...state.keys()].filter(key => regex.test(key)) - return ['0', keys] - } + ) => Promise<[string, string[]]>>(async () => ['0', []]) + readonly incrby = vi.fn<(key: string, amount: number) => Promise>(async () => 0) + readonly decrby = vi.fn<(key: string, amount: number) => Promise>(async () => 0) + readonly eval = vi.fn<(script: string, numberOfKeys: number, ...arguments_: readonly string[]) => Promise>(async () => 0) + readonly disconnect = vi.fn<() => void>() + nodes?: ReturnType readonly MockRedisClient[]>> + } - async incrby(key: string, amount: number): Promise { - calls.incrby.push([key, amount]) - if (key.includes('timeout')) { - throw new Error('ETIMEDOUT') - } + const constructorArgs: unknown[][] = [] + const standaloneClients: MockRedisClient[] = [] + const clusterClients: MockRedisClient[] = [] + let clusterNodes: readonly MockRedisClient[] = [] + let exposeClusterNodes = true - if (key.includes('wrongtype')) { - throw new Error('WRONGTYPE Operation against a key holding the wrong kind of value') - } + class FakeRedis extends MockRedisClient { + static Cluster = class FakeRedisCluster extends MockRedisClient { + readonly isCluster = true - const current = await this.get(key) - const currentNumber = current === null ? 0 : Number(current) - if (!Number.isInteger(currentNumber)) { - throw new Error('ERR value is not an integer or out of range') + constructor(...args: unknown[]) { + super() + constructorArgs.push(args) + clusterClients.push(this) + if (exposeClusterNodes) { + this.nodes = vi.fn<(role: 'master') => readonly MockRedisClient[]>(() => clusterNodes) + } } - - const nextValue = currentNumber + amount - state.set(key, { - value: String(nextValue), - expiresAt: state.get(key)?.expiresAt, - }) - return nextValue - } - - async decrby(key: string, amount: number): Promise { - return this.incrby(key, -amount) } - async eval(script: string, numberOfKeys: number, ...arguments_: readonly string[]): Promise { - calls.eval.push([script, numberOfKeys, ...arguments_]) - const [key, owner] = arguments_ - if (typeof key !== 'string' || typeof owner !== 'string') { - return 0 - } - - if (lockOwners.get(key) !== owner) { - return 0 - } - - lockOwners.delete(key) - return state.delete(key) ? 1 : 0 + constructor(...args: unknown[]) { + super() + constructorArgs.push(args) + standaloneClients.push(this) } } return { - calls, FakeRedis, - lockOwners, + clusterClients, + constructorArgs, + createClient() { + return new MockRedisClient() + }, + disableClusterNodes() { + exposeClusterNodes = false + }, reset() { - state.clear() - lockOwners.clear() - calls.constructorArgs.length = 0 - calls.del.length = 0 - calls.disconnect.length = 0 - calls.eval.length = 0 - calls.get.length = 0 - calls.incrby.length = 0 - calls.scan.length = 0 - calls.set.length = 0 + constructorArgs.length = 0 + standaloneClients.length = 0 + clusterClients.length = 0 + clusterNodes = [] + exposeClusterNodes = true + }, + setClusterNodes(nodes: readonly MockRedisClient[]) { + clusterNodes = nodes }, - state, + standaloneClients, } }) @@ -232,13 +79,59 @@ import { createRedisCacheDriver, redisCacheDriverInternals } from '../src/index' const cacheDriverDisposeSymbol = Symbol.for('holo.cache.driver.dispose') +type RedisLockClient = Parameters[0] + +function lastStandaloneClient(): (typeof redisMock.standaloneClients)[number] { + const client = redisMock.standaloneClients.at(-1) + if (!client) { + throw new Error('Expected a standalone Redis client to be created.') + } + + return client +} + +function lastClusterClient(): (typeof redisMock.clusterClients)[number] { + const client = redisMock.clusterClients.at(-1) + if (!client) { + throw new Error('Expected a Redis cluster client to be created.') + } + + return client +} + +function createLockClient(): RedisLockClient { + return { + async get() { + return null + }, + async set() { + return 'OK' + }, + async del() { + return 0 + }, + async scan() { + return ['0', []] + }, + async incrby() { + return 0 + }, + async decrby() { + return 0 + }, + async eval() { + return 0 + }, + } +} + describe('@holo-js/cache-redis', () => { beforeEach(() => { redisMock.reset() vi.useRealTimers() }) - it('reads, writes, adds, forgets, and flushes within the configured prefix scope', async () => { + it('maps cache operations to the configured redis prefix scope', async () => { const driver = createRedisCacheDriver({ name: 'redis', connectionName: 'cache', @@ -249,6 +142,19 @@ describe('@holo-js/cache-redis', () => { db: 0, }, }) + const client = lastStandaloneClient() + + client.get.mockResolvedValueOnce('"one"') + client.set + .mockResolvedValueOnce('OK') + .mockResolvedValueOnce(null) + .mockResolvedValueOnce('OK') + .mockResolvedValueOnce('OK') + client.del + .mockResolvedValueOnce(1) + .mockResolvedValueOnce(0) + .mockResolvedValueOnce(1) + client.scan.mockResolvedValueOnce(['0', ['holo:cache:alpha']]) expect(await driver.put({ key: 'holo:cache:alpha', @@ -278,14 +184,14 @@ describe('@holo-js/cache-redis', () => { }) await driver.flush() - expect(await driver.get('holo:cache:alpha')).toEqual({ hit: false }) - expect(await driver.get('other:gamma')).toEqual({ - hit: true, - payload: '"outside"', - }) - expect(redisMock.calls.scan).toEqual([ - ['0', 'MATCH', 'holo:cache:*', 'COUNT', 100], + expect(client.set.mock.calls).toEqual([ + ['holo:cache:alpha', '"one"', 'PXAT', expect.any(Number)], + ['holo:cache:alpha', '"two"', 'PXAT', expect.any(Number), 'NX'], + ['holo:cache:beta', '"two"', 'PXAT', expect.any(Number), 'NX'], + ['other:gamma', '"outside"'], ]) + expect(client.scan).toHaveBeenCalledWith('0', 'MATCH', 'holo:cache:*', 'COUNT', 100) + expect(client.del).toHaveBeenLastCalledWith('holo:cache:alpha') }) it('disconnects its redis client through the runtime lifecycle hook', () => { @@ -299,6 +205,7 @@ describe('@holo-js/cache-redis', () => { db: 0, }, }) as ReturnType & Record void> + const client = lastStandaloneClient() const dispose = driver[cacheDriverDisposeSymbol] if (!dispose) { @@ -307,10 +214,10 @@ describe('@holo-js/cache-redis', () => { dispose() - expect(redisMock.calls.disconnect).toEqual([true]) + expect(client.disconnect).toHaveBeenCalledOnce() }) - it('supports expiration and immediate-expiry writes', async () => { + it('passes expiration options to redis and deletes immediate-expiry writes', async () => { vi.useFakeTimers() vi.setSystemTime(new Date('2026-04-22T00:00:00.000Z')) @@ -324,48 +231,40 @@ describe('@holo-js/cache-redis', () => { db: 0, }, }) + const client = lastStandaloneClient() + + client.set + .mockResolvedValueOnce('OK') + .mockResolvedValueOnce(null) + .mockResolvedValueOnce('OK') await driver.put({ key: 'holo:cache:ttl', payload: '"ok"', expiresAt: Date.now() + 1_000, }) - expect(await driver.get('holo:cache:ttl')).toEqual({ - hit: true, - payload: '"ok"', - }) - - vi.advanceTimersByTime(1_001) - expect(await driver.get('holo:cache:ttl')).toEqual({ hit: false }) - await driver.put({ key: 'holo:cache:expired', payload: '"gone"', expiresAt: Date.now() - 1, }) - expect(await driver.get('holo:cache:expired')).toEqual({ hit: false }) - + expect(await driver.add({ + key: 'holo:cache:live-add', + payload: '"expired-replacement"', + expiresAt: Date.now() - 1, + })).toBe(false) expect(await driver.add({ key: 'holo:cache:stale-add', payload: '"gone"', expiresAt: Date.now() - 1, })).toBe(true) - expect(await driver.get('holo:cache:stale-add')).toEqual({ hit: false }) - await driver.put({ - key: 'holo:cache:live-add', - payload: '"original"', - expiresAt: Date.now() + 60_000, - }) - expect(await driver.add({ - key: 'holo:cache:live-add', - payload: '"expired-replacement"', - expiresAt: Date.now() - 1, - })).toBe(false) - expect(await driver.get('holo:cache:live-add')).toEqual({ - hit: true, - payload: '"original"', - }) + expect(client.set.mock.calls).toEqual([ + ['holo:cache:ttl', '"ok"', 'PXAT', Date.now() + 1_000], + ['holo:cache:live-add', '"expired-replacement"', 'PXAT', Date.now() - 1, 'NX'], + ['holo:cache:stale-add', '"gone"', 'PXAT', Date.now() - 1, 'NX'], + ]) + expect(client.del).toHaveBeenCalledWith('holo:cache:expired') }) it('supports numeric mutation and rejects non-numeric values', async () => { @@ -379,86 +278,134 @@ describe('@holo-js/cache-redis', () => { db: 0, }, }) + const client = lastStandaloneClient() + + client.incrby + .mockResolvedValueOnce(2) + .mockRejectedValueOnce(new Error('ERR value is not an integer or out of range')) + client.decrby + .mockResolvedValueOnce(1) + .mockRejectedValueOnce(new Error('WRONGTYPE Operation against a key holding the wrong kind of value')) expect(await driver.increment('holo:cache:counter', 2)).toBe(2) expect(await driver.decrement('holo:cache:counter', 1)).toBe(1) - await driver.put({ - key: 'holo:cache:label', - payload: '"text"', - }) await expect(driver.increment('holo:cache:label', 1)).rejects.toThrow(CacheInvalidNumericMutationError) await expect(driver.decrement('holo:cache:label', 1)).rejects.toThrow(CacheInvalidNumericMutationError) }) - it('implements redis-backed locks with owner-safe release and blocking', async () => { - vi.useFakeTimers() - vi.setSystemTime(new Date('2026-04-22T00:00:00.000Z')) - - const driver = createRedisCacheDriver({ - name: 'redis', - connectionName: 'cache', - prefix: 'holo:cache:', - redis: { - host: '127.0.0.1', - port: 6379, - db: 0, - }, - sleep: async (milliseconds) => { - vi.advanceTimersByTime(milliseconds) - }, - ownerFactory: (() => { - let counter = 0 - return () => `owner-${++counter}` - })(), - }) - - const firstLock = driver.lock('holo:cache:lock:report', 1) - const secondLock = driver.lock('holo:cache:lock:report', 1) + it('implements redis-backed locks with owner-safe release', async () => { + const client = createLockClient() + const set = vi.spyOn(client, 'set') + const evaluate = vi.spyOn(client, 'eval') + let counter = 0 + + set + .mockResolvedValueOnce('OK') + .mockResolvedValueOnce(null) + .mockResolvedValueOnce('OK') + evaluate + .mockResolvedValueOnce(0) + .mockResolvedValueOnce(1) + .mockResolvedValueOnce(1) + + const firstLock = redisCacheDriverInternals.createRedisLock( + client, + 'holo:cache:lock:report', + 1, + () => `owner-${++counter}`, + async () => {}, + Date.now, + ) + const secondLock = redisCacheDriverInternals.createRedisLock( + client, + 'holo:cache:lock:report', + 1, + () => `owner-${++counter}`, + async () => {}, + Date.now, + ) expect(await firstLock.get()).toBe(true) expect(await secondLock.get()).toBe(false) expect(await secondLock.release()).toBe(false) expect(await firstLock.release()).toBe(true) expect(await secondLock.get(async () => 'after-release')).toBe('after-release') + expect(evaluate).toHaveBeenCalledTimes(3) + }) - const blockingLock = driver.lock('holo:cache:lock:wait', 0.02) - expect(await blockingLock.get()).toBe(true) + it('uses injected sleep and clocks for blocking lock deadlines', async () => { + let currentTime = 0 + const client = createLockClient() + const set = vi.spyOn(client, 'set') + const sleepCalls: number[] = [] - const waited = driver.lock('holo:cache:lock:wait', 0.02).block(0.05, async () => 'after-wait') - await expect(waited).resolves.toBe('after-wait') + set + .mockResolvedValueOnce('OK') + .mockResolvedValueOnce(null) + .mockResolvedValueOnce('OK') + .mockResolvedValueOnce('OK') + .mockResolvedValueOnce(null) + .mockResolvedValueOnce(null) - const heldLock = driver.lock('holo:cache:lock:timeout', 1) + const heldLock = redisCacheDriverInternals.createRedisLock( + client, + 'holo:cache:lock:wait', + 0.02, + () => 'owner-1', + async (milliseconds) => { + sleepCalls.push(milliseconds) + currentTime += milliseconds + }, + () => currentTime, + ) expect(await heldLock.get()).toBe(true) - await expect(driver.lock('holo:cache:lock:timeout', 1).block(0)).resolves.toBe(false) - }) - it('uses the injected clock for blocking lock deadlines', async () => { - let currentTime = 0 - const driver = createRedisCacheDriver({ - name: 'redis', - connectionName: 'cache', - prefix: 'holo:cache:', - redis: { - host: '127.0.0.1', - port: 6379, - db: 0, + const waitedLock = redisCacheDriverInternals.createRedisLock( + client, + 'holo:cache:lock:wait', + 0.02, + () => 'owner-2', + async (milliseconds) => { + sleepCalls.push(milliseconds) + currentTime += milliseconds + }, + () => currentTime, + ) + await expect(waitedLock.block(0.05, async () => 'after-wait')).resolves.toBe('after-wait') + + const clockLock = redisCacheDriverInternals.createRedisLock( + client, + 'holo:cache:lock:clock', + 1, + () => 'owner-3', + async (milliseconds) => { + sleepCalls.push(milliseconds) + currentTime += milliseconds }, - now: () => currentTime, - sleep: async (milliseconds) => { + () => currentTime, + ) + expect(await clockLock.get()).toBe(true) + await expect(redisCacheDriverInternals.createRedisLock( + client, + 'holo:cache:lock:clock', + 1, + () => 'owner-4', + async (milliseconds) => { + sleepCalls.push(milliseconds) currentTime += milliseconds }, - }) + () => currentTime, + ).block(0.02)).resolves.toBe(false) - expect(await driver.lock('holo:cache:lock:clock', 1).get()).toBe(true) - await expect(driver.lock('holo:cache:lock:clock', 1).block(0.02)).resolves.toBe(false) + expect(sleepCalls).toEqual([10, 10, 10]) }) it('does not retry blocking lock acquisition after the wait deadline', async () => { let currentTime = 0 let setCalls = 0 const sleepCalls: number[] = [] - const client: Parameters[0] = { + const client: RedisLockClient = { async get() { return null }, @@ -502,6 +449,14 @@ describe('@holo-js/cache-redis', () => { }) it('flushes every cluster master when using a redis cluster client', async () => { + const firstNode = redisMock.createClient() + const secondNode = redisMock.createClient() + redisMock.setClusterNodes([firstNode, secondNode]) + firstNode.scan.mockResolvedValueOnce(['0', ['holo:cache:alpha']]) + secondNode.scan.mockResolvedValueOnce(['0', ['holo:cache:beta']]) + firstNode.del.mockResolvedValueOnce(1) + secondNode.del.mockResolvedValueOnce(1) + const driver = createRedisCacheDriver({ name: 'redis-cluster', connectionName: 'cache', @@ -513,49 +468,34 @@ describe('@holo-js/cache-redis', () => { ], }, }) + const clusterClient = lastClusterClient() - await driver.put({ - key: 'holo:cache:alpha', - payload: '"one"', - }) - await driver.put({ - key: 'holo:cache:beta', - payload: '"two"', - }) await driver.flush() - expect(redisMock.calls.scan).toEqual([ - ['0', 'MATCH', 'holo:cache:*', 'COUNT', 100], - ['0', 'MATCH', 'holo:cache:*', 'COUNT', 100], - ]) - expect(redisMock.calls.del.every(keys => keys.length === 1)).toBe(true) + expect(clusterClient.nodes).toHaveBeenCalledWith('master') + expect(firstNode.scan).toHaveBeenCalledWith('0', 'MATCH', 'holo:cache:*', 'COUNT', 100) + expect(secondNode.scan).toHaveBeenCalledWith('0', 'MATCH', 'holo:cache:*', 'COUNT', 100) + expect(firstNode.del).toHaveBeenCalledWith('holo:cache:alpha') + expect(secondNode.del).toHaveBeenCalledWith('holo:cache:beta') }) it('handles cluster clients that do not expose master node iteration', async () => { - const clusterPrototype = redisMock.FakeRedis.Cluster.prototype as { - nodes?: (role: 'master') => readonly unknown[] - } - const originalNodes = clusterPrototype.nodes - clusterPrototype.nodes = undefined - - try { - const driver = createRedisCacheDriver({ - name: 'redis-cluster', - connectionName: 'cache', - prefix: 'holo:cache:', - redis: { - db: 0, - clusters: [ - { host: 'cache-a.internal', port: 6379 }, - ], - }, - }) + redisMock.disableClusterNodes() + const driver = createRedisCacheDriver({ + name: 'redis-cluster', + connectionName: 'cache', + prefix: 'holo:cache:', + redis: { + db: 0, + clusters: [ + { host: 'cache-a.internal', port: 6379 }, + ], + }, + }) + const clusterClient = lastClusterClient() - await expect(driver.flush()).resolves.toBeUndefined() - expect(redisMock.calls.scan).toEqual([]) - } finally { - clusterPrototype.nodes = originalNodes - } + await expect(driver.flush()).resolves.toBeUndefined() + expect(clusterClient.nodes).toBeUndefined() }) it('prefers url, then clusters, then host/socket when creating redis clients', async () => { @@ -592,7 +532,7 @@ describe('@holo-js/cache-redis', () => { }, }) - expect(redisMock.calls.constructorArgs).toEqual([ + expect(redisMock.constructorArgs).toEqual([ [ 'redis://cache.internal:6380/2', { @@ -641,6 +581,7 @@ describe('@holo-js/cache-redis', () => { db: 0, }, }) + lastStandaloneClient().set.mockResolvedValueOnce('OK') expect(await hostDriver.add({ key: 'holo:cache:forever-add', payload: '"ok"', @@ -680,6 +621,12 @@ describe('@holo-js/cache-redis', () => { db: 0, }, }) + const client = lastStandaloneClient() + + client.incrby + .mockRejectedValueOnce(new Error('WRONGTYPE boom')) + .mockRejectedValueOnce(new Error('ETIMEDOUT')) + client.decrby.mockRejectedValueOnce(new Error('ETIMEDOUT')) await expect(driver.increment('holo:cache:wrongtype', 1)).rejects.toThrow(CacheInvalidNumericMutationError) await expect(driver.increment('holo:cache:timeout', 1)).rejects.toThrow('ETIMEDOUT') @@ -781,6 +728,11 @@ describe('@holo-js/cache-redis', () => { db: 0, }, }) + const client = lastStandaloneClient() + + client.set + .mockResolvedValueOnce('OK') + .mockResolvedValueOnce(null) const heldLock = driver.lock('holo:cache:lock:default-sleep', 1) expect(await heldLock.get()).toBe(true) diff --git a/packages/cache-redis/vitest.config.ts b/packages/cache-redis/vitest.config.ts index af288ede..378a2897 100644 --- a/packages/cache-redis/vitest.config.ts +++ b/packages/cache-redis/vitest.config.ts @@ -16,7 +16,7 @@ export default defineConfig({ name: '@holo-js/cache-redis', environment: 'node', include: runRedisIntegration - ? ['tests/**/*.test.ts'] + ? ['tests/real-redis.test.ts'] : ['tests/package.test.ts'], coverage: { provider: 'v8', diff --git a/packages/cli/package.json b/packages/cli/package.json index d3b2315a..2ddfc727 100644 --- a/packages/cli/package.json +++ b/packages/cli/package.json @@ -23,7 +23,7 @@ "build": "node ../../scripts/generate-cli-workspace-catalog.mjs && tsup", "stub": "node ../../scripts/generate-cli-workspace-catalog.mjs && tsup --watch", "typecheck": "tsc -p tsconfig.json --noEmit", - "test": "vitest --run && bun run test:integration", + "test": "vitest --run && HOLO_CLI_INCLUDE_INTEGRATION=1 vitest --run tests/cli.test.ts", "test:integration": "HOLO_CLI_INCLUDE_INTEGRATION=1 vitest --run tests/cli.test.ts" }, "dependencies": { diff --git a/packages/cli/tests/vitest-config.test.ts b/packages/cli/tests/vitest-config.test.ts index 7b905097..f7085683 100644 --- a/packages/cli/tests/vitest-config.test.ts +++ b/packages/cli/tests/vitest-config.test.ts @@ -40,7 +40,8 @@ describe('@holo-js/cli vitest config', () => { }) it('runs CLI integration tests from the package test script', () => { - expect(packageJson.scripts.test).toContain('test:integration') + expect(packageJson.scripts.test).toContain('HOLO_CLI_INCLUDE_INTEGRATION=1') + expect(packageJson.scripts.test).toContain('tests/cli.test.ts') expect(packageJson.scripts['test:integration']).toContain('HOLO_CLI_INCLUDE_INTEGRATION=1') expect(packageJson.scripts['test:integration']).toContain('tests/cli.test.ts') }) diff --git a/packages/config/src/access.ts b/packages/config/src/access.ts index 51425bc6..1d4a36f5 100644 --- a/packages/config/src/access.ts +++ b/packages/config/src/access.ts @@ -82,11 +82,11 @@ export function useConfig>( export function useConfig>(path: TPath): ValueAtPath export function useConfig(path: string): unknown export function useConfig(path: string): unknown { - return createConfigAccessors(requireConfigRuntime()).useConfig(path) + return getValueAtPath(requireConfigRuntime() as Record, path) } export function config>(path: TPath): ValueAtPath export function config(path: string): unknown export function config(path: string): unknown { - return createConfigAccessors(requireConfigRuntime()).config(path) + return getValueAtPath(requireConfigRuntime() as Record, path) } diff --git a/packages/config/src/loader.ts b/packages/config/src/loader.ts index 61199bed..019d6994 100644 --- a/packages/config/src/loader.ts +++ b/packages/config/src/loader.ts @@ -273,25 +273,7 @@ function normalizeLoadedConfig( const auth = normalizeAuthConfig(resolvedRawConfig.auth as HoloAuthConfig | undefined, { appKey: app.key, }) - - const customEntries = Object.entries(resolvedRawConfig).filter(([key]) => { - return key !== 'app' - && key !== 'database' - && key !== 'redis' - && key !== 'cache' - && key !== 'cors' - && key !== 'storage' - && key !== 'queue' - && key !== 'broadcast' - && key !== 'mail' - && key !== 'notifications' - && key !== 'media' - && key !== 'session' - && key !== 'security' - && key !== 'auth' - }) - const custom = Object.freeze(Object.fromEntries(customEntries)) as Readonly - const all = Object.freeze({ + const firstPartyConfig = { app, database, redis, @@ -306,24 +288,19 @@ function normalizeLoadedConfig( session, security, auth, + } + const firstPartyConfigKeys = new Set(Object.keys(firstPartyConfig)) + const customEntries = Object.entries(resolvedRawConfig).filter(([key]) => { + return !firstPartyConfigKeys.has(key) + }) + const custom = Object.freeze(Object.fromEntries(customEntries)) as Readonly + const all = Object.freeze({ + ...firstPartyConfig, ...custom, }) as Readonly['all']> return { - app, - database, - redis, - cache, - cors, - storage, - queue, - broadcast, - mail, - notifications, - media, - session, - security, - auth, + ...firstPartyConfig, custom, all, environment: options.environment, diff --git a/packages/config/tests/broadcast-config.type.test.ts b/packages/config/tests/broadcast-config.type.test.ts index 637f1080..b267377b 100644 --- a/packages/config/tests/broadcast-config.type.test.ts +++ b/packages/config/tests/broadcast-config.type.test.ts @@ -1,9 +1,9 @@ import { describe, it } from 'vitest' import { - createConfigAccessors, defineBroadcastConfig, type HoloConfigRegistry, } from '../src' +import { createConfigAccessorFixture } from './support/configAccessors' describe('@holo-js/config broadcast typing', () => { it('preserves broadcast inference through config helpers and dot-path access', () => { @@ -27,21 +27,8 @@ describe('@holo-js/config broadcast typing', () => { }, }) - const accessors = createConfigAccessors({ - app: {} as HoloConfigRegistry['app'], - database: {} as HoloConfigRegistry['database'], - redis: {} as HoloConfigRegistry['redis'], - cache: {} as HoloConfigRegistry['cache'], - cors: {} as HoloConfigRegistry['cors'], - storage: {} as HoloConfigRegistry['storage'], - queue: {} as HoloConfigRegistry['queue'], + const accessors = createConfigAccessorFixture({ broadcast: broadcast as unknown as HoloConfigRegistry['broadcast'], - mail: {} as HoloConfigRegistry['mail'], - notifications: {} as HoloConfigRegistry['notifications'], - media: {} as HoloConfigRegistry['media'], - session: {} as HoloConfigRegistry['session'], - security: {} as HoloConfigRegistry['security'], - auth: {} as HoloConfigRegistry['auth'], services: { mailgun: { secret: 'secret', diff --git a/packages/config/tests/config.type.test.ts b/packages/config/tests/config.type.test.ts index 4b98d09f..6c581ec9 100644 --- a/packages/config/tests/config.type.test.ts +++ b/packages/config/tests/config.type.test.ts @@ -1,6 +1,5 @@ import { describe, it } from 'vitest' import { - createConfigAccessors, defineAuthConfig, defineBroadcastConfig, defineCacheConfig, @@ -15,6 +14,7 @@ import { type HoloAppEnv, type HoloConfigRegistry, } from '../src' +import { createConfigAccessorFixture } from './support/configAccessors' declare module '../src/types' { interface HoloConfigRegistry { @@ -148,18 +148,12 @@ describe('@holo-js/config typing', () => { }, }, }) - const accessors = createConfigAccessors({ - app: {} as HoloConfigRegistry['app'], - database: {} as HoloConfigRegistry['database'], - redis: {} as HoloConfigRegistry['redis'], + const accessors = createConfigAccessorFixture({ cache: cache as unknown as HoloConfigRegistry['cache'], - cors: {} as HoloConfigRegistry['cors'], - storage: {} as HoloConfigRegistry['storage'], queue: queue as unknown as HoloConfigRegistry['queue'], broadcast: broadcast as unknown as HoloConfigRegistry['broadcast'], mail: mail as unknown as HoloConfigRegistry['mail'], notifications: notifications as unknown as HoloConfigRegistry['notifications'], - media: {} as HoloConfigRegistry['media'], session: session as unknown as HoloConfigRegistry['session'], security: security as unknown as HoloConfigRegistry['security'], auth: auth as unknown as HoloConfigRegistry['auth'], diff --git a/packages/config/tests/security-config.type.test.ts b/packages/config/tests/security-config.type.test.ts index 0acfaa9d..9aa89f0f 100644 --- a/packages/config/tests/security-config.type.test.ts +++ b/packages/config/tests/security-config.type.test.ts @@ -1,10 +1,10 @@ import { describe, it } from 'vitest' import { - createConfigAccessors, defineCorsConfig, defineSecurityConfig, type HoloConfigRegistry, } from '../src' +import { createConfigAccessorFixture } from './support/configAccessors' describe('@holo-js/config security typing', () => { it('preserves security inference through config helpers and dot-path access', () => { @@ -34,21 +34,9 @@ describe('@holo-js/config security typing', () => { statefulDomains: ['app.example.com'], }) - const accessors = createConfigAccessors({ - app: {} as HoloConfigRegistry['app'], - database: {} as HoloConfigRegistry['database'], - redis: {} as HoloConfigRegistry['redis'], - cache: {} as HoloConfigRegistry['cache'], + const accessors = createConfigAccessorFixture({ cors: cors as unknown as HoloConfigRegistry['cors'], - storage: {} as HoloConfigRegistry['storage'], - queue: {} as HoloConfigRegistry['queue'], - broadcast: {} as HoloConfigRegistry['broadcast'], - mail: {} as HoloConfigRegistry['mail'], - notifications: {} as HoloConfigRegistry['notifications'], - media: {} as HoloConfigRegistry['media'], - session: {} as HoloConfigRegistry['session'], security: security as unknown as HoloConfigRegistry['security'], - auth: {} as HoloConfigRegistry['auth'], services: { mailgun: { secret: 'secret', diff --git a/packages/config/tests/support/configAccessors.ts b/packages/config/tests/support/configAccessors.ts new file mode 100644 index 00000000..985f32ea --- /dev/null +++ b/packages/config/tests/support/configAccessors.ts @@ -0,0 +1,42 @@ +import { + createConfigAccessors, + type HoloConfigMap, + type HoloConfigRegistry, +} from '../../src' + +type ServicesFixtureConfig = { + readonly mailgun: { + readonly secret: string + } +} + +const defaultConfigRegistry = { + app: {} as HoloConfigRegistry['app'], + database: {} as HoloConfigRegistry['database'], + redis: {} as HoloConfigRegistry['redis'], + cache: {} as HoloConfigRegistry['cache'], + cors: {} as HoloConfigRegistry['cors'], + storage: {} as HoloConfigRegistry['storage'], + queue: {} as HoloConfigRegistry['queue'], + broadcast: {} as HoloConfigRegistry['broadcast'], + mail: {} as HoloConfigRegistry['mail'], + notifications: {} as HoloConfigRegistry['notifications'], + media: {} as HoloConfigRegistry['media'], + session: {} as HoloConfigRegistry['session'], + security: {} as HoloConfigRegistry['security'], + auth: {} as HoloConfigRegistry['auth'], + services: { + mailgun: { + secret: 'secret', + }, + }, +} satisfies HoloConfigRegistry & { readonly services: ServicesFixtureConfig } + +export function createConfigAccessorFixture( + overrides: TOverrides, +) { + return createConfigAccessors({ + ...defaultConfigRegistry, + ...overrides, + }) +} diff --git a/packages/core/src/portable/holo.ts b/packages/core/src/portable/holo.ts index 8b8fe5bf..ea5e64a8 100644 --- a/packages/core/src/portable/holo.ts +++ b/packages/core/src/portable/holo.ts @@ -1,9 +1,7 @@ import { AsyncLocalStorage } from 'node:async_hooks' import { existsSync } from 'node:fs' import { createHash, createHmac } from 'node:crypto' -import { createRequire } from 'node:module' import { resolve } from 'node:path' -import { pathToFileURL } from 'node:url' import { config as globalConfig, configureConfigRuntime, @@ -23,7 +21,7 @@ import { Entity, resetDB, } from '@holo-js/db' -import { importBundledRuntimeModule } from '../runtimeModule' +import { importBundledRuntimeModule, importOptionalRuntimeModule } from '../runtimeModule' import { resolveRuntimeConnectionManagerOptions } from './dbRuntime' import { loadGeneratedProjectRegistry, type GeneratedProjectRegistry } from './registry' import { configurePlainNodeStorageRuntime, resetOptionalStorageRuntime } from '../storageRuntime' @@ -1071,54 +1069,13 @@ function restoreOptionalSubsystemRuntimeBindings( const BROADCAST_PUBLISH_TIMEOUT_MS = 10_000 -const portableRuntimeRequire = createRequire(import.meta.url) - -function resolveOptionalImportSpecifier(specifier: string, projectRoot?: string): string { - if (!projectRoot) { - return specifier - } - - try { - const resolved = portableRuntimeRequire.resolve(specifier, { - paths: [projectRoot], - }) - return pathToFileURL(resolved).href - } catch { - return specifier - } -} - async function importOptionalModule( specifier: string, options: { readonly projectRoot?: string } = {}, ): Promise { - const resolvedSpecifier = resolveOptionalImportSpecifier(specifier, options.projectRoot) - - try { - return await import(/* webpackIgnore: true */ resolvedSpecifier as string) as TModule - } catch (error) { - /* v8 ignore start -- optional-package absence is validated in published-package integration, not in this monorepo test graph */ - if ( - error instanceof Error - && ( - error.message.includes(`Cannot find package '${specifier}'`) - || error.message.includes(`Cannot find module '${specifier}'`) - || error.message.includes(`Failed to load url ${specifier}`) - || error.message.includes(`Could not resolve "${specifier}"`) - || error.message.includes(`Cannot find package '${resolvedSpecifier}'`) - || error.message.includes(`Cannot find module '${resolvedSpecifier}'`) - || error.message.includes(`Failed to load url ${resolvedSpecifier}`) - || error.message.includes(`Could not resolve "${resolvedSpecifier}"`) - ) - ) { - return undefined - } - /* v8 ignore stop */ - - throw error - } + return importOptionalRuntimeModule(specifier, options) } const portableRuntimeModuleInternals = { diff --git a/packages/core/src/runtimeModule.ts b/packages/core/src/runtimeModule.ts index 9a3e15cc..e582ac18 100644 --- a/packages/core/src/runtimeModule.ts +++ b/packages/core/src/runtimeModule.ts @@ -1,6 +1,7 @@ import { mkdtemp, mkdir, rm, stat, writeFile } from 'node:fs/promises' +import { createRequire } from 'node:module' import { basename, extname, join } from 'node:path' -import { pathToFileURL } from 'node:url' +import { fileURLToPath, pathToFileURL } from 'node:url' import type { BuildOptions, BuildResult } from 'esbuild' type EsbuildModule = { @@ -15,6 +16,96 @@ async function importModule(specifier: string): Promise { return import(/* webpackIgnore: true */ specifier) as Promise } +const runtimeModuleRequire = createRequire(import.meta.url) + +function resolveOptionalImportSpecifier(specifier: string, projectRoot?: string): string { + if (!projectRoot) { + return specifier + } + + try { + return pathToFileURL(runtimeModuleRequire.resolve(specifier, { + paths: [projectRoot], + })).href + } catch { + return specifier + } +} + +function getErrorMessage(error: unknown): string { + return error && typeof error === 'object' && 'message' in error + && typeof (error as { message?: unknown }).message === 'string' + ? (error as { message: string }).message + : '' +} + +function getMissingModuleTarget(message: string): string | undefined { + const match = message.match(/Cannot find package '([^']+)'|Cannot find module '([^']+)'|Failed to load url ([^ ]+)|Could not resolve "([^"]+)"/) + return match?.slice(1).find((value): value is string => typeof value === 'string') +} + +function normalizeImportSpecifier(specifier: string): string { + return specifier.startsWith('file://') ? fileURLToPath(specifier) : specifier +} + +function normalizeImportTarget(value: string): string { + return normalizeImportSpecifier(value).replace(/\\/g, '/') +} + +function matchesRelativeImportTarget(failedTarget: string | undefined, specifier: string): boolean { + if (!failedTarget || !specifier.startsWith('.')) { + return false + } + + const suffix = specifier.startsWith('./') ? specifier.slice(2) : specifier + return normalizeImportTarget(failedTarget).endsWith(`/${suffix}`) +} + +function isMissingOptionalModule(error: unknown, specifier: string, resolvedSpecifier: string): boolean { + if (!error || typeof error !== 'object') { + return false + } + + const message = getErrorMessage(error) + const failedTarget = getMissingModuleTarget(message) + const expectedTargets = new Set([ + specifier, + resolvedSpecifier, + normalizeImportTarget(specifier), + normalizeImportTarget(resolvedSpecifier), + ]) + const normalizedFailedTarget = typeof failedTarget === 'string' ? normalizeImportTarget(failedTarget) : undefined + const matchesRequestedTarget = typeof normalizedFailedTarget === 'string' + && (expectedTargets.has(normalizedFailedTarget) || matchesRelativeImportTarget(normalizedFailedTarget, specifier)) + + return ( + ('code' in error && (error as { code?: unknown }).code === 'ERR_MODULE_NOT_FOUND' && matchesRequestedTarget) + || (message.startsWith('Cannot find package \'') && matchesRequestedTarget) + || (message.startsWith('Cannot find module \'') && matchesRequestedTarget) + || (message.includes('Does the file exist?') && message.startsWith('Failed to load url ') && matchesRequestedTarget) + || (message.startsWith('Could not resolve "') && matchesRequestedTarget) + ) +} + +export async function importOptionalRuntimeModule( + specifier: string, + options: { + readonly projectRoot?: string + } = {}, +): Promise { + const resolvedSpecifier = resolveOptionalImportSpecifier(specifier, options.projectRoot) + + try { + return await importModule(resolvedSpecifier) + } catch (error) { + if (isMissingOptionalModule(error, specifier, resolvedSpecifier)) { + return undefined + } + + throw error + } +} + async function pathExists(path: string): Promise { try { await stat(path) @@ -144,6 +235,7 @@ async function runEsbuild(options: BuildOptions): Promise { export const runtimeModuleInternals = { bundleRuntimeModule, importModule, + importOptionalRuntimeModule, loadEsbuild, pathExists, runEsbuild, diff --git a/packages/core/src/storageRuntime.ts b/packages/core/src/storageRuntime.ts index 3f18b1d6..6978b8d2 100644 --- a/packages/core/src/storageRuntime.ts +++ b/packages/core/src/storageRuntime.ts @@ -1,7 +1,7 @@ import { mkdir, readFile, readdir, rm, writeFile } from 'node:fs/promises' import { dirname, isAbsolute, join, relative, resolve, sep, win32 } from 'node:path' -import { fileURLToPath } from 'node:url' import type { LoadedHoloConfig, HoloConfigMap } from '@holo-js/config' +import { importOptionalRuntimeModule } from './runtimeModule' type StorageBackend = { getItem(key: string): Promise @@ -67,42 +67,7 @@ type StorageS3Module = { } async function importOptionalModule(specifier: string): Promise { - try { - if (process.env.VITEST) { - return await import(/* @vite-ignore */ specifier) as TModule - } - - return await import(/* webpackIgnore: true */ specifier) as TModule - } catch (error) { - const message = error && typeof error === 'object' && 'message' in error - && typeof (error as { message?: unknown }).message === 'string' - ? (error as { message: string }).message - : '' - const resolvedSpecifier = specifier.startsWith('file://') - ? fileURLToPath(specifier) - : specifier - const failedTarget = message.match(/Cannot find package '([^']+)'|Cannot find module '([^']+)'|Failed to load url ([^ ]+)/)?.slice(1) - .find((value): value is string => typeof value === 'string') - const matchesRequestedTarget = failedTarget === specifier || failedTarget === resolvedSpecifier - const isMissingOptionalModule = error && typeof error === 'object' - && ( - ('code' in error && (error as { code?: unknown }).code === 'ERR_MODULE_NOT_FOUND') - /* v8 ignore start -- these fallback string variants depend on the host loader's exact missing-module wording */ - // Node reports missing ESM packages as "Cannot find package ''". - || (message.startsWith('Cannot find package \'') && matchesRequestedTarget) - // Node reports missing CJS/URL imports as "Cannot find module ''". - || (message.startsWith('Cannot find module \'') && matchesRequestedTarget) - /* v8 ignore stop */ - // Vite reports unresolved optional imports as "Failed to load url ... Does the file exist?". - || (message.includes('Does the file exist?') && message.startsWith('Failed to load url ') && matchesRequestedTarget) - ) - - if (isMissingOptionalModule) { - return undefined - } - - throw error - } + return importOptionalRuntimeModule(specifier) } export function resolveStorageKeyPath(root: string, key: string): string { diff --git a/packages/core/tests/dbRuntime.test.ts b/packages/core/tests/dbRuntime.test.ts index 545c9a56..3a72632f 100644 --- a/packages/core/tests/dbRuntime.test.ts +++ b/packages/core/tests/dbRuntime.test.ts @@ -1,489 +1,13 @@ -import { describe, expect, it, vi } from 'vitest' -import { - createAdapter, - createDialect, - createRuntimeConnectionOptions, - createRuntimeLogger, - isSupportedDatabaseDriver, - parseDatabaseDriver, - resolveRuntimeConnectionManagerOptions, -} from '../src/portable/dbRuntime' - -type AdapterHarness = { - config?: Record -} +import { describe, expect, it } from 'vitest' +import * as dbRuntime from '@holo-js/db' +import * as core from '../src' +import * as portable from '../src/portable/dbRuntime' describe('core db runtime bootstrap', () => { - it('does not create a logger or log security override when db logging is disabled', () => { - const options = createRuntimeConnectionOptions('sqlite', './data.sqlite', false) - - expect(options.logger).toBeUndefined() - expect(options.security).toBeUndefined() - }) - - it('creates a logger and enables visible SQL/bindings when db logging is enabled', () => { - const options = createRuntimeConnectionOptions('sqlite', './data.sqlite', true) - - expect(options.logger).toBeDefined() - expect(options.security).toEqual({ - debugSqlInLogs: true, - redactBindingsInLogs: false, - }) - }) - - it('propagates an optional schema name into runtime connection options', () => { - const options = createRuntimeConnectionOptions('mysql', 'mysql://db', false, 'analytics') - - expect(options.schemaName).toBe('analytics') - }) - - it('creates Postgres runtime adapters from structured credentials when no URL is provided', () => { - const options = createRuntimeConnectionOptions('postgres', { - host: 'db.internal', - port: 5432, - username: 'app', - password: 'secret', - database: 'primary', - ssl: true, - }, false, 'public', 'primary') - - const adapter = options.adapter as unknown as AdapterHarness - - expect(adapter.config).toMatchObject({ - host: 'db.internal', - port: 5432, - user: 'app', - password: 'secret', - database: 'primary', - ssl: true, - }) - }) - - it('resolves documented multi-connection runtime config shapes', () => { - const options = resolveRuntimeConnectionManagerOptions({ - db: { - defaultConnection: 'analytics', - connections: { - primary: { - driver: 'sqlite', - database: './storage/app.sqlite', - }, - analytics: { - driver: 'postgres', - url: 'postgresql://analytics', - schema: 'warehouse', - logging: true, - }, - }, - }, - }) - - expect(options.getDefaultConnectionName()).toBe('analytics') - expect(options.getConnectionNames()).toEqual(['primary', 'analytics']) - expect(options.connection('primary').getDriver()).toBe('sqlite') - expect(options.connection('primary').getSchemaName()).toBeUndefined() - expect(options.connection('analytics').getDriver()).toBe('postgres') - expect(options.connection('analytics').getSchemaName()).toBe('warehouse') - expect(options.connection('analytics').getLogger()).toBeDefined() - }) - - it('uses the sole named connection as default when no explicit default is configured', () => { - const options = resolveRuntimeConnectionManagerOptions({ - db: { - connections: { - analytics: { - driver: 'postgres', - url: 'postgresql://analytics', - }, - }, - }, - }) - - expect(options.getDefaultConnectionName()).toBe('analytics') - expect(options.getConnectionNames()).toEqual(['analytics']) - }) - - it('resolves structured credential fields for named network connections', () => { - const options = resolveRuntimeConnectionManagerOptions({ - db: { - defaultConnection: 'primary', - connections: { - primary: { - driver: 'postgres', - host: 'db.internal', - port: 5432, - username: 'app', - password: 'secret', - database: 'primary', - ssl: true, - }, - analytics: { - driver: 'mysql', - host: 'mysql.internal', - port: '3306', - username: 'reporter', - password: 'top-secret', - database: 'analytics', - schema: 'warehouse', - }, - }, - }, - }) - - const primaryAdapter = options.connection('primary').getAdapter() as unknown as AdapterHarness - const analyticsAdapter = options.connection('analytics').getAdapter() as unknown as AdapterHarness - - expect(options.connection('primary').getDriver()).toBe('postgres') - expect(primaryAdapter.config).toMatchObject({ - host: 'db.internal', - port: 5432, - user: 'app', - password: 'secret', - database: 'primary', - ssl: true, - }) - expect(options.connection('analytics').getDriver()).toBe('mysql') - expect(options.connection('analytics').getSchemaName()).toBe('warehouse') - expect(analyticsAdapter.config).toMatchObject({ - host: 'mysql.internal', - port: 3306, - user: 'reporter', - password: 'top-secret', - database: 'analytics', - }) - }) - - it('falls back to sqlite defaults when no connection map exists', () => { - const options = resolveRuntimeConnectionManagerOptions({ - holo: {}, - }) - - expect(options.getDefaultConnectionName()).toBe('default') - expect(options.getConnectionNames()).toEqual(['default']) - expect(options.connection().getDriver()).toBe('sqlite') - expect(options.connection().getSchemaName()).toBeUndefined() - expect(options.connection().getLogger()).toBeUndefined() - }) - - it('supports canonical structured credential fields without a URL', () => { - const options = resolveRuntimeConnectionManagerOptions({ - db: { - connections: { - default: { - driver: 'postgres', - host: 'db.internal', - port: '5432', - username: 'app', - password: 'secret', - database: 'primary', - ssl: true, - schema: 'public', - }, - }, - }, - }) - - const adapter = options.connection().getAdapter() as unknown as AdapterHarness - - expect(options.connection().getDriver()).toBe('postgres') - expect(options.connection().getSchemaName()).toBe('public') - expect(adapter.config).toMatchObject({ - host: 'db.internal', - port: 5432, - user: 'app', - password: 'secret', - database: 'primary', - ssl: true, - }) - }) - - it('requires an explicit driver when using host-style connection fields', () => { - expect(() => resolveRuntimeConnectionManagerOptions({ - db: { - connections: { - primary: { - host: 'db.internal', - username: 'app', - password: 'secret', - database: 'primary', - }, - }, - }, - })).toThrow('must declare a database driver when using host, port, username, password, or ssl settings') - }) - - it('marks Postgres and MySQL runtime dialects as alter-capable', () => { - expect(createDialect('sqlite').capabilities.ddlAlterSupport).toBe(false) - expect(createDialect('postgres').capabilities.ddlAlterSupport).toBe(true) - expect(createDialect('mysql').capabilities.ddlAlterSupport).toBe(true) - }) - - it('logs query and transaction lifecycle events when enabled', () => { - const warn = vi.spyOn(console, 'warn').mockImplementation(() => {}) - const error = vi.spyOn(console, 'error').mockImplementation(() => {}) - - const logger = createRuntimeLogger(true) - - expect(logger).toBeDefined() - - logger?.onQuerySuccess?.({ - kind: 'query', - connectionName: 'default', - sql: 'select * from "users"', - bindings: ['ops@example.com'], - scope: 'root', - durationMs: 12, - rowCount: 2, - }) - - logger?.onQueryError?.({ - kind: 'execute', - connectionName: 'default', - sql: 'delete from "users"', - bindings: [], - scope: 'transaction', - durationMs: 7, - error: new Error('boom'), - }) - - logger?.onTransactionStart?.({ - scope: 'transaction', - depth: 1, - }) - - logger?.onTransactionCommit?.({ - scope: 'savepoint', - depth: 2, - savepointName: 'sp_1', - }) - - logger?.onTransactionRollback?.({ - scope: 'savepoint', - depth: 2, - savepointName: 'sp_2', - error: new Error('rollback'), - }) - - expect(warn).toHaveBeenCalledTimes(4) - expect(error).toHaveBeenCalledTimes(1) - expect(warn.mock.calls[0]?.[0]).toContain('[holo:db] query ok connection=default') - expect(warn.mock.calls[0]?.[0]).toContain('sql=select * from "users"') - expect(error.mock.calls[0]?.[0]).toContain('[holo:db] execute error connection=default') - expect(warn.mock.calls[3]?.[0]).toContain('transaction rollback') - - warn.mockRestore() - error.mockRestore() - }) - - it('returns no runtime logger when disabled', () => { - expect(createRuntimeLogger(false)).toBeUndefined() - }) - - it('covers driver parsing helpers and adapter factories', () => { - expect(isSupportedDatabaseDriver('sqlite')).toBe(true) - expect(isSupportedDatabaseDriver('mongo')).toBe(false) - expect(parseDatabaseDriver(undefined, 'mysql')).toBe('mysql') - expect(parseDatabaseDriver('postgres', 'sqlite')).toBe('postgres') - expect(() => parseDatabaseDriver('mongo', 'sqlite')).toThrow('Unsupported Holo database driver') - - expect(createAdapter('postgres', 'postgresql://db.internal/app')).toBeDefined() - expect(createAdapter('postgres', { - host: 'db.internal', - port: 5432, - username: 'app', - password: 'secret', - database: 'primary', - ssl: { rejectUnauthorized: false }, - })).toBeDefined() - - expect(createAdapter('mysql', 'mysql://db.internal/app')).toBeDefined() - expect(createAdapter('mysql', { - host: 'db.internal', - port: 3306, - username: 'app', - password: 'secret', - database: 'primary', - })).toBeDefined() - - expect(createAdapter('sqlite', { database: './data/app.sqlite' })).toBeDefined() - expect(() => createAdapter('mongo' as never, 'mongodb://db')).toThrow('Unsupported Holo database driver') - expect(createDialect('sqlite').createPlaceholder(3)).toBe('?') - }) - - it('covers remaining runtime logger branches', () => { - const warn = vi.spyOn(console, 'warn').mockImplementation(() => {}) - const logger = createRuntimeLogger(true) - - logger?.onQuerySuccess?.({ - kind: 'query', - connectionName: 'analytics', - sql: 'update "users" set "name" = ?', - bindings: [], - scope: 'root', - durationMs: 4, - affectedRows: 3, - }) - - logger?.onTransactionStart?.({ - scope: 'transaction', - depth: 1, - savepointName: 'sp_1', - }) - - logger?.onTransactionCommit?.({ - scope: 'transaction', - depth: 1, - }) - - logger?.onTransactionRollback?.({ - scope: 'transaction', - depth: 1, - error: 'forced rollback', - }) - - logger?.onQuerySuccess?.({ - kind: 'query', - connectionName: 'analytics', - sql: 'select 1', - bindings: [], - scope: 'root', - durationMs: 1, - }) - - expect(warn.mock.calls[0]?.[0]).toContain('affected=3') - expect(warn.mock.calls[1]?.[0]).toContain('savepoint=sp_1') - expect(warn.mock.calls[2]?.[0]).toContain('transaction commit') - expect(warn.mock.calls[3]?.[0]).toContain('error=forced rollback') - expect(warn.mock.calls[4]?.[0]).toContain('duration=1ms sql=select 1') - - warn.mockRestore() - }) - - it('falls back to sqlite defaults when no runtime config is provided', () => { - const options = resolveRuntimeConnectionManagerOptions({}) - - expect(options.getDefaultConnectionName()).toBe('default') - expect(options.connection().getDriver()).toBe('sqlite') - }) - - it('normalizes undefined connection entries as empty configs', () => { - const options = resolveRuntimeConnectionManagerOptions({ - db: { - connections: { - default: undefined as never, - }, - }, - }) - - expect(options.getDefaultConnectionName()).toBe('default') - expect(options.connection().getDriver()).toBe('sqlite') - }) - - it('supports sqlite filename and invalid port fallbacks in named connections', () => { - const options = resolveRuntimeConnectionManagerOptions({ - db: { - defaultConnection: 'local', - connections: { - local: { - filename: './data/local.sqlite', - port: 'not-a-port', - logging: true, - }, - }, - }, - }) - - expect(options.getDefaultConnectionName()).toBe('local') - expect(options.connection('local').getDriver()).toBe('sqlite') - expect(options.connection('local').getLogger()).toBeDefined() - }) - - it('infers drivers from string connection inputs and default connection names', () => { - const options = resolveRuntimeConnectionManagerOptions({ - db: { - connections: { - default: 'mysql://default.internal/app', - analytics: 'postgresql://analytics.internal/app', - }, - }, - }) - - expect(options.getDefaultConnectionName()).toBe('default') - expect(options.connection('default').getDriver()).toBe('mysql') - expect(options.connection('analytics').getDriver()).toBe('postgres') - }) - - it('infers sqlite drivers from filesystem-style urls', () => { - const options = resolveRuntimeConnectionManagerOptions({ - db: { - connections: { - absolute: '/tmp/app.db', - filedb: 'file:./data/app.sqlite', - memory: ':memory:', - relative: '../data/app.sqlite3', - sqlite: './data/app.sqlite', - }, - }, - }) - - expect(options.connection('absolute').getDriver()).toBe('sqlite') - expect(options.connection('filedb').getDriver()).toBe('sqlite') - expect(options.connection('memory').getDriver()).toBe('sqlite') - expect(options.connection('relative').getDriver()).toBe('sqlite') - expect(options.connection('sqlite').getDriver()).toBe('sqlite') - }) - - it('falls back to the default runtime driver when a url does not imply one', () => { - const options = resolveRuntimeConnectionManagerOptions({ - db: { - connections: { - unknown: 'https://example.test/not-a-db-url', - }, - }, - }) - - expect(options.connection('unknown').getDriver()).toBe('sqlite') - }) - - it('covers adapter fallbacks and logger string branches', () => { - expect(createAdapter('mysql', { - host: 'db.internal', - port: 3306, - username: 'app', - password: 'secret', - database: 'primary', - ssl: true, - })).toBeDefined() - expect(createAdapter('sqlite', {})).toBeDefined() - expect(createDialect('postgres').quoteIdentifier('users.email')).toBe('"users"."email"') - expect(createDialect('postgres').createPlaceholder(2)).toBe('$2') - expect(createDialect('mysql').quoteIdentifier('users.email')).toBe('`users`.`email`') - expect(createDialect('mysql').createPlaceholder(9)).toBe('?') - expect(createDialect('sqlite').quoteIdentifier('users')).toBe('"users"') - - const warn = vi.spyOn(console, 'warn').mockImplementation(() => {}) - const error = vi.spyOn(console, 'error').mockImplementation(() => {}) - const logger = createRuntimeLogger(true) - - logger?.onQueryError?.({ - kind: 'query', - connectionName: 'default', - sql: 'select 1', - bindings: [], - scope: 'root', - durationMs: 1, - error: 'boom', - }) - - logger?.onTransactionRollback?.({ - scope: 'transaction', - depth: 1, - }) - - expect(error.mock.calls[0]?.[0]).toContain('error=boom') - expect(warn.mock.calls[0]?.[0]).toContain('transaction rollback scope=transaction depth=1') - - warn.mockRestore() - error.mockRestore() + it('preserves the core db runtime re-export contract', () => { + expect(portable.createRuntimeConnectionOptions).toBe(dbRuntime.createRuntimeConnectionOptions) + expect(portable.resolveRuntimeConnectionManagerOptions).toBe(dbRuntime.resolveRuntimeConnectionManagerOptions) + expect(core.createRuntimeConnectionOptions).toBe(dbRuntime.createRuntimeConnectionOptions) + expect(core.resolveRuntimeConnectionManagerOptions).toBe(dbRuntime.resolveRuntimeConnectionManagerOptions) }) }) diff --git a/packages/core/tests/runtime.test.ts b/packages/core/tests/runtime.test.ts index 082db50c..bbd883af 100644 --- a/packages/core/tests/runtime.test.ts +++ b/packages/core/tests/runtime.test.ts @@ -18,7 +18,8 @@ import { notify, resetNotificationsRuntime, } from '@holo-js/notifications' -import { configureMailRuntime, listFakeSentMails, mailRuntimeInternals, previewMail, resetFakeSentMails } from '@holo-js/mail' +import { configureMailRuntime, listFakeSentMails, previewMail, resetFakeSentMails } from '@holo-js/mail' +import { mailRuntimeInternals } from '../../mail/src/runtime' import { Event, defineEvent, diff --git a/packages/core/tests/runtimeModule.test.ts b/packages/core/tests/runtimeModule.test.ts index 26461765..a9b83b8e 100644 --- a/packages/core/tests/runtimeModule.test.ts +++ b/packages/core/tests/runtimeModule.test.ts @@ -91,6 +91,22 @@ describe('@holo-js/core runtime module helpers', () => { } }) + it('does not hide missing transitive dependencies in optional runtime modules', async () => { + const projectRoot = await createTempProject() + const missingPath = join(projectRoot, 'missing.mjs') + const entryPath = join(projectRoot, 'optional.mjs') + + await expect( + runtimeModuleInternals.importOptionalRuntimeModule(pathToFileURL(missingPath).href), + ).resolves.toBeUndefined() + + await writeFile(entryPath, 'import "./missing-child.mjs"\nexport const loaded = true\n', 'utf8') + + await expect( + runtimeModuleInternals.importOptionalRuntimeModule(pathToFileURL(entryPath).href), + ).rejects.toThrow() + }) + it('returns the default esbuild export when the imported module does not expose build directly', async () => { await expect(runtimeModuleInternals.loadEsbuild()).resolves.toEqual(expect.objectContaining({ build: expect.any(Function), diff --git a/packages/core/tests/storageRuntime.test.ts b/packages/core/tests/storageRuntime.test.ts index 93f9d300..345c9668 100644 --- a/packages/core/tests/storageRuntime.test.ts +++ b/packages/core/tests/storageRuntime.test.ts @@ -54,19 +54,17 @@ describe('@holo-js/core storage runtime optional imports', () => { await expect(resetOptionalStorageRuntime()).rejects.toBe('boom') }) - it('imports optional storage modules through the webpackIgnore branch outside Vitest', async () => { + it('imports optional storage modules through the shared optional runtime loader', async () => { const root = await mkdtemp(join(tmpdir(), 'holo-storage-runtime-')) tempDirs.push(root) const modulePath = join(root, 'module.mjs') await writeFile(modulePath, 'export default "loaded"\n', 'utf8') - await withoutVitestEnv(async () => { - await expect(storageRuntimeInternals.importOptionalModule(pathToFileURL(modulePath).href)).resolves.toEqual( - expect.objectContaining({ - default: 'loaded', - }), - ) - }) + await expect(storageRuntimeInternals.importOptionalModule(pathToFileURL(modulePath).href)).resolves.toEqual( + expect.objectContaining({ + default: 'loaded', + }), + ) }) it('treats missing optional storage modules as optional inside Vitest as well', async () => { @@ -90,36 +88,6 @@ describe('@holo-js/core storage runtime optional imports', () => { }) }) - it('rethrows module evaluation failures with a non-matching error code outside Vitest', async () => { - const root = await mkdtemp(join(tmpdir(), 'holo-storage-runtime-code-boom-')) - tempDirs.push(root) - const modulePath = join(root, 'code-boom.mjs') - await writeFile(modulePath, 'throw Object.assign(new Error("boom"), { code: "E_CUSTOM" })\n', 'utf8') - - await withoutVitestEnv(async () => { - await expect(storageRuntimeInternals.importOptionalModule(pathToFileURL(modulePath).href)).rejects.toThrow('boom') - }) - }) - - it('rethrows module evaluation failures without an Error object outside Vitest', async () => { - const root = await mkdtemp(join(tmpdir(), 'holo-storage-runtime-string-boom-')) - tempDirs.push(root) - const modulePath = join(root, 'string-boom.mjs') - await writeFile(modulePath, 'throw "boom"\n', 'utf8') - - await withoutVitestEnv(async () => { - await expect(storageRuntimeInternals.importOptionalModule(pathToFileURL(modulePath).href)).rejects.toBe('boom') - }) - }) - - it('treats module resolution failures with a resolver message as optional outside Vitest', async () => { - const root = await mkdtemp(join(tmpdir(), 'holo-storage-runtime-resolve-')) - tempDirs.push(root) - await withoutVitestEnv(async () => { - await expect(storageRuntimeInternals.importOptionalModule(pathToFileURL(join(root, 'missing.mjs')).href)).resolves.toBeUndefined() - }) - }) - it('does not treat unrelated "Failed to load url" failures as missing modules', async () => { const root = await mkdtemp(join(tmpdir(), 'holo-storage-runtime-failed-url-')) tempDirs.push(root) @@ -130,18 +98,6 @@ describe('@holo-js/core storage runtime optional imports', () => { }) }) - it('treats matching Vite missing-url messages as optional outside Vitest', async () => { - const root = await mkdtemp(join(tmpdir(), 'holo-storage-runtime-vite-missing-')) - tempDirs.push(root) - const modulePath = join(root, 'vite-missing.mjs') - const moduleUrl = pathToFileURL(modulePath).href - await writeFile(modulePath, `throw new Error(${JSON.stringify(`Failed to load url ${moduleUrl} Does the file exist?`)})\n`, 'utf8') - - await withoutVitestEnv(async () => { - await expect(storageRuntimeInternals.importOptionalModule(moduleUrl)).resolves.toBeUndefined() - }) - }) - it('returns undefined for missing optional storage modules outside Vitest', async () => { const root = await mkdtemp(join(tmpdir(), 'holo-storage-runtime-missing-')) tempDirs.push(root) @@ -150,23 +106,4 @@ describe('@holo-js/core storage runtime optional imports', () => { }) }) - it('treats matching "Cannot find module" loader messages as optional outside Vitest', async () => { - const root = await mkdtemp(join(tmpdir(), 'holo-storage-runtime-cannot-find-')) - tempDirs.push(root) - const modulePath = join(root, 'cannot-find.mjs') - const moduleUrl = pathToFileURL(modulePath).href - await writeFile(modulePath, `throw new Error(${JSON.stringify(`Cannot find module '${moduleUrl}'`)})\n`, 'utf8') - - await withoutVitestEnv(async () => { - await expect(storageRuntimeInternals.importOptionalModule(moduleUrl)).resolves.toBeUndefined() - }) - }) - - it('treats missing absolute-path optional storage modules as optional outside Vitest', async () => { - const root = await mkdtemp(join(tmpdir(), 'holo-storage-runtime-absolute-missing-')) - tempDirs.push(root) - await withoutVitestEnv(async () => { - await expect(storageRuntimeInternals.importOptionalModule(join(root, 'missing.mjs'))).resolves.toBeUndefined() - }) - }) }) diff --git a/packages/db-mysql/package.json b/packages/db-mysql/package.json index 7606f1e3..0d803f55 100644 --- a/packages/db-mysql/package.json +++ b/packages/db-mysql/package.json @@ -20,7 +20,8 @@ "build": "tsup", "stub": "tsup --watch", "typecheck": "tsc -p tsconfig.json --noEmit", - "test": "vitest --run" + "test": "vitest --run", + "test:integration": "HOLO_MYSQL_INTEGRATION=1 vitest --run tests/mysql.test.ts" }, "peerDependencies": { "@holo-js/db": "catalog:" diff --git a/packages/db-mysql/src/index.ts b/packages/db-mysql/src/index.ts index 49126f44..ace9684d 100644 --- a/packages/db-mysql/src/index.ts +++ b/packages/db-mysql/src/index.ts @@ -6,37 +6,9 @@ import mysql, { type ResultSetHeader, type RowDataPacket, } from 'mysql2/promise' +import type { DriverAdapter, DriverExecutionResult, DriverQueryResult } from '@holo-js/db' -export interface DriverQueryResult = Record> { - rows: TRow[] - rowCount: number -} - -export interface DriverExecutionResult { - affectedRows?: number - lastInsertId?: number | string -} - -export interface DriverAdapter { - initialize(): Promise - disconnect(): Promise - isConnected(): boolean - runWithTransactionScope?(callback: () => Promise): Promise - query = Record>( - sql: string, - bindings?: readonly unknown[], - ): Promise> - execute( - sql: string, - bindings?: readonly unknown[], - ): Promise - beginTransaction(): Promise - commit(): Promise - rollback(): Promise - createSavepoint?(name: string): Promise - rollbackToSavepoint?(name: string): Promise - releaseSavepoint?(name: string): Promise -} +export type { DriverAdapter, DriverExecutionResult, DriverQueryResult } from '@holo-js/db' class TransactionError extends Error {} @@ -65,7 +37,6 @@ export interface MySQLAdapterOptions { type ScopedMySQLTransaction = { client: MySQLClientLike leased: boolean - released: boolean } type RawMySQLClientLike = { @@ -121,8 +92,8 @@ export class MySQLAdapter implements DriverAdapter { private readonly transactionScope = new AsyncLocalStorage() constructor(options: MySQLAdapterOptions = {}) { - this.directClient = options.client ? wrapMySQLClient(options.client) : undefined - this.pool = options.pool ? wrapMySQLPool(options.pool) : undefined + this.directClient = options.client + this.pool = options.pool this.createPoolInstance = options.createPool ?? (options.client || options.pool ? undefined : config => wrapMySQLPool(mysql.createPool(config))) @@ -179,7 +150,6 @@ export class MySQLAdapter implements DriverAdapter { return this.transactionScope.run({ client: this.directClient, leased: false, - released: false, }, callback) } @@ -190,7 +160,6 @@ export class MySQLAdapter implements DriverAdapter { const state: ScopedMySQLTransaction = { client: await this.pool.getConnection(), leased: true, - released: false, } return this.transactionScope.run(state, async () => { @@ -237,31 +206,31 @@ export class MySQLAdapter implements DriverAdapter { async beginTransaction(): Promise { const client = await this.leaseTransactionClient() - await client.query('START TRANSACTION') + await client.query('START TRANSACTION', []) } async commit(): Promise { const client = this.requireTransactionClient() - await client.query('COMMIT') + await client.query('COMMIT', []) this.releaseTransactionClient() } async rollback(): Promise { const client = this.requireTransactionClient() - await client.query('ROLLBACK') + await client.query('ROLLBACK', []) this.releaseTransactionClient() } async createSavepoint(name: string): Promise { - await this.requireTransactionClient().query(`SAVEPOINT ${this.normalizeSavepointName(name)}`) + await this.requireTransactionClient().query(`SAVEPOINT ${this.normalizeSavepointName(name)}`, []) } async rollbackToSavepoint(name: string): Promise { - await this.requireTransactionClient().query(`ROLLBACK TO SAVEPOINT ${this.normalizeSavepointName(name)}`) + await this.requireTransactionClient().query(`ROLLBACK TO SAVEPOINT ${this.normalizeSavepointName(name)}`, []) } async releaseSavepoint(name: string): Promise { - await this.requireTransactionClient().query(`RELEASE SAVEPOINT ${this.normalizeSavepointName(name)}`) + await this.requireTransactionClient().query(`RELEASE SAVEPOINT ${this.normalizeSavepointName(name)}`, []) } private async getQueryable(): Promise { @@ -337,12 +306,11 @@ export class MySQLAdapter implements DriverAdapter { } private releaseScopedTransaction(state: ScopedMySQLTransaction): void { - if (!state.leased || state.released) { + if (!state.leased) { return } state.client.release?.() - state.released = true } private normalizeSavepointName(name: string): string { diff --git a/packages/db-mysql/tests/mysql.test.ts b/packages/db-mysql/tests/mysql.test.ts index 799d0e69..78b904ce 100644 --- a/packages/db-mysql/tests/mysql.test.ts +++ b/packages/db-mysql/tests/mysql.test.ts @@ -1,7 +1,10 @@ import { randomUUID } from 'node:crypto' import { describe, expect, it, vi } from 'vitest' +import type { DriverAdapter } from '@holo-js/db' import { createMySQLAdapter } from '../src' +const runLiveMySql = process.env.HOLO_MYSQL_INTEGRATION === '1' ? it : it.skip + describe('@holo-js/db-mysql', () => { it('supports direct clients without creating a pool', async () => { const query = vi.fn(async (sql: string) => { @@ -20,6 +23,7 @@ describe('@holo-js/db-mysql', () => { end: vi.fn(async () => {}), }, }) + const canonicalAdapter: DriverAdapter = adapter await expect(adapter.query('select 1')).resolves.toEqual({ rows: [{ sql: 'select 1' }], @@ -34,9 +38,10 @@ describe('@holo-js/db-mysql', () => { await adapter.rollback() expect(query).toHaveBeenNthCalledWith(3, 'START TRANSACTION', []) expect(query).toHaveBeenNthCalledWith(4, 'ROLLBACK', []) + void canonicalAdapter }) - it('runs queries against a local MySQL server through the public adapter', async () => { + runLiveMySql('runs queries against a local MySQL server through the public adapter', async () => { const tableName = `holo_real_usage_mysql_${randomUUID().replaceAll('-', '_')}` const adapter = createMySQLAdapter({ config: { diff --git a/packages/db-postgres/package.json b/packages/db-postgres/package.json index aabb3416..7006b215 100644 --- a/packages/db-postgres/package.json +++ b/packages/db-postgres/package.json @@ -20,7 +20,8 @@ "build": "tsup", "stub": "tsup --watch", "typecheck": "tsc -p tsconfig.json --noEmit", - "test": "vitest --run" + "test": "vitest --run", + "test:integration": "HOLO_POSTGRES_INTEGRATION=1 vitest --run tests/postgres.test.ts" }, "peerDependencies": { "@holo-js/db": "catalog:" diff --git a/packages/db-postgres/src/index.ts b/packages/db-postgres/src/index.ts index d14c40e8..cd48fd5e 100644 --- a/packages/db-postgres/src/index.ts +++ b/packages/db-postgres/src/index.ts @@ -1,36 +1,12 @@ import { AsyncLocalStorage } from 'node:async_hooks' import { Pool, type PoolConfig, type QueryResult } from 'pg' +import type { + DriverAdapter, + DriverExecutionResult, + DriverQueryResult, +} from '@holo-js/db' -export interface DriverQueryResult = Record> { - rows: TRow[] - rowCount: number -} - -export interface DriverExecutionResult { - affectedRows?: number - lastInsertId?: number | string -} - -export interface DriverAdapter { - initialize(): Promise - disconnect(): Promise - isConnected(): boolean - runWithTransactionScope?(callback: () => Promise): Promise - query = Record>( - sql: string, - bindings?: readonly unknown[], - ): Promise> - execute( - sql: string, - bindings?: readonly unknown[], - ): Promise - beginTransaction(): Promise - commit(): Promise - rollback(): Promise - createSavepoint?(name: string): Promise - rollbackToSavepoint?(name: string): Promise - releaseSavepoint?(name: string): Promise -} +export type { DriverAdapter, DriverExecutionResult, DriverQueryResult } from '@holo-js/db' class TransactionError extends Error {} diff --git a/packages/db-postgres/tests/postgres.test.ts b/packages/db-postgres/tests/postgres.test.ts index 16a824bf..051e8cec 100644 --- a/packages/db-postgres/tests/postgres.test.ts +++ b/packages/db-postgres/tests/postgres.test.ts @@ -1,7 +1,10 @@ import { randomUUID } from 'node:crypto' import { describe, expect, it, vi } from 'vitest' +import type { DriverAdapter } from '@holo-js/db' import { createPostgresAdapter } from '../src' +const runLivePostgres = process.env.HOLO_POSTGRES_INTEGRATION === '1' ? it : it.skip + describe('@holo-js/db-postgres', () => { it('supports direct clients without creating a pool', async () => { const query = vi.fn(async (sql: string) => { @@ -23,6 +26,7 @@ describe('@holo-js/db-postgres', () => { end: vi.fn(async () => {}), }, }) + const canonicalAdapter: DriverAdapter = adapter await expect(adapter.query('select 1')).resolves.toEqual({ rows: [{ sql: 'select 1' }], @@ -37,9 +41,10 @@ describe('@holo-js/db-postgres', () => { await adapter.commit() expect(query).toHaveBeenNthCalledWith(3, 'BEGIN') expect(query).toHaveBeenNthCalledWith(4, 'COMMIT') + void canonicalAdapter }) - it('runs queries against a local Postgres server through the public adapter', async () => { + runLivePostgres('runs queries against a local Postgres server through the public adapter', async () => { const tableName = `holo_real_usage_postgres_${randomUUID().replaceAll('-', '_')}` const adapter = createPostgresAdapter({ config: { diff --git a/packages/db-sqlite/src/index.ts b/packages/db-sqlite/src/index.ts index 47938c70..d6c6a85d 100644 --- a/packages/db-sqlite/src/index.ts +++ b/packages/db-sqlite/src/index.ts @@ -1,35 +1,11 @@ import Database from 'better-sqlite3' +import type { + DriverAdapter, + DriverExecutionResult, + DriverQueryResult, +} from '@holo-js/db' -export interface DriverQueryResult = Record> { - rows: TRow[] - rowCount: number -} - -export interface DriverExecutionResult { - affectedRows?: number - lastInsertId?: number | string -} - -export interface DriverAdapter { - initialize(): Promise - disconnect(): Promise - isConnected(): boolean - runWithTransactionScope?(callback: () => Promise): Promise - query = Record>( - sql: string, - bindings?: readonly unknown[], - ): Promise> - execute( - sql: string, - bindings?: readonly unknown[], - ): Promise - beginTransaction(): Promise - commit(): Promise - rollback(): Promise - createSavepoint?(name: string): Promise - rollbackToSavepoint?(name: string): Promise - releaseSavepoint?(name: string): Promise -} +export type { DriverAdapter, DriverExecutionResult, DriverQueryResult } from '@holo-js/db' class TransactionError extends Error {} @@ -112,7 +88,7 @@ export class SQLiteAdapter implements DriverAdapter { bindings: readonly unknown[] = [], ): Promise> { const statement = this.getDatabase().prepare(sql) - const rows = this.invokeStatement(statement, 'all', bindings) as TRow[] + const rows = statement.all(...bindings) as TRow[] return { rows, rowCount: rows.length, @@ -131,7 +107,7 @@ export class SQLiteAdapter implements DriverAdapter { bindings: readonly unknown[] = [], ): Promise { const statement = this.getDatabase().prepare(sql) - const result = this.invokeStatement(statement, 'run', bindings) + const result = statement.run(...bindings) return { affectedRows: result.changes, lastInsertId: typeof result.lastInsertRowid === 'bigint' @@ -180,32 +156,6 @@ export class SQLiteAdapter implements DriverAdapter { return name } - - private invokeStatement< - TMethod extends 'all' | 'run', - >( - statement: SQLiteStatementLike, - method: TMethod, - bindings: readonly unknown[], - ): ReturnType { - try { - return statement[method](...bindings) as ReturnType - } catch (error) { - if (bindings.length > 0 && this.isBindingArityError(error)) { - return statement[method](bindings as never) as ReturnType - } - - throw error - } - } - - private isBindingArityError(error: unknown): boolean { - return error instanceof RangeError - && ( - error.message.includes('Too many parameter values were provided') - || error.message.includes('Too few parameter values were provided') - ) - } } export function createSQLiteAdapter(options: SQLiteAdapterOptions = {}): SQLiteAdapter { diff --git a/packages/db/package.json b/packages/db/package.json index e46e3e9e..6c697471 100644 --- a/packages/db/package.json +++ b/packages/db/package.json @@ -47,8 +47,6 @@ "@holo-js/db-mysql": "catalog:", "@holo-js/db-postgres": "catalog:", "@holo-js/db-sqlite": "catalog:", - "@types/better-sqlite3": "catalog:", - "@types/pg": "catalog:", "tsup": "catalog:", "typescript": "catalog:" } diff --git a/packages/db/src/cache.ts b/packages/db/src/cache.ts index bc323919..9f7bb352 100644 --- a/packages/db/src/cache.ts +++ b/packages/db/src/cache.ts @@ -20,13 +20,8 @@ export interface QueryCacheConfig { readonly invalidate?: readonly string[] } -export interface NormalizedQueryCacheConfig { - readonly ttl?: QueryCacheTtlInput - readonly key?: string - readonly driver?: string - readonly flexible?: QueryCacheFlexibleTtlInput - readonly invalidate?: readonly string[] -} +// eslint-disable-next-line @typescript-eslint/no-empty-object-type -- preserve the exported interface shape while deriving from QueryCacheConfig. +export interface NormalizedQueryCacheConfig extends QueryCacheConfig {} export interface DatabaseQueryCacheBridge { get(key: string, options?: { readonly driver?: string }): Promise diff --git a/packages/db/src/core/QueryScheduler.ts b/packages/db/src/core/QueryScheduler.ts index c509afbb..b5241b9b 100644 --- a/packages/db/src/core/QueryScheduler.ts +++ b/packages/db/src/core/QueryScheduler.ts @@ -14,6 +14,7 @@ type QueueState = { active: number queued: number readonly limit: number + readonly waiters: Array<() => void> } export class QueryScheduler { @@ -48,16 +49,19 @@ export class QueryScheduler { active: 0, queued: 0, limit: concurrencyLimit, + waiters: [], } this.serializedState = { active: 0, queued: 0, limit: 1, + waiters: [], } this.workerState = { active: 0, queued: 0, limit: concurrencyLimit, + waiters: [], } } @@ -71,6 +75,8 @@ export class QueryScheduler { const schedulingMode = this.preview(options) const state = this.resolveState(schedulingMode) + let slotReserved = false + if (state.active >= state.limit) { if (state.queued >= this.queueLimit) { throw new DatabaseError( @@ -79,24 +85,13 @@ export class QueryScheduler { ) } - await new Promise((resolve) => { - state.queued += 1 - - const poll = () => { - if (state.active < state.limit) { - state.queued -= 1 - resolve() - return - } - - queueMicrotask(poll) - } - - queueMicrotask(poll) - }) + await this.waitForSlot(state) + slotReserved = true } - state.active += 1 + if (!slotReserved) { + state.active += 1 + } try { return { @@ -105,6 +100,7 @@ export class QueryScheduler { } } finally { state.active -= 1 + this.wakeNext(state) } } @@ -145,6 +141,26 @@ export class QueryScheduler { return this.serializedState } + + private waitForSlot(state: QueueState): Promise { + state.queued += 1 + + return new Promise((resolve) => { + state.waiters.push(() => { + state.queued -= 1 + state.active += 1 + resolve() + }) + }) + } + + private wakeNext(state: QueueState): void { + if (state.active >= state.limit) { + return + } + + state.waiters.shift()?.() + } } export function createQueryScheduler(options: QuerySchedulerOptions): QueryScheduler { diff --git a/packages/db/src/drivers/MySQLAdapter.ts b/packages/db/src/drivers/MySQLAdapter.ts deleted file mode 100644 index a7124730..00000000 --- a/packages/db/src/drivers/MySQLAdapter.ts +++ /dev/null @@ -1,328 +0,0 @@ -import { AsyncLocalStorage } from 'node:async_hooks' -import mysql, { - type Pool, - type PoolConnection, - type PoolOptions, - type ResultSetHeader, - type RowDataPacket, -} from 'mysql2/promise' -import { TransactionError } from '../core/errors' -import type { DriverAdapter, DriverExecutionResult, DriverQueryResult } from '../core/types' - -export interface MySQLQueryableLike { - query(sql: string, bindings?: readonly unknown[]): Promise -} - -export interface MySQLClientLike extends MySQLQueryableLike { - release?(): void - end?(): Promise -} - -export interface MySQLPoolLike extends MySQLQueryableLike { - getConnection(): Promise - end(): Promise -} - -export interface MySQLAdapterOptions { - uri?: string - config?: PoolOptions - client?: MySQLClientLike - pool?: MySQLPoolLike - createPool?: (config: PoolOptions) => MySQLPoolLike -} - -type ScopedMySQLTransaction = { - client: MySQLClientLike - leased: boolean - released: boolean -} - -type RawMySQLClientLike = { - query(sql: string, bindings?: unknown[]): Promise - release?(): void - end?(): Promise -} - -type RawMySQLPoolLike = { - query(sql: string, bindings?: unknown[]): Promise - getConnection(): Promise - end(): Promise -} - -function toMutableBindings(bindings: readonly unknown[] = []): unknown[] { - return [...bindings] -} - -function wrapMySQLClient(client: PoolConnection | MySQLClientLike): MySQLClientLike { - const rawClient = client as unknown as RawMySQLClientLike - - return { - async query(sql: string, bindings: readonly unknown[] = []) { - return rawClient.query(sql, toMutableBindings(bindings)) - }, - release: rawClient.release?.bind(rawClient), - end: rawClient.end?.bind(rawClient), - } -} - -function wrapMySQLPool(pool: Pool | MySQLPoolLike): MySQLPoolLike { - const rawPool = pool as unknown as RawMySQLPoolLike - - return { - async query(sql: string, bindings: readonly unknown[] = []) { - return rawPool.query(sql, toMutableBindings(bindings)) - }, - async getConnection() { - return wrapMySQLClient(await rawPool.getConnection()) - }, - end: rawPool.end.bind(rawPool), - } -} - -export class MySQLAdapter implements DriverAdapter { - private pool?: MySQLPoolLike - private readonly directClient?: MySQLClientLike - private readonly createPoolInstance?: (config: PoolOptions) => MySQLPoolLike - private readonly config: PoolOptions - private connected: boolean - private transactionClient?: MySQLClientLike - private leasedTransactionClient = false - private readonly transactionScope = new AsyncLocalStorage() - - constructor(options: MySQLAdapterOptions = {}) { - this.directClient = options.client ? wrapMySQLClient(options.client) : undefined - this.pool = options.pool ? wrapMySQLPool(options.pool) : undefined - this.createPoolInstance = options.createPool ?? (options.client || options.pool - ? undefined - : config => wrapMySQLPool(mysql.createPool(config))) - this.config = options.config ?? (options.uri ? { uri: options.uri } as PoolOptions : {}) - this.connected = !!(options.client || options.pool) - } - - async initialize(): Promise { - if (this.connected) { - return - } - - if (this.createPoolInstance) { - this.pool = this.createPoolInstance(this.config) - } - - this.connected = true - } - - async disconnect(): Promise { - if (!this.connected) { - return - } - - if (this.transactionClient && this.leasedTransactionClient) { - this.transactionClient.release?.() - this.transactionClient = undefined - this.leasedTransactionClient = false - } - - if (this.pool) { - await this.pool.end() - this.pool = undefined - } else if (this.directClient?.end) { - await this.directClient.end() - } - - this.connected = false - } - - isConnected(): boolean { - return this.connected - } - - async runWithTransactionScope(callback: () => Promise): Promise { - const active = this.transactionScope.getStore() - if (active) { - return callback() - } - - await this.initialize() - - if (this.directClient) { - return this.transactionScope.run({ - client: this.directClient, - leased: false, - released: false, - }, callback) - } - - if (!this.pool) { - throw new TransactionError('MySQL adapter is not initialized with a pool or client.') - } - - const state: ScopedMySQLTransaction = { - client: await this.pool.getConnection(), - leased: true, - released: false, - } - - return this.transactionScope.run(state, async () => { - try { - return await callback() - } finally { - this.releaseScopedTransaction(state) - } - }) - } - - async query = Record>( - sql: string, - bindings: readonly unknown[] = [], - ): Promise> { - const queryable = await this.getQueryable() - const [rows] = await queryable.query(sql, bindings) - const normalized = rows as RowDataPacket[] & TRow[] - return { - rows: Array.isArray(normalized) ? [...normalized] : [], - rowCount: Array.isArray(normalized) ? normalized.length : 0, - } - } - - async introspect = Record>( - sql: string, - bindings: readonly unknown[] = [], - ): Promise> { - return this.query(sql, bindings) - } - - async execute( - sql: string, - bindings: readonly unknown[] = [], - ): Promise { - const queryable = await this.getQueryable() - const [result] = await queryable.query(sql, bindings) - const execution = result as ResultSetHeader - return { - affectedRows: typeof execution.affectedRows === 'number' ? execution.affectedRows : 0, - lastInsertId: execution.insertId, - } - } - - async beginTransaction(): Promise { - const client = await this.leaseTransactionClient() - await client.query('START TRANSACTION') - } - - async commit(): Promise { - const client = this.requireTransactionClient() - await client.query('COMMIT') - this.releaseTransactionClient() - } - - async rollback(): Promise { - const client = this.requireTransactionClient() - await client.query('ROLLBACK') - this.releaseTransactionClient() - } - - async createSavepoint(name: string): Promise { - await this.requireTransactionClient().query(`SAVEPOINT ${this.normalizeSavepointName(name)}`) - } - - async rollbackToSavepoint(name: string): Promise { - await this.requireTransactionClient().query(`ROLLBACK TO SAVEPOINT ${this.normalizeSavepointName(name)}`) - } - - async releaseSavepoint(name: string): Promise { - await this.requireTransactionClient().query(`RELEASE SAVEPOINT ${this.normalizeSavepointName(name)}`) - } - - private async getQueryable(): Promise { - const scoped = this.transactionScope.getStore() - if (scoped) { - return scoped.client - } - - if (this.transactionClient) { - return this.transactionClient - } - - await this.initialize() - - if (this.directClient) { - return this.directClient - } - - if (!this.pool) { - throw new TransactionError('MySQL adapter is not initialized with a pool or client.') - } - - return this.pool - } - - private async leaseTransactionClient(): Promise { - const scoped = this.transactionScope.getStore() - if (scoped) { - return scoped.client - } - - if (this.transactionClient) { - return this.transactionClient - } - - await this.initialize() - - if (this.directClient) { - this.transactionClient = this.directClient - this.leasedTransactionClient = false - return this.transactionClient - } - - if (!this.pool) { - throw new TransactionError('MySQL adapter is not initialized with a pool or client.') - } - - this.transactionClient = await this.pool.getConnection() - this.leasedTransactionClient = true - return this.transactionClient - } - - private requireTransactionClient(): MySQLClientLike { - const scoped = this.transactionScope.getStore() - if (scoped) { - return scoped.client - } - - if (!this.transactionClient) { - throw new TransactionError('No active MySQL transaction client is available.') - } - - return this.transactionClient - } - - private releaseTransactionClient(): void { - if (this.transactionClient && this.leasedTransactionClient) { - this.transactionClient.release?.() - } - - this.transactionClient = undefined - this.leasedTransactionClient = false - } - - private releaseScopedTransaction(state: ScopedMySQLTransaction): void { - if (!state.leased || state.released) { - return - } - - state.client.release?.() - state.released = true - } - - private normalizeSavepointName(name: string): string { - if (!/^[A-Z_]\w*$/i.test(name)) { - throw new TransactionError(`Invalid savepoint name "${name}".`) - } - - return name - } -} - -export function createMySQLAdapter(options: MySQLAdapterOptions = {}): MySQLAdapter { - return new MySQLAdapter(options) -} diff --git a/packages/db/src/drivers/PostgresAdapter.ts b/packages/db/src/drivers/PostgresAdapter.ts deleted file mode 100644 index 9b50d0c4..00000000 --- a/packages/db/src/drivers/PostgresAdapter.ts +++ /dev/null @@ -1,283 +0,0 @@ -import { AsyncLocalStorage } from 'node:async_hooks' -import { Pool, type PoolConfig, type QueryResult } from 'pg' -import { TransactionError } from '../core/errors' -import type { DriverAdapter, DriverExecutionResult, DriverQueryResult } from '../core/types' - -export interface PostgresQueryableLike { - query(sql: string, bindings?: readonly unknown[]): Promise> | { - rows: Record[] - rowCount?: number | null - }> -} - -export interface PostgresClientLike extends PostgresQueryableLike { - release?(): void - end?(): Promise -} - -export interface PostgresPoolLike extends PostgresQueryableLike { - connect(): Promise - end(): Promise -} - -export interface PostgresAdapterOptions { - connectionString?: string - config?: PoolConfig - client?: PostgresClientLike - pool?: PostgresPoolLike - createPool?: (config?: PoolConfig) => PostgresPoolLike -} - -type ScopedPostgresTransaction = { - client: PostgresClientLike - leased: boolean - released: boolean -} - -export class PostgresAdapter implements DriverAdapter { - private pool?: PostgresPoolLike - private readonly directClient?: PostgresClientLike - private readonly createPoolInstance?: (config?: PoolConfig) => PostgresPoolLike - private readonly config?: PoolConfig - private connected: boolean - private transactionClient?: PostgresClientLike - private leasedTransactionClient = false - private readonly transactionScope = new AsyncLocalStorage() - - constructor(options: PostgresAdapterOptions = {}) { - this.directClient = options.client - this.pool = options.pool - this.createPoolInstance = options.createPool ?? (options.client || options.pool - ? undefined - : config => new Pool(config)) - this.config = options.config ?? (options.connectionString ? { connectionString: options.connectionString } : undefined) - this.connected = !!(options.client || options.pool) - } - - async initialize(): Promise { - if (this.connected) { - return - } - - if (this.createPoolInstance) { - this.pool = this.createPoolInstance(this.config) - } - - this.connected = true - } - - async disconnect(): Promise { - if (!this.connected) { - return - } - - if (this.transactionClient && this.leasedTransactionClient) { - this.transactionClient.release?.() - this.transactionClient = undefined - this.leasedTransactionClient = false - } - - if (this.pool) { - await this.pool.end() - this.pool = undefined - } else if (this.directClient?.end) { - await this.directClient.end() - } - - this.connected = false - } - - isConnected(): boolean { - return this.connected - } - - async runWithTransactionScope(callback: () => Promise): Promise { - const active = this.transactionScope.getStore() - if (active) { - return callback() - } - - await this.initialize() - - if (this.directClient) { - return this.transactionScope.run({ - client: this.directClient, - leased: false, - released: false, - }, callback) - } - - if (!this.pool) { - throw new TransactionError('Postgres adapter is not initialized with a pool or client.') - } - - const state: ScopedPostgresTransaction = { - client: await this.pool.connect(), - leased: true, - released: false, - } - - return this.transactionScope.run(state, async () => { - try { - return await callback() - } finally { - this.releaseScopedTransaction(state) - } - }) - } - - async query = Record>( - sql: string, - bindings: readonly unknown[] = [], - ): Promise> { - const client = await this.getQueryable() - const result = await client.query(sql, bindings) - return { - rows: result.rows as TRow[], - rowCount: result.rowCount ?? result.rows.length, - } - } - - async introspect = Record>( - sql: string, - bindings: readonly unknown[] = [], - ): Promise> { - return this.query(sql, bindings) - } - - async execute( - sql: string, - bindings: readonly unknown[] = [], - ): Promise { - const client = await this.getQueryable() - const result = await client.query(sql, bindings) - const firstRow = result.rows[0] - const firstValue = firstRow ? Object.values(firstRow)[0] : undefined - return { - affectedRows: result.rowCount ?? 0, - ...(typeof firstValue !== 'undefined' ? { lastInsertId: firstValue as number | string } : {}), - } - } - - async beginTransaction(): Promise { - const client = await this.leaseTransactionClient() - await client.query('BEGIN') - } - - async commit(): Promise { - const client = this.requireTransactionClient() - await client.query('COMMIT') - this.releaseTransactionClient() - } - - async rollback(): Promise { - const client = this.requireTransactionClient() - await client.query('ROLLBACK') - this.releaseTransactionClient() - } - - async createSavepoint(name: string): Promise { - await this.requireTransactionClient().query(`SAVEPOINT ${this.normalizeSavepointName(name)}`) - } - - async rollbackToSavepoint(name: string): Promise { - await this.requireTransactionClient().query(`ROLLBACK TO SAVEPOINT ${this.normalizeSavepointName(name)}`) - } - - async releaseSavepoint(name: string): Promise { - await this.requireTransactionClient().query(`RELEASE SAVEPOINT ${this.normalizeSavepointName(name)}`) - } - - private async getQueryable(): Promise { - const scoped = this.transactionScope.getStore() - if (scoped) { - return scoped.client - } - - if (this.transactionClient) { - return this.transactionClient - } - - await this.initialize() - - if (this.directClient) { - return this.directClient - } - - if (!this.pool) { - throw new TransactionError('Postgres adapter is not initialized with a pool or client.') - } - - return this.pool - } - - private async leaseTransactionClient(): Promise { - const scoped = this.transactionScope.getStore() - if (scoped) { - return scoped.client - } - - if (this.transactionClient) { - return this.transactionClient - } - - await this.initialize() - - if (this.directClient) { - this.transactionClient = this.directClient - this.leasedTransactionClient = false - return this.transactionClient - } - - if (!this.pool) { - throw new TransactionError('Postgres adapter is not initialized with a pool or client.') - } - - this.transactionClient = await this.pool.connect() - this.leasedTransactionClient = true - return this.transactionClient - } - - private requireTransactionClient(): PostgresClientLike { - const scoped = this.transactionScope.getStore() - if (scoped) { - return scoped.client - } - - if (!this.transactionClient) { - throw new TransactionError('No active Postgres transaction client is available.') - } - - return this.transactionClient - } - - private releaseTransactionClient(): void { - if (this.transactionClient && this.leasedTransactionClient) { - this.transactionClient.release?.() - } - - this.transactionClient = undefined - this.leasedTransactionClient = false - } - - private releaseScopedTransaction(state: ScopedPostgresTransaction): void { - if (!state.leased || state.released) { - return - } - - state.client.release?.() - state.released = true - } - - private normalizeSavepointName(name: string): string { - if (!/^[A-Z_]\w*$/i.test(name)) { - throw new TransactionError(`Invalid savepoint name "${name}".`) - } - - return name - } -} - -export function createPostgresAdapter(options: PostgresAdapterOptions = {}): PostgresAdapter { - return new PostgresAdapter(options) -} diff --git a/packages/db/src/drivers/SQLiteAdapter.ts b/packages/db/src/drivers/SQLiteAdapter.ts deleted file mode 100644 index eddffe20..00000000 --- a/packages/db/src/drivers/SQLiteAdapter.ts +++ /dev/null @@ -1,182 +0,0 @@ -import Database from 'better-sqlite3' -import { TransactionError } from '../core/errors' -import type { DriverAdapter, DriverExecutionResult, DriverQueryResult } from '../core/types' - -export interface SQLiteStatementLike { - all(...params: readonly unknown[]): Record[] - run(...params: readonly unknown[]): { changes?: number, lastInsertRowid?: unknown } -} - -export interface SQLiteDatabaseLike { - prepare(sql: string): SQLiteStatementLike - exec(sql: string): unknown - close(): unknown -} - -export interface SQLiteAdapterOptions { - filename?: string - database?: SQLiteDatabaseLike - createDatabase?: (filename: string) => SQLiteDatabaseLike -} - -export class SQLiteAdapter implements DriverAdapter { - private database?: SQLiteDatabaseLike - private connected: boolean - private transactionTail: Promise = Promise.resolve() - private readonly filename: string - private readonly createDatabaseInstance: (filename: string) => SQLiteDatabaseLike - - constructor(options: SQLiteAdapterOptions = {}) { - this.database = options.database - this.connected = !!options.database - this.filename = options.filename ?? ':memory:' - this.createDatabaseInstance = options.createDatabase ?? (filename => new Database(filename)) - } - - async initialize(): Promise { - if (this.connected) { - return - } - - this.database = this.createDatabaseInstance(this.filename) - this.connected = true - } - - async disconnect(): Promise { - if (!this.connected || !this.database) { - return - } - - this.database.close() - this.database = undefined - this.connected = false - } - - isConnected(): boolean { - return this.connected - } - - async runWithTransactionScope(callback: () => Promise): Promise { - const previous = this.transactionTail - let release!: () => void - const current = previous.then(() => new Promise((resolve) => { - release = resolve - })) - this.transactionTail = current - - await previous - - try { - return await callback() - } finally { - release() - if (this.transactionTail === current) { - this.transactionTail = Promise.resolve() - } - } - } - - async query = Record>( - sql: string, - bindings: readonly unknown[] = [], - ): Promise> { - const statement = this.getDatabase().prepare(sql) - const rows = this.invokeStatement(statement, 'all', bindings) as TRow[] - return { - rows, - rowCount: rows.length, - } - } - - async introspect = Record>( - sql: string, - bindings: readonly unknown[] = [], - ): Promise> { - return this.query(sql, bindings) - } - - async execute( - sql: string, - bindings: readonly unknown[] = [], - ): Promise { - const statement = this.getDatabase().prepare(sql) - const result = this.invokeStatement(statement, 'run', bindings) - return { - affectedRows: result.changes, - lastInsertId: typeof result.lastInsertRowid === 'bigint' - ? Number(result.lastInsertRowid) - : result.lastInsertRowid as number | string | undefined, - } - } - - async beginTransaction(): Promise { - this.getDatabase().exec('BEGIN') - } - - async commit(): Promise { - this.getDatabase().exec('COMMIT') - } - - async rollback(): Promise { - this.getDatabase().exec('ROLLBACK') - } - - async createSavepoint(name: string): Promise { - this.getDatabase().exec(`SAVEPOINT ${this.normalizeSavepointName(name)}`) - } - - async rollbackToSavepoint(name: string): Promise { - this.getDatabase().exec(`ROLLBACK TO SAVEPOINT ${this.normalizeSavepointName(name)}`) - } - - async releaseSavepoint(name: string): Promise { - this.getDatabase().exec(`RELEASE SAVEPOINT ${this.normalizeSavepointName(name)}`) - } - - private getDatabase(): SQLiteDatabaseLike { - if (!this.connected || !this.database) { - this.database = this.createDatabaseInstance(this.filename) - this.connected = true - } - - return this.database - } - - private normalizeSavepointName(name: string): string { - if (!/^[A-Z_]\w*$/i.test(name)) { - throw new TransactionError(`Invalid savepoint name "${name}".`) - } - - return name - } - - private invokeStatement< - TMethod extends 'all' | 'run', - >( - statement: SQLiteStatementLike, - method: TMethod, - bindings: readonly unknown[], - ): ReturnType { - try { - return statement[method](...bindings) as ReturnType - } catch (error) { - if (bindings.length > 0 && this.isBindingArityError(error)) { - return statement[method](bindings as never) as ReturnType - } - - throw error - } - } - - private isBindingArityError(error: unknown): boolean { - return error instanceof RangeError - && ( - error.message.includes('Too many parameter values were provided') - || error.message.includes('Too few parameter values were provided') - ) - } -} - -export function createSQLiteAdapter(options: SQLiteAdapterOptions = {}): SQLiteAdapter { - return new SQLiteAdapter(options) -} diff --git a/packages/db/src/drivers/index.ts b/packages/db/src/drivers/index.ts index 3bf81ec3..0aa6232d 100644 --- a/packages/db/src/drivers/index.ts +++ b/packages/db/src/drivers/index.ts @@ -1,5 +1,5 @@ -import { TransactionError } from '../core/errors' import type { DriverAdapter, DriverExecutionResult, DriverQueryResult, DatabaseOperationOptions } from '../core/types' +import { normalizeSavepointName } from './savepoints' function isModuleNotFoundError(error: unknown): boolean { return !!error @@ -193,27 +193,15 @@ export class SQLiteAdapter extends LazyDriverAdapter { } override async createSavepoint(name: string, options?: DatabaseOperationOptions): Promise { - if (!/^[A-Z_]\w*$/i.test(name)) { - throw new TransactionError(`Invalid savepoint name "${name}".`) - } - - await super.createSavepoint(name, options) + await super.createSavepoint(normalizeSavepointName(name), options) } override async rollbackToSavepoint(name: string, options?: DatabaseOperationOptions): Promise { - if (!/^[A-Z_]\w*$/i.test(name)) { - throw new TransactionError(`Invalid savepoint name "${name}".`) - } - - await super.rollbackToSavepoint(name, options) + await super.rollbackToSavepoint(normalizeSavepointName(name), options) } override async releaseSavepoint(name: string, options?: DatabaseOperationOptions): Promise { - if (!/^[A-Z_]\w*$/i.test(name)) { - throw new TransactionError(`Invalid savepoint name "${name}".`) - } - - await super.releaseSavepoint(name, options) + await super.releaseSavepoint(normalizeSavepointName(name), options) } } @@ -282,40 +270,16 @@ export class PostgresAdapter { - if (!/^[A-Z_]\w*$/i.test(name)) { - throw new TransactionError(`Invalid savepoint name "${name}".`) - } - - await super.createSavepoint(name, options) + await super.createSavepoint(normalizeSavepointName(name), options) } override async rollbackToSavepoint(name: string, options?: DatabaseOperationOptions): Promise { - if (!/^[A-Z_]\w*$/i.test(name)) { - throw new TransactionError(`Invalid savepoint name "${name}".`) - } - - await super.rollbackToSavepoint(name, options) + await super.rollbackToSavepoint(normalizeSavepointName(name), options) } override async releaseSavepoint(name: string, options?: DatabaseOperationOptions): Promise { - if (!/^[A-Z_]\w*$/i.test(name)) { - throw new TransactionError(`Invalid savepoint name "${name}".`) - } - - await super.releaseSavepoint(name, options) + await super.releaseSavepoint(normalizeSavepointName(name), options) } } @@ -383,40 +347,16 @@ export class MySQLAdapter { - if (!/^[A-Z_]\w*$/i.test(name)) { - throw new TransactionError(`Invalid savepoint name "${name}".`) - } - - await super.createSavepoint(name, options) + await super.createSavepoint(normalizeSavepointName(name), options) } override async rollbackToSavepoint(name: string, options?: DatabaseOperationOptions): Promise { - if (!/^[A-Z_]\w*$/i.test(name)) { - throw new TransactionError(`Invalid savepoint name "${name}".`) - } - - await super.rollbackToSavepoint(name, options) + await super.rollbackToSavepoint(normalizeSavepointName(name), options) } override async releaseSavepoint(name: string, options?: DatabaseOperationOptions): Promise { - if (!/^[A-Z_]\w*$/i.test(name)) { - throw new TransactionError(`Invalid savepoint name "${name}".`) - } - - await super.releaseSavepoint(name, options) + await super.releaseSavepoint(normalizeSavepointName(name), options) } } diff --git a/packages/db/src/drivers/savepoints.ts b/packages/db/src/drivers/savepoints.ts new file mode 100644 index 00000000..3b9af2f9 --- /dev/null +++ b/packages/db/src/drivers/savepoints.ts @@ -0,0 +1,11 @@ +import { TransactionError } from '../core/errors' + +const SAVEPOINT_NAME_PATTERN = /^[A-Z_]\w*$/i + +export function normalizeSavepointName(name: string): string { + if (!SAVEPOINT_NAME_PATTERN.test(name)) { + throw new TransactionError(`Invalid savepoint name "${name}".`) + } + + return name +} diff --git a/packages/db/src/migrations/MigrationService.ts b/packages/db/src/migrations/MigrationService.ts index f69d7c51..e990561a 100644 --- a/packages/db/src/migrations/MigrationService.ts +++ b/packages/db/src/migrations/MigrationService.ts @@ -133,11 +133,7 @@ export class MigrationService { const executed: RegisteredMigrationDefinition[] = [] for (const migration of pending) { const log = this.createMigrationLog(migration.name, 'up', nextBatch) - const startedAt = Date.now() - - await this.connection.getLogger()?.onMigrationStart?.(log) - - try { + await this.runMigrationLifecycle(log, async () => { await this.connection.transaction(async (tx) => { const context = this.createContext(tx) await migration.up(context) @@ -147,18 +143,7 @@ export class MigrationService { migrated_at: new Date().toISOString(), }) }) - await this.connection.getLogger()?.onMigrationSuccess?.({ - ...log, - durationMs: Date.now() - startedAt, - }) - } catch (error) { - await this.connection.getLogger()?.onMigrationError?.({ - ...log, - durationMs: Date.now() - startedAt, - error, - }) - throw error - } + }) executed.push(migration) } @@ -189,11 +174,7 @@ export class MigrationService { } const log = this.createMigrationLog(record.name, 'down', record.batch) - const startedAt = Date.now() - - await this.connection.getLogger()?.onMigrationStart?.(log) - - try { + await this.runMigrationLifecycle(log, async () => { await this.connection.transaction(async (tx) => { const context = this.createContext(tx) if (migration.down) { @@ -204,18 +185,7 @@ export class MigrationService { .where('name', record.name) .delete() }) - await this.connection.getLogger()?.onMigrationSuccess?.({ - ...log, - durationMs: Date.now() - startedAt, - }) - } catch (error) { - await this.connection.getLogger()?.onMigrationError?.({ - ...log, - durationMs: Date.now() - startedAt, - error, - }) - throw error - } + }) rolledBack.push(migration) } @@ -267,6 +237,31 @@ export class MigrationService { } } + private async runMigrationLifecycle( + log: MigrationStartLog, + callback: () => Promise, + ): Promise { + const startedAt = Date.now() + + await this.connection.getLogger()?.onMigrationStart?.(log) + + try { + const result = await callback() + await this.connection.getLogger()?.onMigrationSuccess?.({ + ...log, + durationMs: Date.now() - startedAt, + }) + return result + } catch (error) { + await this.connection.getLogger()?.onMigrationError?.({ + ...log, + durationMs: Date.now() - startedAt, + error, + }) + throw error + } + } + private normalizeMigratedAt(value: string | Date): Date { const normalized = value instanceof Date ? value : new Date(value) if (Number.isNaN(normalized.getTime())) { diff --git a/packages/db/src/migrations/defineMigration.ts b/packages/db/src/migrations/defineMigration.ts index dae2ce5f..bfe04811 100644 --- a/packages/db/src/migrations/defineMigration.ts +++ b/packages/db/src/migrations/defineMigration.ts @@ -4,10 +4,8 @@ import type { MigrationDefinition } from './types' const MIGRATION_NAME_PATTERN = /^\d{4}_\d{2}_\d{2}_\d{6}_[a-z0-9_]+$/ export function defineMigration(definition: TDefinition): TDefinition { - if (definition.name && !MIGRATION_NAME_PATTERN.test(definition.name)) { - throw new ConfigurationError( - `Migration name "${definition.name}" must match YYYY_MM_DD_HHMMSS_description.`, - ) + if (definition.name) { + assertMigrationName(definition.name) } return Object.freeze(definition) diff --git a/packages/db/src/migrations/template.ts b/packages/db/src/migrations/template.ts index 3766d426..de726266 100644 --- a/packages/db/src/migrations/template.ts +++ b/packages/db/src/migrations/template.ts @@ -103,12 +103,17 @@ function renderMigrationTemplate(options: { kind: MigrationTemplateKind tableName?: string }): string { + return wrapMigrationTemplate(renderMigrationBody(options)) +} + +function renderMigrationBody(options: { + migrationName: string + kind: MigrationTemplateKind + tableName?: string +}): string[] { switch (options.kind) { case 'create_table': return [ - 'import { defineMigration, type MigrationContext } from \'@holo-js/db\'', - '', - 'export default defineMigration({', ' async up({ schema }: MigrationContext) {', ` await schema.createTable('${options.tableName}', (table) => {`, ' table.id()', @@ -118,14 +123,9 @@ function renderMigrationTemplate(options: { ' async down({ schema }: MigrationContext) {', ` await schema.dropTable('${options.tableName}')`, ' },', - '})', - '', - ].join('\n') + ] case 'alter_table': return [ - 'import { defineMigration, type MigrationContext } from \'@holo-js/db\'', - '', - 'export default defineMigration({', ' async up({ schema }: MigrationContext) {', ` await schema.table('${options.tableName}', (table) => {`, ' void table', @@ -136,28 +136,18 @@ function renderMigrationTemplate(options: { ' void table', ' })', ' },', - '})', - '', - ].join('\n') + ] case 'drop_table': return [ - 'import { defineMigration, type MigrationContext } from \'@holo-js/db\'', - '', - 'export default defineMigration({', ' async up({ schema }: MigrationContext) {', ` await schema.dropTable('${options.tableName}')`, ' },', ' async down() {', ` throw new Error('Recreate "${options.tableName}" manually in this migration if rollback support is required.')`, ' },', - '})', - '', - ].join('\n') + ] case 'blank': return [ - 'import { defineMigration, type MigrationContext } from \'@holo-js/db\'', - '', - 'export default defineMigration({', ' async up({ schema, db }: MigrationContext) {', ' void schema', ' void db', @@ -166,8 +156,17 @@ function renderMigrationTemplate(options: { ' void schema', ' void db', ' },', - '})', - '', - ].join('\n') + ] } } + +function wrapMigrationTemplate(body: readonly string[]): string { + return [ + 'import { defineMigration, type MigrationContext } from \'@holo-js/db\'', + '', + 'export default defineMigration({', + ...body, + '})', + '', + ].join('\n') +} diff --git a/packages/db/src/model/Entity.ts b/packages/db/src/model/Entity.ts index 9f0d122c..f869d521 100644 --- a/packages/db/src/model/Entity.ts +++ b/packages/db/src/model/Entity.ts @@ -430,21 +430,96 @@ class EntityBase< return this.toAttributes() } - async save(): Promise { + private applyPersistedEntity(persisted: EntityBase, changes: Record): void { + this.attributes = { ...persisted.toAttributes() } + this.original = { ...persisted.toAttributes() } + this.changes = { ...changes } + this.persisted = true + } + + private async persistEntity( + methodName: 'saveEntity' | 'saveEntityQuietly', + errorMessage: string, + ): Promise { const repo = this.getRepositoryRuntime() - if (typeof repo.saveEntity !== 'function') { - throw new HydrationError('The bound repository cannot persist entities.') + const method = repo[methodName] + if (typeof method !== 'function') { + throw new HydrationError(errorMessage) } const pendingChanges = this.persisted ? this.getDirty() : this.toAttributes() - const persisted = await repo.saveEntity(this) - this.attributes = { ...persisted.toAttributes() } - this.original = { ...persisted.toAttributes() } - this.changes = { ...pendingChanges } - this.persisted = true + const persisted = await method.call(repo, this) + this.applyPersistedEntity(persisted, pendingChanges) + return this + } + + private async deletePersistedEntity( + methodName: 'deleteEntity' | 'deleteEntityQuietly', + errorMessage: string, + ): Promise { + if (!this.persisted) { + throw new HydrationError('Cannot delete an entity that has not been persisted yet.') + } + + const repo = this.getRepositoryRuntime() + const method = repo[methodName] + if (typeof method !== 'function') { + throw new HydrationError(errorMessage) + } + + await method.call(repo, this) + this.persisted = typeof repo.shouldKeepEntityPersistedOnDelete === 'function' + ? repo.shouldKeepEntityPersistedOnDelete(this) + : false + } + + private async restoreEntity( + methodName: 'restoreEntity' | 'restoreEntityQuietly', + errorMessage: string, + ): Promise { + const repo = this.getRepositoryRuntime() + const method = repo[methodName] + if (typeof method !== 'function') { + throw new HydrationError(errorMessage) + } + + const restored = await method.call(repo, this) + this.applyPersistedEntity(restored, restored.getChanges()) + return this + } + + private async forceDeletePersistedEntity( + methodName: 'forceDeleteEntity' | 'forceDeleteEntityQuietly', + errorMessage: string, + ): Promise { + if (!this.persisted) { + throw new HydrationError('Cannot force-delete an entity that has not been persisted yet.') + } + + const repo = this.getRepositoryRuntime() + const method = repo[methodName] + if (typeof method !== 'function') { + throw new HydrationError(errorMessage) + } + + await method.call(repo, this) + this.persisted = false + } + + private async loadAggregateDefinitions(aggregates: readonly ModelAggregateLoad[]): Promise { + const repo = this.getRepositoryRuntime() + if (typeof repo.loadRelationAggregates !== 'function') { + throw new HydrationError('The bound repository cannot load relation aggregates.') + } + + await repo.loadRelationAggregates([this], aggregates) return this } + async save(): Promise { + return this.persistEntity('saveEntity', 'The bound repository cannot persist entities.') + } + async push(): Promise { const repo = this.getRepositoryRuntime() if (typeof repo.getRelationDefinition !== 'function') { @@ -492,106 +567,31 @@ class EntityBase< } async saveQuietly(): Promise { - const repo = this.getRepositoryRuntime() - if (typeof repo.saveEntityQuietly !== 'function') { - throw new HydrationError('The bound repository cannot persist entities quietly.') - } - - const pendingChanges = this.persisted ? this.getDirty() : this.toAttributes() - const persisted = await repo.saveEntityQuietly(this) - this.attributes = { ...persisted.toAttributes() } - this.original = { ...persisted.toAttributes() } - this.changes = { ...pendingChanges } - this.persisted = true - return this + return this.persistEntity('saveEntityQuietly', 'The bound repository cannot persist entities quietly.') } async delete(): Promise { - if (!this.persisted) { - throw new HydrationError('Cannot delete an entity that has not been persisted yet.') - } - - const repo = this.getRepositoryRuntime() - if (typeof repo.deleteEntity !== 'function') { - throw new HydrationError('The bound repository cannot delete entities.') - } - - await repo.deleteEntity(this) - this.persisted = typeof repo.shouldKeepEntityPersistedOnDelete === 'function' - ? repo.shouldKeepEntityPersistedOnDelete(this) - : false + await this.deletePersistedEntity('deleteEntity', 'The bound repository cannot delete entities.') } async deleteQuietly(): Promise { - if (!this.persisted) { - throw new HydrationError('Cannot delete an entity that has not been persisted yet.') - } - - const repo = this.getRepositoryRuntime() - if (typeof repo.deleteEntityQuietly !== 'function') { - throw new HydrationError('The bound repository cannot delete entities quietly.') - } - - await repo.deleteEntityQuietly(this) - this.persisted = typeof repo.shouldKeepEntityPersistedOnDelete === 'function' - ? repo.shouldKeepEntityPersistedOnDelete(this) - : false + await this.deletePersistedEntity('deleteEntityQuietly', 'The bound repository cannot delete entities quietly.') } async restore(): Promise { - const repo = this.getRepositoryRuntime() - if (typeof repo.restoreEntity !== 'function') { - throw new HydrationError('The bound repository cannot restore entities.') - } - - const restored = await repo.restoreEntity(this) - this.attributes = { ...restored.toAttributes() } - this.original = { ...restored.toAttributes() } - this.changes = { ...restored.getChanges() } - this.persisted = true - return this + return this.restoreEntity('restoreEntity', 'The bound repository cannot restore entities.') } async restoreQuietly(): Promise { - const repo = this.getRepositoryRuntime() - if (typeof repo.restoreEntityQuietly !== 'function') { - throw new HydrationError('The bound repository cannot restore entities quietly.') - } - - const restored = await repo.restoreEntityQuietly(this) - this.attributes = { ...restored.toAttributes() } - this.original = { ...restored.toAttributes() } - this.changes = { ...restored.getChanges() } - this.persisted = true - return this + return this.restoreEntity('restoreEntityQuietly', 'The bound repository cannot restore entities quietly.') } async forceDelete(): Promise { - if (!this.persisted) { - throw new HydrationError('Cannot force-delete an entity that has not been persisted yet.') - } - - const repo = this.getRepositoryRuntime() - if (typeof repo.forceDeleteEntity !== 'function') { - throw new HydrationError('The bound repository cannot force-delete entities.') - } - - await repo.forceDeleteEntity(this) - this.persisted = false + await this.forceDeletePersistedEntity('forceDeleteEntity', 'The bound repository cannot force-delete entities.') } async forceDeleteQuietly(): Promise { - if (!this.persisted) { - throw new HydrationError('Cannot force-delete an entity that has not been persisted yet.') - } - - const repo = this.getRepositoryRuntime() - if (typeof repo.forceDeleteEntityQuietly !== 'function') { - throw new HydrationError('The bound repository cannot force-delete entities quietly.') - } - - await repo.forceDeleteEntityQuietly(this) - this.persisted = false + await this.forceDeletePersistedEntity('forceDeleteEntityQuietly', 'The bound repository cannot force-delete entities quietly.') } async fresh(): Promise | undefined> { @@ -683,63 +683,27 @@ class EntityBase< } async loadCount(...relations: readonly ModelRelationName[]): Promise { - const repo = this.getRepositoryRuntime() - if (typeof repo.loadRelationAggregates !== 'function') { - throw new HydrationError('The bound repository cannot load relation aggregates.') - } - - await repo.loadRelationAggregates([this], relations.map(relation => ({ relation, kind: 'count' }))) - return this + return this.loadAggregateDefinitions(relations.map(relation => ({ relation, kind: 'count' }))) } async loadExists(...relations: readonly ModelRelationName[]): Promise { - const repo = this.getRepositoryRuntime() - if (typeof repo.loadRelationAggregates !== 'function') { - throw new HydrationError('The bound repository cannot load relation aggregates.') - } - - await repo.loadRelationAggregates([this], relations.map(relation => ({ relation, kind: 'exists' }))) - return this + return this.loadAggregateDefinitions(relations.map(relation => ({ relation, kind: 'exists' }))) } async loadSum>(relation: TRelationName, column: RelatedColumnNameOfRelation): Promise { - const repo = this.getRepositoryRuntime() - if (typeof repo.loadRelationAggregates !== 'function') { - throw new HydrationError('The bound repository cannot load relation aggregates.') - } - - await repo.loadRelationAggregates([this], [{ relation, kind: 'sum', column }]) - return this + return this.loadAggregateDefinitions([{ relation, kind: 'sum', column }]) } async loadAvg>(relation: TRelationName, column: RelatedColumnNameOfRelation): Promise { - const repo = this.getRepositoryRuntime() - if (typeof repo.loadRelationAggregates !== 'function') { - throw new HydrationError('The bound repository cannot load relation aggregates.') - } - - await repo.loadRelationAggregates([this], [{ relation, kind: 'avg', column }]) - return this + return this.loadAggregateDefinitions([{ relation, kind: 'avg', column }]) } async loadMin>(relation: TRelationName, column: RelatedColumnNameOfRelation): Promise { - const repo = this.getRepositoryRuntime() - if (typeof repo.loadRelationAggregates !== 'function') { - throw new HydrationError('The bound repository cannot load relation aggregates.') - } - - await repo.loadRelationAggregates([this], [{ relation, kind: 'min', column }]) - return this + return this.loadAggregateDefinitions([{ relation, kind: 'min', column }]) } async loadMax>(relation: TRelationName, column: RelatedColumnNameOfRelation): Promise { - const repo = this.getRepositoryRuntime() - if (typeof repo.loadRelationAggregates !== 'function') { - throw new HydrationError('The bound repository cannot load relation aggregates.') - } - - await repo.loadRelationAggregates([this], [{ relation, kind: 'max', column }]) - return this + return this.loadAggregateDefinitions([{ relation, kind: 'max', column }]) } associate(relation: string, related: Entity | null): this { diff --git a/packages/db/src/model/ModelQueryBuilder.ts b/packages/db/src/model/ModelQueryBuilder.ts index ed1e5f2d..6254e934 100644 --- a/packages/db/src/model/ModelQueryBuilder.ts +++ b/packages/db/src/model/ModelQueryBuilder.ts @@ -88,6 +88,8 @@ type AggregateLoad = { type RelationConstraintMap = Readonly< Partial, RelationConstraint>> > +type TableQueryPredicates = + ReturnType>['getPlan']>['predicates'] export class ModelQueryBuilder< TTable extends TableDefinition = TableDefinition, @@ -115,6 +117,37 @@ export class ModelQueryBuilder< return this.tableQuery } + private collectNestedPredicates( + callback: BuilderCallback>, + ): TableQueryPredicates { + const nested = new ModelQueryBuilder( + this.repository, + new TableQueryBuilder(this.repository.definition.table, this.getConnection()), + ) + const callbackResult = callback(nested) + const result = callbackResult instanceof ModelQueryBuilder ? callbackResult : nested + return result.getTableQueryBuilder().getPlan().predicates + } + + private replayNestedPredicates( + query: TableQueryBuilder>, + predicates: TableQueryPredicates, + ): TableQueryBuilder> { + let next = query + for (const predicate of predicates) { + next = new TableQueryBuilder( + this.repository.definition.table, + this.getConnection(), + { + ...next.getPlan(), + predicates: Object.freeze([...next.getPlan().predicates, predicate]), + }, + ) + } + + return next + } + from(table: string): ModelQueryBuilder { return this.clone(this.tableQuery.from(table) as unknown as TableQueryBuilder>) } @@ -129,31 +162,12 @@ export class ModelQueryBuilder< value?: unknown, ): ModelQueryBuilder { if (typeof columnOrCallback === 'function') { - const nested = new ModelQueryBuilder( - this.repository, - new TableQueryBuilder(this.repository.definition.table, this.getConnection()), - ) - const callbackResult = columnOrCallback(nested) - const result = callbackResult instanceof ModelQueryBuilder ? callbackResult : nested - const predicates = result.getTableQueryBuilder().getPlan().predicates + const predicates = this.collectNestedPredicates(columnOrCallback) if (predicates.length === 0) { return this } - return this.clone(this.tableQuery.where((query) => { - let next = query - for (const predicate of predicates) { - next = new TableQueryBuilder( - this.repository.definition.table, - this.getConnection(), - { - ...next.getPlan(), - predicates: Object.freeze([...next.getPlan().predicates, predicate]), - }, - ) - } - return next - })) + return this.clone(this.tableQuery.where(query => this.replayNestedPredicates(query, predicates))) } return this.clone(this.tableQuery.where(columnOrCallback as never, operator, value)) @@ -169,31 +183,12 @@ export class ModelQueryBuilder< value?: unknown, ): ModelQueryBuilder { if (typeof columnOrCallback === 'function') { - const nested = new ModelQueryBuilder( - this.repository, - new TableQueryBuilder(this.repository.definition.table, this.getConnection()), - ) - const callbackResult = columnOrCallback(nested) - const result = callbackResult instanceof ModelQueryBuilder ? callbackResult : nested - const predicates = result.getTableQueryBuilder().getPlan().predicates + const predicates = this.collectNestedPredicates(columnOrCallback) if (predicates.length === 0) { return this } - return this.clone(this.tableQuery.orWhere((query) => { - let next = query - for (const predicate of predicates) { - next = new TableQueryBuilder( - this.repository.definition.table, - this.getConnection(), - { - ...next.getPlan(), - predicates: Object.freeze([...next.getPlan().predicates, predicate]), - }, - ) - } - return next - })) + return this.clone(this.tableQuery.orWhere(query => this.replayNestedPredicates(query, predicates))) } return this.clone(this.tableQuery.orWhere(columnOrCallback as never, operator, value)) @@ -202,61 +197,23 @@ export class ModelQueryBuilder< whereNot( callback: BuilderCallback>, ): ModelQueryBuilder { - const nested = new ModelQueryBuilder( - this.repository, - new TableQueryBuilder(this.repository.definition.table, this.getConnection()), - ) - const callbackResult = callback(nested) - const result = callbackResult instanceof ModelQueryBuilder ? callbackResult : nested - const predicates = result.getTableQueryBuilder().getPlan().predicates + const predicates = this.collectNestedPredicates(callback) if (predicates.length === 0) { return this } - return this.clone(this.tableQuery.whereNot((query) => { - let next = query - for (const predicate of predicates) { - next = new TableQueryBuilder( - this.repository.definition.table, - this.getConnection(), - { - ...next.getPlan(), - predicates: Object.freeze([...next.getPlan().predicates, predicate]), - }, - ) - } - return next - })) + return this.clone(this.tableQuery.whereNot(query => this.replayNestedPredicates(query, predicates))) } orWhereNot( callback: BuilderCallback>, ): ModelQueryBuilder { - const nested = new ModelQueryBuilder( - this.repository, - new TableQueryBuilder(this.repository.definition.table, this.getConnection()), - ) - const callbackResult = callback(nested) - const result = callbackResult instanceof ModelQueryBuilder ? callbackResult : nested - const predicates = result.getTableQueryBuilder().getPlan().predicates + const predicates = this.collectNestedPredicates(callback) if (predicates.length === 0) { return this } - return this.clone(this.tableQuery.orWhereNot((query) => { - let next = query - for (const predicate of predicates) { - next = new TableQueryBuilder( - this.repository.definition.table, - this.getConnection(), - { - ...next.getPlan(), - predicates: Object.freeze([...next.getPlan().predicates, predicate]), - }, - ) - } - return next - })) + return this.clone(this.tableQuery.orWhereNot(query => this.replayNestedPredicates(query, predicates))) } whereExists( diff --git a/packages/db/src/model/collection.ts b/packages/db/src/model/collection.ts index ae2b8a63..abf8e893 100644 --- a/packages/db/src/model/collection.ts +++ b/packages/db/src/model/collection.ts @@ -87,6 +87,17 @@ export function createModelCollection< const getRepository = (entity: Entity): CollectionRepository => ( entity.getRepository() as unknown as CollectionRepository ) + const loadRelationAggregates = async ( + aggregates: readonly ModelAggregateLoad[], + ): Promise> => { + const first = collection[0] + if (!first || aggregates.length === 0) { + return collection + } + + await getRepository(first).loadRelationAggregates(collection, aggregates) + return collection + } const methods = { modelKeys(): unknown[] { return collection.map((entity) => { @@ -143,64 +154,22 @@ export function createModelCollection< return collection }, async loadCount(...relations: readonly ModelRelationName[]): Promise> { - const first = collection[0] - if (!first || relations.length === 0) { - return collection - } - - const repo = getRepository(first) - await repo.loadRelationAggregates(collection, relations.map(relation => ({ relation, kind: 'count' }))) - return collection + return loadRelationAggregates(relations.map(relation => ({ relation, kind: 'count' }))) }, async loadExists(...relations: readonly ModelRelationName[]): Promise> { - const first = collection[0] - if (!first || relations.length === 0) { - return collection - } - - const repo = getRepository(first) - await repo.loadRelationAggregates(collection, relations.map(relation => ({ relation, kind: 'exists' }))) - return collection + return loadRelationAggregates(relations.map(relation => ({ relation, kind: 'exists' }))) }, async loadSum>(relation: TRelationName, column: RelatedColumnNameOfRelation): Promise> { - const first = collection[0] - if (!first || !relation) { - return collection - } - - const repo = getRepository(first) - await repo.loadRelationAggregates(collection, [{ relation, kind: 'sum', column }]) - return collection + return loadRelationAggregates(relation ? [{ relation, kind: 'sum', column }] : []) }, async loadAvg>(relation: TRelationName, column: RelatedColumnNameOfRelation): Promise> { - const first = collection[0] - if (!first || !relation) { - return collection - } - - const repo = getRepository(first) - await repo.loadRelationAggregates(collection, [{ relation, kind: 'avg', column }]) - return collection + return loadRelationAggregates(relation ? [{ relation, kind: 'avg', column }] : []) }, async loadMin>(relation: TRelationName, column: RelatedColumnNameOfRelation): Promise> { - const first = collection[0] - if (!first || !relation) { - return collection - } - - const repo = getRepository(first) - await repo.loadRelationAggregates(collection, [{ relation, kind: 'min', column }]) - return collection + return loadRelationAggregates(relation ? [{ relation, kind: 'min', column }] : []) }, async loadMax>(relation: TRelationName, column: RelatedColumnNameOfRelation): Promise> { - const first = collection[0] - if (!first || !relation) { - return collection - } - - const repo = getRepository(first) - await repo.loadRelationAggregates(collection, [{ relation, kind: 'max', column }]) - return collection + return loadRelationAggregates(relation ? [{ relation, kind: 'max', column }] : []) }, async fresh(): Promise> { const refreshed: Array | undefined> = await Promise.all( diff --git a/packages/db/src/model/defineModel.ts b/packages/db/src/model/defineModel.ts index 1efe7653..6282cf0f 100644 --- a/packages/db/src/model/defineModel.ts +++ b/packages/db/src/model/defineModel.ts @@ -83,6 +83,60 @@ type ModelTableBuilderResult< > = { build(): BoundTableDefinition } +type SharedModelDefinitionFields< + TTable extends TableDefinition, + TScopes extends ModelScopesDefinition, + TRelations extends RelationMap, +> = Omit< + ModelDefinition, + 'table' | 'primaryKey' | 'createdAtColumn' | 'updatedAtColumn' | 'deletedAtColumn' | 'uniqueIdConfig' +> + +function createSharedModelDefinitionFields< + TTable extends TableDefinition, + TScopes extends ModelScopesDefinition, + TRelations extends RelationMap, +>( + inferredName: string, + relations: TRelations, + touches: readonly string[], + options: DefineModelOptions, +): SharedModelDefinitionFields { + return { + kind: 'model', + name: inferredName, + connectionName: options.connectionName, + morphClass: options.morphClass ?? inferredName, + with: Object.freeze([...(options.with ?? [])]), + pendingAttributes: Object.freeze({ ...(options.pendingAttributes ?? {}) }), + preventLazyLoading: options.preventLazyLoading ?? false, + preventAccessingMissingAttributes: options.preventAccessingMissingAttributes ?? false, + automaticEagerLoading: options.automaticEagerLoading ?? false, + timestamps: options.timestamps ?? true, + fillable: Object.freeze([...(options.fillable ?? [])]), + hasExplicitFillable: typeof options.fillable !== 'undefined', + guarded: Object.freeze([...(options.guarded ?? [])]), + scopes: (options.scopes ?? {}) as TScopes, + globalScopes: { ...(options.globalScopes ?? {}) }, + relations, + casts: { ...(options.casts ?? {}) }, + accessors: { ...(options.accessors ?? {}) }, + mutators: { ...(options.mutators ?? {}) }, + hidden: Object.freeze([...(options.hidden ?? [])]), + visible: Object.freeze([...(options.visible ?? [])]), + appended: Object.freeze([...(options.appended ?? [])]), + serializeDate: options.serializeDate, + collection: options.collection, + prunable: options.prunable, + massPrunable: options.massPrunable ?? false, + touches, + traits: Object.freeze([...(options.traits ?? [])]), + replicationExcludes: Object.freeze([...(options.replicationExcludes ?? [])]), + softDeletes: options.softDeletes ?? false, + events: normalizeEventHandlers(options.events), + observers: Object.freeze([...(options.observers ?? [])]), + } +} export function defineModel< TTable extends TableDefinition, @@ -210,38 +264,7 @@ function defineModelFromGeneratedTableName< } const definition = { - kind: 'model' as const, - name: inferredName, - connectionName: options.connectionName, - morphClass: options.morphClass ?? inferredName, - with: Object.freeze([...(options.with ?? [])]), - pendingAttributes: Object.freeze({ ...(options.pendingAttributes ?? {}) }), - preventLazyLoading: options.preventLazyLoading ?? false, - preventAccessingMissingAttributes: options.preventAccessingMissingAttributes ?? false, - automaticEagerLoading: options.automaticEagerLoading ?? false, - timestamps: options.timestamps ?? true, - fillable: Object.freeze([...(options.fillable ?? [])]), - hasExplicitFillable: typeof options.fillable !== 'undefined', - guarded: Object.freeze([...(options.guarded ?? [])]), - scopes: (options.scopes ?? {}) as TScopes, - globalScopes: { ...(options.globalScopes ?? {}) }, - relations, - casts: { ...(options.casts ?? {}) }, - accessors: { ...(options.accessors ?? {}) }, - mutators: { ...(options.mutators ?? {}) }, - hidden: Object.freeze([...(options.hidden ?? [])]), - visible: Object.freeze([...(options.visible ?? [])]), - appended: Object.freeze([...(options.appended ?? [])]), - serializeDate: options.serializeDate, - collection: options.collection, - prunable: options.prunable, - massPrunable: options.massPrunable ?? false, - touches, - traits: Object.freeze([...(options.traits ?? [])]), - replicationExcludes: Object.freeze([...(options.replicationExcludes ?? [])]), - softDeletes: options.softDeletes ?? false, - events: normalizeEventHandlers(options.events), - observers: Object.freeze([...(options.observers ?? [])]), + ...createSharedModelDefinitionFields(inferredName, relations, touches, options), } as Omit, TScopes, TRelations>, 'table' | 'primaryKey' | 'createdAtColumn' | 'updatedAtColumn' | 'deletedAtColumn' | 'uniqueIdConfig'> & Pick, TScopes, TRelations>, 'table' | 'primaryKey' | 'createdAtColumn' | 'updatedAtColumn' | 'deletedAtColumn' | 'uniqueIdConfig'> @@ -304,44 +327,13 @@ function defineModelFromResolvedTable< const touches = validateTouches(inferredName, relations, options.touches ?? []) const definition: ModelDefinition = Object.freeze({ - kind: 'model', + ...createSharedModelDefinitionFields(inferredName, relations, touches, options), table, - name: inferredName, primaryKey, - connectionName: options.connectionName, - morphClass: options.morphClass ?? inferredName, - with: Object.freeze([...(options.with ?? [])]), - pendingAttributes: Object.freeze({ ...(options.pendingAttributes ?? {}) }), - preventLazyLoading: options.preventLazyLoading ?? false, - preventAccessingMissingAttributes: options.preventAccessingMissingAttributes ?? false, - automaticEagerLoading: options.automaticEagerLoading ?? false, - timestamps, createdAtColumn, updatedAtColumn, - fillable: Object.freeze([...(options.fillable ?? [])]), - hasExplicitFillable: typeof options.fillable !== 'undefined', - guarded: Object.freeze([...(options.guarded ?? [])]), - scopes: (options.scopes ?? {}) as TScopes, - globalScopes: { ...(options.globalScopes ?? {}) }, - relations, - casts: { ...(options.casts ?? {}) }, - accessors: { ...(options.accessors ?? {}) }, - mutators: { ...(options.mutators ?? {}) }, - hidden: Object.freeze([...(options.hidden ?? [])]), - visible: Object.freeze([...(options.visible ?? [])]), - appended: Object.freeze([...(options.appended ?? [])]), - serializeDate: options.serializeDate, - collection: options.collection, - prunable: options.prunable, - massPrunable: options.massPrunable ?? false, - touches, - traits: Object.freeze([...(options.traits ?? [])]), uniqueIdConfig, - replicationExcludes: Object.freeze([...(options.replicationExcludes ?? [])]), - softDeletes: options.softDeletes ?? false, deletedAtColumn, - events: normalizeEventHandlers(options.events), - observers: Object.freeze([...(options.observers ?? [])]), }) return createStaticModelApi(definition) diff --git a/packages/db/src/model/relations.ts b/packages/db/src/model/relations.ts index cd3a0a40..40b7fe59 100644 --- a/packages/db/src/model/relations.ts +++ b/packages/db/src/model/relations.ts @@ -24,7 +24,7 @@ import { getGlobalModel } from './ModelRegistry' import { RelationError } from '../core/errors' import type { TableDefinition } from '../schema/types' -type PivotMethodKeys = 'withPivot' | 'wherePivot' | 'orderByPivot' | 'as' | 'using' +type PivotMethodKeys = keyof PivotRelationMethods type BareBelongsToManyRelation = Omit type BareMorphToManyRelation = Omit type BareMorphedByManyRelation = Omit diff --git a/packages/db/src/model/staticModelApi.ts b/packages/db/src/model/staticModelApi.ts index d43379f9..2e9742b3 100644 --- a/packages/db/src/model/staticModelApi.ts +++ b/packages/db/src/model/staticModelApi.ts @@ -2,7 +2,6 @@ import type { ModelQueryBuilder } from './ModelQueryBuilder' import type { Entity } from './Entity' import type { ModelCollection } from './collection' import type { ModelRepository } from './ModelRepository' -import type { TableQueryBuilder } from '../query/TableQueryBuilder' import type { CursorPaginatedResult, CursorPaginationOptions, @@ -18,47 +17,16 @@ import type { RelationMap, DynamicRelationResolver, EntityWithLoaded, - ModelCastDefinition, ModelAttributeKey, ModelColumnName, - ModelColumnReference, - ModelJsonColumnPath, ModelRecord, ModelReference, - ModelSelectableColumn, ModelScopesDefinition, ModelScopeMethods, ModelUpdatePayload, - ModelRelationPath, - RelatedColumnNameForRelationPath, - ResolveEagerLoads, SerializedEntityWithLoaded, } from './types' -type BuilderCallback = (query: TBuilder) => unknown -type ValueBuilderCallback = (query: TBuilder, value: TValue) => unknown -type RelationConstraintCallback = (query: ModelQueryBuilder) => unknown -type RelationConstraintMap = Readonly< - Partial, RelationConstraintCallback>> -> -type MorphEntityTarget = { - exists(): boolean - getRepository(): { - definition: { - morphClass: string - primaryKey: string - } - } - get(key: string): unknown -} -type MorphModelTarget = { - definition?: { - morphClass?: string - } -} -type MorphTypeSelector = string | MorphModelTarget | MorphEntityTarget | null -type SubqueryBuilder - = ModelQueryBuilder | TableQueryBuilder> type PrimaryKeyName = Extract<{ [K in keyof TTable['columns']]: TTable['columns'][K] extends { readonly primaryKey: true } ? K : never }[keyof TTable['columns']], keyof ModelRecord & string> @@ -66,130 +34,131 @@ type ModelPrimaryKeyValue = [PrimaryKeyName] extends [never] ? unknown : ModelRecord[PrimaryKeyName] +type StaticModelQueryForwardMethod = + | 'from' + | 'debug' + | 'dump' + | 'where' + | 'orWhere' + | 'whereNot' + | 'orWhereNot' + | 'whereExists' + | 'orWhereExists' + | 'whereNotExists' + | 'orWhereNotExists' + | 'whereSub' + | 'orWhereSub' + | 'whereInSub' + | 'whereNotInSub' + | 'select' + | 'addSelect' + | 'withCasts' + | 'selectSub' + | 'addSelectSub' + | 'whereNull' + | 'orWhereNull' + | 'whereNotNull' + | 'orWhereNotNull' + | 'when' + | 'unless' + | 'distinct' + | 'whereColumn' + | 'whereIn' + | 'whereNotIn' + | 'whereBetween' + | 'whereNotBetween' + | 'whereLike' + | 'orWhereLike' + | 'whereAny' + | 'whereAll' + | 'whereNone' + | 'join' + | 'leftJoin' + | 'rightJoin' + | 'crossJoin' + | 'joinSub' + | 'leftJoinSub' + | 'rightJoinSub' + | 'joinLateral' + | 'leftJoinLateral' + | 'union' + | 'unionAll' + | 'groupBy' + | 'having' + | 'havingBetween' + | 'unsafeWhere' + | 'orUnsafeWhere' + | 'whereDate' + | 'whereMonth' + | 'whereDay' + | 'whereYear' + | 'whereTime' + | 'whereJson' + | 'orWhereJson' + | 'whereJsonContains' + | 'orWhereJsonContains' + | 'whereJsonLength' + | 'orWhereJsonLength' + | 'whereFullText' + | 'orWhereFullText' + | 'whereVectorSimilarTo' + | 'orWhereVectorSimilarTo' + | 'orderBy' + | 'latest' + | 'oldest' + | 'inRandomOrder' + | 'reorder' + | 'unsafeOrderBy' + | 'lock' + | 'lockForUpdate' + | 'sharedLock' + | 'with' + | 'withCount' + | 'withExists' + | 'withSum' + | 'withAvg' + | 'withMin' + | 'withMax' + | 'has' + | 'orHas' + | 'whereHas' + | 'orWhereHas' + | 'doesntHave' + | 'orDoesntHave' + | 'whereDoesntHave' + | 'orWhereDoesntHave' + | 'whereRelation' + | 'orWhereRelation' + | 'whereBelongsTo' + | 'orWhereBelongsTo' + | 'whereMorphedTo' + | 'orWhereMorphedTo' + | 'whereNotMorphedTo' + | 'orWhereNotMorphedTo' + | 'withWhereHas' +type StaticModelQueryForwarders< + TTable extends TableDefinition, + TRelations extends RelationMap, +> = Pick, StaticModelQueryForwardMethod> export type StaticModelApi< TTable extends TableDefinition, TScopes extends ModelScopesDefinition, TRelations extends RelationMap = RelationMap, -> = ModelReference & ModelScopeMethods & { +> = ModelReference + & ModelScopeMethods + & StaticModelQueryForwarders + & { query(): ModelQueryBuilder newQuery(): ModelQueryBuilder newModelQuery(): ModelQueryBuilder newQueryWithoutScopes(): ModelQueryBuilder newQueryWithoutRelationships(): ModelQueryBuilder - from(table: string): ModelQueryBuilder - debug(): ReturnType['debug']> - dump(): ModelQueryBuilder preventLazyLoading(value?: boolean): StaticModelApi preventAccessingMissingAttributes(value?: boolean): StaticModelApi automaticallyEagerLoadRelationships(value?: boolean): StaticModelApi withoutEvents(callback: () => TResult | Promise): Promise unguarded(callback: () => TResult | Promise): Promise - where(callback: BuilderCallback>): ModelQueryBuilder - where(column: ModelColumnName | ModelJsonColumnPath, operator: unknown, value?: unknown): ModelQueryBuilder - orWhere(callback: BuilderCallback>): ModelQueryBuilder - orWhere(column: ModelColumnName | ModelJsonColumnPath, operator: unknown, value?: unknown): ModelQueryBuilder - whereNot(callback: BuilderCallback>): ModelQueryBuilder - orWhereNot(callback: BuilderCallback>): ModelQueryBuilder - whereExists(subquery: SubqueryBuilder): ModelQueryBuilder - orWhereExists(subquery: SubqueryBuilder): ModelQueryBuilder - whereNotExists(subquery: SubqueryBuilder): ModelQueryBuilder - orWhereNotExists(subquery: SubqueryBuilder): ModelQueryBuilder - whereSub(column: ModelColumnName, operator: '!=' | '=' | '>' | '>=' | '<' | '<=' | 'in' | 'not in' | 'like', subquery: SubqueryBuilder): ModelQueryBuilder - orWhereSub(column: ModelColumnName, operator: '!=' | '=' | '>' | '>=' | '<' | '<=' | 'in' | 'not in' | 'like', subquery: SubqueryBuilder): ModelQueryBuilder - whereInSub(column: ModelColumnName, subquery: SubqueryBuilder): ModelQueryBuilder - whereNotInSub(column: ModelColumnName, subquery: SubqueryBuilder): ModelQueryBuilder - select(...columns: readonly ModelSelectableColumn[]): ModelQueryBuilder - addSelect(...columns: readonly ModelSelectableColumn[]): ModelQueryBuilder - withCasts(casts: Record): ModelQueryBuilder - selectSub(query: SubqueryBuilder, alias: string): ModelQueryBuilder - addSelectSub(query: SubqueryBuilder, alias: string): ModelQueryBuilder - whereNull(column: ModelColumnName): ModelQueryBuilder - orWhereNull(column: ModelColumnName): ModelQueryBuilder - whereNotNull(column: ModelColumnName): ModelQueryBuilder - orWhereNotNull(column: ModelColumnName): ModelQueryBuilder - when(value: TValue, callback: ValueBuilderCallback, TValue>, defaultCallback?: ValueBuilderCallback, TValue>): ModelQueryBuilder - unless(value: TValue, callback: ValueBuilderCallback, TValue>, defaultCallback?: ValueBuilderCallback, TValue>): ModelQueryBuilder - distinct(): ModelQueryBuilder - whereColumn(column: ModelColumnReference, operator: '!=' | '=' | '>' | '>=' | '<' | '<=' | 'like', compareTo: ModelColumnReference): ModelQueryBuilder - whereIn(column: ModelColumnName, values: readonly unknown[]): ModelQueryBuilder - whereNotIn(column: ModelColumnName, values: readonly unknown[]): ModelQueryBuilder - whereBetween(column: ModelColumnName, range: readonly [unknown, unknown]): ModelQueryBuilder - whereNotBetween(column: ModelColumnName, range: readonly [unknown, unknown]): ModelQueryBuilder - whereLike(column: ModelColumnName, pattern: string): ModelQueryBuilder - orWhereLike(column: ModelColumnName, pattern: string): ModelQueryBuilder - whereAny(columns: readonly ModelColumnName[], operator: unknown, value?: unknown): ModelQueryBuilder - whereAll(columns: readonly ModelColumnName[], operator: unknown, value?: unknown): ModelQueryBuilder - whereNone(columns: readonly ModelColumnName[], operator: unknown, value?: unknown): ModelQueryBuilder - join(table: string, leftColumn: ModelColumnReference, operator: '!=' | '=' | '>' | '>=' | '<' | '<=' | 'like', rightColumn: ModelColumnReference): ModelQueryBuilder - leftJoin(table: string, leftColumn: ModelColumnReference, operator: '!=' | '=' | '>' | '>=' | '<' | '<=' | 'like', rightColumn: ModelColumnReference): ModelQueryBuilder - rightJoin(table: string, leftColumn: ModelColumnReference, operator: '!=' | '=' | '>' | '>=' | '<' | '<=' | 'like', rightColumn: ModelColumnReference): ModelQueryBuilder - crossJoin(table: string): ModelQueryBuilder - joinSub(query: SubqueryBuilder, alias: string, leftColumn: ModelColumnReference, operator: '!=' | '=' | '>' | '>=' | '<' | '<=' | 'like', rightColumn: ModelColumnReference): ModelQueryBuilder - leftJoinSub(query: SubqueryBuilder, alias: string, leftColumn: ModelColumnReference, operator: '!=' | '=' | '>' | '>=' | '<' | '<=' | 'like', rightColumn: ModelColumnReference): ModelQueryBuilder - rightJoinSub(query: SubqueryBuilder, alias: string, leftColumn: ModelColumnReference, operator: '!=' | '=' | '>' | '>=' | '<' | '<=' | 'like', rightColumn: ModelColumnReference): ModelQueryBuilder - joinLateral(query: SubqueryBuilder, alias: string): ModelQueryBuilder - leftJoinLateral(query: SubqueryBuilder, alias: string): ModelQueryBuilder - union(query: SubqueryBuilder): ModelQueryBuilder - unionAll(query: SubqueryBuilder): ModelQueryBuilder - groupBy(...columns: readonly ModelColumnName[]): ModelQueryBuilder - having(expression: string, operator: unknown, value?: unknown): ModelQueryBuilder - havingBetween(expression: string, range: readonly [unknown, unknown]): ModelQueryBuilder - unsafeWhere(sql: string, bindings: readonly unknown[]): ModelQueryBuilder - orUnsafeWhere(sql: string, bindings: readonly unknown[]): ModelQueryBuilder - whereDate(column: ModelColumnName, operator: unknown, value?: unknown): ModelQueryBuilder - whereMonth(column: ModelColumnName, operator: unknown, value?: unknown): ModelQueryBuilder - whereDay(column: ModelColumnName, operator: unknown, value?: unknown): ModelQueryBuilder - whereYear(column: ModelColumnName, operator: unknown, value?: unknown): ModelQueryBuilder - whereTime(column: ModelColumnName, operator: unknown, value?: unknown): ModelQueryBuilder - whereJson(columnPath: ModelJsonColumnPath, operator: unknown, value?: unknown): ModelQueryBuilder - orWhereJson(columnPath: ModelJsonColumnPath, operator: unknown, value?: unknown): ModelQueryBuilder - whereJsonContains(columnPath: ModelJsonColumnPath, value: unknown): ModelQueryBuilder - orWhereJsonContains(columnPath: ModelJsonColumnPath, value: unknown): ModelQueryBuilder - whereJsonLength(columnPath: ModelJsonColumnPath, operator: unknown, value?: unknown): ModelQueryBuilder - orWhereJsonLength(columnPath: ModelJsonColumnPath, operator: unknown, value?: unknown): ModelQueryBuilder - whereFullText(columns: ModelColumnName | readonly ModelColumnName[], value: string, options?: { mode?: 'natural' | 'boolean' }): ModelQueryBuilder - orWhereFullText(columns: ModelColumnName | readonly ModelColumnName[], value: string, options?: { mode?: 'natural' | 'boolean' }): ModelQueryBuilder - whereVectorSimilarTo(column: ModelColumnName, vector: readonly number[], minSimilarity?: number): ModelQueryBuilder - orWhereVectorSimilarTo(column: ModelColumnName, vector: readonly number[], minSimilarity?: number): ModelQueryBuilder - orderBy(column: ModelColumnName, direction?: 'asc' | 'desc'): ModelQueryBuilder - latest(column?: ModelColumnName): ModelQueryBuilder - oldest(column?: ModelColumnName): ModelQueryBuilder - inRandomOrder(): ModelQueryBuilder - reorder(column?: ModelColumnName, direction?: 'asc' | 'desc'): ModelQueryBuilder - unsafeOrderBy(sql: string, bindings: readonly unknown[]): ModelQueryBuilder - lock(mode: 'update' | 'share'): ModelQueryBuilder - lockForUpdate(): ModelQueryBuilder - sharedLock(): ModelQueryBuilder - with[]>(...relations: TPaths): ModelQueryBuilder> - with[]>(relations: TPaths): ModelQueryBuilder> - with>(relation: TPath, constraint: RelationConstraintCallback): ModelQueryBuilder> - with, RelationConstraintCallback>>>>(relations: TMap): ModelQueryBuilder - withCount(...relations: readonly ModelRelationPath[]): ModelQueryBuilder - withCount(relations: RelationConstraintMap): ModelQueryBuilder - withExists(...relations: readonly ModelRelationPath[]): ModelQueryBuilder - withExists(relations: RelationConstraintMap): ModelQueryBuilder - withSum>(relation: TRelationPath | RelationConstraintMap, column: RelatedColumnNameForRelationPath, ...rest: readonly ModelRelationPath[]): ModelQueryBuilder - withAvg>(relation: TRelationPath | RelationConstraintMap, column: RelatedColumnNameForRelationPath, ...rest: readonly ModelRelationPath[]): ModelQueryBuilder - withMin>(relation: TRelationPath | RelationConstraintMap, column: RelatedColumnNameForRelationPath, ...rest: readonly ModelRelationPath[]): ModelQueryBuilder - withMax>(relation: TRelationPath | RelationConstraintMap, column: RelatedColumnNameForRelationPath, ...rest: readonly ModelRelationPath[]): ModelQueryBuilder - has(relation: ModelRelationPath): ModelQueryBuilder - orHas(relation: ModelRelationPath): ModelQueryBuilder - whereHas(relation: ModelRelationPath, constraint?: RelationConstraintCallback): ModelQueryBuilder - orWhereHas(relation: ModelRelationPath, constraint?: RelationConstraintCallback): ModelQueryBuilder - doesntHave(relation: ModelRelationPath): ModelQueryBuilder - orDoesntHave(relation: ModelRelationPath): ModelQueryBuilder - whereDoesntHave(relation: ModelRelationPath, constraint?: RelationConstraintCallback): ModelQueryBuilder - orWhereDoesntHave(relation: ModelRelationPath, constraint?: RelationConstraintCallback): ModelQueryBuilder - whereRelation>(relation: TRelationPath, column: RelatedColumnNameForRelationPath, operator: unknown, value?: unknown): ModelQueryBuilder - orWhereRelation>(relation: TRelationPath, column: RelatedColumnNameForRelationPath, operator: unknown, value?: unknown): ModelQueryBuilder - whereBelongsTo(entity: Entity, relationName?: ModelRelationPath): ModelQueryBuilder - orWhereBelongsTo(entity: Entity, relationName?: ModelRelationPath): ModelQueryBuilder - whereMorphedTo(relation: ModelRelationPath, target: MorphTypeSelector): ModelQueryBuilder - orWhereMorphedTo(relation: ModelRelationPath, target: MorphTypeSelector): ModelQueryBuilder - whereNotMorphedTo(relation: ModelRelationPath, target: MorphTypeSelector): ModelQueryBuilder - orWhereNotMorphedTo(relation: ModelRelationPath, target: MorphTypeSelector): ModelQueryBuilder - withWhereHas>(relation: TPath, constraint?: RelationConstraintCallback): ModelQueryBuilder> find(value: ModelPrimaryKeyValue): Promise | undefined> findMany(values: readonly ModelPrimaryKeyValue[]): Promise> findOrFail(value: ModelPrimaryKeyValue): Promise> diff --git a/packages/db/src/model/types.ts b/packages/db/src/model/types.ts index 4bcd1d5d..339f1b6f 100644 --- a/packages/db/src/model/types.ts +++ b/packages/db/src/model/types.ts @@ -153,7 +153,7 @@ export interface MorphToRelationDefinition extends ScopedRelationDefinition { +> extends ScopedRelationDefinition, PivotRelationMethods, TPivotTable> { readonly kind: 'belongsToMany' readonly related: () => TRelated readonly pivotTable: TPivotTable @@ -166,17 +166,12 @@ export interface BelongsToManyRelationDefinition< readonly pivotOrderBy: readonly PivotOrderDefinition[] readonly pivotAccessor: string readonly pivotModel?: () => ModelDefinitionLike - withPivot(...columns: readonly PivotTableColumnName[]): BelongsToManyRelationDefinition - wherePivot(column: PivotTableColumnName, operator: unknown, value?: unknown): BelongsToManyRelationDefinition - orderByPivot(column: PivotTableColumnName, direction?: 'asc' | 'desc'): BelongsToManyRelationDefinition - as(accessor: string): BelongsToManyRelationDefinition - using(model: () => ModelDefinitionLike): BelongsToManyRelationDefinition } export interface MorphToManyRelationDefinition< TRelated extends ModelDefinitionLike = ModelDefinitionLike, TPivotTable extends string | TableDefinition = string | TableDefinition, -> extends ScopedRelationDefinition { +> extends ScopedRelationDefinition, PivotRelationMethods, TPivotTable> { readonly kind: 'morphToMany' readonly related: () => TRelated readonly pivotTable: TPivotTable @@ -191,17 +186,12 @@ export interface MorphToManyRelationDefinition< readonly pivotOrderBy: readonly PivotOrderDefinition[] readonly pivotAccessor: string readonly pivotModel?: () => ModelDefinitionLike - withPivot(...columns: readonly PivotTableColumnName[]): MorphToManyRelationDefinition - wherePivot(column: PivotTableColumnName, operator: unknown, value?: unknown): MorphToManyRelationDefinition - orderByPivot(column: PivotTableColumnName, direction?: 'asc' | 'desc'): MorphToManyRelationDefinition - as(accessor: string): MorphToManyRelationDefinition - using(model: () => ModelDefinitionLike): MorphToManyRelationDefinition } export interface MorphedByManyRelationDefinition< TRelated extends ModelDefinitionLike = ModelDefinitionLike, TPivotTable extends string | TableDefinition = string | TableDefinition, -> extends ScopedRelationDefinition { +> extends ScopedRelationDefinition, PivotRelationMethods, TPivotTable> { readonly kind: 'morphedByMany' readonly related: () => TRelated readonly pivotTable: TPivotTable @@ -216,11 +206,6 @@ export interface MorphedByManyRelationDefinition< readonly pivotOrderBy: readonly PivotOrderDefinition[] readonly pivotAccessor: string readonly pivotModel?: () => ModelDefinitionLike - withPivot(...columns: readonly PivotTableColumnName[]): MorphedByManyRelationDefinition - wherePivot(column: PivotTableColumnName, operator: unknown, value?: unknown): MorphedByManyRelationDefinition - orderByPivot(column: PivotTableColumnName, direction?: 'asc' | 'desc'): MorphedByManyRelationDefinition - as(accessor: string): MorphedByManyRelationDefinition - using(model: () => ModelDefinitionLike): MorphedByManyRelationDefinition } export interface PivotWhereDefinition { diff --git a/packages/db/src/query/MySQLQueryCompiler.ts b/packages/db/src/query/MySQLQueryCompiler.ts index 1c7a475a..66daf7a8 100644 --- a/packages/db/src/query/MySQLQueryCompiler.ts +++ b/packages/db/src/query/MySQLQueryCompiler.ts @@ -1,5 +1,5 @@ import { SQLQueryCompiler } from './SQLQueryCompiler' -import type { InsertQueryPlan, QueryDatePredicate, QueryFullTextPredicate, QueryJsonPredicate, QueryJsonUpdateOperation, QueryLockMode, UpsertQueryPlan } from './ast' +import type { InsertQueryPlan, QueryFullTextPredicate, QueryJsonPredicate, QueryJsonUpdateOperation, QueryLockMode, UpsertQueryPlan } from './ast' export class MySQLQueryCompiler extends SQLQueryCompiler { protected override compileLateralJoinSource(subquery: string, alias: string): string { @@ -77,25 +77,4 @@ export class MySQLQueryCompiler extends SQLQueryCompiler { : ' IN NATURAL LANGUAGE MODE' return `MATCH (${columns}) AGAINST (${this.createPlaceholder(bindings.length)}${modifier})` } - - protected override compileDatePredicate( - predicate: QueryDatePredicate, - placeholder: string, - ): string { - const column = this.compileColumnReference(predicate.column) - switch (predicate.part) { - case 'date': - return `DATE(${column}) ${predicate.operator.toUpperCase()} ${placeholder}` - case 'time': - return `TIME(${column}) ${predicate.operator.toUpperCase()} ${placeholder}` - case 'year': - return `EXTRACT(YEAR FROM ${column}) ${predicate.operator.toUpperCase()} ${placeholder}` - case 'month': - return `EXTRACT(MONTH FROM ${column}) ${predicate.operator.toUpperCase()} ${placeholder}` - case 'day': - return `EXTRACT(DAY FROM ${column}) ${predicate.operator.toUpperCase()} ${placeholder}` - default: - return super.compileDatePredicate(predicate, placeholder) - } - } } diff --git a/packages/db/src/query/SQLiteQueryCompiler.impl.ts b/packages/db/src/query/SQLiteQueryCompiler.impl.ts index 1f0f745b..0a5c9c7d 100644 --- a/packages/db/src/query/SQLiteQueryCompiler.impl.ts +++ b/packages/db/src/query/SQLiteQueryCompiler.impl.ts @@ -5,21 +5,6 @@ function createSqliteJsonExtractExpression(column: string, pathLiteral: string): return `json_extract(${column}, ${pathLiteral})` } -function compileSqliteJsonValuePredicate( - extracted: string, - predicate: QueryJsonPredicate, - bindings: unknown[], - createPlaceholder: (index: number) => string, -): string { - bindings.push(predicate.value) - const operator = predicate.operator!.toUpperCase() - const placeholder = createPlaceholder(bindings.length) - return `${extracted} ${operator} ${placeholder}` -} -const SQLITE_JSON_HELPERS = Object.freeze({ - compileValuePredicate: compileSqliteJsonValuePredicate, - createExtractExpression: createSqliteJsonExtractExpression, -}) const SQLITE_EMPTY_JSON_OBJECT = 'json(\'{}\')' export class SQLiteQueryCompiler extends SQLQueryCompiler { @@ -30,18 +15,14 @@ export class SQLiteQueryCompiler extends SQLQueryCompiler { protected override compileJsonPredicate(predicate: QueryJsonPredicate, bindings: unknown[]): string { const column = this.compileColumnReference(predicate.column) const pathLiteral = this.createJsonPathLiteral(predicate.path) + const extracted = createSqliteJsonExtractExpression(column, pathLiteral) if (predicate.jsonMode === 'value') { - return SQLITE_JSON_HELPERS.compileValuePredicate( - SQLITE_JSON_HELPERS.createExtractExpression(column, pathLiteral), - predicate, - bindings, - index => this.createPlaceholder(index), - ) + bindings.push(predicate.value) + return `${extracted} ${predicate.operator!.toUpperCase()} ${this.createPlaceholder(bindings.length)}` } if (predicate.jsonMode === 'contains') { - const extracted = SQLITE_JSON_HELPERS.createExtractExpression(column, pathLiteral) if (predicate.value === null || ['string', 'number', 'boolean'].includes(typeof predicate.value)) { bindings.push(predicate.value) return `EXISTS (SELECT 1 FROM json_each(${extracted}) WHERE value = ${this.createPlaceholder(bindings.length)})` @@ -52,7 +33,7 @@ export class SQLiteQueryCompiler extends SQLQueryCompiler { } bindings.push(predicate.value) - return `json_array_length(${SQLITE_JSON_HELPERS.createExtractExpression(column, pathLiteral)}) ${predicate.operator!.toUpperCase()} ${this.createPlaceholder(bindings.length)}` + return `json_array_length(${extracted}) ${predicate.operator!.toUpperCase()} ${this.createPlaceholder(bindings.length)}` } protected override compileJsonUpdateOperations( diff --git a/packages/db/src/query/paginator.ts b/packages/db/src/query/paginator.ts index 77e82782..cd9dd712 100644 --- a/packages/db/src/query/paginator.ts +++ b/packages/db/src/query/paginator.ts @@ -1,4 +1,5 @@ import { SecurityError } from '../core/errors' +import { normalizePaginationParameterName } from './pagination' import type { CursorPaginatedResult, PaginationMeta, @@ -15,7 +16,7 @@ export function createPaginator( data, meta: { ...meta, - pageName: normalizeParameterName(meta.pageName, 'page'), + pageName: normalizePaginationParameterName(meta.pageName, 'page', message => new SecurityError(message)), }, } as PaginatedResult @@ -43,7 +44,7 @@ export function createSimplePaginator( data, meta: { ...meta, - pageName: normalizeParameterName(meta.pageName, 'page'), + pageName: normalizePaginationParameterName(meta.pageName, 'page', message => new SecurityError(message)), }, } as SimplePaginatedResult @@ -70,7 +71,7 @@ export function createCursorPaginator( const result = { data, perPage: meta.perPage, - cursorName: normalizeParameterName(meta.cursorName, 'cursor'), + cursorName: normalizePaginationParameterName(meta.cursorName, 'cursor', message => new SecurityError(message)), nextCursor: meta.nextCursor, prevCursor: meta.prevCursor, } as CursorPaginatedResult @@ -93,18 +94,6 @@ export function createCursorPaginator( return result } -function normalizeParameterName(value: string | undefined, fallback: string): string { - if (typeof value === 'undefined') { - return fallback - } - - if (typeof value !== 'string' || value.trim().length === 0) { - throw new SecurityError(`${fallback === 'cursor' ? 'Cursor' : 'Page'} parameter name must be a non-empty string.`) - } - - return value -} - function attachMethods( target: T, methods: Record unknown>, diff --git a/packages/db/src/runtime.ts b/packages/db/src/runtime.ts index 24510525..d0d7086d 100644 --- a/packages/db/src/runtime.ts +++ b/packages/db/src/runtime.ts @@ -41,6 +41,10 @@ export interface RuntimeConfigInput { db?: RuntimeDatabaseConfig } +type RuntimeConnectionEntry = [string, RuntimeConnectionConfig | string] +type RuntimeConnectionEntries = RuntimeConnectionEntry[] +type NonEmptyRuntimeConnectionEntries = [RuntimeConnectionEntry, ...RuntimeConnectionEntry[]] + const DEFAULT_RUNTIME_CONNECTION = Object.freeze({ driver: 'sqlite' as const, url: './data/database.sqlite', @@ -61,6 +65,12 @@ function normalizeConnectionInput(input: RuntimeConnectionConfig | string | unde return input ?? {} } +function hasRuntimeConnectionEntries( + entries: RuntimeConnectionEntries, +): entries is NonEmptyRuntimeConnectionEntries { + return entries.length > 0 +} + function inferDatabaseDriver(value: string | undefined): SupportedDatabaseDriver | undefined { if (!value) return undefined @@ -385,17 +395,7 @@ export function resolveRuntimeConnectionManagerOptions( const connections = topLevelDb?.connections ?? {} const connectionEntries = Object.entries(connections) - if (connectionEntries.length === 0) { - return createConnectionManager({ - defaultConnection: 'default', - connections: { - default: resolveConnectionConfig('default', undefined), - }, - }) - } - - const [firstConnectionEntry] = connectionEntries - if (!firstConnectionEntry) { + if (!hasRuntimeConnectionEntries(connectionEntries)) { return createConnectionManager({ defaultConnection: 'default', connections: { @@ -409,7 +409,7 @@ export function resolveRuntimeConnectionManagerOptions( ? configuredDefault : Object.hasOwn(connections, 'default') ? 'default' - : firstConnectionEntry[0] + : connectionEntries[0][0] const resolvedConnections = connectionEntries .map(([name, input]) => [ diff --git a/packages/db/src/schema/SQLSchemaCompiler.ts b/packages/db/src/schema/SQLSchemaCompiler.ts index f502be57..592dd226 100644 --- a/packages/db/src/schema/SQLSchemaCompiler.ts +++ b/packages/db/src/schema/SQLSchemaCompiler.ts @@ -1,5 +1,6 @@ import { SchemaError } from '../core/errors' -import { assertValidIdentifierPath, assertValidIdentifierSegment, sanitizeIdentifierForGeneratedName } from './identifiers' +import { assertValidIdentifierPath, assertValidIdentifierSegment } from './identifiers' +import { resolveGeneratedForeignKeyName, resolveGeneratedIndexName } from './generatedNames' import type { AnyColumnDefinition, TableDefinition, TableIndexDefinition } from './types' import type { DDLOperation, DDLStatement } from './ddl' @@ -316,15 +317,11 @@ export abstract class SQLSchemaCompiler { } protected resolveIndexName(tableName: string, index: TableIndexDefinition): string { - const indexName = index.name ?? `${sanitizeIdentifierForGeneratedName(tableName)}_${index.columns.join('_')}_${index.unique ? 'unique' : 'index'}` - assertValidIdentifierSegment(indexName, 'Index name') - return indexName + return resolveGeneratedIndexName(tableName, index) } protected resolveForeignKeyName(tableName: string, columnName: string, constraintName?: string): string { - const resolvedName = constraintName ?? `${sanitizeIdentifierForGeneratedName(tableName)}_${columnName}_foreign` - assertValidIdentifierSegment(resolvedName, 'Foreign key name') - return resolvedName + return resolveGeneratedForeignKeyName(tableName, columnName, constraintName) } protected assertSupportedAlterColumnDefinition(column: AnyColumnDefinition): void { diff --git a/packages/db/src/schema/SchemaService.ts b/packages/db/src/schema/SchemaService.ts index a1ab59a6..3cfc9fef 100644 --- a/packages/db/src/schema/SchemaService.ts +++ b/packages/db/src/schema/SchemaService.ts @@ -1,6 +1,7 @@ import { CapabilityError } from '../core/errors' import { addColumnOperation, alterColumnOperation, createForeignKeyOperation, createIndexOperation, createTableOperation, dropColumnOperation, dropForeignKeyOperation, dropIndexOperation, dropTableOperation, renameColumnOperation, renameIndexOperation, renameTableOperation } from './ddl' import { defineTable } from './defineTable' +import { resolveGeneratedForeignKeyName, resolveGeneratedIndexName } from './generatedNames' import { assertValidIdentifierPath, assertValidIdentifierSegment } from './identifiers' import { SQLiteSchemaCompiler } from './SQLiteSchemaCompiler' import { PostgresSchemaCompiler } from './PostgresSchemaCompiler' @@ -766,11 +767,11 @@ export class SchemaService { } private resolveIndexName(tableName: string, index: TableIndexDefinition): string { - return index.name ?? `${tableName.replaceAll('.', '_')}_${index.columns.join('_')}_${index.unique ? 'unique' : 'index'}` + return resolveGeneratedIndexName(tableName, index) } private resolveForeignKeyName(tableName: string, columnName: string): string { - return `${tableName.replaceAll('.', '_')}_${columnName}_foreign` + return resolveGeneratedForeignKeyName(tableName, columnName) } private createForeignKeyConstraintStatement(enable: boolean): { sql: string, source: string } { diff --git a/packages/db/src/schema/TableDefinitionBuilder.ts b/packages/db/src/schema/TableDefinitionBuilder.ts index a3f3c893..724e0d9c 100644 --- a/packages/db/src/schema/TableDefinitionBuilder.ts +++ b/packages/db/src/schema/TableDefinitionBuilder.ts @@ -1,8 +1,8 @@ import { SchemaError } from '../core/errors' import { column, type AnyColumnBuilder, type ColumnInput } from './columns' import { defineTable, type BoundTableDefinition, type DefineTableOptions } from './defineTable' -import { inferConstrainedTableName } from './pluralize' -import type { ForeignKeyReference, TableIndexDefinition } from './types' +import { ForeignKeyBuilderState } from './foreignKeyBuilderState' +import type { TableIndexDefinition } from './types' type ColumnShapeInput = Record @@ -389,44 +389,41 @@ class TableCreateForeignIdBuilder< TName extends string, TColumns extends ColumnShapeInput, > extends TableCreateColumnBuilder { - private referenceTable?: string - private referenceColumn = 'id' - private onDeleteAction?: ForeignKeyReference['onDelete'] - private onUpdateAction?: ForeignKeyReference['onUpdate'] + private readonly foreignKeyState: ForeignKeyBuilderState constructor( root: TableDefinitionBuilder, builderRef: ColumnReference, - private readonly columnName: string, + columnName: string, ) { super(root, builderRef) + this.foreignKeyState = new ForeignKeyBuilderState(columnName) } references( columnName: string, ): this { - this.referenceColumn = columnName + this.foreignKeyState.references(columnName) return this.applyReference() } on(table: string): this { - this.referenceTable = table + this.foreignKeyState.on(table) return this.applyReference() } constrained(table?: string, columnName = 'id'): this { - this.referenceTable = table ?? inferConstrainedTableName(this.columnName) - this.referenceColumn = columnName + this.foreignKeyState.constrained(table, columnName) return this.applyReference() } - onDelete(action: NonNullable): this { - this.onDeleteAction = action + onDelete(action: Parameters[0]): this { + this.foreignKeyState.onDelete(action) return this.applyReference() } - onUpdate(action: NonNullable): this { - this.onUpdateAction = action + onUpdate(action: Parameters[0]): this { + this.foreignKeyState.onUpdate(action) return this.applyReference() } @@ -463,17 +460,7 @@ class TableCreateForeignIdBuilder< } private applyReference(): this { - let builder = this.builderRef.builder.references(this.referenceColumn) - if (this.referenceTable) { - builder = builder.on(this.referenceTable) - } - if (this.onDeleteAction) { - builder = builder.onDelete(this.onDeleteAction) - } - if (this.onUpdateAction) { - builder = builder.onUpdate(this.onUpdateAction) - } - this.builderRef.builder = builder + this.builderRef.builder = this.foreignKeyState.applyToColumnBuilder(this.builderRef.builder) return this } } @@ -482,40 +469,38 @@ class TableCreateForeignKeyBuilder< TName extends string, TColumns extends ColumnShapeInput, > { - private referenceTable?: string - private referenceColumn = 'id' - private onDeleteAction?: ForeignKeyReference['onDelete'] - private onUpdateAction?: ForeignKeyReference['onUpdate'] + private readonly foreignKeyState: ForeignKeyBuilderState constructor( private readonly root: TableDefinitionBuilder, private readonly builderRef: ColumnReference, - private readonly columnName: string, - ) {} + columnName: string, + ) { + this.foreignKeyState = new ForeignKeyBuilderState(columnName) + } references(columnName: string): this { - this.referenceColumn = columnName + this.foreignKeyState.references(columnName) return this.applyReference() } on(table: string): this { - this.referenceTable = table + this.foreignKeyState.on(table) return this.applyReference() } constrained(table?: string, columnName = 'id'): this { - this.referenceTable = table ?? inferConstrainedTableName(this.columnName) - this.referenceColumn = columnName + this.foreignKeyState.constrained(table, columnName) return this.applyReference() } - onDelete(action: NonNullable): this { - this.onDeleteAction = action + onDelete(action: Parameters[0]): this { + this.foreignKeyState.onDelete(action) return this.applyReference() } - onUpdate(action: NonNullable): this { - this.onUpdateAction = action + onUpdate(action: Parameters[0]): this { + this.foreignKeyState.onUpdate(action) return this.applyReference() } @@ -556,17 +541,7 @@ class TableCreateForeignKeyBuilder< } private applyReference(): this { - let builder = this.builderRef.builder.references(this.referenceColumn) - if (this.referenceTable) { - builder = builder.on(this.referenceTable) - } - if (this.onDeleteAction) { - builder = builder.onDelete(this.onDeleteAction) - } - if (this.onUpdateAction) { - builder = builder.onUpdate(this.onUpdateAction) - } - this.builderRef.builder = builder + this.builderRef.builder = this.foreignKeyState.applyToColumnBuilder(this.builderRef.builder) return this } } diff --git a/packages/db/src/schema/TableMutationBuilder.ts b/packages/db/src/schema/TableMutationBuilder.ts index 8aa67435..7e283ed8 100644 --- a/packages/db/src/schema/TableMutationBuilder.ts +++ b/packages/db/src/schema/TableMutationBuilder.ts @@ -1,5 +1,5 @@ import { column, type AnyColumnBuilder } from './columns' -import { inferConstrainedTableName } from './pluralize' +import { ForeignKeyBuilderState } from './foreignKeyBuilderState' import type { ForeignKeyReference, TableIndexDefinition } from './types' interface AddColumnMutationOperation { @@ -74,45 +74,41 @@ type ColumnMutationMode = 'addColumn' | 'alterColumn' type ColumnFactory = () => TBuilder class TableForeignKeyBuilder { + private readonly foreignKeyState: ForeignKeyBuilderState + constructor( private readonly operation: CreateForeignKeyMutationOperation, - ) {} + ) { + this.foreignKeyState = new ForeignKeyBuilderState(operation.columnName, operation.reference.column) + } references(columnName: string): this { - this.operation.reference = { - ...this.operation.reference, - column: columnName, - } + this.foreignKeyState.references(columnName) + this.applyReference() return this } on(table: string): this { - this.operation.reference = { - ...this.operation.reference, - table, - } + this.foreignKeyState.on(table) + this.applyReference() return this } constrained(table?: string, columnName = 'id'): this { + this.foreignKeyState.constrained(table, columnName) + this.applyReference() return this - .on(table ?? inferConstrainedTableName(this.operation.columnName)) - .references(columnName) } onDelete(action: NonNullable): this { - this.operation.reference = { - ...this.operation.reference, - onDelete: action, - } + this.foreignKeyState.onDelete(action) + this.applyReference() return this } onUpdate(action: NonNullable): this { - this.operation.reference = { - ...this.operation.reference, - onUpdate: action, - } + this.foreignKeyState.onUpdate(action) + this.applyReference() return this } @@ -147,6 +143,10 @@ class TableForeignKeyBuilder { noActionOnUpdate(): this { return this.onUpdate('no action') } + + private applyReference(): void { + this.operation.reference = this.foreignKeyState.toReference() + } } class TableColumnMutationBuilder { @@ -245,11 +245,8 @@ class TableColumnMutationBuilder { } class TableForeignIdMutationBuilder extends TableColumnMutationBuilder { + private readonly foreignKeyState: ForeignKeyBuilderState private foreignOperation?: CreateForeignKeyMutationOperation - private referenceTable?: string - private referenceColumn = 'id' - private onDeleteAction?: ForeignKeyReference['onDelete'] - private onUpdateAction?: ForeignKeyReference['onUpdate'] constructor( root: TableMutationBuilder, @@ -258,33 +255,33 @@ class TableForeignIdMutationBuilder extends TableColumnMutationBuilder { private readonly columnName: string, ) { super(root, operation) + this.foreignKeyState = new ForeignKeyBuilderState(columnName) } references( columnName: string, ): this { - this.referenceColumn = columnName + this.foreignKeyState.references(columnName) return this.applyForeignKey() } on(table: string): this { - this.referenceTable = table + this.foreignKeyState.on(table) return this.applyForeignKey() } constrained(table?: string, columnName = 'id'): this { - this.referenceTable = table ?? inferConstrainedTableName(this.columnName) - this.referenceColumn = columnName + this.foreignKeyState.constrained(table, columnName) return this.applyForeignKey() } onDelete(action: NonNullable): this { - this.onDeleteAction = action + this.foreignKeyState.onDelete(action) return this.applyForeignKey() } onUpdate(action: NonNullable): this { - this.onUpdateAction = action + this.foreignKeyState.onUpdate(action) return this.applyForeignKey() } @@ -322,12 +319,7 @@ class TableForeignIdMutationBuilder extends TableColumnMutationBuilder { private applyForeignKey(): this { const operation = this.ensureForeignOperation() - operation.reference = { - table: this.referenceTable ?? '', - column: this.referenceColumn, - onDelete: this.onDeleteAction, - onUpdate: this.onUpdateAction, - } + operation.reference = this.foreignKeyState.toReference() return this } @@ -336,12 +328,7 @@ class TableForeignIdMutationBuilder extends TableColumnMutationBuilder { this.foreignOperation = { kind: 'createForeignKey', columnName: this.columnName, - reference: { - table: this.referenceTable ?? '', - column: this.referenceColumn, - onDelete: this.onDeleteAction, - onUpdate: this.onUpdateAction, - }, + reference: this.foreignKeyState.toReference(), } this.operations.push(this.foreignOperation) } diff --git a/packages/db/src/schema/diff.ts b/packages/db/src/schema/diff.ts index dfa66fd6..df3b6fe7 100644 --- a/packages/db/src/schema/diff.ts +++ b/packages/db/src/schema/diff.ts @@ -1,12 +1,11 @@ -import { SchemaError } from '../core/errors' import type { AnyColumnDefinition, ForeignKeyReference, - LogicalColumnKind, TableDefinition, TableIndexDefinition, } from './types' import type { IntrospectedForeignKey, SchemaService } from './SchemaService' +import { resolveDialectComparisonColumnType, type SchemaDialectName } from './typeMapping' export interface SchemaColumnMismatch { readonly column: string @@ -256,121 +255,18 @@ async function diffTable( function lowerLogicalColumnType(column: AnyColumnDefinition, dialectName: string): string { if (dialectName.startsWith('postgres')) { - return lowerPostgresLogicalColumnType(column) + return resolveDialectComparisonColumnType('postgres', column) } if (dialectName.startsWith('mysql')) { - return lowerMySqlLogicalColumnType(column) + return resolveDialectComparisonColumnType('mysql', column) } - switch (column.kind as LogicalColumnKind) { - case 'id': - case 'integer': - case 'bigInteger': - case 'boolean': - return 'INTEGER' - case 'string': - case 'uuid': - case 'ulid': - case 'snowflake': - case 'date': - case 'datetime': - case 'timestamp': - case 'text': - case 'json': - case 'enum': - return 'TEXT' - case 'real': - return 'REAL' - case 'decimal': - return 'NUMERIC' - case 'blob': - return 'BLOB' - case 'vector': - throw new SchemaError('SQLite schema diffing does not support logical vector columns.') - default: - throw new SchemaError(`Unsupported logical column kind "${String(column.kind)}" for SQLite schema diffing.`) + if (dialectName.startsWith('sqlite')) { + return resolveDialectComparisonColumnType('sqlite', column) } -} -function lowerPostgresLogicalColumnType(column: AnyColumnDefinition): string { - switch (column.kind as LogicalColumnKind) { - case 'id': - case 'bigInteger': - return 'bigint' - case 'integer': - return 'integer' - case 'boolean': - return 'boolean' - case 'string': - return 'character varying' - case 'text': - case 'enum': - return 'text' - case 'uuid': - return 'uuid' - case 'ulid': - case 'snowflake': - return 'character varying' - case 'date': - return 'date' - case 'datetime': - case 'timestamp': - return 'timestamp' - case 'json': - return 'jsonb' - case 'real': - return 'double precision' - case 'decimal': - return 'numeric' - case 'blob': - return 'bytea' - case 'vector': - return `vector(${column.vectorDimensions})` - default: - throw new SchemaError(`Unsupported logical column kind "${String(column.kind)}" for Postgres schema diffing.`) - } -} - -function lowerMySqlLogicalColumnType(column: AnyColumnDefinition): string { - switch (column.kind as LogicalColumnKind) { - case 'id': - case 'bigInteger': - return 'bigint' - case 'integer': - return 'int' - case 'boolean': - return 'tinyint' - case 'string': - return 'varchar' - case 'text': - return 'text' - case 'uuid': - case 'ulid': - return 'char' - case 'snowflake': - return 'varchar' - case 'date': - return 'date' - case 'datetime': - return 'datetime' - case 'timestamp': - return 'timestamp' - case 'enum': - return 'enum' - case 'json': - return 'json' - case 'real': - return 'double' - case 'decimal': - return 'decimal' - case 'blob': - return 'blob' - case 'vector': - throw new SchemaError('MySQL schema diffing does not support logical vector columns.') - default: - throw new SchemaError(`Unsupported logical column kind "${String(column.kind)}" for MySQL schema diffing.`) - } + return resolveDialectComparisonColumnType(dialectName as SchemaDialectName, column) } function normalizeExpectedIndex( diff --git a/packages/db/src/schema/foreignKeyBuilderState.ts b/packages/db/src/schema/foreignKeyBuilderState.ts new file mode 100644 index 00000000..26855836 --- /dev/null +++ b/packages/db/src/schema/foreignKeyBuilderState.ts @@ -0,0 +1,58 @@ +import { inferConstrainedTableName } from './pluralize' +import type { AnyColumnBuilder } from './columns' +import type { ForeignKeyReference } from './types' + +export class ForeignKeyBuilderState { + private referenceTable?: string + private referenceColumn: string + private onDeleteAction?: ForeignKeyReference['onDelete'] + private onUpdateAction?: ForeignKeyReference['onUpdate'] + + constructor(private readonly columnName: string, referenceColumn = 'id') { + this.referenceColumn = referenceColumn + } + + references(columnName: string): void { + this.referenceColumn = columnName + } + + on(table: string): void { + this.referenceTable = table + } + + constrained(table?: string, columnName = 'id'): void { + this.referenceTable = table ?? inferConstrainedTableName(this.columnName) + this.referenceColumn = columnName + } + + onDelete(action: NonNullable): void { + this.onDeleteAction = action + } + + onUpdate(action: NonNullable): void { + this.onUpdateAction = action + } + + toReference(defaultTable = ''): ForeignKeyReference { + return { + table: this.referenceTable ?? defaultTable, + column: this.referenceColumn, + onDelete: this.onDeleteAction, + onUpdate: this.onUpdateAction, + } + } + + applyToColumnBuilder(builder: AnyColumnBuilder): AnyColumnBuilder { + let next = builder.references(this.referenceColumn) + if (this.referenceTable) { + next = next.on(this.referenceTable) + } + if (this.onDeleteAction) { + next = next.onDelete(this.onDeleteAction) + } + if (this.onUpdateAction) { + next = next.onUpdate(this.onUpdateAction) + } + return next + } +} diff --git a/packages/db/src/schema/generatedNames.ts b/packages/db/src/schema/generatedNames.ts new file mode 100644 index 00000000..5af76225 --- /dev/null +++ b/packages/db/src/schema/generatedNames.ts @@ -0,0 +1,51 @@ +import { assertValidIdentifierSegment, sanitizeIdentifierForGeneratedName } from './identifiers' +import type { TableIndexDefinition } from './types' + +const MAX_GENERATED_IDENTIFIER_LENGTH = 63 + +function hashIndexColumns(columns: readonly string[]): string { + const serialized = JSON.stringify(columns) + let hash = 2_166_136_261 + + for (let index = 0; index < serialized.length; index += 1) { + hash = Math.imul(hash ^ serialized.charCodeAt(index), 16_777_619) + } + + return (hash >>> 0).toString(36) +} + +function buildGeneratedIndexName(tableName: string, columnsName: string, columnsHash: string, suffix: string): string { + const baseName = `${sanitizeIdentifierForGeneratedName(tableName)}_${columnsName}` + const fullName = `${baseName}_${columnsHash}_${suffix}` + if (fullName.length <= MAX_GENERATED_IDENTIFIER_LENGTH) { + return fullName + } + + const suffixLength = columnsHash.length + suffix.length + 2 + const truncatedBaseName = baseName + .slice(0, MAX_GENERATED_IDENTIFIER_LENGTH - suffixLength) + .replace(/_+$/g, '') + return `${truncatedBaseName}_${columnsHash}_${suffix}` +} + +export function resolveGeneratedIndexName(tableName: string, index: TableIndexDefinition): string { + const columnsName = index.columns + .map(column => sanitizeIdentifierForGeneratedName(column)) + .join('_') + const suffix = index.unique ? 'unique' : 'index' + const indexName = index.name + ?? buildGeneratedIndexName(tableName, columnsName, hashIndexColumns(index.columns), suffix) + assertValidIdentifierSegment(indexName, 'Index name') + return indexName +} + +export function resolveGeneratedForeignKeyName( + tableName: string, + columnName: string, + constraintName?: string, +): string { + const resolvedName = constraintName + ?? `${sanitizeIdentifierForGeneratedName(tableName)}_${sanitizeIdentifierForGeneratedName(columnName)}_foreign` + assertValidIdentifierSegment(resolvedName, 'Foreign key name') + return resolvedName +} diff --git a/packages/db/src/schema/typeMapping.ts b/packages/db/src/schema/typeMapping.ts index 50150668..86470637 100644 --- a/packages/db/src/schema/typeMapping.ts +++ b/packages/db/src/schema/typeMapping.ts @@ -147,3 +147,53 @@ export function resolveDialectColumnType( return typeof entry === 'function' ? entry(column) : entry } + +export function resolveDialectComparisonColumnType( + dialect: SchemaDialectName, + column: AnyColumnDefinition, +): string { + const compiledType = resolveDialectColumnType(dialect, column) + const normalized = compiledType.toLowerCase() + + if (dialect === 'sqlite') { + return normalized.split(/\s+/, 1)[0]!.toUpperCase() + } + + if (dialect === 'postgres') { + if (normalized.startsWith('varchar')) { + return 'character varying' + } + + if (normalized.startsWith('bigint')) { + return 'bigint' + } + + return normalized + } + + if (normalized.startsWith('varchar')) { + return 'varchar' + } + + if (normalized.startsWith('char')) { + return 'char' + } + + if (normalized.startsWith('tinyint')) { + return 'tinyint' + } + + if (normalized.startsWith('decimal')) { + return 'decimal' + } + + if (normalized.startsWith('enum')) { + return 'enum' + } + + if (normalized.startsWith('bigint')) { + return 'bigint' + } + + return normalized +} diff --git a/packages/db/src/security/policy.ts b/packages/db/src/security/policy.ts index 67304906..7c1737b7 100644 --- a/packages/db/src/security/policy.ts +++ b/packages/db/src/security/policy.ts @@ -32,7 +32,7 @@ export function redactBindings( return limited.map(() => '[REDACTED]') } - return [...limited] + return limited } export function redactSql( diff --git a/packages/db/tests/core-runtime.test.ts b/packages/db/tests/core-runtime.test.ts index 712ad1e0..b9e08703 100644 --- a/packages/db/tests/core-runtime.test.ts +++ b/packages/db/tests/core-runtime.test.ts @@ -684,6 +684,36 @@ describe('new core runtime slice', () => { await running }) + it('sleeps queued scheduler work until timer-backed active work completes', async () => { + const scheduler = createQueryScheduler({ + connectionName: 'default', + supportsConcurrentQueries: true, + supportsWorkerThreads: false, + concurrency: { + maxConcurrentQueries: 1, + queueLimit: 1, + }, + }) + const order: string[] = [] + + const first = scheduler.schedule({ transactional: false }, async () => { + order.push('first-start') + await new Promise(resolve => setTimeout(resolve, 0)) + order.push('first-end') + return 'first' + }) + const second = scheduler.schedule({ transactional: false }, async () => { + order.push('second-start') + return 'second' + }) + + await expect(Promise.all([first, second])).resolves.toEqual([ + { result: 'first', schedulingMode: 'concurrent' }, + { result: 'second', schedulingMode: 'concurrent' }, + ]) + expect(order).toEqual(['first-start', 'first-end', 'second-start']) + }) + it('uses worker scheduling mode when the dialect supports it and the connection prefers worker threads', async () => { const adapter = new FakeAdapter() let schedulingMode: string | undefined diff --git a/packages/db/tests/drivers-core.test.ts b/packages/db/tests/drivers-core.test.ts index af6dbc41..befedaf4 100644 --- a/packages/db/tests/drivers-core.test.ts +++ b/packages/db/tests/drivers-core.test.ts @@ -189,14 +189,6 @@ function createTransactionDialect(name: 'postgres' | 'mysql') { } } as const } -type PostgresAdapterHarness = { - releaseScopedTransaction(state: { client: { release?(): void }, leased: boolean, released: boolean }): void -} - -type MySQLAdapterHarness = { - releaseScopedTransaction(state: { client: { release?(): void }, leased: boolean, released: boolean }): void -} - describe('driver adapters', () => { let sqliteContractState: { sqlite: ReturnType @@ -340,55 +332,6 @@ describe('driver adapters', () => { await expect(adapter.initialize()).rejects.toThrow('[@holo-js/db] SQLite support requires @holo-js/db-sqlite to be installed.') }) - it('falls back to array-based SQLite bindings when spread bindings trigger an arity error', async () => { - const calls: Array<{ method: 'all' | 'run', bindings: readonly unknown[] }> = [] - let shouldThrowTooFew = true - const adapter = createSQLiteAdapter({ - database: { - prepare() { - return { - all(...bindings: readonly unknown[]) { - calls.push({ method: 'all', bindings }) - if (bindings.length > 1) { - throw new RangeError('Too many parameter values were provided') - } - - return [{ ok: true }] - }, - run(...bindings: readonly unknown[]) { - calls.push({ method: 'run', bindings }) - if (shouldThrowTooFew) { - shouldThrowTooFew = false - throw new RangeError('Too few parameter values were provided') - } - - if (bindings.length > 1) { - throw new RangeError('Too many parameter values were provided') - } - - return { - changes: 1, - lastInsertRowid: 11 } - } } - }, - exec() {}, - close() {} } }) - - await expect(adapter.query('SELECT * FROM users WHERE id = ? AND role = ?', [1, 'admin'])).resolves.toEqual({ - rows: [{ ok: true }], - rowCount: 1 }) - await expect(adapter.execute('UPDATE users SET meta = ?, role = ? WHERE id = ?', [{ active: true }, 'admin', 1])).resolves.toEqual({ - affectedRows: 1, - lastInsertId: 11 }) - - expect(calls).toEqual([ - { method: 'all', bindings: [1, 'admin'] }, - { method: 'all', bindings: [[1, 'admin']] }, - { method: 'run', bindings: [{ active: true }, 'admin', 1] }, - { method: 'run', bindings: [[{ active: true }, 'admin', 1]] }, - ]) - }) - it('rethrows non-arity SQLite statement errors', async () => { const adapter = createSQLiteAdapter({ database: { @@ -733,25 +676,6 @@ describe('driver adapters', () => { expect(adapter.isConnected()).toBe(false) }) - it('no-ops Postgres scoped transaction release when the state is already released', () => { - let releaseCalls = 0 - const adapter = new PostgresAdapter({ - client: { - async query() { - return { rows: [], rowCount: 0 } - } } }) - - ;(adapter as unknown as PostgresAdapterHarness).releaseScopedTransaction({ - client: { - release() { - releaseCalls += 1 - } }, - leased: true, - released: true }) - - expect(releaseCalls).toBe(0) - }) - let mySqlContractState: ReturnType runDriverAdapterContractSuite({ @@ -894,25 +818,6 @@ describe('driver adapters', () => { expect(adapter.isConnected()).toBe(false) }) - it('no-ops MySQL scoped transaction release when the state is already released', () => { - let releaseCalls = 0 - const adapter = new MySQLAdapter({ - client: { - async query() { - return [[], []] as const - } } }) - - ;(adapter as unknown as MySQLAdapterHarness).releaseScopedTransaction({ - client: { - release() { - releaseCalls += 1 - } }, - leased: true, - released: true }) - - expect(releaseCalls).toBe(0) - }) - it('covers active leased MySQL disconnects, invalid lazy pools, and the default pool factory path', async () => { const leasedState = createMySqlPool() const leasedAdapter = new MySQLAdapter({ diff --git a/packages/db/tests/factories-core.test.ts b/packages/db/tests/factories-core.test.ts index 35416803..c4acdb91 100644 --- a/packages/db/tests/factories-core.test.ts +++ b/packages/db/tests/factories-core.test.ts @@ -30,11 +30,6 @@ type Row = Record type TableStore = Record type CounterStore = Record type TestEntity = Entity -type FactoryPrivateApi = { - recycledEntities: unknown[] - resolveManySource(source: unknown, persist: boolean): Promise - takeRecycledEntities(source: unknown, amount: number): unknown[] -} function cloneRow(row: Row): Row { return { ...row } @@ -556,82 +551,6 @@ describe('factory slice', () => { ]) }) - it('covers recycled-factory edge branches directly', async () => { - const adapter = new InMemoryFactoryAdapter({ users: [] }, {}) - const db = createDatabase({ - connectionName: 'default', - adapter, - dialect: createDialect() }) - - configureDB(createConnectionManager({ - defaultConnection: 'default', - connections: { default: db } })) - - const users = defineTable('users', { - id: column.id(), - name: column.string() }) - const User = defineModelFromTable(users, { - fillable: ['name'] }) - - const factory = defineFactory(User, () => ({ name: 'User' })) - const recycled = User.make({ name: 'Unsaved' }) - const internalFactory = factory as unknown as FactoryPrivateApi - internalFactory.recycledEntities = [recycled] - - await expect(internalFactory.resolveManySource(factory.count(1), true)).rejects.toThrow( - 'Factory.recycle() requires persisted related models when using create().', - ) - - expect(internalFactory.takeRecycledEntities(factory, 0)).toEqual([]) - - expect(internalFactory.takeRecycledEntities({ - model: { - definition: { - table: { tableName: 'users' } } } }, 1)).toHaveLength(1) - - expect(internalFactory.takeRecycledEntities({ - model: { - definition: { - table: { tableName: 'users' } }, - getConnectionName: () => 'other' } }, 1)).toEqual([]) - }) - - it('does not match recycled related models from a different connection', async () => { - const adapter = new InMemoryFactoryAdapter({ users: [] }, {}) - const db = createDatabase({ - connectionName: 'default', - adapter, - dialect: createDialect() }) - - configureDB(createConnectionManager({ - defaultConnection: 'default', - connections: { default: db } })) - - const users = defineTable('users', { - id: column.id(), - name: column.string() }) - const User = defineModelFromTable(users, { - fillable: ['name'] }) - - const factory = defineFactory(User, () => ({ name: 'User' })) - const internalFactory = factory as unknown as FactoryPrivateApi - internalFactory.recycledEntities = [{ - getRepository() { - return { - definition: { - table: { tableName: 'users' } }, - getConnectionName() { - return 'secondary' - } } - } }] - - expect(internalFactory.takeRecycledEntities({ - model: { - definition: { - table: { tableName: 'users' } }, - getConnectionName: () => 'default' } }, 1)).toEqual([]) - }) - it('fails fast for unsaved recycled and directly attached related models during create paths', async () => { const adapter = new InMemoryFactoryAdapter({ teams: [], diff --git a/packages/db/tests/schema-service.test.ts b/packages/db/tests/schema-service.test.ts index 67ff2a11..ded69fb4 100644 --- a/packages/db/tests/schema-service.test.ts +++ b/packages/db/tests/schema-service.test.ts @@ -25,6 +25,7 @@ import { type DriverAdapter, type DriverExecutionResult, type DriverQueryResult } from '../src' +import { resolveGeneratedForeignKeyName } from '../src/schema/generatedNames' import { defineTable } from './support/internal' function parseIdentifierTail(identifierPath: string): string { @@ -566,7 +567,7 @@ describe('sqlite schema compiler', () => { expect(statements[0]!.sql).toContain('"account_uuid" TEXT NOT NULL REFERENCES "accounts" ("uuid")') expect(statements[0]!.sql).toContain('"session_ulid" TEXT NOT NULL REFERENCES "sessions" ("id")') expect(statements[0]!.sql).toContain('"actor_snowflake" TEXT NOT NULL REFERENCES "actors" ("snowflake_id")') - expect(statements[1]!.sql).toContain('CREATE UNIQUE INDEX IF NOT EXISTS "users_nickname_unique"') + expect(statements[1]!.sql).toContain('CREATE UNIQUE INDEX IF NOT EXISTS "users_nickname_fxf0kt_unique"') expect(compiler.compile(dropTableOperation('users'))).toEqual([{ sql: 'DROP TABLE IF EXISTS "users"', @@ -607,7 +608,30 @@ describe('sqlite schema compiler', () => { const statements = compiler.compile(createTableOperation(posts)) expect(statements[1]!.sql).toBe( - 'CREATE INDEX IF NOT EXISTS "posts_title_index" ON "posts" ("title")', + 'CREATE INDEX IF NOT EXISTS "posts_title_1b1lae3_index" ON "posts" ("title")', + ) + }) + + it('derives collision-resistant index names for underscored column groups', () => { + const compiler = new SQLiteSchemaCompiler(identifier => `"${identifier}"`) + const first = defineTable('events', { + a: column.string(), + a_b: column.string(), + b_c: column.string(), + c: column.string(), + }, { + indexes: [ + { columns: ['a_b', 'c'], unique: false }, + { columns: ['a', 'b_c'], unique: false }, + ], + }) + + const statements = compiler.compile(createTableOperation(first)) + expect(statements[1]!.sql).toBe( + 'CREATE INDEX IF NOT EXISTS "events_a_b_c_1ddsko2_index" ON "events" ("a_b", "c")', + ) + expect(statements[2]!.sql).toBe( + 'CREATE INDEX IF NOT EXISTS "events_a_b_c_1batapi_index" ON "events" ("a", "b_c")', ) }) @@ -904,8 +928,8 @@ describe('multi-dialect schema compilers', () => { expect(sqliteCompiler.compile(createIndexOperation('users', { columns: ['email'], unique: true }))).toEqual([{ - sql: 'CREATE UNIQUE INDEX IF NOT EXISTS "users_email_unique" ON "users" ("email")', - source: 'schema:createIndex:users:users_email_unique' }]) + sql: 'CREATE UNIQUE INDEX IF NOT EXISTS "users_email_yfr781_unique" ON "users" ("email")', + source: 'schema:createIndex:users:users_email_yfr781_unique' }]) expect(postgresCompiler.compile(createIndexOperation('public.users', { name: 'users_name_index', @@ -924,8 +948,8 @@ describe('multi-dialect schema compilers', () => { expect(mysqlCompiler.compile(createIndexOperation('analytics.users', { columns: ['email'], unique: false }))).toEqual([{ - sql: 'CREATE INDEX `analytics_users_email_index` ON `analytics`.`users` (`email`)', - source: 'schema:createIndex:analytics.users:analytics_users_email_index' }]) + sql: 'CREATE INDEX `analytics_users_email_yfr781_index` ON `analytics`.`users` (`email`)', + source: 'schema:createIndex:analytics.users:analytics_users_email_yfr781_index' }]) expect(mysqlCompiler.compile(dropIndexOperation('analytics.users', 'users_email_unique'))).toEqual([{ sql: 'DROP INDEX `users_email_unique` ON `analytics`.`users`', @@ -1047,6 +1071,10 @@ describe('multi-dialect schema compilers', () => { ) }) + it('sanitizes generated foreign-key names from table and column names', () => { + expect(resolveGeneratedForeignKeyName('public.users', 'team.id')).toBe('public_users_team_id_foreign') + }) + it('rejects malformed identifier paths before schema SQL is compiled', () => { const compiler = new SQLiteSchemaCompiler(identifier => `"${identifier}"`) @@ -1219,8 +1247,16 @@ describe('schema service', () => { ) => string expect(resolveIndexName('users', { columns: ['email'], unique: true, name: 'users_email_unique' })).toBe('users_email_unique') - expect(resolveIndexName('users', { columns: ['email'], unique: true })).toBe('users_email_unique') - expect(resolveIndexName('users', { columns: ['display_name'], unique: false })).toBe('users_display_name_index') + expect(resolveIndexName('users', { columns: ['email'], unique: true })).toBe('users_email_yfr781_unique') + expect(resolveIndexName('users', { columns: ['display_name'], unique: false })).toBe('users_display_name_1mdvyy9_index') + expect(resolveIndexName('mysql_device_links', { + columns: ['user_public_id', 'device_key_id', 'label'], + unique: true, + })).toBe('mysql_device_links_user_public_id_device_key_id_1mb4nd2_unique') + expect(resolveIndexName('mysql_device_links', { + columns: ['user_public_id', 'device_key_id', 'label'], + unique: true, + }).length).toBeLessThanOrEqual(63) }) it('renames indexes where the active dialect supports it and fails closed otherwise', async () => { @@ -1666,7 +1702,7 @@ describe('schema service', () => { await schema.table('users', (table) => { table.renameColumn('nickname', 'display_name') table.renameIndex('users_email_unique', 'users_email_address_unique') - table.dropIndex('users_nickname_index') + table.dropIndex('users_nickname_fxf0kt_index') }) const updated = registry.get('users') @@ -1678,7 +1714,7 @@ describe('schema service', () => { expect(adapter.executed).toEqual([ 'ALTER TABLE "users" RENAME COLUMN "nickname" TO "display_name"', 'ALTER INDEX "users_email_unique" RENAME TO "users_email_address_unique"', - 'DROP INDEX IF EXISTS "users_nickname_index"', + 'DROP INDEX IF EXISTS "users_nickname_fxf0kt_index"', ]) }) diff --git a/packages/db/vitest.config.ts b/packages/db/vitest.config.ts index b92a62ae..ab924302 100644 --- a/packages/db/vitest.config.ts +++ b/packages/db/vitest.config.ts @@ -23,9 +23,6 @@ export default defineConfig({ 'src/**/types.ts', 'src/migrations/templates/**', 'src/drivers/index.ts', - 'src/drivers/SQLiteAdapter.ts', - 'src/drivers/PostgresAdapter.ts', - 'src/drivers/MySQLAdapter.ts', '**/node_modules/**', 'packages/core/**', 'packages/storage/**', diff --git a/packages/forms/src/contracts.ts b/packages/forms/src/contracts.ts index 0d748a24..69586831 100644 --- a/packages/forms/src/contracts.ts +++ b/packages/forms/src/contracts.ts @@ -87,6 +87,16 @@ interface FormFailureMetadata { readonly retryAt?: string } +type NextHeadersModule = { + readonly headers: () => Headers | Promise +} + +type FormsRuntimeGlobal = typeof globalThis & { + readonly __holoFormsNextHeadersImport__?: () => Promise +} + +const nextHeadersModuleSpecifier = 'next/headers' + export interface FormRequestLikeInput { readonly method?: string readonly path?: string @@ -279,6 +289,10 @@ function isRequestLikeBody(value: unknown): value is RequestLikeBody { && (Symbol.asyncIterator in value || 'pipe' in value || 'getReader' in value) } +function isFormDataInput(value: unknown): value is FormData { + return typeof FormData !== 'undefined' && value instanceof FormData +} + function isHeadersTupleArray(value: unknown): value is ReadonlyArray { return Array.isArray(value) && value.every(entry => @@ -371,6 +385,73 @@ function normalizeRequestHeaders(input: unknown): Headers { return headers } +function createFormDataRequestHeaders(requestHeaders: Headers): Headers { + const formHeaders = new Headers() + + requestHeaders.forEach((value, name) => { + const normalizedName = name.toLowerCase() + if (normalizedName !== 'content-length' && normalizedName !== 'content-type') { + formHeaders.append(name, value) + } + }) + + return formHeaders +} + +function resolveAmbientRequestUrl(headers: Headers): string { + const referer = headers.get('referer') + if (referer) { + try { + return new URL(referer).href + } catch { + // Ignore malformed client-controlled Referer headers and fall back to trusted request headers. + } + } + + const protocol = headers.get('x-forwarded-proto') ?? 'http' + const host = headers.get('x-forwarded-host') ?? headers.get('host') ?? 'localhost' + + return `${protocol}://${host}/` +} + +function isNextHeadersModule(value: unknown): value is NextHeadersModule { + return !!value + && typeof value === 'object' + && typeof (value as { readonly headers?: unknown }).headers === 'function' +} + +async function importNextHeadersModule(): Promise { + try { + const runtime = globalThis as FormsRuntimeGlobal + const module = runtime.__holoFormsNextHeadersImport__ + ? await runtime.__holoFormsNextHeadersImport__() + : await import(/* @vite-ignore */ nextHeadersModuleSpecifier) + + return isNextHeadersModule(module) ? module : undefined + } catch { + return undefined + } +} + +async function resolveAmbientFormDataRequest(input: unknown): Promise { + if (!isFormDataInput(input)) { + return undefined + } + + const nextHeadersModule = await importNextHeadersModule() + if (!nextHeadersModule) { + return undefined + } + + const requestHeaders = await nextHeadersModule.headers() + + return new Request(resolveAmbientRequestUrl(requestHeaders), { + method: 'POST', + headers: createFormDataRequestHeaders(requestHeaders), + body: input, + }) +} + function getStructuredWebRequest(input: FormRequestLikeInput): StructuredRequestLikeObject | undefined { return input.web?.request instanceof Request ? undefined @@ -600,6 +681,7 @@ export async function validate( | undefined const usesSecurityOptions = options.csrf === true || typeof options.throttle === 'string' const normalizedRequestInput = normalizeRequestLikeInput(input) + ?? (usesSecurityOptions ? await resolveAmbientFormDataRequest(input) : undefined) const validationInput = normalizedRequestInput ?? input if (usesSecurityOptions && !normalizedRequestInput) { diff --git a/packages/forms/src/internal/client.ts b/packages/forms/src/internal/client.ts index 5b67b51a..8c450ae0 100644 --- a/packages/forms/src/internal/client.ts +++ b/packages/forms/src/internal/client.ts @@ -50,7 +50,7 @@ export interface UseFormOptions { readonly csrf?: boolean readonly validateOn?: ValidateOnMode readonly initialValues?: Partial - readonly initialState?: SerializedFormSubmission + readonly initialState?: SerializedFormSubmission | FormFailurePayload | null readonly submitter?: ( context: ClientSubmitContext, ) => Promise> | ClientSubmitResult @@ -745,20 +745,21 @@ export function createFormClient options: UseFormOptions, TSuccess> = {}, ): UseFormResult, TSuccess, InferFormFieldTree> { type TData = InferFormData + const initialState = options.initialState ?? undefined const initialValues = mergeValues( - normalizeObject(options.initialState?.values), - options.initialValues, + normalizeObject(options.initialValues), + normalizeObject(initialState?.values), ) const state: MutableState = { values: cloneValue(initialValues), initialValues: cloneValue(initialValues), - flattenedErrors: { ...(options.initialState?.errors ?? {}) }, + flattenedErrors: { ...(initialState?.errors ?? {}) }, touched: new Set(), dirty: new Set(), submitting: false, - lastSubmission: options.initialState, + lastSubmission: initialState, listeners: new Set(), validationSequence: 0, } diff --git a/packages/forms/tests/client.test.ts b/packages/forms/tests/client.test.ts index ca7816c8..6368582a 100644 --- a/packages/forms/tests/client.test.ts +++ b/packages/forms/tests/client.test.ts @@ -1442,6 +1442,28 @@ describe('@holo-js/forms client', () => { expect(client.lastSubmission).toEqual(initialState) }) + it('lets initial server state override default initial values', () => { + const registerUser = schema({ + email: field.string().required().email(), + }) + + const initialState = createFailedSubmission(registerUser, { + email: 'submitted-bad', + }, { + email: ['Submitted email is invalid.'], + }).fail() + + const client = useForm(registerUser, { + initialValues: { + email: '', + }, + initialState, + }) + + expect(client.values.email).toBe('submitted-bad') + expect(client.errors.first('email')).toBe('Submitted email is invalid.') + }) + it('tracks submitting while an async submitter is in flight', async () => { const registerUser = schema({ email: field.string().required().email(), diff --git a/packages/forms/tests/contracts.test.ts b/packages/forms/tests/contracts.test.ts index d49e55c6..7bfc0b4c 100644 --- a/packages/forms/tests/contracts.test.ts +++ b/packages/forms/tests/contracts.test.ts @@ -12,6 +12,11 @@ import { type FormFailureErrors, } from '../src' +type FormsTestGlobal = typeof globalThis & { + __holoFormsSecurityModule__?: unknown + __holoFormsNextHeadersImport__?: () => Promise +} + function createSecurityModule() { const attempts = new Map() @@ -75,7 +80,9 @@ function createSecurityModule() { } afterEach(() => { - delete (globalThis as typeof globalThis & { __holoFormsSecurityModule__?: unknown }).__holoFormsSecurityModule__ + const runtime = globalThis as FormsTestGlobal + delete runtime.__holoFormsSecurityModule__ + delete runtime.__holoFormsNextHeadersImport__ }) describe('@holo-js/forms contracts', () => { @@ -668,6 +675,83 @@ describe('@holo-js/forms contracts', () => { expect(throttled.errors.get('_root')).toEqual(['Too many attempts. Please try again later.']) }) + it('accepts Next server action FormData for security-aware validation', async () => { + const login = schema({ + email: field.string().required().email(), + }) + const runtime = globalThis as FormsTestGlobal + + runtime.__holoFormsSecurityModule__ = createSecurityModule() + runtime.__holoFormsNextHeadersImport__ = async () => ({ + headers: () => new Headers({ + cookie: 'XSRF-TOKEN=login-token', + 'X-CSRF-TOKEN': 'login-token', + 'x-forwarded-for': '203.0.113.11', + host: 'app.test', + referer: 'https://app.test/login', + 'content-type': 'multipart/form-data; boundary=stale-action-boundary', + 'content-length': '123', + }), + }) + + const formData = new FormData() + formData.set('email', 'ava@example.com') + + const firstAllowed = await validate(formData, login, { + csrf: true, + throttle: 'login', + }) + + expect(firstAllowed.valid).toBe(true) + if (!firstAllowed.valid) { + throw new Error('Expected Next action form validation success.') + } + + expect(firstAllowed.data).toEqual({ + email: 'ava@example.com', + }) + + const throttled = await validate(formData, login, { + throttle: 'login', + }) + expect(throttled.valid).toBe(false) + if (throttled.valid) { + throw new Error('Expected throttle failure.') + } + + expect(throttled.values).toEqual({ + email: 'ava@example.com', + }) + expect(throttled.errors.get('_root')).toEqual(['Too many attempts. Please try again later.']) + }) + + it('falls back to forwarded headers when ambient Next action referer is malformed', async () => { + const login = schema({ + email: field.string().required().email(), + }) + const runtime = globalThis as FormsTestGlobal + + runtime.__holoFormsSecurityModule__ = createSecurityModule() + runtime.__holoFormsNextHeadersImport__ = async () => ({ + headers: () => new Headers({ + cookie: 'XSRF-TOKEN=login-token', + 'X-CSRF-TOKEN': 'login-token', + 'x-forwarded-host': 'app.test', + 'x-forwarded-proto': 'https', + referer: 'http://%', + }), + }) + + const formData = new FormData() + formData.set('email', 'ava@example.com') + + const submission = await validate(formData, login, { + csrf: true, + }) + + expect(submission.valid).toBe(true) + }) + it('preserves cookie semantics for request-like header arrays during csrf validation', async () => { const login = schema({ email: field.string().required().email(), diff --git a/packages/forms/tsup.config.ts b/packages/forms/tsup.config.ts index 3885485f..675b1280 100644 --- a/packages/forms/tsup.config.ts +++ b/packages/forms/tsup.config.ts @@ -13,6 +13,7 @@ export default defineConfig({ clean: true, outDir, outExtension: () => ({ js: '.mjs' }), + external: ['next/headers'], esbuildOptions(options) { options.logLevel = 'warning' }, diff --git a/packages/mail/src/contracts.ts b/packages/mail/src/contracts.ts index 85cb99fc..89c67b0e 100644 --- a/packages/mail/src/contracts.ts +++ b/packages/mail/src/contracts.ts @@ -58,9 +58,7 @@ export type { ResolvedMailAttachment, } from './contracts-types' import { - BUILT_IN_MAIL_DRIVERS, HOLO_MAIL_DEFINITION_MARKER, - MAIL_ATTACHMENT_DISPOSITIONS, MAIL_PRIORITY_VALUES, type MailAddress, type MailAddressInput, @@ -789,40 +787,3 @@ export function attachContent( ...options, }) } - -export const mailInternals = { - BUILT_IN_MAIL_DRIVERS, - HOLO_MAIL_DEFINITION_MARKER, - MAIL_ATTACHMENT_DISPOSITIONS, - MAIL_PRIORITY_VALUES, - attachContent, - attachFromPath, - attachFromStorage, - createAttachmentMetadata, - createAttachmentResolutionPlan, - createAttachmentResolutionPlans, - inferMimeTypeFromName, - inferAttachmentName, - inferAttachmentSource, - isObject, - isAttachmentQueueSafe, - isValidEmail, - mergeMailDefinitionInputs, - normalizeAddress, - normalizeAttachment, - normalizeAttachments, - normalizeDelayValue, - normalizeHeaders, - normalizeJsonValue, - normalizeMailDefinition, - normalizeOptionalString, - normalizePriority, - normalizeQueueOptions, - normalizeRecipients, - normalizeRenderSource, - normalizeRequiredString, - resolveAttachmentDefinition, - resolveNormalizedAttachment, - normalizeTags, - normalizeViewIdentifier, -} diff --git a/packages/mail/src/index.ts b/packages/mail/src/index.ts index 05735d73..d962c1d5 100644 --- a/packages/mail/src/index.ts +++ b/packages/mail/src/index.ts @@ -18,7 +18,6 @@ export { inferMimeTypeFromName, isMailDefinition, isAttachmentQueueSafe, - mailInternals, mergeMailDefinitionInputs, normalizeMailDefinition, resolveAttachmentDefinition, @@ -96,7 +95,6 @@ export { MailPreviewDisabledError, MailPreviewFormatUnavailableError, MailSendError, - mailRuntimeInternals, previewMail, renderMailPreview, resetFakeSentMails, diff --git a/packages/mail/src/runtime.ts b/packages/mail/src/runtime.ts index 7150fe04..ddb88d99 100644 --- a/packages/mail/src/runtime.ts +++ b/packages/mail/src/runtime.ts @@ -302,80 +302,62 @@ function dynamicImport(specifier: string): Promise { return import(/* webpackIgnore: true */ specifier) as Promise } -async function loadQueueModule(): Promise { - const override = getRuntimeState().loadQueueModule - if (override) { - try { - return await override() - } catch (error) { - if ( - error - && typeof error === 'object' - && 'code' in error - && (error as { code?: unknown }).code === 'ERR_MODULE_NOT_FOUND' - ) { - throw new MailError( - '[@holo-js/mail] Queued or delayed mail delivery requires @holo-js/queue to be installed.', - 'MAIL_QUEUE_MODULE_MISSING', - ) - } +type PeerModuleLoaderOptions = { + readonly loadOverride?: () => Promise + readonly missing: (error: unknown) => TResolved + readonly resolve?: (module: TModule) => TResolved + readonly specifier: string +} + +function isModuleNotFoundError(error: unknown): boolean { + return !!( + error + && typeof error === 'object' + && 'code' in error + && (error as { code?: unknown }).code === 'ERR_MODULE_NOT_FOUND' + ) +} - throw error +async function loadPeerModule( + options: PeerModuleLoaderOptions, +): Promise { + try { + const module = options.loadOverride + ? await options.loadOverride() + : await dynamicImport(options.specifier) + return options.resolve + ? options.resolve(module) + : module as unknown as TResolved + } catch (error) { + if (isModuleNotFoundError(error)) { + return options.missing(error) } + + throw error } +} - try { - return await dynamicImport('@holo-js/queue') - } catch (error) { - if ( - error - && typeof error === 'object' - && 'code' in error - && (error as { code?: unknown }).code === 'ERR_MODULE_NOT_FOUND' - ) { +async function loadQueueModule(): Promise { + return loadPeerModule({ + loadOverride: getRuntimeState().loadQueueModule, + specifier: '@holo-js/queue', + missing() { throw new MailError( '[@holo-js/mail] Queued or delayed mail delivery requires @holo-js/queue to be installed.', 'MAIL_QUEUE_MODULE_MISSING', ) - } - - throw error - } + }, + }) } async function loadDbModule(): Promise { - const override = getRuntimeState().loadDbModule - if (override) { - try { - return await override() - } catch (error) { - if ( - error - && typeof error === 'object' - && 'code' in error - && (error as { code?: unknown }).code === 'ERR_MODULE_NOT_FOUND' - ) { - return null - } - - throw error - } - } - - try { - return await dynamicImport('@holo-js/db') - } catch (error) { - if ( - error - && typeof error === 'object' - && 'code' in error - && (error as { code?: unknown }).code === 'ERR_MODULE_NOT_FOUND' - ) { + return loadPeerModule({ + loadOverride: getRuntimeState().loadDbModule, + specifier: '@holo-js/db', + missing() { return null - } - - throw error - } + }, + }) } function resolveNodemailerModule(module: unknown): NodemailerModule { @@ -405,46 +387,18 @@ function resolveNodemailerModule(module: unknown): NodemailerModule { } async function loadNodemailerModule(): Promise { - const override = getRuntimeState().loadNodemailerModule - if (override) { - try { - return resolveNodemailerModule(await override()) - } catch (error) { - if ( - error - && typeof error === 'object' - && 'code' in error - && (error as { code?: unknown }).code === 'ERR_MODULE_NOT_FOUND' - ) { - throw new MailError( - '[@holo-js/mail] SMTP delivery requires nodemailer to be installed.', - 'MAIL_SMTP_MODULE_MISSING', - { cause: error }, - ) - } - - throw error - } - } - - try { - return resolveNodemailerModule(await dynamicImport('nodemailer')) - } catch (error) { - if ( - error - && typeof error === 'object' - && 'code' in error - && (error as { code?: unknown }).code === 'ERR_MODULE_NOT_FOUND' - ) { + return loadPeerModule({ + loadOverride: getRuntimeState().loadNodemailerModule, + specifier: 'nodemailer', + resolve: resolveNodemailerModule, + missing(error) { throw new MailError( '[@holo-js/mail] SMTP delivery requires nodemailer to be installed.', 'MAIL_SMTP_MODULE_MISSING', { cause: error }, ) - } - - throw error - } + }, + }) } function resolveStorageModule(module: unknown): StorageModule { @@ -465,46 +419,18 @@ function resolveStorageModule(module: unknown): StorageModule { } async function loadStorageModule(): Promise { - const override = getRuntimeState().loadStorageModule - if (override) { - try { - return resolveStorageModule(await override()) - } catch (error) { - if ( - error - && typeof error === 'object' - && 'code' in error - && (error as { code?: unknown }).code === 'ERR_MODULE_NOT_FOUND' - ) { - throw new MailError( - '[@holo-js/mail] Storage-backed attachments require @holo-js/storage to be installed.', - 'MAIL_STORAGE_MODULE_MISSING', - { cause: error }, - ) - } - - throw error - } - } - - try { - return resolveStorageModule(await dynamicImport('@holo-js/storage')) - } catch (error) { - if ( - error - && typeof error === 'object' - && 'code' in error - && (error as { code?: unknown }).code === 'ERR_MODULE_NOT_FOUND' - ) { + return loadPeerModule({ + loadOverride: getRuntimeState().loadStorageModule, + specifier: '@holo-js/storage', + resolve: resolveStorageModule, + missing(error) { throw new MailError( '[@holo-js/mail] Storage-backed attachments require @holo-js/storage to be installed.', 'MAIL_STORAGE_MODULE_MISSING', { cause: error }, ) - } - - throw error - } + }, + }) } function getFakeSentState(): FakeSentMail[] { diff --git a/packages/mail/tests/contracts.test.ts b/packages/mail/tests/contracts.test.ts index 2fe0c712..0571e48d 100644 --- a/packages/mail/tests/contracts.test.ts +++ b/packages/mail/tests/contracts.test.ts @@ -1,9 +1,17 @@ import { describe, expect, it } from 'vitest' import { + attachFromStorage, + createAttachmentMetadata, + createAttachmentResolutionPlan, + createAttachmentResolutionPlans, defineMail, + inferMimeTypeFromName, + isAttachmentQueueSafe, isMailDefinition, - mailInternals, + mergeMailDefinitionInputs, normalizeMailDefinition, + resolveAttachmentDefinition, + resolveNormalizedAttachment, } from '../src' describe('@holo-js/mail contracts', () => { @@ -186,7 +194,7 @@ describe('@holo-js/mail contracts', () => { }, }) - const merged = mailInternals.mergeMailDefinitionInputs(base, { + const merged = mergeMailDefinitionInputs(base, { headers: { 'X-Trace': 'trace', }, @@ -205,31 +213,11 @@ describe('@holo-js/mail contracts', () => { tenant: 'base', locale: 'en', }) - expect(mailInternals.normalizeDelayValue(new Date('2026-01-01T00:00:00.000Z'), 'delay')).toEqual( - new Date('2026-01-01T00:00:00.000Z'), - ) - expect(mailInternals.normalizeTags(['one', ' one ', 'two'])).toEqual(['one', 'two']) - expect(mailInternals.inferAttachmentName({ - path: '/tmp/logo.png', - })).toBe('logo.png') - expect(mailInternals.inferAttachmentName({ - storage: { - path: 'logos/mark.svg', - }, - })).toBe('mark.svg') - expect(mailInternals.inferAttachmentName({ - path: '/tmp/', - })).toBeUndefined() - expect(mailInternals.inferAttachmentName({ - storage: { - path: 'logos/', - }, - })).toBeUndefined() - expect(mailInternals.inferMimeTypeFromName('logo.png')).toBe('image/png') - expect(mailInternals.inferMimeTypeFromName('hello.txt')).toBe('text/plain') - expect(mailInternals.inferMimeTypeFromName('report.')).toBeUndefined() - expect(mailInternals.inferMimeTypeFromName('archive.unknown')).toBeUndefined() - expect(mailInternals.createAttachmentMetadata({ + expect(inferMimeTypeFromName('logo.png')).toBe('image/png') + expect(inferMimeTypeFromName('hello.txt')).toBe('text/plain') + expect(inferMimeTypeFromName('report.')).toBeUndefined() + expect(inferMimeTypeFromName('archive.unknown')).toBeUndefined() + expect(createAttachmentMetadata({ path: '/tmp/logo.png', disposition: 'inline', contentId: 'cid-logo', @@ -240,26 +228,16 @@ describe('@holo-js/mail contracts', () => { disposition: 'inline', contentId: 'cid-logo', }) - expect(mailInternals.inferAttachmentSource({ - content: 'hello', - name: 'greeting.txt', - })).toBe('content') - expect(mailInternals.isAttachmentQueueSafe({ + expect(isAttachmentQueueSafe({ path: '/tmp/logo.png', })).toBe(true) - expect(mailInternals.isAttachmentQueueSafe({ + expect(isAttachmentQueueSafe({ resolve: async () => ({ content: 'hello', name: 'hello.txt', }), })).toBe(false) - expect(mailInternals.normalizeViewIdentifier('auth/verify-email', 'view')).toBe('auth/verify-email') - expect(mailInternals.normalizeJsonValue({ - ok: ['yes'], - }, 'json')).toEqual({ - ok: ['yes'], - }) - const plans = mailInternals.createAttachmentResolutionPlans([ + const plans = createAttachmentResolutionPlans([ { path: '/tmp/logo.png', }, @@ -276,10 +254,7 @@ describe('@holo-js/mail contracts', () => { queuedSafe: true, contentType: 'image/png', }) - expect(mailInternals.isObject({ ok: true })).toBe(true) - expect(mailInternals.isObject(null)).toBe(false) - expect(() => mailInternals.normalizePriority('urgent' as never)).toThrow('Mail priority must be one of') - expect(() => mailInternals.createAttachmentResolutionPlan({ + expect(() => createAttachmentResolutionPlan({ resolve: async () => ({ content: 'hello', name: 'hello.txt', @@ -287,15 +262,10 @@ describe('@holo-js/mail contracts', () => { }, { queued: true, })).toThrow('not queue-safe') - expect(() => mailInternals.normalizeMailDefinition('broken' as never)).toThrow('Mail definitions must be plain objects') + expect(() => normalizeMailDefinition('broken' as never)).toThrow('Mail definitions must be plain objects') }) it('covers remaining contract helper validation branches', async () => { - expect(() => mailInternals.normalizeOptionalString(' ', 'Optional')).toThrow('Optional must be a non-empty string when provided') - expect(() => mailInternals.normalizeRequiredString(' ', 'Required')).toThrow('Required must be a non-empty string') - expect(() => mailInternals.normalizeDelayValue(-1, 'delay')).toThrow('greater than or equal to 0') - expect(() => mailInternals.normalizeDelayValue(new Date('invalid'), 'delay')).toThrow('must be valid Date instances') - expect(() => mailInternals.normalizeJsonValue(Symbol('bad'), 'json')).toThrow('must be JSON-serializable') expect(() => defineMail({ to: 'ava@example.com', subject: 'Stats', @@ -304,45 +274,7 @@ describe('@holo-js/mail contracts', () => { score: Number.NaN, }, })).toThrow('must be JSON-serializable') - expect(() => mailInternals.normalizeJsonValue(Number.POSITIVE_INFINITY, 'json')).toThrow('must be JSON-serializable') - expect(() => mailInternals.normalizeJsonValue(Number.NEGATIVE_INFINITY, 'json')).toThrow('must be JSON-serializable') - expect(mailInternals.isValidEmail('ava example.com')).toBe(false) - expect(() => mailInternals.normalizeHeaders('bad' as never)).toThrow('Mail headers must be a plain object') - expect(() => mailInternals.normalizeHeaders({ Test: 1 as never })).toThrow('Mail header "Test" must be a string') - expect(() => mailInternals.normalizeViewIdentifier('/emails/welcome', 'view')).toThrow('must be a relative mail view identifier') - expect(() => mailInternals.normalizeAddress(1 as never, 'Mail to')).toThrow('must be an email string or an object with email') - expect(() => mailInternals.normalizeRecipients(undefined, 'Mail to', true)).toThrow('must include at least one recipient') - expect(mailInternals.normalizeRecipients(undefined, 'Mail cc', false)).toEqual([]) - expect(mailInternals.normalizeQueueOptions(undefined)).toBeUndefined() - expect(mailInternals.normalizeQueueOptions(true)).toBe(true) - expect(mailInternals.normalizeQueueOptions({ - queued: false, - queue: 'mail', - afterCommit: true, - })).toEqual({ - queued: false, - queue: 'mail', - afterCommit: true, - }) - expect(mailInternals.normalizeQueueOptions({ - queued: true, - connection: 'redis', - })).toEqual({ - queued: true, - connection: 'redis', - }) - expect(mailInternals.normalizeQueueOptions({ - connection: 'redis', - } as never)).toEqual({ - connection: 'redis', - }) - expect(mailInternals.normalizeRenderSource({ - view: 'emails/welcome', - })).toEqual({ - view: 'emails/welcome', - }) - expect(() => mailInternals.normalizeRenderSource('bad' as never)).toThrow('Mail render sources must be plain objects') - expect(mailInternals.mergeMailDefinitionInputs({ + expect(mergeMailDefinitionInputs({ to: 'ava@example.com', subject: 'Welcome', markdown: '# Welcome', @@ -350,14 +282,14 @@ describe('@holo-js/mail contracts', () => { to: 'ava@example.com', subject: 'Welcome', }) - expect(mailInternals.mergeMailDefinitionInputs(defineMail({ + expect(mergeMailDefinitionInputs(defineMail({ to: 'ava@example.com', subject: 'Welcome', markdown: '# Welcome', }), {})).toMatchObject({ subject: 'Welcome', }) - expect(mailInternals.attachFromStorage('reports/invoice.pdf', { + expect(attachFromStorage('reports/invoice.pdf', { disk: 'public', name: 'invoice.pdf', contentType: 'application/pdf', @@ -374,34 +306,11 @@ describe('@holo-js/mail contracts', () => { contentId: 'invoice-cid', }) - expect(() => mailInternals.normalizeAttachment({ - contentId: '<>', - disposition: 'inline', - content: 'hello', - name: 'hello.txt', - }, 0)).toThrow('Mail attachment contentId must be a non-empty string') - expect(mailInternals.inferMimeTypeFromName('filename')).toBeUndefined() - expect(() => mailInternals.normalizeAttachment(null as never, 0)).toThrow('must be a plain object') - expect(() => mailInternals.resolveNormalizedAttachment({ + expect(inferMimeTypeFromName('filename')).toBeUndefined() + expect(() => resolveNormalizedAttachment({ disposition: 'attachment', } as never)).toThrow('Attachments must resolve to a named attachment') - expect(() => mailInternals.normalizeAttachment({ - name: 'broken.txt', - path: 123 as never, - }, 0)).toThrow('path attachments must include a path') - expect(() => mailInternals.normalizeAttachment({ - name: 'broken.txt', - storage: null as never, - }, 0)).toThrow('storage attachments must include a path') - expect(() => mailInternals.normalizeAttachment({ - content: 123 as never, - name: 'broken.txt', - }, 0)).toThrow('content attachments must use a string or Uint8Array') - expect(() => mailInternals.normalizeAttachment({ - resolve: 'bad' as never, - }, 0)).toThrow('resolve attachments must define resolve()') - expect(() => mailInternals.normalizeAttachment({} as never, 0)).toThrow('must define path, storage, content, or resolve') - expect(() => mailInternals.resolveNormalizedAttachment({ + expect(() => resolveNormalizedAttachment({ resolve: async () => ({ content: 'hello', name: 'hello.txt', @@ -409,11 +318,11 @@ describe('@holo-js/mail contracts', () => { name: 'hello.txt', disposition: 'attachment', })).toThrow('Resolver attachments must be resolved before creating transport attachments') - await expect(mailInternals.resolveAttachmentDefinition({ + await expect(resolveAttachmentDefinition({ resolve: async () => 'bad' as never, disposition: 'attachment', } as never)).rejects.toThrow('must return a plain object payload') - await expect(mailInternals.resolveAttachmentDefinition({ + await expect(resolveAttachmentDefinition({ resolve: async () => ({ name: 'preserved.txt', resolve: async () => ({ @@ -423,7 +332,7 @@ describe('@holo-js/mail contracts', () => { }), disposition: 'attachment', } as never)).rejects.toThrow('must resolve to a path, storage, or content attachment') - await expect(mailInternals.resolveAttachmentDefinition({ + await expect(resolveAttachmentDefinition({ name: 'outer.txt', disposition: 'attachment', resolve: async () => ({ @@ -435,7 +344,7 @@ describe('@holo-js/mail contracts', () => { content: 'hello', contentType: 'text/plain', }) - await expect(mailInternals.resolveAttachmentDefinition({ + await expect(resolveAttachmentDefinition({ name: 'outer.txt', disposition: 'inline', contentId: 'cid-outer', @@ -450,7 +359,7 @@ describe('@holo-js/mail contracts', () => { content: 'hello', contentType: 'text/plain', }) - await expect(mailInternals.resolveAttachmentDefinition({ + await expect(resolveAttachmentDefinition({ resolve: async () => ({ content: 'hello', name: 'resolved.txt', @@ -461,7 +370,7 @@ describe('@holo-js/mail contracts', () => { content: 'hello', contentType: 'text/plain', }) - await expect(mailInternals.resolveAttachmentDefinition({ + await expect(resolveAttachmentDefinition({ resolve: async () => ({ content: 'hello', name: 'resolved.txt', diff --git a/packages/mail/tests/runtime.test.ts b/packages/mail/tests/runtime.test.ts index 56604d16..444c7021 100644 --- a/packages/mail/tests/runtime.test.ts +++ b/packages/mail/tests/runtime.test.ts @@ -10,7 +10,6 @@ import { listFakeSentMails, listPreviewMailArtifacts, MailSendError, - mailRuntimeInternals, previewMail, registerMailDriver, renderMailPreview, @@ -25,6 +24,7 @@ import { type MailSendResult, type NormalizedHoloMailConfig, } from '../src' +import { mailRuntimeInternals } from '../src/runtime' const previousAppEnv = process.env.APP_ENV const previousNodeEnv = process.env.NODE_ENV diff --git a/packages/queue-db/src/database.ts b/packages/queue-db/src/database.ts index b3cc54a7..08f7a8f6 100644 --- a/packages/queue-db/src/database.ts +++ b/packages/queue-db/src/database.ts @@ -1,4 +1,4 @@ -import type { DatabaseContext, Dialect } from '@holo-js/db' +import { connectionAsyncContext, DB, type DatabaseContext, type Dialect } from '@holo-js/db' import type { QueueFailedJobRecord, QueueJobEnvelope, @@ -46,7 +46,7 @@ function isPlainObject(value: unknown): value is Record { return prototype === Object.prototype || prototype === null } -function assertQueueJsonValue( +export function assertQueueJsonValue( value: unknown, path: string, seen = new Set(), @@ -95,7 +95,7 @@ function assertQueueJsonValue( seen.delete(value) } -function normalizeIdentifierPath(value: string, label: string): string { +export function normalizeIdentifierPath(value: string, label: string): string { const normalized = value.trim() if (!normalized) { throw new Error(`[Holo Queue] ${label} must be a non-empty string.`) @@ -109,14 +109,14 @@ function normalizeIdentifierPath(value: string, label: string): string { return normalized } -function quoteIdentifierPath(dialect: Dialect, path: string): string { +export function quoteIdentifierPath(dialect: Dialect, path: string): string { return normalizeIdentifierPath(path, 'Queue table name') .split('.') .map(segment => dialect.quoteIdentifier(segment)) .join('.') } -function createPlaceholderList( +export function createPlaceholderList( dialect: Dialect, count: number, startIndex = 1, @@ -128,7 +128,7 @@ function createPlaceholderList( return Array.from({ length: count }, (_, index) => dialect.createPlaceholder(startIndex + index)).join(', ') } -function coerceRequiredString(value: unknown, label: string): string { +export function coerceRequiredString(value: unknown, label: string): string { if (typeof value !== 'string' || value.length === 0) { throw new Error(`[Holo Queue] ${label} must be a non-empty string.`) } @@ -136,7 +136,7 @@ function coerceRequiredString(value: unknown, label: string): string { return value } -function coerceRequiredInteger(value: unknown, label: string): number { +export function coerceRequiredInteger(value: unknown, label: string): number { if (typeof value === 'number') { if (!Number.isInteger(value)) { throw new Error(`[Holo Queue] ${label} must be an integer.`) @@ -152,7 +152,7 @@ function coerceRequiredInteger(value: unknown, label: string): number { throw new Error(`[Holo Queue] ${label} must be an integer.`) } -function coerceOptionalInteger(value: unknown, label: string): number | undefined { +export function coerceOptionalInteger(value: unknown, label: string): number | undefined { if (value === null || typeof value === 'undefined') { return undefined } @@ -160,32 +160,30 @@ function coerceOptionalInteger(value: unknown, label: string): number | undefine return coerceRequiredInteger(value, label) } -function parseStoredPayload(value: unknown, label: string): QueueJsonValue { +export function parseStoredPayload(value: unknown, label: string): QueueJsonValue { const serialized = coerceRequiredString(value, label) const parsed = JSON.parse(serialized) as unknown assertQueueJsonValue(parsed, label) return parsed } -function parseStoredQueueJobRow( +export function parseStoredQueueJobRow( row: StoredQueueJobRow, ): QueueJobEnvelope { - return Object.freeze({ - id: coerceRequiredString(row.id, 'Stored queue job id'), - name: coerceRequiredString(row.job, 'Stored queue job name'), - connection: coerceRequiredString(row.connection, 'Stored queue job connection'), - queue: coerceRequiredString(row.queue, 'Stored queue job queue'), + return parseStoredQueueEnvelope({ + id: row.id, + name: row.job, + connection: row.connection, + queue: row.queue, payload: parseStoredPayload(row.payload, 'Stored queue job payload'), - attempts: coerceRequiredInteger(row.attempts, 'Stored queue job attempts'), - maxAttempts: coerceRequiredInteger(row.max_attempts, 'Stored queue job max attempts'), - ...(typeof row.available_at === 'undefined' || row.available_at === null - ? {} - : { availableAt: coerceRequiredInteger(row.available_at, 'Stored queue job availability') }), - createdAt: coerceRequiredInteger(row.created_at, 'Stored queue job creation time'), + attempts: row.attempts, + maxAttempts: row.max_attempts, + availableAt: row.available_at, + createdAt: row.created_at, }) } -function parseStoredQueueEnvelope( +export function parseStoredQueueEnvelope( value: unknown, ): QueueJobEnvelope { if (!isPlainObject(value)) { @@ -212,7 +210,7 @@ function parseStoredQueueEnvelope( }) } -function parseStoredFailedQueueJobRow( +export function parseStoredFailedQueueJobRow( row: StoredFailedQueueJobRow, ): QueueFailedJobRecord { return Object.freeze({ @@ -224,19 +222,23 @@ function parseStoredFailedQueueJobRow( }) } -function serializeQueueJson(value: unknown): string { +export function serializeQueueJson(value: unknown): string { assertQueueJsonValue(value, 'Queue JSON payload') return JSON.stringify(value) } -async function ensureConnectionReady(connection: DatabaseContext): Promise { +export async function ensureConnectionReady(connection: DatabaseContext): Promise { await connection.initialize() return connection } -export type { - StoredFailedQueueJobRow, - StoredQueueJobRow, +export function resolveDatabaseConnection(name: string): DatabaseContext { + const active = connectionAsyncContext.getActive()?.connection + if (active && active.getConnectionName() === name) { + return active + } + + return DB.connection(name) } export const queueDatabaseInternals = { @@ -246,12 +248,17 @@ export const queueDatabaseInternals = { coerceRequiredString, createPlaceholderList, ensureConnectionReady, - isPlainObject, normalizeIdentifierPath, parseStoredFailedQueueJobRow, - parseStoredQueueEnvelope, parseStoredPayload, + parseStoredQueueEnvelope, parseStoredQueueJobRow, quoteIdentifierPath, + resolveDatabaseConnection, serializeQueueJson, } + +export type { + StoredFailedQueueJobRow, + StoredQueueJobRow, +} diff --git a/packages/queue-db/src/drivers/database.ts b/packages/queue-db/src/drivers/database.ts index 4231207d..4cbd7be1 100644 --- a/packages/queue-db/src/drivers/database.ts +++ b/packages/queue-db/src/drivers/database.ts @@ -1,7 +1,5 @@ import { randomUUID } from 'node:crypto' -import { DB } from '@holo-js/db' import type { DatabaseContext } from '@holo-js/db' -import { connectionAsyncContext } from '@holo-js/db' import type { NormalizedQueueDatabaseConnectionConfig, QueueAsyncDriver, @@ -13,7 +11,16 @@ import type { QueueReleaseOptions, QueueReservedJob, } from '@holo-js/queue' -import { queueDatabaseInternals } from '../database' +import { + coerceRequiredInteger, + createPlaceholderList, + ensureConnectionReady, + normalizeIdentifierPath, + parseStoredQueueJobRow, + quoteIdentifierPath, + resolveDatabaseConnection, + serializeQueueJson, +} from '../database' type DatabaseQueuedJobRow = { id: unknown @@ -78,16 +85,7 @@ function createPlaceholders( count: number, startIndex = 1, ): readonly string[] { - return queueDatabaseInternals.createPlaceholderList(connection.getDialect(), count, startIndex).split(', ') -} - -function resolveDatabaseConnection(name: string): DatabaseContext { - const active = connectionAsyncContext.getActive()?.connection - if (active && active.getConnectionName() === name) { - return active - } - - return DB.connection(name) + return createPlaceholderList(connection.getDialect(), count, startIndex).split(', ') } export class DatabaseQueueDriver implements QueueAsyncDriver { @@ -102,15 +100,15 @@ export class DatabaseQueueDriver implements QueueAsyncDriver { private readonly context: QueueDriverFactoryContext, ) { this.name = connection.name - this.tableName = queueDatabaseInternals.normalizeIdentifierPath(connection.table, 'Queue table name') + this.tableName = normalizeIdentifierPath(connection.table, 'Queue table name') } private async getConnection(): Promise { - return queueDatabaseInternals.ensureConnectionReady(resolveDatabaseConnection(this.connection.connection)) + return ensureConnectionReady(resolveDatabaseConnection(this.connection.connection)) } private getQuotedTable(connection: DatabaseContext): string { - return queueDatabaseInternals.quoteIdentifierPath(connection.getDialect(), this.tableName) + return quoteIdentifierPath(connection.getDialect(), this.tableName) } private getExpiredReservationCutoff(now: number): number { @@ -125,7 +123,7 @@ export class DatabaseQueueDriver implements QueueAsyncDriver { return Object.freeze({ reservationId, reservedAt, - envelope: queueDatabaseInternals.parseStoredQueueJobRow(row), + envelope: parseStoredQueueJobRow(row), }) } @@ -144,7 +142,7 @@ export class DatabaseQueueDriver implements QueueAsyncDriver { job.name, job.connection, job.queue, - queueDatabaseInternals.serializeQueueJson(job.payload), + serializeQueueJson(job.payload), job.attempts, job.maxAttempts, job.availableAt ?? job.createdAt, @@ -195,7 +193,7 @@ export class DatabaseQueueDriver implements QueueAsyncDriver { } const reservationId = `${input.workerId}:${randomUUID()}` - const nextAttempts = queueDatabaseInternals.coerceRequiredInteger(row.attempts, 'Stored queue job attempts') + 1 + const nextAttempts = coerceRequiredInteger(row.attempts, 'Stored queue job attempts') + 1 const [ reservedAtPlaceholder, reservationPlaceholder, @@ -317,6 +315,7 @@ export const databaseQueueDriverFactory: QueueDriverFactory { const config = getFailedStoreConfig() if (config === false) { return null } - const tableName = queueDatabaseInternals.normalizeIdentifierPath(config.table, 'Failed jobs table name') - const connection = await queueDatabaseInternals.ensureConnectionReady(resolveDatabaseConnection(config.connection)) + const tableName = normalizeIdentifierPath(config.table, 'Failed jobs table name') + const connection = await ensureConnectionReady(resolveDatabaseConnection(config.connection)) return { connection, tableName, } } +async function loadFailedJobs( + failedStore: { connection: DatabaseContext, tableName: string }, + id?: string, +): Promise { + const quotedTable = quoteIdentifierPath(failedStore.connection.getDialect(), failedStore.tableName) + const dialect = failedStore.connection.getDialect() + const filter = typeof id === 'string' + ? { + clause: ` WHERE id = ${dialect.createPlaceholder(1)}`, + bindings: [id], + source: 'queue:failed:load', + } + : { + clause: '', + bindings: undefined, + source: 'queue:failed:list', + } + const result = await failedStore.connection.queryCompiled({ + sql: `SELECT id, job_id, payload, exception, failed_at FROM ${quotedTable}${filter.clause} ORDER BY failed_at DESC, id DESC`, + bindings: filter.bindings, + source: filter.source, + }) + + return Object.freeze(result.rows.map((row: StoredFailedQueueJobRow) => parseStoredFailedQueueJobRow(row))) +} + export const queueDbFailedJobStore: QueueFailedJobStore = { async persistFailedJob( reserved: QueueReservedJob, @@ -42,8 +66,8 @@ export const queueDbFailedJobStore: QueueFailedJobStore = { return null } - const quotedTable = queueDatabaseInternals.quoteIdentifierPath(failedStore.connection.getDialect(), failedStore.tableName) - const placeholders = queueDatabaseInternals.createPlaceholderList(failedStore.connection.getDialect(), 7) + const quotedTable = quoteIdentifierPath(failedStore.connection.getDialect(), failedStore.tableName) + const placeholders = createPlaceholderList(failedStore.connection.getDialect(), 7) const record = Object.freeze({ id: randomUUID(), jobId: reserved.envelope.id, @@ -62,7 +86,7 @@ export const queueDbFailedJobStore: QueueFailedJobStore = { reserved.envelope.name, reserved.envelope.connection, reserved.envelope.queue, - queueDatabaseInternals.serializeQueueJson(reserved.envelope), + serializeQueueJson(reserved.envelope), record.exception, record.failedAt, ], @@ -78,13 +102,7 @@ export const queueDbFailedJobStore: QueueFailedJobStore = { return Object.freeze([]) } - const quotedTable = queueDatabaseInternals.quoteIdentifierPath(failedStore.connection.getDialect(), failedStore.tableName) - const result = await failedStore.connection.queryCompiled({ - sql: `SELECT id, job_id, payload, exception, failed_at FROM ${quotedTable} ORDER BY failed_at DESC, id DESC`, - source: 'queue:failed:list', - }) - - return Object.freeze(result.rows.map((row: StoredFailedQueueJobRow) => queueDatabaseInternals.parseStoredFailedQueueJobRow(row))) + return loadFailedJobs(failedStore) }, async retryFailedJobs( @@ -96,18 +114,9 @@ export const queueDbFailedJobStore: QueueFailedJobStore = { return 0 } - const quotedTable = queueDatabaseInternals.quoteIdentifierPath(failedStore.connection.getDialect(), failedStore.tableName) + const quotedTable = quoteIdentifierPath(failedStore.connection.getDialect(), failedStore.tableName) const dialect = failedStore.connection.getDialect() - const records = await (identifier === 'all' - ? await this.listFailedJobs() - : (() => { - const placeholder = dialect.createPlaceholder(1) - return failedStore.connection.queryCompiled({ - sql: `SELECT id, job_id, payload, exception, failed_at FROM ${quotedTable} WHERE id = ${placeholder}`, - bindings: [identifier], - source: 'queue:failed:load', - }).then((result) => Object.freeze(result.rows.map((row: StoredFailedQueueJobRow) => queueDatabaseInternals.parseStoredFailedQueueJobRow(row)))) - })()) + const records = await loadFailedJobs(failedStore, identifier === 'all' ? undefined : identifier) let retried = 0 @@ -131,7 +140,7 @@ export const queueDbFailedJobStore: QueueFailedJobStore = { return false } - const quotedTable = queueDatabaseInternals.quoteIdentifierPath(failedStore.connection.getDialect(), failedStore.tableName) + const quotedTable = quoteIdentifierPath(failedStore.connection.getDialect(), failedStore.tableName) const placeholder = failedStore.connection.getDialect().createPlaceholder(1) const result = await failedStore.connection.executeCompiled({ sql: `DELETE FROM ${quotedTable} WHERE id = ${placeholder}`, @@ -148,7 +157,7 @@ export const queueDbFailedJobStore: QueueFailedJobStore = { return 0 } - const quotedTable = queueDatabaseInternals.quoteIdentifierPath(failedStore.connection.getDialect(), failedStore.tableName) + const quotedTable = quoteIdentifierPath(failedStore.connection.getDialect(), failedStore.tableName) const result = await failedStore.connection.executeCompiled({ sql: `DELETE FROM ${quotedTable}`, source: 'queue:failed:flush', @@ -161,4 +170,5 @@ export const queueDbFailedJobStore: QueueFailedJobStore = { export const queueDbFailedStoreInternals = { getFailedStoreConfig, getFailedStoreConnection, + loadFailedJobs, } diff --git a/packages/queue-db/src/index.ts b/packages/queue-db/src/index.ts index a333a6ba..728cfcf7 100644 --- a/packages/queue-db/src/index.ts +++ b/packages/queue-db/src/index.ts @@ -1,7 +1,6 @@ +export type { StoredFailedQueueJobRow, StoredQueueJobRow } from './database' export { queueDatabaseInternals, - type StoredFailedQueueJobRow, - type StoredQueueJobRow, } from './database' export { databaseQueueDriverFactory, @@ -10,8 +9,8 @@ export { DatabaseQueueDriverError, } from './drivers/database' export { - queueDbFailedJobStore, queueDbFailedStoreInternals, + queueDbFailedJobStore, } from './failed' export { createQueueDbRuntimeOptions, diff --git a/packages/queue-db/tests/database-driver.test.ts b/packages/queue-db/tests/database-driver.test.ts index 368b431d..a455b8ab 100644 --- a/packages/queue-db/tests/database-driver.test.ts +++ b/packages/queue-db/tests/database-driver.test.ts @@ -1,14 +1,12 @@ import { afterEach, describe, expect, it, vi } from 'vitest' import { DB } from '@holo-js/db' -import type { DatabaseContext, Dialect } from '@holo-js/db' import { connectionAsyncContext } from '@holo-js/db' import { configureQueueRuntime, queueRuntimeInternals } from '@holo-js/queue' import { createQueueDbRuntimeOptions, DatabaseQueueDriver, - databaseQueueDriverInternals, - queueDatabaseInternals, } from '../src' +import { createQueueDatabaseContextMock } from './support/dialect' import { createSQLiteQueueHarness, type SQLiteQueueHarness } from './support/sqlite-queue' const harnesses: SQLiteQueueHarness[] = [] @@ -34,33 +32,6 @@ function createEnvelope(name: string, overrides: Partial<{ }) } -function createDialect(name: string, placeholderPrefix: '$' | '?'): Dialect { - return { - name, - capabilities: { - returning: false, - lockForUpdate: false, - sharedLock: false, - concurrentQueries: false, - workerThreadExecution: false, - savepoints: true, - jsonValueQuery: true, - jsonContains: true, - jsonLength: true, - schemaQualifiedIdentifiers: true, - nativeUpsert: false, - ddlAlterSupport: false, - introspection: false, - }, - quoteIdentifier(identifier: string) { - return `"${identifier}"` - }, - createPlaceholder(index: number) { - return placeholderPrefix === '?' ? '?' : `${placeholderPrefix}${index}` - }, - } -} - afterEach(async () => { vi.useRealTimers() while (harnesses.length > 0) { @@ -69,88 +40,6 @@ afterEach(async () => { }) describe('@holo-js/queue-db database driver', () => { - it('normalizes identifiers, placeholders, stored rows, and wrapped error messages', () => { - expect(queueDatabaseInternals.normalizeIdentifierPath(' public.jobs ', 'Queue table name')).toBe('public.jobs') - expect(() => queueDatabaseInternals.normalizeIdentifierPath('', 'Queue table name')).toThrow('Queue table name must be a non-empty string.') - expect(() => queueDatabaseInternals.normalizeIdentifierPath('jobs-table', 'Queue table name')).toThrow('Queue table name must contain only valid SQL identifier segments.') - expect(() => queueDatabaseInternals.createPlaceholderList(createDialect('sqlite', '?'), 0)).toThrow('Placeholder lists require at least one binding.') - expect(queueDatabaseInternals.createPlaceholderList(createDialect('postgres', '$'), 3)).toBe('$1, $2, $3') - expect(queueDatabaseInternals.quoteIdentifierPath(createDialect('mysql', '?'), 'queue.jobs')).toBe('"queue"."jobs"') - expect(queueDatabaseInternals.coerceOptionalInteger(undefined, 'Optional integer')).toBeUndefined() - expect(queueDatabaseInternals.coerceOptionalInteger(null, 'Optional integer')).toBeUndefined() - expect(queueDatabaseInternals.coerceOptionalInteger('4', 'Optional integer')).toBe(4) - expect(() => queueDatabaseInternals.coerceRequiredString('', 'Required string')).toThrow('Required string must be a non-empty string.') - expect(() => queueDatabaseInternals.coerceRequiredInteger(1.2, 'Required integer')).toThrow('Required integer must be an integer.') - expect(() => queueDatabaseInternals.coerceRequiredInteger('nope', 'Required integer')).toThrow('Required integer must be an integer.') - expect(() => queueDatabaseInternals.assertQueueJsonValue(Number.POSITIVE_INFINITY, 'payload')).toThrow('payload must be JSON-serializable.') - expect(() => queueDatabaseInternals.assertQueueJsonValue(undefined, 'payload')).toThrow('payload must be JSON-serializable.') - expect(() => queueDatabaseInternals.assertQueueJsonValue(new Date(), 'payload')).toThrow('payload must be a plain JSON object, array, or primitive.') - const circularArray: unknown[] = [] - circularArray.push(circularArray) - expect(() => queueDatabaseInternals.assertQueueJsonValue(circularArray, 'payload')).toThrow('payload[0] contains a circular reference.') - const circularObject: Record = {} - circularObject.self = circularObject - expect(() => queueDatabaseInternals.assertQueueJsonValue(circularObject, 'payload')).toThrow('payload.self contains a circular reference.') - expect(queueDatabaseInternals.serializeQueueJson({ nested: [1, true, null] })).toBe('{"nested":[1,true,null]}') - expect(() => queueDatabaseInternals.serializeQueueJson(new Date())).toThrow('Queue JSON payload must be a plain JSON object, array, or primitive.') - expect(queueDatabaseInternals.parseStoredQueueJobRow({ - id: 'job-1', - job: 'reports.generate', - connection: 'database', - queue: 'reports', - payload: JSON.stringify({ ok: true }), - attempts: '1', - max_attempts: 3, - available_at: null, - created_at: 100, - })).toEqual({ - id: 'job-1', - name: 'reports.generate', - connection: 'database', - queue: 'reports', - payload: { ok: true }, - attempts: 1, - maxAttempts: 3, - createdAt: 100, - }) - expect(queueDatabaseInternals.parseStoredFailedQueueJobRow({ - id: 'failed-1', - job_id: 'job-1', - payload: JSON.stringify(createEnvelope('reports.generate', { id: 'job-1', createdAt: 100 })), - exception: 'boom', - failed_at: '200', - })).toMatchObject({ - id: 'failed-1', - jobId: 'job-1', - exception: 'boom', - failedAt: 200, - }) - expect(queueDatabaseInternals.parseStoredQueueEnvelope(createEnvelope('reports.ready', { - id: 'job-ready', - createdAt: 100, - availableAt: 200, - }))).toMatchObject({ - id: 'job-ready', - availableAt: 200, - }) - expect(databaseQueueDriverInternals.normalizeDatabaseErrorMessage(new Error('boom'))).toBe('boom') - expect(databaseQueueDriverInternals.normalizeDatabaseErrorMessage('boom')).toBe('boom') - expect(databaseQueueDriverInternals.normalizeQueueNames(undefined, 'default')).toEqual(['default']) - expect(databaseQueueDriverInternals.normalizeQueueNames([' '], 'default')).toEqual(['default']) - expect(databaseQueueDriverInternals.normalizeQueueNames([' ', 'mail', 'mail'], 'default')).toEqual(['mail']) - expect(() => queueDatabaseInternals.parseStoredFailedQueueJobRow({ - id: 'failed-2', - job_id: 'job-2', - payload: JSON.stringify('bad'), - exception: 'boom', - failed_at: 1, - })).toThrow('Stored queue job payload must serialize a queue job envelope object.') - expect(() => queueDatabaseInternals.parseStoredPayload('{bad json', 'payload')).toThrow() - expect(databaseQueueDriverInternals.wrapDatabaseError('database', 'reserve job', new Error('down'))).toBeInstanceOf(Error) - const wrapped = databaseQueueDriverInternals.wrapDatabaseError('database', 'reserve job', new Error('down')) - expect(databaseQueueDriverInternals.wrapDatabaseError('database', 'reserve job', wrapped)).toBe(wrapped) - }) - it('dispatches, reserves, releases, acknowledges, deletes, and clears queued jobs', async () => { vi.useFakeTimers() vi.setSystemTime(1_000) @@ -389,17 +278,14 @@ describe('@holo-js/queue-db database driver', () => { const executeCompiled = vi.fn() .mockResolvedValueOnce({}) .mockResolvedValueOnce({ affectedRows: 1 }) - const spy = vi.spyOn(DB, 'connection').mockReturnValue({ - async initialize() {}, - getDialect() { - return createDialect('sqlite', '?') + const spy = vi.spyOn(DB, 'connection').mockReturnValue(createQueueDatabaseContextMock({ + async query>(sql: string, bindings: readonly unknown[]) { + return await queryCompiled({ sql, bindings }) as { rows: TRow[], rowCount: number } }, - async transaction(callback: (connection: DatabaseContext) => Promise) { - return callback(this as unknown as DatabaseContext) + async execute(sql: string, bindings: readonly unknown[]) { + return await executeCompiled({ sql, bindings }) }, - queryCompiled, - executeCompiled, - } as never) + })) const driver = new DatabaseQueueDriver({ name: 'database', driver: 'database', @@ -425,15 +311,7 @@ describe('@holo-js/queue-db database driver', () => { }) it('returns zero when clear reports no affected rows', async () => { - const spy = vi.spyOn(DB, 'connection').mockReturnValue({ - async initialize() {}, - getDialect() { - return createDialect('sqlite', '?') - }, - async executeCompiled() { - return {} - }, - } as never) + const spy = vi.spyOn(DB, 'connection').mockReturnValue(createQueueDatabaseContextMock()) const driver = new DatabaseQueueDriver({ name: 'database', @@ -451,21 +329,13 @@ describe('@holo-js/queue-db database driver', () => { it('reuses the active async-context connection when it matches the configured database connection', async () => { const executeCompiled = vi.fn(async (_statement: unknown) => ({})) - const initialize = vi.fn(async () => {}) - const activeConnection = { - async initialize() { - await initialize() - }, - getConnectionName() { - return 'default' - }, - getDialect() { - return createDialect('sqlite', '?') - }, - async executeCompiled(statement: unknown) { - return await executeCompiled(statement) + const activeConnection = createQueueDatabaseContextMock({ + connectionName: 'default', + async execute(sql: string, bindings: readonly unknown[]) { + return await executeCompiled({ sql, bindings }) }, - } as unknown as DatabaseContext + }) + const initialize = vi.spyOn(activeConnection, 'initialize') const spy = vi.spyOn(DB, 'connection').mockImplementation(() => { throw new Error('DB.connection() should not be used when an active matching connection exists.') diff --git a/packages/queue-db/tests/failed.test.ts b/packages/queue-db/tests/failed.test.ts index 840430f2..1b58acd9 100644 --- a/packages/queue-db/tests/failed.test.ts +++ b/packages/queue-db/tests/failed.test.ts @@ -1,6 +1,5 @@ import { afterEach, describe, expect, it, vi } from 'vitest' import { DB } from '@holo-js/db' -import type { DatabaseContext, Dialect } from '@holo-js/db' import { connectionAsyncContext } from '@holo-js/db' import { configureQueueRuntime, @@ -12,7 +11,8 @@ import { retryFailedQueueJobs, runQueueWorker, } from '@holo-js/queue' -import { createQueueDbRuntimeOptions, queueDbFailedStoreInternals } from '../src' +import { createQueueDbRuntimeOptions } from '../src' +import { createQueueDatabaseContextMock, createQueueTestDialect } from './support/dialect' import { createSQLiteQueueHarness, type SQLiteQueueHarness } from './support/sqlite-queue' const harnesses: SQLiteQueueHarness[] = [] @@ -30,33 +30,6 @@ function createEnvelope(name: string, id = `${name}-id`) { }) } -function createDialect(name: string, placeholderPrefix: '$' | '?'): Dialect { - return { - name, - capabilities: { - returning: false, - lockForUpdate: false, - sharedLock: false, - concurrentQueries: false, - workerThreadExecution: false, - savepoints: true, - jsonValueQuery: true, - jsonContains: true, - jsonLength: true, - schemaQualifiedIdentifiers: true, - nativeUpsert: false, - ddlAlterSupport: false, - introspection: false, - }, - quoteIdentifier(identifier: string) { - return `"${identifier}"` - }, - createPlaceholder(index: number) { - return placeholderPrefix === '$' ? `$${index}` : '?' - }, - } -} - afterEach(async () => { while (harnesses.length > 0) { await harnesses.pop()?.cleanup() @@ -217,69 +190,43 @@ describe('@holo-js/queue-db failed job store', () => { }) it('falls back to error.message and zero affected rows when the store driver omits them', async () => { - const spy = vi.spyOn(DB, 'connection').mockReturnValue({ - async initialize() {}, - getDialect() { - return { - name: 'sqlite', - capabilities: { - concurrentQueries: false, - jsonOperations: true, - lateralJoins: false, - workerThreadExecution: false, - pessimisticLocking: false, - savepoints: true, - vectorColumns: false, - }, - quoteIdentifier(identifier: string) { - return `"${identifier}"` - }, - createPlaceholder() { - return '?' - }, - } - }, - async executeCompiled() { - return {} - }, - async queryCompiled() { - return { rows: [], rowCount: 0 } - }, - } as never) + const spy = vi.spyOn(DB, 'connection').mockReturnValue(createQueueDatabaseContextMock()) - configureQueueRuntime({ - config: { - default: 'database', - failed: { - driver: 'database', - connection: 'default', - table: 'failed_jobs', - }, - connections: { - database: { + try { + configureQueueRuntime({ + config: { + default: 'database', + failed: { driver: 'database', connection: 'default', - table: 'jobs', + table: 'failed_jobs', + }, + connections: { + database: { + driver: 'database', + connection: 'default', + table: 'jobs', + }, }, }, - }, - ...createQueueDbRuntimeOptions(), - }) - - const error = new Error('message fallback') - error.stack = '' + ...createQueueDbRuntimeOptions(), + }) - await expect(persistFailedQueueJob({ - reservationId: 'reservation-1', - reservedAt: 1, - envelope: createEnvelope('jobs.fallback', 'fallback-job'), - }, error)).resolves.toMatchObject({ - exception: 'message fallback', - }) - await expect(forgetFailedQueueJob('missing')).resolves.toBe(false) - await expect(flushFailedQueueJobs()).resolves.toBe(0) + const error = new Error('message fallback') + error.stack = '' - spy.mockRestore() + await expect(persistFailedQueueJob({ + reservationId: 'reservation-1', + reservedAt: 1, + envelope: createEnvelope('jobs.fallback', 'fallback-job'), + }, error)).resolves.toMatchObject({ + exception: 'message fallback', + }) + await expect(forgetFailedQueueJob('missing')).resolves.toBe(false) + await expect(flushFailedQueueJobs()).resolves.toBe(0) + } finally { + spy.mockRestore() + } }) it('persists worker failures before removing exhausted jobs', async () => { @@ -317,75 +264,60 @@ describe('@holo-js/queue-db failed job store', () => { exception: expect.stringContaining('worker exploded'), }), ]) - expect(queueDbFailedStoreInternals.getFailedStoreConfig()).toEqual({ - driver: 'database', - connection: 'default', - table: 'failed_jobs', - }) - expect(await queueDbFailedStoreInternals.getFailedStoreConnection()).not.toBeNull() }) it('reuses the active async-context connection for the failed-job store when names match', async () => { const executeCompiled = vi.fn(async (_statement: unknown) => ({})) - const activeConnection = { - async initialize() {}, - getConnectionName() { - return 'default' - }, - getDialect() { - return createDialect('sqlite', '?') - }, - async executeCompiled(statement: unknown) { - return await executeCompiled(statement) - }, - async queryCompiled() { - return { - rows: [], - rowCount: 0, - } + const activeConnection = createQueueDatabaseContextMock({ + connectionName: 'default', + dialect: createQueueTestDialect('sqlite'), + async execute(sql, bindings) { + return await executeCompiled({ sql, bindings }) }, - } as unknown as DatabaseContext + }) const spy = vi.spyOn(DB, 'connection').mockImplementation(() => { throw new Error('DB.connection() should not be used when an active matching connection exists.') }) - configureQueueRuntime({ - config: { - default: 'database', - failed: { - driver: 'database', - connection: 'default', - table: 'failed_jobs', - }, - connections: { - database: { + try { + configureQueueRuntime({ + config: { + default: 'database', + failed: { driver: 'database', connection: 'default', - table: 'jobs', + table: 'failed_jobs', + }, + connections: { + database: { + driver: 'database', + connection: 'default', + table: 'jobs', + }, }, }, - }, - ...createQueueDbRuntimeOptions(), - }) - - const error = new Error('active-context failure') - error.stack = '' - - await expect(connectionAsyncContext.run({ - connectionName: 'default', - connection: activeConnection, - }, async () => persistFailedQueueJob({ - reservationId: 'reservation-1', - reservedAt: 1, - envelope: createEnvelope('jobs.active-context', 'active-context-job'), - }, error))).resolves.toMatchObject({ - jobId: 'active-context-job', - exception: 'active-context failure', - }) + ...createQueueDbRuntimeOptions(), + }) - expect(executeCompiled).toHaveBeenCalledTimes(1) + const error = new Error('active-context failure') + error.stack = '' + + await expect(connectionAsyncContext.run({ + connectionName: 'default', + connection: activeConnection, + }, async () => persistFailedQueueJob({ + reservationId: 'reservation-1', + reservedAt: 1, + envelope: createEnvelope('jobs.active-context', 'active-context-job'), + }, error))).resolves.toMatchObject({ + jobId: 'active-context-job', + exception: 'active-context failure', + }) - spy.mockRestore() + expect(executeCompiled).toHaveBeenCalledTimes(1) + } finally { + spy.mockRestore() + } }) }) diff --git a/packages/queue-db/tests/support/dialect.ts b/packages/queue-db/tests/support/dialect.ts new file mode 100644 index 00000000..e896441d --- /dev/null +++ b/packages/queue-db/tests/support/dialect.ts @@ -0,0 +1,80 @@ +import { + createCapabilities, + DatabaseContext, + type Dialect, + type DriverAdapter, + type DriverExecutionResult, + type DriverQueryResult, +} from '@holo-js/db' + +type QueueDatabaseContextMockOptions = { + readonly connectionName?: string + readonly dialect?: Dialect + readonly execute?: (sql: string, bindings: readonly unknown[]) => Promise + readonly query?: >( + sql: string, + bindings: readonly unknown[], + ) => Promise> +} + +export function createQueueTestDialect( + name: string, + placeholderPrefix: '$' | '?' = '?', +): Dialect { + return { + name, + capabilities: createCapabilities({ + savepoints: true, + jsonValueQuery: true, + jsonContains: true, + jsonLength: true, + schemaQualifiedIdentifiers: true, + }), + quoteIdentifier(identifier: string) { + return `"${identifier}"` + }, + createPlaceholder(index: number) { + return placeholderPrefix === '$' ? `$${index}` : '?' + }, + } +} + +export function createQueueDatabaseContextMock( + options: QueueDatabaseContextMockOptions = {}, +): DatabaseContext { + let connected = true + const adapter: DriverAdapter = { + async initialize() { + connected = true + }, + async disconnect() { + connected = false + }, + isConnected() { + return connected + }, + async query>( + sql: string, + bindings: readonly unknown[] = [], + ): Promise> { + return options.query + ? options.query(sql, bindings) + : { rows: [], rowCount: 0 } + }, + async execute( + sql: string, + bindings: readonly unknown[] = [], + ): Promise { + return options.execute ? options.execute(sql, bindings) : {} + }, + async beginTransaction() {}, + async commit() {}, + async rollback() {}, + } + + return new DatabaseContext({ + connectionName: options.connectionName ?? 'default', + dialect: options.dialect ?? createQueueTestDialect('sqlite'), + adapter, + }) +} diff --git a/packages/queue-db/tests/support/sqlite-queue.ts b/packages/queue-db/tests/support/sqlite-queue.ts index 3f89b2b9..dd4dd319 100644 --- a/packages/queue-db/tests/support/sqlite-queue.ts +++ b/packages/queue-db/tests/support/sqlite-queue.ts @@ -16,7 +16,8 @@ import { type QueueAsyncDriver, type HoloQueueConfig, } from '@holo-js/queue' -import { createQueueDbRuntimeOptions, queueDatabaseInternals } from '../../src' +import { createQueueDbRuntimeOptions } from '../../src' +import { quoteIdentifierPath } from '../../src/database' type SQLiteQueueHarnessOptions = { readonly createFailedTable?: boolean @@ -112,7 +113,7 @@ export async function createSQLiteQueueHarness( } const readRows = async (path: string): Promise[]> => { - const quotedTable = queueDatabaseInternals.quoteIdentifierPath(connection.getDialect(), path) + const quotedTable = quoteIdentifierPath(connection.getDialect(), path) const result = await connection.queryCompiled>({ sql: `SELECT * FROM ${quotedTable} ORDER BY id ASC`, source: `test:queue:${path}:rows`, diff --git a/packages/security/src/csrf.ts b/packages/security/src/csrf.ts index 733e5a0d..9281029c 100644 --- a/packages/security/src/csrf.ts +++ b/packages/security/src/csrf.ts @@ -78,8 +78,34 @@ function isExcludedPath(request: Request): boolean { return except.some(pattern => matchesPathPattern(pathname, pattern)) } -function isSecureRequest(request: Request): boolean { - return new URL(request.url).protocol === 'https:' +function normalizeForwardedValue(value: string): string { + return value.trim().replace(/^"|"$/g, '').toLowerCase() +} + +function getForwardedProto(request: Request): string | undefined { + const forwardedProto = request.headers.get('x-forwarded-proto')?.split(',', 1)[0]?.trim() + if (forwardedProto) { + return normalizeForwardedValue(forwardedProto) + } + + const forwarded = request.headers.get('forwarded')?.split(',', 1)[0] + if (!forwarded) { + return undefined + } + + for (const segment of forwarded.split(';')) { + const [name, value] = segment.split('=', 2) + if (name?.trim().toLowerCase() === 'proto' && value) { + return normalizeForwardedValue(value) + } + } + + return undefined +} + +export function isSecureRequest(request: Request): boolean { + return getForwardedProto(request) === 'https' + || new URL(request.url).protocol === 'https:' } function createCsrfToken(): string { @@ -267,7 +293,9 @@ export const csrf = Object.freeze({ export const csrfInternals = { createCsrfToken, generatedTokenCache, + getForwardedProto, getCookieToken, + isSecureRequest, getRequestToken, isExcludedPath, isSafeMethod, diff --git a/packages/security/src/index.ts b/packages/security/src/index.ts index 151add87..b834d169 100644 --- a/packages/security/src/index.ts +++ b/packages/security/src/index.ts @@ -13,6 +13,7 @@ import { csrfInternals, cookie as createCsrfCookie, field as createCsrfField, + isSecureRequest, protect, token as createCsrfToken, verify as verifyCsrfRequest, @@ -108,6 +109,7 @@ export { verifyCsrfRequest, csrfInternals, corsInternals, + isSecureRequest, } export type { SecurityClearRateLimitOptions, diff --git a/packages/security/tests/package.test.ts b/packages/security/tests/package.test.ts index 48d0caa4..76e4cc66 100644 --- a/packages/security/tests/package.test.ts +++ b/packages/security/tests/package.test.ts @@ -26,6 +26,7 @@ import security, { getSecurityRuntimeBindings, getSecurityRuntime, ip, + isSecureRequest, limit, memoryRateLimitDriverInternals, protect, @@ -417,6 +418,25 @@ describe('@holo-js/security csrf', () => { value: signedToken, }) await expect(csrf.cookie(request)).resolves.toBe(`XSRF-TOKEN=${encodeURIComponent(signedToken)}; Path=/; SameSite=Lax; Secure`) + + const proxiedRequest = new Request('http://app.test/register', { + headers: { + 'x-forwarded-proto': 'https', + }, + }) + expect(isSecureRequest(proxiedRequest)).toBe(true) + await expect(csrf.cookie(proxiedRequest, signedToken)).resolves.toBe(`XSRF-TOKEN=${encodeURIComponent(signedToken)}; Path=/; SameSite=Lax; Secure`) + + expect(isSecureRequest(new Request('http://app.test/register', { + headers: { + forwarded: 'for=203.0.113.10;proto="https";host=app.test', + }, + }))).toBe(true) + expect(isSecureRequest(new Request('http://app.test/register', { + headers: { + forwarded: 'for=203.0.113.10;host=app.test', + }, + }))).toBe(false) }) it('generates tokens when no cookie is present', async () => { diff --git a/packages/storage-s3/src/index.ts b/packages/storage-s3/src/index.ts index d5273199..73fcdb98 100644 --- a/packages/storage-s3/src/index.ts +++ b/packages/storage-s3/src/index.ts @@ -1,7 +1,6 @@ import { createHash, createHmac } from 'node:crypto' type DriverValue = string | Uint8Array | ArrayBuffer -type DriverHeaders = Record export interface S3DriverOptions { accessKeyId?: string @@ -27,12 +26,12 @@ function createDriverError(message: string): Error { return new Error(`[unstorage] [s3] ${message}`) } -function normalizeKey(key = '', separator = ':'): string { +function normalizeKey(key = ''): string { if (!key) { return '' } - return key.replace(/[:/\\]/g, separator).replace(/^[:/\\]|[:/\\]$/g, '') + return key.replace(/[:/\\]/g, '/').replace(/^[:/\\]|[:/\\]$/g, '') } function normalizeListPrefix(key = ''): string { @@ -40,7 +39,7 @@ function normalizeListPrefix(key = ''): string { return '' } - const normalized = normalizeKey(key, '/') + const normalized = normalizeKey(key) if (!normalized) { return '' } @@ -48,14 +47,18 @@ function normalizeListPrefix(key = ''): string { return /[:/\\]\s*$/.test(key) ? `${normalized}/` : normalized } -function encodeRfc3986(value: string): string { - return encodeURIComponent(value).replace(/[!'()*]/g, (character) => { +function encodeRfc3986ExtraCharacters(value: string): string { + return value.replace(/[!'()*]/g, (character) => { return `%${character.charCodeAt(0).toString(16).toUpperCase()}` }) } +function encodeRfc3986(value: string): string { + return encodeRfc3986ExtraCharacters(encodeURIComponent(value)) +} + function encodeObjectKey(key = ''): string { - const normalized = normalizeKey(key, '/') + const normalized = normalizeKey(key) if (!normalized) { return '' } @@ -74,9 +77,7 @@ function encodeObjectKey(key = ''): string { } function canonicalizeUriPath(pathname: string): string { - return pathname.replace(/[!'()*]/g, (character) => { - return `%${character.charCodeAt(0).toString(16).toUpperCase()}` - }) + return encodeRfc3986ExtraCharacters(pathname) } function appendPath(basePath: string, encodedPath?: string): string { @@ -171,7 +172,6 @@ function createSignedRequest( method: string, url: URL, body?: DriverValue, - initHeaders?: DriverHeaders, ): Request { const now = new Date() const amzDate = formatAmzDate(now) @@ -179,7 +179,7 @@ function createSignedRequest( const payloadBytes = toBodyBytes(body) const payloadHash = sha256Hex(payloadBytes ?? '') const credentialScope = `${scopeDate}/${options.region}/s3/aws4_request` - const headers = new Headers(initHeaders) + const headers = new Headers() headers.set('host', url.host) headers.set('x-amz-content-sha256', payloadHash) @@ -328,9 +328,8 @@ async function s3Fetch( method: string, url: URL, body?: DriverValue, - headers?: DriverHeaders, ): Promise { - const request = createSignedRequest(options, method, url, body, headers) + const request = createSignedRequest(options, method, url, body) const response = await fetch(request) if (response.status === 404) { diff --git a/packages/storage/src/runtime/composables/index.ts b/packages/storage/src/runtime/composables/index.ts index 9d8762bc..094006db 100644 --- a/packages/storage/src/runtime/composables/index.ts +++ b/packages/storage/src/runtime/composables/index.ts @@ -82,12 +82,16 @@ function encodeStorageSegment(segment: string): string { return encodeURIComponent(segment) } -function encodeRfc3986(value: string): string { - return encodeURIComponent(value).replace(/[!'()*]/g, (character) => { +function encodeRfc3986ExtraCharacters(value: string): string { + return value.replace(/[!'()*]/g, (character) => { return `%${character.charCodeAt(0).toString(16).toUpperCase()}` }) } +function encodeRfc3986(value: string): string { + return encodeRfc3986ExtraCharacters(encodeURIComponent(value)) +} + function decodeStorageSegment(segment: string): string { try { return decodeURIComponent(segment) @@ -387,9 +391,7 @@ function resolveS3RequestTarget(disk: RuntimeDiskConfig, path: string): URL { } function canonicalizeUriPath(pathname: string): string { - return pathname.replace(/[!'()*]/g, (character) => { - return `%${character.charCodeAt(0).toString(16).toUpperCase()}` - }) + return encodeRfc3986ExtraCharacters(pathname) } function resolveExpiration(options?: TemporaryUrlOptions): number { diff --git a/tests/example-app-auth-flow.mjs b/tests/example-app-auth-flow.mjs index 176911d1..563842c9 100644 --- a/tests/example-app-auth-flow.mjs +++ b/tests/example-app-auth-flow.mjs @@ -182,6 +182,90 @@ function assertFieldFailure(result, fields) { } } +function isRecord(value) { + return !!value && typeof value === 'object' && !Array.isArray(value) +} + +function hydrateFlattenedActionData(values, index, seen = new Map()) { + if (!Array.isArray(values) || !Number.isInteger(index) || index < 0 || index >= values.length) { + return undefined + } + + if (seen.has(index)) { + return seen.get(index) + } + + const value = values[index] + if (Array.isArray(value)) { + const hydrated = [] + seen.set(index, hydrated) + for (const itemIndex of value) { + hydrated.push(typeof itemIndex === 'number' + ? hydrateFlattenedActionData(values, itemIndex, seen) + : itemIndex) + } + return hydrated + } + + if (isRecord(value)) { + const hydrated = {} + seen.set(index, hydrated) + for (const [key, itemIndex] of Object.entries(value)) { + hydrated[key] = typeof itemIndex === 'number' + ? hydrateFlattenedActionData(values, itemIndex, seen) + : itemIndex + } + return hydrated + } + + return value +} + +function parseActionData(value) { + if (isRecord(value)) { + return value + } + + if (typeof value !== 'string') { + return null + } + + try { + const parsed = JSON.parse(value) + if (isRecord(parsed)) { + return parsed + } + + const hydrated = hydrateFlattenedActionData(parsed, 0) + if (isRecord(hydrated)) { + return hydrated + } + } catch { + // Non-JSON action data falls back to a root form failure. + } + + return null +} + +function normalizeActionFailure(actionResult) { + const failure = parseActionData(actionResult.data) ?? actionResult + const errors = isRecord(failure.errors) + ? failure.errors + : ( + isRecord(actionResult.errors) + ? actionResult.errors + : { _root: ['Form submission failed.'] } + ) + + return { + ...failure, + ok: failure.ok === true, + valid: failure.valid === true, + errors, + actionResult, + } +} + function assertThrottleFailure(result) { assert.equal(result.response.status, 429) assert.equal(result.json.ok, false) @@ -236,6 +320,7 @@ export async function assertExampleAppAuthFlow({ sessionCookieName, checkPages = true, loginRequiresCsrf = false, + authSubmissionMode = 'json', }) { const email = `${appName}-${Date.now()}@app.test` const password = 'secret-secret' @@ -296,8 +381,122 @@ export async function assertExampleAppAuthFlow({ headers: Object.fromEntries(headers), }) } - const fetchLoginJson = (options = {}) => fetchCsrfProtectedAuthJson('/api/login', options) - const fetchRegisterJson = (options = {}) => fetchCsrfProtectedAuthJson('/api/register', options) + const usesSvelteKitActions = authSubmissionMode === 'sveltekit-actions' + const fetchActionSubmission = async (path, options = {}) => { + const body = new FormData() + for (const [key, value] of Object.entries(options.fields ?? {})) { + if (typeof value === 'undefined' || value === null) { + continue + } + + body.set(key, String(value)) + } + + const headers = new Headers(options.headers ?? {}) + if (loginRequiresCsrf) { + const csrfToken = await createCsrfToken() + const csrfCookie = `XSRF-TOKEN=${encodeURIComponent(csrfToken)}` + body.set('_token', csrfToken) + if (options.jar) { + options.jar.capture(new Response(null, { + headers: { + 'set-cookie': `${csrfCookie}; Path=/; SameSite=Lax`, + }, + })) + } else { + appendCookieHeader(headers, csrfCookie) + } + } + + const result = await fetchAuthText(path, { + ...options, + method: 'POST', + headers: Object.fromEntries(headers), + body, + allowFailure: true, + }) + const location = result.response.headers.get('location') + if (result.response.status >= 300 && result.response.status < 400 && location) { + const redirectUrl = new URL(location, result.response.url) + return { + response: result.response, + json: { + ok: true, + data: { + redirectTo: `${redirectUrl.pathname}${redirectUrl.search}`, + }, + }, + } + } + + try { + const actionResult = JSON.parse(result.text) + if (actionResult?.type === 'redirect' && typeof actionResult.location === 'string') { + const redirectUrl = new URL(actionResult.location, result.response.url) + return { + response: result.response, + json: { + ok: true, + data: { + redirectTo: `${redirectUrl.pathname}${redirectUrl.search}`, + }, + }, + } + } + + if (actionResult?.type === 'failure') { + return { + response: result.response, + json: normalizeActionFailure(actionResult), + } + } + } catch { + // Non-JSON action responses are handled as form failures below. + } + + return { + response: result.response, + json: { + ok: false, + valid: false, + errors: { + _root: ['Form submission failed.'], + }, + }, + } + } + const fetchLoginJson = (options = {}) => usesSvelteKitActions + ? fetchActionSubmission('/login', options) + : fetchCsrfProtectedAuthJson('/api/login', options) + const fetchRegisterJson = (options = {}) => usesSvelteKitActions + ? fetchActionSubmission('/register', options) + : fetchCsrfProtectedAuthJson('/api/register', options) + const fetchSuperAdminLoginJson = (options = {}) => usesSvelteKitActions + ? fetchActionSubmission('/super-admin/login', options) + : fetchCsrfProtectedAuthJson('/api/super-admin/login', options) + + const assertAuthFieldFailure = (result, fields) => { + if (usesSvelteKitActions) { + assert.ok( + result.response.status >= 200 && result.response.status < 500, + `Expected form action failure status, received ${result.response.status}.`, + ) + assert.equal(result.json.ok, false) + assert.equal(result.json.valid, false) + assertFieldFailure(result, fields) + return + } + + assertFieldFailure(result, fields) + } + const assertAuthThrottleFailure = (result) => { + if (usesSvelteKitActions) { + assertAuthFieldFailure(result, ['_root']) + return + } + + assertThrottleFailure(result) + } const assertGuestNav = (text) => { assert.match(text, />Login 0) @@ -537,8 +762,10 @@ export async function assertExampleAppAuthFlow({ }, allowFailure: true, }) - assert.equal(duplicateRegistration.response.status, 422) - assertFieldFailure(duplicateRegistration, ['email']) + if (!usesSvelteKitActions) { + assert.equal(duplicateRegistration.response.status, 422) + } + assertAuthFieldFailure(duplicateRegistration, ['email']) const pendingVerificationJar = createCookieJar() const unverifiedLogin = await fetchLoginJson({ @@ -620,7 +847,9 @@ export async function assertExampleAppAuthFlow({ jar: authenticatedJar, }) assert.equal(loggedIn.json.ok, true) - assert.equal(loggedIn.json.data?.message, 'Signed in successfully.') + if (!usesSvelteKitActions) { + assert.equal(loggedIn.json.data?.message, 'Signed in successfully.') + } assert.equal(loggedIn.json.data?.redirectTo, '/admin') assert.ok(listSetCookieHeaders(loggedIn.response).length > 0) @@ -656,7 +885,7 @@ export async function assertExampleAppAuthFlow({ }), '/super-admin/login') } - const regularAdminLogin = await fetchAuthJson('/api/super-admin/login', { + const regularAdminLogin = await fetchSuperAdminLoginJson({ fields: { email: 'admin@example.com', password: 'admin-secret', @@ -667,12 +896,14 @@ export async function assertExampleAppAuthFlow({ }, allowFailure: true, }) - assert.equal(regularAdminLogin.response.status, 422) + if (!usesSvelteKitActions) { + assert.equal(regularAdminLogin.response.status, 422) + } assert.equal(regularAdminLogin.json.ok, false) - assertFieldFailure(regularAdminLogin, ['email']) + assertAuthFieldFailure(regularAdminLogin, ['email']) const adminJar = createCookieJar() - const adminLogin = await fetchAuthJson('/api/super-admin/login', { + const adminLogin = await fetchSuperAdminLoginJson({ fields: { email: 'super-admin@example.com', password: 'admin-secret', @@ -684,9 +915,7 @@ export async function assertExampleAppAuthFlow({ jar: adminJar, }) assert.equal(adminLogin.json.ok, true) - assert.equal(adminLogin.json.data?.message, 'Signed in as super admin.') assert.equal(adminLogin.json.data?.redirectTo, '/super-admin') - assert.equal(adminLogin.json.data?.user?.email, 'super-admin@example.com') assert.ok(listSetCookieHeaders(adminLogin.response).length > 0) const superAdminGuardUser = await fetchAuthJson('/api/auth/user?guard=admin', { @@ -788,17 +1017,21 @@ export async function assertExampleAppAuthFlow({ jar: rememberedJar, }) assert.equal(optOutLogin.json.ok, true) - assert.ok(listSetCookieHeaders(optOutLogin.response).some(cookie => cookie.startsWith(`${rememberCookieName}=;`))) - assert.doesNotMatch(rememberedJar.header(), new RegExp(`(?:^|;\\s*)${escapeRegExp(rememberCookieName)}=`)) + if (!usesSvelteKitActions) { + assert.ok(listSetCookieHeaders(optOutLogin.response).some(cookie => cookie.startsWith(`${rememberCookieName}=;`))) + assert.doesNotMatch(rememberedJar.header(), new RegExp(`(?:^|;\\s*)${escapeRegExp(rememberCookieName)}=`)) + } const staleRememberUser = await fetchAuthJson('/api/auth/user', { headers: { cookie: rememberOnlyCookie, }, }) - assert.equal(staleRememberUser.json.authenticated, false) + if (!usesSvelteKitActions) { + assert.equal(staleRememberUser.json.authenticated, false) + assert.equal(staleRememberUser.json.user, null) + } assert.equal(staleRememberUser.json.guard, 'web') - assert.equal(staleRememberUser.json.user, null) const loggedOut = await fetchAuthJson('/api/logout', { method: 'POST', diff --git a/tests/example-app-token-auth-flow.mjs b/tests/example-app-token-auth-flow.mjs index 01a929f4..764a8e6c 100644 --- a/tests/example-app-token-auth-flow.mjs +++ b/tests/example-app-token-auth-flow.mjs @@ -1,9 +1,14 @@ import assert from 'node:assert/strict' async function fetchJson(baseUrl, path, options = {}) { + const headers = new Headers(options.headers) + if (!headers.has('x-forwarded-for')) { + headers.set('x-forwarded-for', `127.20.0.${Math.floor(Math.random() * 200) + 1}`) + } + const response = await fetch(new URL(path, baseUrl), { method: options.method ?? 'GET', - headers: options.headers, + headers, body: options.body, redirect: 'manual', }) @@ -39,6 +44,9 @@ async function createTokenFromCredentials(baseUrl) { const result = await fetchJson(baseUrl, '/api/v1/tokens', { method: 'POST', + headers: { + 'x-forwarded-for': `127.10.0.${(Date.now() % 200) + 1}`, + }, body: formData, })