diff --git a/src/subscription/index.ts b/src/subscription/index.ts index 41ab8864f..7e6a56e7d 100644 --- a/src/subscription/index.ts +++ b/src/subscription/index.ts @@ -72,13 +72,30 @@ export const subscription = ((useSWRNext: SWRHook) => subscriptions.set(subscriptionKey, refCount + 1) if (!refCount) { - const dispose = subscribe(args, { next }) - if (typeof dispose !== 'function') { + const result = subscribe(args, { next }) + + if (result && typeof (result as any).then === 'function') { + // Race condition guard: if cleanup runs before the async subscribe + // resolves, the flag tells the resolver to dispose immediately. + let shouldDisposeOnResolve = false + ;(result as Promise<(() => void) | void>).then(dispose => { + if (shouldDisposeOnResolve) { + if (typeof dispose === 'function') dispose() + } else if (typeof dispose === 'function') { + disposers.set(subscriptionKey!, dispose) + } + }) + + disposers.set(subscriptionKey, () => { + shouldDisposeOnResolve = true + }) + } else if (typeof result === 'function') { + disposers.set(subscriptionKey, result) + } else if (typeof result !== 'undefined') { throw new Error( 'The `subscribe` function must return a function to unsubscribe.' ) } - disposers.set(subscriptionKey, dispose) } return () => { diff --git a/src/subscription/types.ts b/src/subscription/types.ts index c9345db0f..cef832eb6 100644 --- a/src/subscription/types.ts +++ b/src/subscription/types.ts @@ -4,16 +4,22 @@ export type SWRSubscriptionOptions = { next: (err?: Error | null, data?: Data | MutatorCallback) => void } +type SWRSubscribeReturn = (() => void) | void +type SWRSubscribeFn = ( + key: Arg, + { next }: SWRSubscriptionOptions +) => SWRSubscribeReturn | Promise + export type SWRSubscription< SWRSubKey extends Key = Key, Data = any, Error = any > = SWRSubKey extends () => infer Arg | null | undefined | false - ? (key: Arg, { next }: SWRSubscriptionOptions) => void + ? SWRSubscribeFn : SWRSubKey extends null | undefined | false ? never : SWRSubKey extends infer Arg - ? (key: Arg, { next }: SWRSubscriptionOptions) => void + ? SWRSubscribeFn : never export type SWRSubscriptionResponse = { diff --git a/test/type/subscription.ts b/test/type/subscription.ts index 42d419ae3..85b707b9f 100644 --- a/test/type/subscription.ts +++ b/test/type/subscription.ts @@ -95,4 +95,20 @@ export function useTestSubscription() { const { data: data2, error: error2 } = useSWRSubscription('key', sub) expectType(data2) expectType(error2) + + // Async subscribe should be accepted. + useSWRSubscription( + 'key', + async (_key, { next: _ }: SWRSubscriptionOptions) => { + return () => {} + } + ) + + const asyncSub: SWRSubscription = async ( + _, + { next: __ } + ) => { + return () => {} + } + useSWRSubscription('key', asyncSub) } diff --git a/test/use-swr-subscription.test.tsx b/test/use-swr-subscription.test.tsx index 9ba56380d..c7285671a 100644 --- a/test/use-swr-subscription.test.tsx +++ b/test/use-swr-subscription.test.tsx @@ -293,6 +293,79 @@ describe('useSWRSubscription', () => { await screen.findByText(`data: 3`) }) + it('should support async subscribe', async () => { + const swrKey = createKey() + let emitter: ((data: string) => void) | null = null + let disposed = false + + async function subscribe(_key, { next }) { + await sleep(50) + emitter = (data: string) => next(undefined, data) + return () => { + disposed = true + emitter = null + } + } + + function Page() { + const { data } = useSWRSubscription(swrKey, subscribe, { + fallbackData: 'fallback' + }) + return
{'data:' + data}
+ } + + renderWithConfig() + screen.getByText('data:fallback') + + // Wait for async subscribe to resolve. + await act(() => sleep(100)) + act(() => emitter?.('hello')) + await act(() => sleep(10)) + screen.getByText('data:hello') + + expect(disposed).toBe(false) + }) + + it('should clean up async subscribe on unmount', async () => { + const swrKey = createKey() + let disposed = false + + async function subscribe(_key, { next }) { + await sleep(100) + next(undefined, 'connected') + return () => { + disposed = true + } + } + + function Page() { + const [show, setShow] = useState(true) + return ( + <> + {show ? : null} + + + ) + } + function Child() { + const { data } = useSWRSubscription(swrKey, subscribe, { + fallbackData: 'fallback' + }) + return
{'data:' + data}
+ } + + renderWithConfig() + screen.getByText('data:fallback') + + // Unmount before async subscribe resolves. + await act(() => sleep(10)) + fireEvent.click(screen.getByText('unmount')) + + // After the Promise resolves, the dispose should still be called. + await act(() => sleep(200)) + expect(disposed).toBe(true) + }) + it('should require a dispose function', async () => { jest.spyOn(console, 'error').mockImplementation(() => {}) @@ -303,6 +376,7 @@ describe('useSWRSubscription', () => { } function Page() { + // @ts-expect-error -- intentionally passing an invalid subscribe function useSWRSubscription(swrKey, subscribe) return null }