diff --git a/packages/cli/src/acp-integration/acpAgent.test.ts b/packages/cli/src/acp-integration/acpAgent.test.ts index 59596c3d53..277ccd79c4 100644 --- a/packages/cli/src/acp-integration/acpAgent.test.ts +++ b/packages/cli/src/acp-integration/acpAgent.test.ts @@ -844,9 +844,10 @@ describe('QwenAgent MCP SSE/HTTP support', () => { sendAvailableCommandsUpdate: vi.fn().mockResolvedValue(undefined), replayHistory: vi.fn().mockResolvedValue(undefined), installRewriter: vi.fn(), - captureHistorySnapshot: vi - .fn() - .mockReturnValue([{ role: 'user', parts: [{ text: 'before' }] }]), + captureHistorySnapshot: vi.fn().mockReturnValue({ + history: [{ role: 'user', parts: [{ text: 'before' }] }], + modelFacingUserTurnCount: 1, + }), restoreHistory: vi.fn(), rewindToTurn: vi .fn() @@ -1719,7 +1720,10 @@ describe('QwenAgent MCP SSE/HTTP support', () => { expect(lastSessionMock?.rewindToTurn).toHaveBeenCalledWith(1); expect(response).toEqual({ success: true, - historyBeforeRewind: [{ role: 'user', parts: [{ text: 'before' }] }], + historyBeforeRewind: { + history: [{ role: 'user', parts: [{ text: 'before' }] }], + modelFacingUserTurnCount: 1, + }, targetTurnIndex: 1, apiTruncateIndex: 2, }); @@ -1843,6 +1847,85 @@ describe('QwenAgent MCP SSE/HTTP support', () => { await agentPromise; }); + it('restoreSessionHistory extension method restores history snapshots', async () => { + const sessionId = '11111111-1111-1111-1111-111111111111'; + await setupSessionMocks(sessionId); + + const agentPromise = runAcpAgent( + mockConfig, + makeSessionSettings(), + mockArgv, + ); + await vi.waitFor(() => expect(capturedAgentFactory).toBeDefined()); + + const agent = capturedAgentFactory!({ + get closed() { + return mockConnectionState.promise; + }, + }) as AgentLike; + + await agent.newSession({ cwd: '/tmp', mcpServers: [] }); + const snapshot = { + history: [{ role: 'user', parts: [{ text: 'restored' }] }], + modelFacingUserTurnCount: 1, + }; + const response = await agent.extMethod('restoreSessionHistory', { + sessionId, + history: snapshot, + cwd: '/tmp', + }); + + expect(lastSessionMock?.restoreHistory).toHaveBeenCalledWith(snapshot); + expect(response).toEqual({ success: true }); + + mockConnectionState.resolve(); + await agentPromise; + }); + + it('restoreSessionHistory rejects invalid history snapshot turn counts', async () => { + const sessionId = '11111111-1111-1111-1111-111111111111'; + await setupSessionMocks(sessionId); + + const agentPromise = runAcpAgent( + mockConfig, + makeSessionSettings(), + mockArgv, + ); + await vi.waitFor(() => expect(capturedAgentFactory).toBeDefined()); + + const agent = capturedAgentFactory!({ + get closed() { + return mockConnectionState.promise; + }, + }) as AgentLike; + + await agent.newSession({ cwd: '/tmp', mcpServers: [] }); + + for (const modelFacingUserTurnCount of [ + NaN, + Infinity, + -Infinity, + -1, + 1.5, + Number.MAX_SAFE_INTEGER + 1, + ]) { + await expect( + agent.extMethod('restoreSessionHistory', { + sessionId, + history: { + history: [], + modelFacingUserTurnCount, + }, + }), + ).rejects.toThrow('Invalid or missing history'); + } + + expect(lastSessionMock?.restoreHistory).not.toHaveBeenCalled(); + + mockConnectionState.resolve(); + await agentPromise; + }); + it('restoreSessionHistory rejects invalid session ids', async () => { await setupSessionMocks('11111111-1111-1111-1111-111111111111'); diff --git a/packages/cli/src/acp-integration/acpAgent.ts b/packages/cli/src/acp-integration/acpAgent.ts index 4a36b20cd5..541db82f2a 100644 --- a/packages/cli/src/acp-integration/acpAgent.ts +++ b/packages/cli/src/acp-integration/acpAgent.ts @@ -78,7 +78,11 @@ import type { ApprovalModeValue } from './session/types.js'; import { z } from 'zod'; import type { CliArgs } from '../config/config.js'; import { loadCliConfig } from '../config/config.js'; -import { Session, buildAvailableCommandsSnapshot } from './session/Session.js'; +import { + Session, + buildAvailableCommandsSnapshot, + type HistorySnapshot, +} from './session/Session.js'; import { formatAcpModelId, parseAcpBaseModelId, @@ -1732,7 +1736,24 @@ class QwenAgent implements Agent { 'Invalid or missing sessionId', ); } - if (!Array.isArray(history)) { + const isHistorySnapshot = + !!history && + typeof history === 'object' && + !Array.isArray(history) && + Array.isArray((history as { history?: unknown }).history) && + Number.isInteger( + (history as { modelFacingUserTurnCount?: unknown }) + .modelFacingUserTurnCount, + ) && + Number.isFinite( + (history as { modelFacingUserTurnCount?: unknown }) + .modelFacingUserTurnCount as number, + ) && + ((history as { modelFacingUserTurnCount?: unknown }) + .modelFacingUserTurnCount as number) >= 0 && + ((history as { modelFacingUserTurnCount?: unknown }) + .modelFacingUserTurnCount as number) <= Number.MAX_SAFE_INTEGER; + if (!Array.isArray(history) && !isHistorySnapshot) { throw RequestError.invalidParams( undefined, 'Invalid or missing history', @@ -1746,7 +1767,7 @@ class QwenAgent implements Agent { ); } - session.restoreHistory(history as Content[]); + session.restoreHistory(history as Content[] | HistorySnapshot); return { success: true }; } case 'getAccountInfo': { diff --git a/packages/cli/src/acp-integration/session/Session.test.ts b/packages/cli/src/acp-integration/session/Session.test.ts index f1124e3e6a..46cc4848c1 100644 --- a/packages/cli/src/acp-integration/session/Session.test.ts +++ b/packages/cli/src/acp-integration/session/Session.test.ts @@ -8,7 +8,11 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import * as fs from 'node:fs/promises'; import * as os from 'node:os'; import * as path from 'node:path'; -import { computeInitialTurnFromHistory, Session } from './Session.js'; +import { + computeInitialModelFacingUserTurnCountFromHistory, + computeInitialTurnFromHistory, + Session, +} from './Session.js'; import type { Content } from '@google/genai'; import type { ChatRecord, Config, GeminiChat } from '@qwen-code/qwen-code-core'; import { ApprovalMode, AuthType } from '@qwen-code/qwen-code-core'; @@ -123,6 +127,50 @@ describe('computeInitialTurnFromHistory', () => { }); }); +describe('computeInitialModelFacingUserTurnCountFromHistory', () => { + it('counts model-facing user text records without slash-only records', () => { + expect( + computeInitialModelFacingUserTurnCountFromHistory( + [ + chatRecord({ + uuid: 'user-1', + message: { parts: [{ text: 'first' }] }, + }), + chatRecord({ + uuid: 'slash-1', + message: { parts: [{ text: '/help' }] }, + }), + chatRecord({ + uuid: 'question-command-1', + message: { parts: [{ text: '?help' }] }, + }), + chatRecord({ + uuid: 'cron-1', + subtype: 'cron', + message: { parts: [{ text: 'cron prompt' }] }, + }), + chatRecord({ + uuid: 'mid-turn-1', + subtype: 'mid_turn_user_message', + message: { parts: [{ text: 'mid-turn prompt text' }] }, + }), + chatRecord({ + uuid: 'notification-1', + subtype: 'notification', + message: { parts: [{ text: 'FYI' }] }, + }), + chatRecord({ + uuid: 'other-session', + sessionId: 'other-session-id', + message: { parts: [{ text: 'other' }] }, + }), + ], + 'test-session-id', + ), + ).toBe(2); + }); +}); + // Helper to create empty async generator (avoids memory leak from inline generators) function createEmptyStream() { return (async function* () {})(); @@ -139,6 +187,13 @@ function createStreamWithChunks( })(); } +function createThrowingStream(error: Error) { + return (async function* () { + yield* []; + throw error; + })(); +} + function expectCompressBeforeSend( compressMock: ReturnType, sendMock: ReturnType, @@ -153,6 +208,18 @@ function expectCompressBeforeSend( ); } +function setSessionTurnCounters( + targetSession: Session, + counters: { turn?: number; modelFacingUserTurnCount?: number }, +) { + Object.assign(targetSession as unknown as Record, counters); +} + +function getSessionModelFacingUserTurnCount(targetSession: Session): number { + return (targetSession as unknown as { modelFacingUserTurnCount: number }) + .modelFacingUserTurnCount; +} + describe('Session', () => { let mockChat: GeminiChat; let mockConfig: Config; @@ -351,6 +418,214 @@ describe('Session', () => { expect(mockChat.truncateHistory).toHaveBeenCalledWith(2); }); + it('maps ACP rewind to the uncompressed tail after chat compression', () => { + setSessionTurnCounters(session, { + turn: 3, + modelFacingUserTurnCount: 3, + }); + const history: Content[] = [ + { role: 'user', parts: [{ text: 'summary of first two turns' }] }, + { + role: 'model', + parts: [{ text: core.COMPRESSION_SUMMARY_MODEL_ACK }], + }, + { role: 'user', parts: [{ text: 'third' }] }, + { role: 'model', parts: [{ text: 'third reply' }] }, + ]; + vi.mocked(mockChat.getHistory).mockReturnValue(history); + + const result = session.rewindToTurn(2); + + expect(result).toEqual({ targetTurnIndex: 2, apiTruncateIndex: 2 }); + expect(mockChat.truncateHistory).toHaveBeenCalledWith(2); + }); + + it('keeps compressed tail reachable after rewind and resend', async () => { + setSessionTurnCounters(session, { + turn: 3, + modelFacingUserTurnCount: 3, + }); + vi.mocked(mockChat.getHistory).mockReturnValue([ + { role: 'user', parts: [{ text: 'summary of first two turns' }] }, + { + role: 'model', + parts: [{ text: core.COMPRESSION_SUMMARY_MODEL_ACK }], + }, + { role: 'user', parts: [{ text: 'third' }] }, + { role: 'model', parts: [{ text: 'third reply' }] }, + ]); + + expect(session.rewindToTurn(2)).toEqual({ + targetTurnIndex: 2, + apiTruncateIndex: 2, + }); + + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValue(createEmptyStream()); + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'new third' }], + }); + + vi.mocked(mockChat.getHistory).mockReturnValue([ + { role: 'user', parts: [{ text: 'summary of first two turns' }] }, + { + role: 'model', + parts: [{ text: core.COMPRESSION_SUMMARY_MODEL_ACK }], + }, + { role: 'user', parts: [{ text: 'new third' }] }, + { role: 'model', parts: [{ text: 'new third reply' }] }, + ]); + mockChat.truncateHistory = vi.fn(); + + expect(session.rewindToTurn(2)).toEqual({ + targetTurnIndex: 2, + apiTruncateIndex: 2, + }); + expect(mockChat.truncateHistory).toHaveBeenCalledWith(2); + }); + + it('updates model-facing turn count through the real prompt send path', async () => { + setSessionTurnCounters(session, { + turn: 2, + modelFacingUserTurnCount: 2, + }); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValue(createEmptyStream()); + + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'third' }], + }); + + vi.mocked(mockChat.getHistory).mockReturnValue([ + { role: 'user', parts: [{ text: 'summary of first two turns' }] }, + { + role: 'model', + parts: [{ text: core.COMPRESSION_SUMMARY_MODEL_ACK }], + }, + { role: 'user', parts: [{ text: 'third' }] }, + { role: 'model', parts: [{ text: 'third reply' }] }, + ]); + + expect(() => session.rewindToTurn(1)).toThrow( + 'Cannot rewind to the requested turn', + ); + expect(mockChat.truncateHistory).not.toHaveBeenCalled(); + }); + + it('uses model-facing turn count when slash commands advance prompt ids', () => { + setSessionTurnCounters(session, { + turn: 4, + modelFacingUserTurnCount: 3, + }); + vi.mocked(mockChat.getHistory).mockReturnValue([ + { role: 'user', parts: [{ text: 'summary of first two turns' }] }, + { + role: 'model', + parts: [{ text: core.COMPRESSION_SUMMARY_MODEL_ACK }], + }, + { role: 'user', parts: [{ text: 'third' }] }, + { role: 'model', parts: [{ text: 'third reply' }] }, + ]); + + const result = session.rewindToTurn(2); + + expect(result).toEqual({ targetTurnIndex: 2, apiTruncateIndex: 2 }); + expect(mockChat.truncateHistory).toHaveBeenCalledWith(2); + }); + + it('rejects compressed rewind targets when the model-facing count is too low', () => { + setSessionTurnCounters(session, { + turn: 3, + modelFacingUserTurnCount: 2, + }); + vi.mocked(mockChat.getHistory).mockReturnValue([ + { role: 'user', parts: [{ text: 'summary of first two turns' }] }, + { + role: 'model', + parts: [{ text: core.COMPRESSION_SUMMARY_MODEL_ACK }], + }, + { role: 'user', parts: [{ text: 'third' }] }, + { role: 'model', parts: [{ text: 'third reply' }] }, + ]); + + expect(() => session.rewindToTurn(2)).toThrow( + 'Cannot rewind to the requested turn', + ); + expect(mockChat.truncateHistory).not.toHaveBeenCalled(); + }); + + it('uses model-facing turn count when cron adds user text entries', () => { + setSessionTurnCounters(session, { + turn: 2, + modelFacingUserTurnCount: 3, + }); + vi.mocked(mockChat.getHistory).mockReturnValue([ + { role: 'user', parts: [{ text: 'summary of first two turns' }] }, + { + role: 'model', + parts: [{ text: core.COMPRESSION_SUMMARY_MODEL_ACK }], + }, + { role: 'user', parts: [{ text: 'cron prompt' }] }, + { role: 'model', parts: [{ text: 'cron reply' }] }, + ]); + + expect(() => session.rewindToTurn(1)).toThrow( + 'Cannot rewind to the requested turn', + ); + expect(mockChat.truncateHistory).not.toHaveBeenCalled(); + }); + + it('rejects ACP rewind to the first turn after compression absorbed it', () => { + setSessionTurnCounters(session, { + turn: 3, + modelFacingUserTurnCount: 3, + }); + vi.mocked(mockChat.getHistory).mockReturnValue([ + { role: 'user', parts: [{ text: 'summary of first two turns' }] }, + { + role: 'model', + parts: [{ text: core.COMPRESSION_SUMMARY_MODEL_ACK }], + }, + { role: 'user', parts: [{ text: 'third' }] }, + { role: 'model', parts: [{ text: 'third reply' }] }, + ]); + + expect(() => session.rewindToTurn(0)).toThrow( + 'Cannot rewind to the requested turn', + ); + expect(mockChat.truncateHistory).not.toHaveBeenCalled(); + }); + + it('does not treat the compression bridge as an ACP rewind target', () => { + setSessionTurnCounters(session, { + turn: 3, + modelFacingUserTurnCount: 3, + }); + vi.mocked(mockChat.getHistory).mockReturnValue([ + { role: 'user', parts: [{ text: 'summary of first two turns' }] }, + { + role: 'model', + parts: [{ text: core.COMPRESSION_SUMMARY_MODEL_ACK }], + }, + { + role: 'user', + parts: [{ text: core.COMPRESSION_CONTINUATION_BRIDGE }], + }, + { role: 'model', parts: [{ text: 'continued response' }] }, + { role: 'user', parts: [{ text: 'third' }] }, + { role: 'model', parts: [{ text: 'third reply' }] }, + ]); + + expect(() => session.rewindToTurn(1)).toThrow( + 'Cannot rewind to the requested turn', + ); + expect(mockChat.truncateHistory).not.toHaveBeenCalled(); + }); + it('rejects unreachable user turns', () => { vi.mocked(mockChat.getHistory).mockReturnValue([ { role: 'user', parts: [{ text: 'first' }] }, @@ -400,6 +675,7 @@ describe('Session', () => { }); it('restores a captured history snapshot', () => { + setSessionTurnCounters(session, { modelFacingUserTurnCount: 1 }); const history: Content[] = [ { role: 'user', parts: [{ text: 'first' }] }, { role: 'model', parts: [{ text: 'first reply' }] }, @@ -409,8 +685,147 @@ describe('Session', () => { const snapshot = session.captureHistorySnapshot(); session.restoreHistory(snapshot); - expect(snapshot).toEqual(history); + expect(snapshot).toEqual({ + history, + modelFacingUserTurnCount: 1, + }); + expect(mockChat.setHistory).toHaveBeenCalledWith(history); + }); + + it('returns an isolated history snapshot from the chat history clone', () => { + const history: Content[] = [{ role: 'user', parts: [{ text: 'first' }] }]; + vi.mocked(mockChat.getHistory).mockImplementation(() => + structuredClone(history), + ); + + const snapshot = session.captureHistorySnapshot(); + (snapshot.history[0]!.parts![0] as { text: string }).text = 'mutated'; + + expect(history[0]!.parts![0]).toEqual({ text: 'first' }); + }); + + it('restores model-facing turn count with the history snapshot', () => { + setSessionTurnCounters(session, { + turn: 4, + modelFacingUserTurnCount: 4, + }); + const history: Content[] = [ + { role: 'user', parts: [{ text: 'summary of first two turns' }] }, + { + role: 'model', + parts: [{ text: core.COMPRESSION_SUMMARY_MODEL_ACK }], + }, + { role: 'user', parts: [{ text: 'third' }] }, + { role: 'model', parts: [{ text: 'third reply' }] }, + { role: 'user', parts: [{ text: 'fourth' }] }, + { role: 'model', parts: [{ text: 'fourth reply' }] }, + ]; + vi.mocked(mockChat.getHistory).mockReturnValue(history); + + const snapshot = session.captureHistorySnapshot(); + expect(session.rewindToTurn(2)).toEqual({ + targetTurnIndex: 2, + apiTruncateIndex: 2, + }); + + session.restoreHistory(snapshot); + vi.mocked(mockChat.getHistory).mockReturnValue(history); + vi.mocked(mockChat.truncateHistory).mockClear(); + + expect(() => session.rewindToTurn(1)).toThrow( + 'Cannot rewind to the requested turn', + ); + expect(mockChat.truncateHistory).not.toHaveBeenCalled(); + }); + + it('derives model-facing turn count when restoring legacy history arrays', () => { + setSessionTurnCounters(session, { modelFacingUserTurnCount: 99 }); + const history: Content[] = [ + { role: 'user', parts: [{ text: 'first' }] }, + { role: 'model', parts: [{ text: 'first reply' }] }, + { + role: 'user', + parts: [{ functionResponse: { name: 'tool', response: {} } }], + }, + { role: 'user', parts: [{ text: 'second' }] }, + ]; + + session.restoreHistory(history); + expect(mockChat.setHistory).toHaveBeenCalledWith(history); + expect(getSessionModelFacingUserTurnCount(session)).toBe(2); + }); + + it('restores history without sharing caller-owned arrays', () => { + const history: Content[] = [ + { role: 'user', parts: [{ text: 'first' }] }, + { role: 'model', parts: [{ text: 'first reply' }] }, + ]; + + session.restoreHistory({ history, modelFacingUserTurnCount: 1 }); + const restoredHistory = vi.mocked(mockChat.setHistory).mock + .calls[0]![0] as Content[]; + (history[0]!.parts![0] as { text: string }).text = 'mutated'; + + expect(restoredHistory).toEqual([ + { role: 'user', parts: [{ text: 'first' }] }, + { role: 'model', parts: [{ text: 'first reply' }] }, + ]); + expect(restoredHistory).not.toBe(history); + }); + + it('counts compression summaries as one compressed turn when restoring legacy history arrays', () => { + setSessionTurnCounters(session, { modelFacingUserTurnCount: 99 }); + const history: Content[] = [ + { role: 'user', parts: [{ text: 'summary of first two turns' }] }, + { + role: 'model', + parts: [{ text: core.COMPRESSION_SUMMARY_MODEL_ACK }], + }, + { role: 'user', parts: [{ text: 'third' }] }, + { role: 'model', parts: [{ text: 'third reply' }] }, + { role: 'user', parts: [{ text: 'fourth' }] }, + ]; + + session.restoreHistory(history); + + expect(mockChat.setHistory).toHaveBeenCalledWith(history); + expect(getSessionModelFacingUserTurnCount(session)).toBe(3); + }); + + it('keeps absorbed turns unreachable after restoring legacy compressed history arrays', () => { + setSessionTurnCounters(session, { modelFacingUserTurnCount: 99 }); + const history: Content[] = [ + { role: 'user', parts: [{ text: 'summary of earlier turns' }] }, + { + role: 'model', + parts: [{ text: core.COMPRESSION_SUMMARY_MODEL_ACK }], + }, + { role: 'user', parts: [{ text: 'visible tail turn' }] }, + { role: 'model', parts: [{ text: 'visible tail reply' }] }, + ]; + vi.mocked(mockChat.getHistory).mockReturnValue(history); + + session.restoreHistory(history); + + expect(() => session.rewindToTurn(0)).toThrow( + 'Cannot rewind to the requested turn', + ); + expect(mockChat.truncateHistory).not.toHaveBeenCalled(); + }); + + it('rejects restoring an empty history array', () => { + expect(() => session.restoreHistory([])).toThrow( + 'Cannot restore an empty history snapshot', + ); + expect(mockChat.setHistory).not.toHaveBeenCalled(); + }); + + it('rejects restoring an empty HistorySnapshot', () => { + expect(() => + session.restoreHistory({ history: [], modelFacingUserTurnCount: 0 }), + ).toThrow('Cannot restore an empty history snapshot'); + expect(mockChat.setHistory).not.toHaveBeenCalled(); }); it('rejects history restore while a prompt is running', () => { @@ -1257,6 +1672,7 @@ describe('Session', () => { expect(mockGeminiClient.tryCompressChat).toHaveBeenCalled(); expect(mockChat.sendMessageStream).not.toHaveBeenCalled(); expect(mockChat.addHistory).not.toHaveBeenCalled(); + expect(getSessionModelFacingUserTurnCount(session)).toBe(0); expect(mockClient.sessionUpdate).not.toHaveBeenCalledWith({ sessionId: 'test-session-id', update: { @@ -1308,6 +1724,40 @@ describe('Session', () => { expect(mockChat.addHistory).not.toHaveBeenCalled(); }); + it('keeps model-facing turn count when the main response stream throws after dispatch', async () => { + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValue( + createThrowingStream(new Error('main stream failed')), + ); + + await expect( + session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }), + ).rejects.toThrow('main stream failed'); + + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(1); + expect(getSessionModelFacingUserTurnCount(session)).toBe(1); + }); + + it('rolls back model-facing turn count when the main send fails before dispatch', async () => { + mockChat.sendMessageStream = vi + .fn() + .mockRejectedValue(new Error('send failed before dispatch')); + + await expect( + session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }), + ).rejects.toThrow('send failed before dispatch'); + + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(1); + expect(getSessionModelFacingUserTurnCount(session)).toBe(0); + }); + it('also runs automatic compression before tool response follow-up sends', async () => { const executeSpy = vi.fn().mockResolvedValue({ llmContent: 'file contents', @@ -1369,6 +1819,7 @@ describe('Session', () => { sendMessageStream, 1, ); + expect(getSessionModelFacingUserTurnCount(session)).toBe(1); }); it('stops tool response follow-up before sending when the session token limit is exceeded', async () => { @@ -1449,6 +1900,7 @@ describe('Session', () => { }), ], }); + expect(getSessionModelFacingUserTurnCount(session)).toBe(1); expect(mockClient.sessionUpdate).toHaveBeenCalledWith({ sessionId: 'test-session-id', update: { @@ -1515,6 +1967,81 @@ describe('Session', () => { sendMessageStream, 1, ); + expect(getSessionModelFacingUserTurnCount(session)).toBe(2); + }); + + it('keeps model-facing turn count when Stop-hook continuation stream throws after dispatch', async () => { + const messageBus = { + request: vi.fn().mockResolvedValueOnce({ + success: true, + output: { + decision: 'block', + reason: 'Continue after Stop hook', + }, + }), + }; + mockConfig.getMessageBus = vi.fn().mockReturnValue(messageBus); + mockConfig.getDisableAllHooks = vi.fn().mockReturnValue(false); + mockConfig.hasHooksForEvent = vi + .fn() + .mockImplementation((eventName: string) => eventName === 'Stop'); + mockChat.getHistory = vi + .fn() + .mockReturnValue([ + { role: 'model', parts: [{ text: 'response text' }] }, + ]); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValueOnce(createEmptyStream()) + .mockResolvedValueOnce( + createThrowingStream(new Error('stop stream failed')), + ); + + await expect( + session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }), + ).rejects.toThrow('stop stream failed'); + + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2); + expect(getSessionModelFacingUserTurnCount(session)).toBe(2); + }); + + it('rolls back model-facing turn count when Stop-hook continuation send fails before dispatch', async () => { + const messageBus = { + request: vi.fn().mockResolvedValueOnce({ + success: true, + output: { + decision: 'block', + reason: 'Continue after Stop hook', + }, + }), + }; + mockConfig.getMessageBus = vi.fn().mockReturnValue(messageBus); + mockConfig.getDisableAllHooks = vi.fn().mockReturnValue(false); + mockConfig.hasHooksForEvent = vi + .fn() + .mockImplementation((eventName: string) => eventName === 'Stop'); + mockChat.getHistory = vi + .fn() + .mockReturnValue([ + { role: 'model', parts: [{ text: 'response text' }] }, + ]); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValueOnce(createEmptyStream()) + .mockRejectedValueOnce(new Error('stop send failed before dispatch')); + + await expect( + session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }), + ).rejects.toThrow('stop send failed before dispatch'); + + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2); + expect(getSessionModelFacingUserTurnCount(session)).toBe(1); }); it('skips automatic compression after the first Stop-hook continuation', async () => { @@ -1640,6 +2167,7 @@ describe('Session', () => { expect.any(AbortSignal), ); expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(1); + expect(getSessionModelFacingUserTurnCount(session)).toBe(1); expect(mockClient.sessionUpdate).toHaveBeenCalledWith({ sessionId: 'test-session-id', update: { @@ -1703,6 +2231,90 @@ describe('Session', () => { ); }); + it('keeps model-facing turn count when cron stream throws after dispatch', async () => { + let cronCallback: ((job: { prompt: string }) => void) | undefined; + const scheduler = { + size: 1, + start: vi.fn((callback: (job: { prompt: string }) => void) => { + cronCallback = callback; + }), + stop: vi.fn(), + getExitSummary: vi.fn().mockReturnValue(undefined), + }; + mockConfig.isCronEnabled = vi.fn().mockReturnValue(true); + mockConfig.getCronScheduler = vi.fn().mockReturnValue(scheduler); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValueOnce(createEmptyStream()) + .mockResolvedValueOnce( + createThrowingStream(new Error('cron stream failed')), + ); + + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }); + + cronCallback?.({ prompt: 'scheduled prompt' }); + + await vi.waitFor(() => { + expect(mockClient.sessionUpdate).toHaveBeenCalledWith({ + sessionId: 'test-session-id', + update: { + sessionUpdate: 'agent_message_chunk', + content: { + type: 'text', + text: '[cron error] cron stream failed', + }, + }, + }); + }); + + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2); + expect(getSessionModelFacingUserTurnCount(session)).toBe(2); + }); + + it('rolls back model-facing turn count when cron send fails before dispatch', async () => { + let cronCallback: ((job: { prompt: string }) => void) | undefined; + const scheduler = { + size: 1, + start: vi.fn((callback: (job: { prompt: string }) => void) => { + cronCallback = callback; + }), + stop: vi.fn(), + getExitSummary: vi.fn().mockReturnValue(undefined), + }; + mockConfig.isCronEnabled = vi.fn().mockReturnValue(true); + mockConfig.getCronScheduler = vi.fn().mockReturnValue(scheduler); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValueOnce(createEmptyStream()) + .mockRejectedValueOnce(new Error('cron send failed before dispatch')); + + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }); + + cronCallback?.({ prompt: 'scheduled prompt' }); + + await vi.waitFor(() => { + expect(mockClient.sessionUpdate).toHaveBeenCalledWith({ + sessionId: 'test-session-id', + update: { + sessionUpdate: 'agent_message_chunk', + content: { + type: 'text', + text: '[cron error] cron send failed before dispatch', + }, + }, + }); + }); + + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2); + expect(getSessionModelFacingUserTurnCount(session)).toBe(1); + }); + it('stops cron-fired ACP prompt before sending when the session token limit is exceeded', async () => { let cronCallback: ((job: { prompt: string }) => void) | undefined; const scheduler = { diff --git a/packages/cli/src/acp-integration/session/Session.ts b/packages/cli/src/acp-integration/session/Session.ts index 624a3554cb..2f6b4d0de5 100644 --- a/packages/cli/src/acp-integration/session/Session.ts +++ b/packages/cli/src/acp-integration/session/Session.ts @@ -54,7 +54,6 @@ import { getPlanModeSystemReminder, getSubagentSystemReminder, getArenaSystemReminder, - STARTUP_CONTEXT_MODEL_ACK, evaluatePermissionFlow, needsConfirmation, isPlanModeBlocked, @@ -100,6 +99,12 @@ import { import { isSlashCommand } from '../../ui/utils/commandUtils.js'; import { CommandKind } from '../../ui/commands/types.js'; import { parseAcpModelOption } from '../../utils/acpModelUtils.js'; +import { + getApiUserTextIndices, + hasCompressionSummaryPair, + hasStartupContext, + isApiUserTextContent, +} from '../../utils/apiHistoryUtils.js'; import { classifyApiError } from '../../ui/hooks/useGeminiStream.js'; import { getPersistScopeForModelSelection } from '../../config/modelProvidersScope.js'; @@ -129,6 +134,49 @@ type AutoCompressionSendResult = | { responseStream: AsyncGenerator; stopReason?: never } | { responseStream: null; stopReason: PromptResponse['stopReason'] }; +export interface HistorySnapshot { + history: Content[]; + modelFacingUserTurnCount: number; +} + +function computeVisibleModelFacingUserTurnCount(apiHistory: Content[]): number { + const startIndex = hasStartupContext(apiHistory) ? 2 : 0; + if (hasCompressionSummaryPair(apiHistory, startIndex)) { + const tailCount = getApiUserTextIndices( + apiHistory, + startIndex + 2, + true, + ).length; + // Legacy Content[] snapshots do not carry the exact number of turns that + // the compression summary represents. Count it as one compressed turn so + // rewind never treats absorbed history as the first visible tail turn. + return tailCount + 1; + } + return getApiUserTextIndices(apiHistory, startIndex, true).length; +} + +function validateModelFacingUserTurnCount(count: unknown): number { + if (typeof count !== 'number' || !Number.isInteger(count)) { + throw RequestError.invalidParams( + undefined, + `modelFacingUserTurnCount must be an integer, got ${typeof count}`, + ); + } + if (!Number.isFinite(count) || count < 0) { + throw RequestError.invalidParams( + undefined, + `modelFacingUserTurnCount must be a non-negative finite integer, got ${count}`, + ); + } + if (count > Number.MAX_SAFE_INTEGER) { + throw RequestError.invalidParams( + undefined, + `modelFacingUserTurnCount exceeds maximum safe integer, got ${count}`, + ); + } + return count; +} + export function computeInitialTurnFromHistory( records: ChatRecord[], sessionId: string, @@ -159,6 +207,16 @@ export function computeInitialTurnFromHistory( return maxPromptTurn > 0 ? maxPromptTurn : userMessageCount; } +export function computeInitialModelFacingUserTurnCountFromHistory( + records: ChatRecord[], + sessionId: string, +): number { + return records.filter( + (record) => + record.sessionId === sessionId && isModelFacingUserPromptRecord(record), + ).length; +} + function getRecordPromptIds(record: ChatRecord): string[] { const promptIds: string[] = []; const recordPromptId = (record as { promptId?: unknown }).promptId; @@ -195,6 +253,35 @@ function isUserPromptRecord(record: ChatRecord): boolean { ); } +function isModelFacingUserPromptRecord(record: ChatRecord): boolean { + if (record.type !== 'user') { + return false; + } + if ( + record.subtype === 'notification' || + record.subtype === 'mid_turn_user_message' + ) { + return false; + } + const textParts = + record.message?.parts + ?.filter( + (part): part is { text: string } & Part => + 'text' in part && + typeof part.text === 'string' && + part.text.trim().length > 0, + ) + .map((part) => part.text.trim()) ?? []; + if (textParts.length === 0) { + return false; + } + if (record.subtype === 'cron') { + return true; + } + const fullText = textParts.join(' '); + return !fullText.startsWith('?') && !isSlashCommand(fullText); +} + export interface AvailableCommandsSnapshot { availableCommands: AvailableCommand[]; availableSkills?: string[]; @@ -264,6 +351,7 @@ export class Session implements SessionContext { */ private pendingPromptCompletion: Promise | null = null; private turn: number = 0; + private modelFacingUserTurnCount: number = 0; private readonly runtimeBaseDir: string; // Cron scheduling state @@ -351,6 +439,13 @@ export class Session implements SessionContext { this.turn, computeInitialTurnFromHistory(records, this.config.getSessionId()), ); + this.modelFacingUserTurnCount = Math.max( + this.modelFacingUserTurnCount, + computeInitialModelFacingUserTurnCountFromHistory( + records, + this.config.getSessionId(), + ), + ); await this.historyReplayer.replay(records); } @@ -388,6 +483,9 @@ export class Session implements SessionContext { chat.truncateHistory(apiTruncateIndex); chat.stripThoughtsFromHistory(); + // targetTurnIndex is zero-based; after truncating before that turn, + // exactly targetTurnIndex model-facing user turns remain. + this.modelFacingUserTurnCount = targetTurnIndex; this.config.getChatRecordingService()?.rewindRecording(targetTurnIndex, { truncatedCount: Math.max(0, apiHistory.length - apiTruncateIndex), @@ -396,11 +494,14 @@ export class Session implements SessionContext { return { targetTurnIndex, apiTruncateIndex }; } - captureHistorySnapshot(): Content[] { - return this.config.getGeminiClient()!.getChat().getHistory(); + captureHistorySnapshot(): HistorySnapshot { + return { + history: this.config.getGeminiClient()!.getChat().getHistory(), + modelFacingUserTurnCount: this.modelFacingUserTurnCount, + }; } - restoreHistory(history: Content[]): void { + restoreHistory(snapshot: Content[] | HistorySnapshot): void { if (this.pendingPrompt || this.cronProcessing || this.cronAbortController) { throw RequestError.invalidParams( undefined, @@ -408,60 +509,73 @@ export class Session implements SessionContext { ); } + const history = Array.isArray(snapshot) ? snapshot : snapshot.history; + if (history.length === 0) { + throw RequestError.invalidParams( + undefined, + 'Cannot restore an empty history snapshot', + ); + } this.config .getGeminiClient()! .getChat() .setHistory(structuredClone(history)); + this.modelFacingUserTurnCount = Array.isArray(snapshot) + ? computeVisibleModelFacingUserTurnCount(history) + : validateModelFacingUserTurnCount(snapshot.modelFacingUserTurnCount); } #computeApiTruncationIndexForUserTurn( apiHistory: Content[], targetTurnIndex: number, ): number { - const startIndex = this.#hasStartupContext(apiHistory) ? 2 : 0; + const startIndex = hasStartupContext(apiHistory) ? 2 : 0; - if (targetTurnIndex === 0) { - return startIndex; - } - - let realUserPromptCount = 0; - for (let i = startIndex; i < apiHistory.length; i++) { - if (!this.#isUserTextContent(apiHistory[i]!)) { - continue; + if (hasCompressionSummaryPair(apiHistory, startIndex)) { + const apiTailUserIndices = getApiUserTextIndices( + apiHistory, + startIndex + 2, + true, + ); + if (this.modelFacingUserTurnCount < targetTurnIndex + 1) { + debugLogger.warn( + `Cannot rewind to user turn ${targetTurnIndex}; ` + + `model-facing user turn count is ${this.modelFacingUserTurnCount}.`, + ); + return -1; } + const totalUserTurns = this.modelFacingUserTurnCount; + const compressedTurnCount = Math.max( + 0, + totalUserTurns - apiTailUserIndices.length, + ); - if (realUserPromptCount === targetTurnIndex) { - return i; + if (targetTurnIndex < compressedTurnCount) { + debugLogger.warn( + `Rewind to turn ${targetTurnIndex} rejected: ` + + `modelFacingUserTurnCount=${this.modelFacingUserTurnCount}, ` + + `compressedTurnCount=${compressedTurnCount}, ` + + `apiTailUserIndicesLength=${apiTailUserIndices.length}, ` + + `totalUserTurns=${totalUserTurns}, ` + + 'hasCompressionSummaryPair=true', + ); + return -1; } - realUserPromptCount += 1; + // Defensive: the guard above (targetTurnIndex < compressedTurnCount) + // should always prevent out-of-bounds access here, so ?? -1 is + // unreachable in normal operation. + return apiTailUserIndices[targetTurnIndex - compressedTurnCount] ?? -1; } - return -1; - } + if (targetTurnIndex === 0) { + return startIndex; + } - #hasStartupContext(apiHistory: Content[]): boolean { - if (apiHistory.length < 2) return false; - const first = apiHistory[0]; - const second = apiHistory[1]; - if (first?.role !== 'user' || second?.role !== 'model') return false; return ( - second.parts?.some( - (part) => 'text' in part && part.text === STARTUP_CONTEXT_MODEL_ACK, - ) ?? false - ); - } - - #isUserTextContent(content: Content): boolean { - if (content.role !== 'user') return false; - if (!content.parts || content.parts.length === 0) return false; - - const hasFunctionResponse = content.parts.some( - (part) => 'functionResponse' in part, + getApiUserTextIndices(apiHistory, startIndex, false)[targetTurnIndex] ?? + -1 ); - if (hasFunctionResponse) return false; - - return content.parts.some((part) => 'text' in part && part.text); } async cancelPendingPrompt(): Promise { @@ -694,20 +808,28 @@ export class Session implements SessionContext { while (nextMessage !== null) { if (pendingSend.signal.aborted) { this.#getCurrentChat().addHistory(nextMessage); + this.#recordModelFacingUserTurn(nextMessage); return { stopReason: 'cancelled' }; } const functionCalls: FunctionCall[] = []; let usageMetadata: GenerateContentResponseUsageMetadata | null = null; const streamStartTime = Date.now(); + let recordedModelFacingTurn = false; + let sendDispatched = false; try { + recordedModelFacingTurn = + this.#recordModelFacingUserTurn(nextMessage); const sendResult = await this.#sendMessageStreamWithAutoCompression( promptId, nextMessage?.parts ?? [], pendingSend.signal, ); if (!sendResult.responseStream) { + if (sendResult.stopReason !== 'cancelled') { + this.#rollbackModelFacingUserTurn(recordedModelFacingTurn); + } this.#preserveUnsentMessageHistory( nextMessage, sendResult.stopReason === 'cancelled', @@ -715,6 +837,7 @@ export class Session implements SessionContext { return { stopReason: sendResult.stopReason }; } const responseStream = sendResult.responseStream; + sendDispatched = true; nextMessage = null; for await (const resp of responseStream) { @@ -756,6 +879,10 @@ export class Session implements SessionContext { } } } catch (error) { + if (!sendDispatched) { + this.#rollbackModelFacingUserTurn(recordedModelFacingTurn); + } + // Fire StopFailure hook (fire-and-forget, replaces Stop event for API errors) // Aligned with useGeminiStream.ts handleFinishedWithErrorEvent const errorStatus = getErrorStatus(error); @@ -959,8 +1086,12 @@ export class Session implements SessionContext { const functionCalls: FunctionCall[] = []; let usageMetadata: GenerateContentResponseUsageMetadata | null = null; const streamStartTime = Date.now(); + let recordedModelFacingTurn = false; + let sendDispatched = false; try { + recordedModelFacingTurn = + this.#recordModelFacingUserTurn(nextMessage); const continueSendResult = await this.#sendMessageStreamWithAutoCompression( promptId + '_stop_hook_' + stopHookIterationCount, @@ -969,6 +1100,9 @@ export class Session implements SessionContext { { skipCompression: stopHookIterationCount > 1 }, ); if (!continueSendResult.responseStream) { + if (continueSendResult.stopReason !== 'cancelled') { + this.#rollbackModelFacingUserTurn(recordedModelFacingTurn); + } this.#preserveUnsentMessageHistory( nextMessage, continueSendResult.stopReason === 'cancelled', @@ -976,6 +1110,7 @@ export class Session implements SessionContext { return { stopReason: continueSendResult.stopReason }; } const continueResponseStream = continueSendResult.responseStream; + sendDispatched = true; nextMessage = null; for await (const resp of continueResponseStream) { @@ -1014,6 +1149,10 @@ export class Session implements SessionContext { } } } catch (error) { + if (!sendDispatched) { + this.#rollbackModelFacingUserTurn(recordedModelFacingTurn); + } + // Fire StopFailure hook (fire-and-forget) const errorStatus = getErrorStatus(error); const errorMessage = @@ -1093,6 +1232,23 @@ export class Session implements SessionContext { return this.config.getGeminiClient()!.getChat(); } + #recordModelFacingUserTurn(message: Content): boolean { + if (isApiUserTextContent(message)) { + this.modelFacingUserTurnCount += 1; + return true; + } + return false; + } + + #rollbackModelFacingUserTurn(recorded: boolean): void { + if (recorded) { + this.modelFacingUserTurnCount = Math.max( + 0, + this.modelFacingUserTurnCount - 1, + ); + } + } + /** * Mirrors the core send path for ACP model sends. * @@ -1396,6 +1552,9 @@ export class Session implements SessionContext { const promptId = this.config.getSessionId() + '########cron' + Date.now(); + let recordedModelFacingTurn = false; + let sendDispatched = false; + try { // Echo the cron prompt as a user message so the client sees it await this.sendUpdate({ @@ -1413,6 +1572,8 @@ export class Session implements SessionContext { }; while (nextMessage !== null) { + recordedModelFacingTurn = false; + sendDispatched = false; if (ac.signal.aborted) return; const functionCalls: FunctionCall[] = []; @@ -1420,12 +1581,17 @@ export class Session implements SessionContext { null; const streamStartTime = Date.now(); + recordedModelFacingTurn = + this.#recordModelFacingUserTurn(nextMessage); const sendResult = await this.#sendMessageStreamWithAutoCompression( promptId, nextMessage.parts ?? [], ac.signal, ); if (!sendResult.responseStream) { + if (sendResult.stopReason !== 'cancelled') { + this.#rollbackModelFacingUserTurn(recordedModelFacingTurn); + } this.#preserveUnsentMessageHistory( nextMessage, sendResult.stopReason === 'cancelled', @@ -1436,6 +1602,7 @@ export class Session implements SessionContext { return; } const responseStream = sendResult.responseStream; + sendDispatched = true; nextMessage = null; for await (const resp of responseStream) { @@ -1496,6 +1663,10 @@ export class Session implements SessionContext { } } } catch (error) { + if (!sendDispatched) { + this.#rollbackModelFacingUserTurn(recordedModelFacingTurn); + } + if (ac.signal.aborted) return; debugLogger.error('Error processing cron prompt:', error); const msg = error instanceof Error ? error.message : String(error); diff --git a/packages/cli/src/ui/commands/directoryCommand.test.tsx b/packages/cli/src/ui/commands/directoryCommand.test.tsx index 23421ad2b1..3814b44167 100644 --- a/packages/cli/src/ui/commands/directoryCommand.test.tsx +++ b/packages/cli/src/ui/commands/directoryCommand.test.tsx @@ -351,10 +351,7 @@ describe('getDirPathCompletions', () => { fs.mkdirSync(path.join(tempTestDir, 'sub1', 'deep'), { recursive: true }); // Add some non-directory files (should be filtered out) fs.writeFileSync(path.join(tempTestDir, 'file.txt'), ''); - fs.writeFileSync( - path.join(tempTestDir, 'sub1', 'nested.txt'), - '', - ); + fs.writeFileSync(path.join(tempTestDir, 'sub1', 'nested.txt'), ''); }); afterAll(() => { diff --git a/packages/cli/src/ui/commands/directoryCommand.tsx b/packages/cli/src/ui/commands/directoryCommand.tsx index 1919e8c113..59e8837fcf 100644 --- a/packages/cli/src/ui/commands/directoryCommand.tsx +++ b/packages/cli/src/ui/commands/directoryCommand.tsx @@ -4,7 +4,11 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { SlashCommand, CommandContext, CommandCompletionItem } from './types.js'; +import type { + SlashCommand, + CommandContext, + CommandCompletionItem, +} from './types.js'; import { CommandKind } from './types.js'; import { MessageType } from '../types.js'; import * as fs from 'node:fs'; @@ -58,7 +62,9 @@ function findExistingWorkspaceDirectory( * Returns directory path completions for the given partial argument. * Supports comma-separated paths by completing only the last segment. */ -export function getDirPathCompletions(partialArg: string): CommandCompletionItem[] { +export function getDirPathCompletions( + partialArg: string, +): CommandCompletionItem[] { const lastComma = partialArg.lastIndexOf(','); const prefix = lastComma >= 0 ? partialArg.substring(0, lastComma + 1) : ''; const partial = diff --git a/packages/cli/src/ui/commands/skillsCommand.ts b/packages/cli/src/ui/commands/skillsCommand.ts index 192772a7be..7a93a78509 100644 --- a/packages/cli/src/ui/commands/skillsCommand.ts +++ b/packages/cli/src/ui/commands/skillsCommand.ts @@ -63,8 +63,7 @@ export const skillsCommand: SlashCommand = { const sortedSkills = [...skills].sort( (a, b) => normalizeSkillPriority(b.priority) - - normalizeSkillPriority(a.priority) || - a.name.localeCompare(b.name), + normalizeSkillPriority(a.priority) || a.name.localeCompare(b.name), ); const skillsListItem: HistoryItemSkillsList = { type: MessageType.SKILLS_LIST, diff --git a/packages/cli/src/ui/hooks/useCommandCompletion.test.ts b/packages/cli/src/ui/hooks/useCommandCompletion.test.ts index f0b22e3e88..527864fe9f 100644 --- a/packages/cli/src/ui/hooks/useCommandCompletion.test.ts +++ b/packages/cli/src/ui/hooks/useCommandCompletion.test.ts @@ -599,7 +599,11 @@ describe('useCommandCompletion', () => { it('should not append trailing space for directory completions', async () => { setupMocks({ atSuggestions: [ - { label: 'src/components/', value: 'src/components/', isDirectory: true }, + { + label: 'src/components/', + value: 'src/components/', + isDirectory: true, + }, ], }); @@ -696,9 +700,7 @@ describe('useCommandCompletion', () => { result.current.handleAutocomplete(0); }); - expect(result.current.textBuffer.text).toBe( - '@src/components/ is a dir', - ); + expect(result.current.textBuffer.text).toBe('@src/components/ is a dir'); }); }); diff --git a/packages/cli/src/ui/hooks/useCommandCompletion.tsx b/packages/cli/src/ui/hooks/useCommandCompletion.tsx index 09ddfc3a7f..5a45b8d7dd 100644 --- a/packages/cli/src/ui/hooks/useCommandCompletion.tsx +++ b/packages/cli/src/ui/hooks/useCommandCompletion.tsx @@ -229,7 +229,10 @@ export function useCommandCompletion( const lineCodePoints = toCodePoints(buffer.lines[cursorRow] || ''); const charAfterCompletion = lineCodePoints[end]; const isDirectory = suggestions[indexToUse].isDirectory; - if (charAfterCompletion !== ' ' && !(isDirectory && !charAfterCompletion)) { + if ( + charAfterCompletion !== ' ' && + !(isDirectory && !charAfterCompletion) + ) { suggestionText += ' '; } diff --git a/packages/cli/src/ui/hooks/useSlashCompletion.test.ts b/packages/cli/src/ui/hooks/useSlashCompletion.test.ts index 733efbc5ae..23042d712d 100644 --- a/packages/cli/src/ui/hooks/useSlashCompletion.test.ts +++ b/packages/cli/src/ui/hooks/useSlashCompletion.test.ts @@ -1125,10 +1125,12 @@ describe('useSlashCompletion', () => { describe('isDirectory propagation', () => { it('should propagate isDirectory from CommandCompletionItem to Suggestion', async () => { - const mockCompletionFn = vi.fn().mockResolvedValue([ - { value: '/tmp/workspace/', isDirectory: true }, - { value: '/tmp/file.txt' }, - ]); + const mockCompletionFn = vi + .fn() + .mockResolvedValue([ + { value: '/tmp/workspace/', isDirectory: true }, + { value: '/tmp/file.txt' }, + ]); const slashCommands = [ createTestCommand({ diff --git a/packages/cli/src/ui/utils/historyMapping.test.ts b/packages/cli/src/ui/utils/historyMapping.test.ts index 8f6426a6d9..f5af23ac78 100644 --- a/packages/cli/src/ui/utils/historyMapping.test.ts +++ b/packages/cli/src/ui/utils/historyMapping.test.ts @@ -8,6 +8,11 @@ import { describe, it, expect } from 'vitest'; import { computeApiTruncationIndex, isRealUserTurn } from './historyMapping.js'; import type { HistoryItem } from '../types.js'; import type { Content, Part } from '@google/genai'; +import { + COMPRESSION_CONTINUATION_BRIDGE, + COMPRESSION_SUMMARY_MODEL_ACK, + STARTUP_CONTEXT_MODEL_ACK, +} from '@qwen-code/qwen-code-core'; // --------------------------------------------------------------------------- // Helpers @@ -21,6 +26,17 @@ function modelContent(text: string): Content { return { role: 'model', parts: [{ text } as Part] }; } +function functionCallContent(): Content { + return { + role: 'model', + parts: [ + { + functionCall: { name: 'tool', args: {} }, + } as unknown as Part, + ], + }; +} + function functionResponseContent(): Content { return { role: 'user', @@ -35,7 +51,7 @@ function functionResponseContent(): Content { function startupPair(): [Content, Content] { return [ userContent('Environment context...'), - modelContent('Got it. Thanks for the context!'), + modelContent(STARTUP_CONTEXT_MODEL_ACK), ]; } @@ -58,6 +74,16 @@ describe('computeApiTruncationIndex', () => { expect(computeApiTruncationIndex(ui, 1, api)).toBe(0); }); + it('returns -1 when the target user item is absent', () => { + const ui: HistoryItem[] = [userItem(1), geminiItem(2)]; + const api: Content[] = [ + userContent('prompt 1'), + modelContent('response 1'), + ]; + + expect(computeApiTruncationIndex(ui, 99, api)).toBe(-1); + }); + describe('without startup context', () => { it('rewinds to the first user turn (keep nothing)', () => { const ui: HistoryItem[] = [ @@ -169,6 +195,153 @@ describe('computeApiTruncationIndex', () => { }); describe('compression fallback', () => { + it('maps tail user turns after a compression summary pair', () => { + const ui: HistoryItem[] = [ + userItem(1), + geminiItem(2), + userItem(3), + geminiItem(4), + userItem(5), + geminiItem(6), + ]; + const api: Content[] = [ + userContent('compressed summary of prompt 1 and prompt 3'), + modelContent(COMPRESSION_SUMMARY_MODEL_ACK), + userContent('prompt 5'), + modelContent('response 5'), + ]; + + expect(computeApiTruncationIndex(ui, 5, api)).toBe(2); + }); + + it('keeps compressed turns unreachable after a compression summary pair', () => { + const ui: HistoryItem[] = [ + userItem(1), + geminiItem(2), + userItem(3), + geminiItem(4), + userItem(5), + geminiItem(6), + ]; + const api: Content[] = [ + userContent('compressed summary of prompt 1 and prompt 3'), + modelContent(COMPRESSION_SUMMARY_MODEL_ACK), + userContent('prompt 5'), + modelContent('response 5'), + ]; + + expect(computeApiTruncationIndex(ui, 3, api)).toBe(-1); + }); + + it('keeps the first UI turn unreachable when compression absorbed it', () => { + const ui: HistoryItem[] = [ + userItem(1), + geminiItem(2), + userItem(3), + geminiItem(4), + userItem(5), + geminiItem(6), + ]; + const api: Content[] = [ + userContent('compressed summary of prompt 1 and prompt 3'), + modelContent(COMPRESSION_SUMMARY_MODEL_ACK), + userContent('prompt 5'), + modelContent('response 5'), + ]; + + expect(computeApiTruncationIndex(ui, 1, api)).toBe(-1); + }); + + it('maps compressed tail turns after startup context', () => { + const ui: HistoryItem[] = [ + userItem(1), + geminiItem(2), + userItem(3), + geminiItem(4), + userItem(5), + geminiItem(6), + ]; + const api: Content[] = [ + ...startupPair(), + userContent('compressed summary of prompt 1 and prompt 3'), + modelContent(COMPRESSION_SUMMARY_MODEL_ACK), + userContent('prompt 5'), + modelContent('response 5'), + ]; + + expect(computeApiTruncationIndex(ui, 5, api)).toBe(4); + }); + + it('ignores the compression continuation bridge when mapping tail turns', () => { + const ui: HistoryItem[] = [ + userItem(1), + geminiItem(2), + userItem(3), + geminiItem(4), + userItem(5), + geminiItem(6), + ]; + const api: Content[] = [ + userContent('compressed summary of prompt 1 and prompt 3'), + modelContent(COMPRESSION_SUMMARY_MODEL_ACK), + userContent(COMPRESSION_CONTINUATION_BRIDGE), + modelContent('continued response'), + userContent('prompt 5'), + modelContent('response 5'), + ]; + + expect(computeApiTruncationIndex(ui, 5, api)).toBe(4); + }); + + it('does not skip a real user prompt with the visible bridge text', () => { + const visibleBridgeText = + 'Continue with the prior task using the context above.'; + const ui: HistoryItem[] = [ + userItem(1), + geminiItem(2), + userItem(3), + geminiItem(4), + userItem(5, visibleBridgeText), + geminiItem(6), + userItem(7), + geminiItem(8), + ]; + const api: Content[] = [ + userContent('compressed summary of prompt 1 and prompt 3'), + modelContent(COMPRESSION_SUMMARY_MODEL_ACK), + userContent(visibleBridgeText), + modelContent('continued response'), + userContent('prompt 7'), + modelContent('response 7'), + ]; + + expect(computeApiTruncationIndex(ui, 5, api)).toBe(2); + expect(computeApiTruncationIndex(ui, 7, api)).toBe(4); + }); + + it('ignores tool-call entries between the bridge and next tail prompt', () => { + const ui: HistoryItem[] = [ + userItem(1), + geminiItem(2), + userItem(3), + geminiItem(4), + userItem(5), + geminiItem(6), + ]; + const api: Content[] = [ + userContent('compressed summary of prompt 1 and prompt 3'), + modelContent(COMPRESSION_SUMMARY_MODEL_ACK), + userContent(COMPRESSION_CONTINUATION_BRIDGE), + functionCallContent(), + functionResponseContent(), + modelContent('tool result response'), + userContent('prompt 5'), + modelContent('response 5'), + ]; + + expect(computeApiTruncationIndex(ui, 5, api)).toBe(6); + }); + it('returns -1 when not enough user prompts found', () => { const ui: HistoryItem[] = [ userItem(1), diff --git a/packages/cli/src/ui/utils/historyMapping.ts b/packages/cli/src/ui/utils/historyMapping.ts index f0d9c12e8d..4d1c8bb181 100644 --- a/packages/cli/src/ui/utils/historyMapping.ts +++ b/packages/cli/src/ui/utils/historyMapping.ts @@ -6,8 +6,15 @@ import type { HistoryItem, HistoryItemUser } from '../types.js'; import type { Content } from '@google/genai'; -import { STARTUP_CONTEXT_MODEL_ACK } from '@qwen-code/qwen-code-core'; +import { createDebugLogger } from '@qwen-code/qwen-code-core'; import { isSlashCommand } from './commandUtils.js'; +import { + getApiUserTextIndices, + hasCompressionSummaryPair, + hasStartupContext, +} from '../../utils/apiHistoryUtils.js'; + +const debugLogger = createDebugLogger('HISTORY_MAPPING'); /** * Returns true when the history item represents a real user prompt that was @@ -26,36 +33,23 @@ export function isRealUserTurn( return !isSlashCommand(item.text) && !item.text.startsWith('?'); } -/** - * Checks if a Content entry is a user-initiated text prompt - * as opposed to a tool result (functionResponse). - */ -function isUserTextContent(content: Content): boolean { - if (content.role !== 'user') return false; - if (!content.parts || content.parts.length === 0) return false; +function getUiTurnOrdinals( + uiHistory: HistoryItem[], + targetUserItemId: number, +): { targetOrdinal: number; totalRealUserTurns: number } { + let targetOrdinal = -1; + let totalRealUserTurns = 0; - const hasFunctionResponse = content.parts.some( - (part) => 'functionResponse' in part, - ); - if (hasFunctionResponse) return false; + for (const item of uiHistory) { + if (!isRealUserTurn(item)) continue; - return content.parts.some((part) => 'text' in part && part.text); -} + totalRealUserTurns++; + if (item.id === targetUserItemId) { + targetOrdinal = totalRealUserTurns; + } + } -/** - * Detects whether the API history starts with the startup context pair - * (user env context + model acknowledgment). - */ -function hasStartupContext(apiHistory: Content[]): boolean { - if (apiHistory.length < 2) return false; - const first = apiHistory[0]; - const second = apiHistory[1]; - if (first?.role !== 'user' || second?.role !== 'model') return false; - return ( - second.parts?.some( - (part) => 'text' in part && part.text === STARTUP_CONTEXT_MODEL_ACK, - ) ?? false - ); + return { targetOrdinal, totalRealUserTurns }; } /** @@ -88,39 +82,58 @@ export function computeApiTruncationIndex( targetUserItemId: number, apiHistory: Content[], ): number { - // Count how many UI user turns exist before the target - let uiUserTurnCount = 0; - for (const item of uiHistory) { - if (item.id === targetUserItemId) { - break; - } - if (isRealUserTurn(item)) { - uiUserTurnCount++; - } - } + const { targetOrdinal, totalRealUserTurns } = getUiTurnOrdinals( + uiHistory, + targetUserItemId, + ); + + if (targetOrdinal < 0) return -1; // Determine the starting index in the API history (skip startup context) const startIndex = hasStartupContext(apiHistory) ? 2 : 0; - if (uiUserTurnCount === 0) { + if (hasCompressionSummaryPair(apiHistory, startIndex)) { + // Compression replaces the oldest N UI turns with one synthetic + // summary user entry plus a fixed model acknowledgment. The remaining + // API user-text entries are the uncompressed tail, so align that tail + // against the end of the UI turn list instead of counting from the front. + const apiTailUserIndices = getApiUserTextIndices( + apiHistory, + startIndex + 2, + true, + ); + const compressedTurnCount = Math.max( + 0, + totalRealUserTurns - apiTailUserIndices.length, + ); + + if (targetOrdinal <= compressedTurnCount) { + debugLogger.info( + `Rewind target turn ${targetOrdinal} is unreachable: compressed ${compressedTurnCount} of ${totalRealUserTurns} total turns, tail has ${apiTailUserIndices.length} entries`, + ); + return -1; + } + + // Defensive: the guard above (targetOrdinal <= compressedTurnCount) + // should always prevent out-of-bounds access here, so ?? -1 is + // unreachable in normal operation. + return apiTailUserIndices[targetOrdinal - compressedTurnCount - 1] ?? -1; + } + + if (targetOrdinal === 1) { // Rewinding to the first user turn: keep only startup context (if any) return startIndex; } // Walk the API history from after the startup context, counting // user text prompts to find the one corresponding to the target turn. - let realUserPromptCount = 0; - - for (let i = startIndex; i < apiHistory.length; i++) { - if (isUserTextContent(apiHistory[i]!)) { - realUserPromptCount++; - // The target turn is the (uiUserTurnCount + 1)th real user prompt. - // We want to truncate right before it. - if (realUserPromptCount > uiUserTurnCount) { - return i; - } - } - } + const apiUserTextIndices = getApiUserTextIndices( + apiHistory, + startIndex, + false, + ); + const targetApiIndex = apiUserTextIndices[targetOrdinal - 1]; + if (targetApiIndex !== undefined) return targetApiIndex; // If we didn't find enough user prompts (e.g., after compression), // signal that the target turn is unreachable. diff --git a/packages/cli/src/utils/apiHistoryUtils.test.ts b/packages/cli/src/utils/apiHistoryUtils.test.ts new file mode 100644 index 0000000000..14b1aca18e --- /dev/null +++ b/packages/cli/src/utils/apiHistoryUtils.test.ts @@ -0,0 +1,337 @@ +/** + * @license + * Copyright 2025 Qwen Code + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import type { Content, Part } from '@google/genai'; +import { + COMPRESSION_CONTINUATION_BRIDGE, + COMPRESSION_CONTINUATION_BRIDGE_MARKER, + COMPRESSION_SUMMARY_MODEL_ACK, + STARTUP_CONTEXT_MODEL_ACK, +} from '@qwen-code/qwen-code-core'; +import { + hasTextPart, + hasModelTextPart, + isApiUserTextContent, + hasCompressionSummaryPair, + getApiUserTextIndices, + hasStartupContext, + isCompressionContinuationBridge, +} from './apiHistoryUtils.js'; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function userTextContent(text: string): Content { + return { role: 'user', parts: [{ text } as Part] }; +} + +function modelTextContent(text: string): Content { + return { role: 'model', parts: [{ text } as Part] }; +} + +function functionResponseContent(): Content { + return { + role: 'user', + parts: [ + { + functionResponse: { name: 'tool', response: { result: 'ok' } }, + } as unknown as Part, + ], + }; +} + +function functionCallContent(): Content { + return { + role: 'model', + parts: [{ functionCall: { name: 'tool', args: {} } } as unknown as Part], + }; +} + +// --------------------------------------------------------------------------- +// hasTextPart +// --------------------------------------------------------------------------- + +describe('hasTextPart', () => { + it('returns true when content has a text part matching exactly', () => { + expect(hasTextPart(userTextContent('hello'), 'hello')).toBe(true); + }); + + it('returns false when text does not match', () => { + expect(hasTextPart(userTextContent('hello'), 'world')).toBe(false); + }); + + it('returns false for undefined content', () => { + expect(hasTextPart(undefined, 'hello')).toBe(false); + }); + + it('returns false when parts is undefined', () => { + expect(hasTextPart({ role: 'user' }, 'hello')).toBe(false); + }); +}); + +// --------------------------------------------------------------------------- +// hasModelTextPart +// --------------------------------------------------------------------------- + +describe('hasModelTextPart', () => { + it('returns true when model content has matching text', () => { + expect(hasModelTextPart(modelTextContent('ack'), 'ack')).toBe(true); + }); + + it('returns false when role is not model', () => { + expect(hasModelTextPart(userTextContent('ack'), 'ack')).toBe(false); + }); + + it('returns false when text does not match', () => { + expect(hasModelTextPart(modelTextContent('ack'), 'other')).toBe(false); + }); +}); + +// --------------------------------------------------------------------------- +// isApiUserTextContent +// --------------------------------------------------------------------------- + +describe('isApiUserTextContent', () => { + it('returns true for user text content', () => { + expect(isApiUserTextContent(userTextContent('hello'))).toBe(true); + }); + + it('returns false for model content', () => { + expect(isApiUserTextContent(modelTextContent('hello'))).toBe(false); + }); + + it('returns false for functionResponse content', () => { + expect(isApiUserTextContent(functionResponseContent())).toBe(false); + }); + + it('returns false for empty parts', () => { + expect(isApiUserTextContent({ role: 'user', parts: [] })).toBe(false); + }); + + it('returns false for undefined parts', () => { + expect(isApiUserTextContent({ role: 'user' })).toBe(false); + }); + + it('returns false for functionCall content (model role)', () => { + expect(isApiUserTextContent(functionCallContent())).toBe(false); + }); + + it('rejects user content with no text (only functionResponse)', () => { + const content: Content = { + role: 'user', + parts: [ + { functionResponse: { name: 't', response: {} } } as unknown as Part, + { text: 'some text' } as Part, + ], + }; + expect(isApiUserTextContent(content)).toBe(false); + }); +}); + +// --------------------------------------------------------------------------- +// hasCompressionSummaryPair +// --------------------------------------------------------------------------- + +describe('hasCompressionSummaryPair', () => { + it('detects a compression summary pair', () => { + const history: Content[] = [ + userTextContent('summary text'), + modelTextContent(COMPRESSION_SUMMARY_MODEL_ACK), + ]; + expect(hasCompressionSummaryPair(history, 0)).toBe(true); + }); + + it('returns false when the ack text does not match', () => { + const history: Content[] = [ + userTextContent('summary text'), + modelTextContent('different ack'), + ]; + expect(hasCompressionSummaryPair(history, 0)).toBe(false); + }); + + it('returns false when startIndex is out of bounds', () => { + const history: Content[] = [userTextContent('only one')]; + expect(hasCompressionSummaryPair(history, 1)).toBe(false); + }); + + it('respects startIndex offset', () => { + const history: Content[] = [ + userTextContent('env context'), + modelTextContent(STARTUP_CONTEXT_MODEL_ACK), + userTextContent('summary'), + modelTextContent(COMPRESSION_SUMMARY_MODEL_ACK), + ]; + expect(hasCompressionSummaryPair(history, 0)).toBe(false); + expect(hasCompressionSummaryPair(history, 2)).toBe(true); + }); +}); + +// --------------------------------------------------------------------------- +// getApiUserTextIndices +// --------------------------------------------------------------------------- + +describe('getApiUserTextIndices', () => { + it('returns indices of all user text entries from startIndex', () => { + const history: Content[] = [ + userTextContent('first'), + modelTextContent('ack'), + userTextContent('second'), + modelTextContent('resp'), + userTextContent('third'), + ]; + expect(getApiUserTextIndices(history, 0, false)).toEqual([0, 2, 4]); + }); + + it('respects startIndex', () => { + const history: Content[] = [ + userTextContent('first'), + modelTextContent('ack'), + userTextContent('second'), + modelTextContent('resp'), + ]; + expect(getApiUserTextIndices(history, 2, false)).toEqual([2]); + }); + + it('skips functionResponse entries', () => { + const history: Content[] = [ + userTextContent('first'), + modelTextContent('resp'), + functionResponseContent(), + modelTextContent('resp2'), + userTextContent('second'), + ]; + expect(getApiUserTextIndices(history, 0, false)).toEqual([0, 4]); + }); + + describe('skipContinuationBridge', () => { + it('skips the compression continuation bridge', () => { + const history: Content[] = [ + userTextContent('summary'), + modelTextContent(COMPRESSION_SUMMARY_MODEL_ACK), + userTextContent(COMPRESSION_CONTINUATION_BRIDGE), + modelTextContent('continued'), + userTextContent('tail turn'), + ]; + const indices = getApiUserTextIndices(history, 0, true); + expect(indices).toEqual([0, 4]); + }); + + it('includes the bridge when skipContinuationBridge is false', () => { + const history: Content[] = [ + userTextContent('summary'), + modelTextContent(COMPRESSION_SUMMARY_MODEL_ACK), + userTextContent(COMPRESSION_CONTINUATION_BRIDGE), + modelTextContent('continued'), + userTextContent('tail turn'), + ]; + const indices = getApiUserTextIndices(history, 0, false); + expect(indices).toEqual([0, 2, 4]); + }); + + it('does not skip user prompts with same visible text but no sentinel', () => { + const visibleText = + 'Continue with the prior task using the context above.'; + const history: Content[] = [ + userTextContent('summary'), + modelTextContent(COMPRESSION_SUMMARY_MODEL_ACK), + userTextContent(visibleText), // no invisible prefix + userTextContent('tail turn'), + ]; + const indices = getApiUserTextIndices(history, 0, true); + // The visible text without sentinel is treated as a real user turn + expect(indices).toEqual([0, 2, 3]); + }); + }); +}); + +// --------------------------------------------------------------------------- +// hasStartupContext +// --------------------------------------------------------------------------- + +describe('hasStartupContext', () => { + it('detects the startup context pair', () => { + const history: Content[] = [ + userTextContent('Environment context...'), + modelTextContent(STARTUP_CONTEXT_MODEL_ACK), + ]; + expect(hasStartupContext(history)).toBe(true); + }); + + it('returns false for too-short history', () => { + expect(hasStartupContext([userTextContent('only one')])).toBe(false); + expect(hasStartupContext([])).toBe(false); + }); + + it('returns false when roles are wrong', () => { + const history: Content[] = [ + modelTextContent('not user'), + modelTextContent(STARTUP_CONTEXT_MODEL_ACK), + ]; + expect(hasStartupContext(history)).toBe(false); + }); + + it('returns false when ack text does not match', () => { + const history: Content[] = [ + userTextContent('Environment context...'), + modelTextContent('different ack'), + ]; + expect(hasStartupContext(history)).toBe(false); + }); +}); + +// --------------------------------------------------------------------------- +// isCompressionContinuationBridge +// --------------------------------------------------------------------------- + +describe('isCompressionContinuationBridge', () => { + it('detects the synthetic bridge by sentinel marker prefix', () => { + const bridge: Content = { + role: 'user', + parts: [{ text: COMPRESSION_CONTINUATION_BRIDGE } as Part], + }; + expect(isCompressionContinuationBridge(bridge)).toBe(true); + }); + + it('returns false for a real user prompt with identical visible text', () => { + const visibleText = 'Continue with the prior task using the context above.'; + const userPrompt: Content = { + role: 'user', + parts: [{ text: visibleText } as Part], + }; + expect(isCompressionContinuationBridge(userPrompt)).toBe(false); + }); + + it('returns false for model role content', () => { + const modelContent: Content = { + role: 'model', + parts: [{ text: COMPRESSION_CONTINUATION_BRIDGE } as Part], + }; + expect(isCompressionContinuationBridge(modelContent)).toBe(false); + }); + + it('returns false for undefined content', () => { + expect(isCompressionContinuationBridge(undefined)).toBe(false); + }); + + it('returns false when parts do not start with the sentinel', () => { + const content: Content = { + role: 'user', + parts: [{ text: 'some other text' } as Part], + }; + expect(isCompressionContinuationBridge(content)).toBe(false); + }); + + it('detects bridge even with additional content after the marker', () => { + const bridgeWithExtra = `${COMPRESSION_CONTINUATION_BRIDGE_MARKER}Continue with the prior task using the context above.`; + const content: Content = { + role: 'user', + parts: [{ text: bridgeWithExtra } as Part], + }; + expect(isCompressionContinuationBridge(content)).toBe(true); + }); +}); diff --git a/packages/cli/src/utils/apiHistoryUtils.ts b/packages/cli/src/utils/apiHistoryUtils.ts new file mode 100644 index 0000000000..ecd818264c --- /dev/null +++ b/packages/cli/src/utils/apiHistoryUtils.ts @@ -0,0 +1,114 @@ +/** + * @license + * Copyright 2025 Qwen Code + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Content } from '@google/genai'; +import { + COMPRESSION_CONTINUATION_BRIDGE_MARKER, + COMPRESSION_SUMMARY_MODEL_ACK, + STARTUP_CONTEXT_MODEL_ACK, + createDebugLogger, +} from '@qwen-code/qwen-code-core'; + +const debugLogger = createDebugLogger('API_HISTORY_UTILS'); + +/** + * Checks whether a Content entry is the synthetic continuation bridge + * inserted after compression. Detection uses the invisible sentinel + * marker prefix rather than matching on the full string, so a real user + * prompt with the same visible text is not mistaken for the bridge. + */ +export function isCompressionContinuationBridge( + content: Content | undefined, +): boolean { + if (!content || content.role !== 'user') return false; + return ( + content.parts?.some( + (part) => + 'text' in part && + typeof part.text === 'string' && + part.text.startsWith(COMPRESSION_CONTINUATION_BRIDGE_MARKER), + ) ?? false + ); +} + +export function hasTextPart( + content: Content | undefined, + text: string, +): boolean { + return ( + content?.parts?.some( + (part) => + 'text' in part && typeof part.text === 'string' && part.text === text, + ) ?? false + ); +} + +export function hasModelTextPart( + content: Content | undefined, + text: string, +): boolean { + return content?.role === 'model' && hasTextPart(content, text); +} + +/** + * Checks if a Content entry is a user-initiated text prompt + * as opposed to a tool result (functionResponse). + */ +export function isApiUserTextContent(content: Content): boolean { + if (content.role !== 'user') return false; + if (!content.parts || content.parts.length === 0) return false; + + const hasFunctionResponse = content.parts.some( + (part) => 'functionResponse' in part, + ); + if (hasFunctionResponse) return false; + + return content.parts.some((part) => 'text' in part && part.text); +} + +export function hasCompressionSummaryPair( + apiHistory: Content[], + startIndex: number, +): boolean { + const summary = apiHistory[startIndex]; + return ( + !!summary && + isApiUserTextContent(summary) && + hasModelTextPart(apiHistory[startIndex + 1], COMPRESSION_SUMMARY_MODEL_ACK) + ); +} + +export function getApiUserTextIndices( + apiHistory: Content[], + startIndex: number, + skipContinuationBridge: boolean, +): number[] { + const indices: number[] = []; + + for (let i = startIndex; i < apiHistory.length; i++) { + const content = apiHistory[i]!; + if (!isApiUserTextContent(content)) continue; + if (skipContinuationBridge && isCompressionContinuationBridge(content)) { + debugLogger.debug('Skipping compression continuation bridge at index', i); + continue; + } + indices.push(i); + } + + return indices; +} + +/** + * Detects whether the API history starts with the startup context pair + * (user env context + model acknowledgment). + */ +export function hasStartupContext(apiHistory: Content[]): boolean { + if (apiHistory.length < 2) return false; + const first = apiHistory[0]; + const second = apiHistory[1]; + if (first?.role !== 'user' || second?.role !== 'model') return false; + return hasTextPart(second, STARTUP_CONTEXT_MODEL_ACK); +} diff --git a/packages/core/src/core/openaiContentGenerator/provider/dashscope.ts b/packages/core/src/core/openaiContentGenerator/provider/dashscope.ts index cb0a7650d2..b39e73f952 100644 --- a/packages/core/src/core/openaiContentGenerator/provider/dashscope.ts +++ b/packages/core/src/core/openaiContentGenerator/provider/dashscope.ts @@ -112,10 +112,7 @@ export class DashScopeOpenAICompatibleProvider extends DefaultOpenAICompatiblePr } return ( - isDashscopeOrigin || - isTokenPlanOrigin || - isInternalOrigin || - isProxyMatch + isDashscopeOrigin || isTokenPlanOrigin || isInternalOrigin || isProxyMatch ); } diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 336f23ff33..710dc6d214 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -153,6 +153,7 @@ export * from './services/gitWorktreeService.js'; export * from './services/sessionRecap.js'; export * from './services/sessionService.js'; export * from './services/sessionTitle.js'; +export * from './services/chatCompressionConstants.js'; export * from './services/worktreeSessionService.js'; export { stripTerminalControlSequences, diff --git a/packages/core/src/services/chatCompressionConstants.ts b/packages/core/src/services/chatCompressionConstants.ts new file mode 100644 index 0000000000..c5ead46b61 --- /dev/null +++ b/packages/core/src/services/chatCompressionConstants.ts @@ -0,0 +1,19 @@ +/** + * @license + * Copyright 2025 Qwen Code + * SPDX-License-Identifier: Apache-2.0 + */ + +export const COMPRESSION_SUMMARY_MODEL_ACK = + 'Got it. Thanks for the additional context!'; + +export const COMPRESSION_CONTINUATION_BRIDGE_MARKER = + '\u200B\u200C\u200D\u2060'; +const COMPRESSION_CONTINUATION_BRIDGE_PROMPT = + 'Continue with the prior task using the context above.'; + +// The invisible sentinel marker prevents a real user prompt with the same +// visible text from being treated as the synthetic bridge inserted after +// compression. Detection should use isCompressionContinuationBridge() +// which checks for the marker prefix rather than the full string. +export const COMPRESSION_CONTINUATION_BRIDGE = `${COMPRESSION_CONTINUATION_BRIDGE_MARKER}${COMPRESSION_CONTINUATION_BRIDGE_PROMPT}`; diff --git a/packages/core/src/services/chatCompressionService.test.ts b/packages/core/src/services/chatCompressionService.test.ts index c73f08fcd7..beecae5374 100644 --- a/packages/core/src/services/chatCompressionService.test.ts +++ b/packages/core/src/services/chatCompressionService.test.ts @@ -7,6 +7,8 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import { ChatCompressionService, + COMPRESSION_CONTINUATION_BRIDGE, + COMPRESSION_SUMMARY_MODEL_ACK, computeThresholds, findCompressSplitPoint, MAX_CONSECUTIVE_FAILURES, @@ -1815,6 +1817,9 @@ describe('ChatCompressionService', () => { expect(result.newHistory).toHaveLength(2); expect(result.newHistory![0].role).toBe('user'); expect(result.newHistory![1].role).toBe('model'); + expect(result.newHistory![1].parts?.[0].text).toBe( + COMPRESSION_SUMMARY_MODEL_ACK, + ); // The orphaned funcCall is stripped before compression, so only the first 5 // messages are sent, plus the compression instruction (+1) = history.length total. const optionsArg = mockGenerateContent.mock.calls[0][0]; @@ -2029,8 +2034,11 @@ describe('ChatCompressionService', () => { expect(newHistory[0].role).toBe('user'); expect(newHistory[0].parts?.[0].text).toBe('state snapshot summary'); expect(newHistory[1].role).toBe('model'); + expect(newHistory[1].parts?.[0].text).toBe(COMPRESSION_SUMMARY_MODEL_ACK); expect(newHistory[2].role).toBe('user'); - expect(newHistory[2].parts?.[0].text).toMatch(/Continue/); + expect(newHistory[2].parts?.[0].text).toBe( + COMPRESSION_CONTINUATION_BRIDGE, + ); // Retained two complete pairs (4 entries) + trailing model+fc = 5. expect(newHistory.slice(3)).toHaveLength(5); expect(newHistory[3].role).toBe('model'); diff --git a/packages/core/src/services/chatCompressionService.ts b/packages/core/src/services/chatCompressionService.ts index a6e2434e1c..e591761731 100644 --- a/packages/core/src/services/chatCompressionService.ts +++ b/packages/core/src/services/chatCompressionService.ts @@ -20,7 +20,15 @@ import { resolveSlimmingConfig, slimCompactionInput, } from './compactionInputSlimming.js'; +import { + COMPRESSION_CONTINUATION_BRIDGE, + COMPRESSION_SUMMARY_MODEL_ACK, +} from './chatCompressionConstants.js'; import { estimatePromptTokens } from './tokenEstimation.js'; +export { + COMPRESSION_CONTINUATION_BRIDGE, + COMPRESSION_SUMMARY_MODEL_ACK, +} from './chatCompressionConstants.js'; /** * The fraction of the latest chat history to keep. A value of 0.3 @@ -595,7 +603,7 @@ export class ChatCompressionService { }, { role: 'model', - parts: [{ text: 'Got it. Thanks for the additional context!' }], + parts: [{ text: COMPRESSION_SUMMARY_MODEL_ACK }], }, // When the kept slice starts with model+functionCall (because // tool-round absorption pulled the only fresh user message into @@ -607,7 +615,7 @@ export class ChatCompressionService { role: 'user' as const, parts: [ { - text: 'Continue with the prior task using the context above.', + text: COMPRESSION_CONTINUATION_BRIDGE, }, ], }, diff --git a/packages/vscode-ide-companion/src/services/acpConnection.ts b/packages/vscode-ide-companion/src/services/acpConnection.ts index c822ad0d94..259c207237 100644 --- a/packages/vscode-ide-companion/src/services/acpConnection.ts +++ b/packages/vscode-ide-companion/src/services/acpConnection.ts @@ -42,6 +42,13 @@ import * as fs from 'node:fs'; import { AcpFileHandler } from './acpFileHandler.js'; import { ACP_ERROR_CODES } from '../constants/acpSchema.js'; +export type AcpHistorySnapshot = + | unknown[] + | { + history: unknown[]; + modelFacingUserTurnCount: number; + }; + /** * ACP Connection Handler for VSCode Extension * @@ -485,7 +492,7 @@ export class AcpConnection { async rewindSession( targetTurnIndex: number, - ): Promise<{ historyBeforeRewind?: unknown[] }> { + ): Promise<{ historyBeforeRewind?: AcpHistorySnapshot }> { const conn = this.ensureConnection(); if (!this.sessionId) { throw new Error('No active ACP session'); @@ -495,10 +502,10 @@ export class AcpConnection { sessionId: this.sessionId, targetTurnIndex, cwd: this.workingDir, - })) as { historyBeforeRewind?: unknown[] }; + })) as { historyBeforeRewind?: AcpHistorySnapshot }; } - async restoreSessionHistory(history: unknown[]): Promise { + async restoreSessionHistory(history: AcpHistorySnapshot): Promise { const conn = this.ensureConnection(); if (!this.sessionId) { throw new Error('No active ACP session'); diff --git a/packages/vscode-ide-companion/src/services/qwenAgentManager.ts b/packages/vscode-ide-companion/src/services/qwenAgentManager.ts index e959d04899..7a790d3d37 100644 --- a/packages/vscode-ide-companion/src/services/qwenAgentManager.ts +++ b/packages/vscode-ide-companion/src/services/qwenAgentManager.ts @@ -3,7 +3,7 @@ * Copyright 2025 Qwen Team * SPDX-License-Identifier: Apache-2.0 */ -import { AcpConnection } from './acpConnection.js'; +import { AcpConnection, type AcpHistorySnapshot } from './acpConnection.js'; import type { ModelInfo, AvailableCommand, @@ -41,7 +41,7 @@ import { isAuthenticationRequiredError } from '../utils/authErrors.js'; import { getErrorMessage } from '../utils/errorMessage.js'; import { handleAuthenticateUpdate } from '../utils/authNotificationHandler.js'; -export type { ChatMessage, PlanEntry, ToolCallUpdateData }; +export type { AcpHistorySnapshot, ChatMessage, PlanEntry, ToolCallUpdateData }; /** * Extract session list items from ACP response. @@ -398,11 +398,11 @@ export class QwenAgentManager { async rewindSession( targetTurnIndex: number, - ): Promise<{ historyBeforeRewind?: unknown[] }> { + ): Promise<{ historyBeforeRewind?: AcpHistorySnapshot }> { return this.connection.rewindSession(targetTurnIndex); } - async restoreSessionHistory(history: unknown[]): Promise { + async restoreSessionHistory(history: AcpHistorySnapshot): Promise { await this.connection.restoreSessionHistory(history); } diff --git a/packages/vscode-ide-companion/src/webview/handlers/SessionMessageHandler.test.ts b/packages/vscode-ide-companion/src/webview/handlers/SessionMessageHandler.test.ts index a8400c7efb..8f25e788ba 100644 --- a/packages/vscode-ide-companion/src/webview/handlers/SessionMessageHandler.test.ts +++ b/packages/vscode-ide-companion/src/webview/handlers/SessionMessageHandler.test.ts @@ -238,7 +238,10 @@ describe('SessionMessageHandler', () => { isConnected: true, currentSessionId: 'session-1', rewindSession: vi.fn().mockResolvedValue({ - historyBeforeRewind: [{ role: 'user', parts: [{ text: 'first' }] }], + historyBeforeRewind: { + history: [{ role: 'user', parts: [{ text: 'first' }] }], + modelFacingUserTurnCount: 1, + }, }), restoreSessionHistory: vi.fn().mockResolvedValue(undefined), sendMessage: vi.fn().mockResolvedValue(undefined), @@ -445,7 +448,10 @@ describe('SessionMessageHandler', () => { isConnected: true, currentSessionId: 'session-1', rewindSession: vi.fn().mockResolvedValue({ - historyBeforeRewind: [{ role: 'user', parts: [{ text: 'first' }] }], + historyBeforeRewind: { + history: [{ role: 'user', parts: [{ text: 'first' }] }], + modelFacingUserTurnCount: 1, + }, }), restoreSessionHistory: vi.fn().mockResolvedValue(undefined), sendMessage: vi.fn().mockResolvedValue(undefined), @@ -524,12 +530,14 @@ describe('SessionMessageHandler', () => { createdAt: 1, updatedAt: 4, }; + const historyBeforeRewind = { + history: [{ role: 'user', parts: [{ text: 'first' }] }], + modelFacingUserTurnCount: 1, + }; const agentManager = { isConnected: true, currentSessionId: 'session-1', - rewindSession: vi.fn().mockResolvedValue({ - historyBeforeRewind: [{ role: 'user', parts: [{ text: 'first' }] }], - }), + rewindSession: vi.fn().mockResolvedValue({ historyBeforeRewind }), restoreSessionHistory: vi.fn().mockResolvedValue(undefined), sendMessage: vi.fn().mockRejectedValue(new Error('send failed')), }; @@ -557,9 +565,9 @@ describe('SessionMessageHandler', () => { }, }); - expect(agentManager.restoreSessionHistory).toHaveBeenCalledWith([ - { role: 'user', parts: [{ text: 'first' }] }, - ]); + expect(agentManager.restoreSessionHistory).toHaveBeenCalledWith( + historyBeforeRewind, + ); expect(conversationStore.replaceMessages).toHaveBeenCalledWith( 'session-1', originalConversation.messages, @@ -586,7 +594,10 @@ describe('SessionMessageHandler', () => { isConnected: true, currentSessionId: 'session-1', rewindSession: vi.fn().mockResolvedValue({ - historyBeforeRewind: [{ role: 'user', parts: [{ text: 'first' }] }], + historyBeforeRewind: { + history: [{ role: 'user', parts: [{ text: 'first' }] }], + modelFacingUserTurnCount: 1, + }, }), restoreSessionHistory: vi.fn(), sendMessage: vi.fn().mockResolvedValue(undefined), @@ -658,7 +669,10 @@ describe('SessionMessageHandler', () => { isConnected: true, currentSessionId: 'session-1', rewindSession: vi.fn().mockResolvedValue({ - historyBeforeRewind: [{ role: 'user', parts: [{ text: 'first' }] }], + historyBeforeRewind: { + history: [{ role: 'user', parts: [{ text: 'first' }] }], + modelFacingUserTurnCount: 1, + }, }), restoreSessionHistory: vi.fn().mockResolvedValue(undefined), sendMessage: vi.fn().mockResolvedValue(undefined), @@ -781,7 +795,10 @@ describe('SessionMessageHandler', () => { promptImages: [], }); - const historyBeforeRewind = [{ role: 'user', parts: [{ text: 'first' }] }]; + const historyBeforeRewind = { + history: [{ role: 'user', parts: [{ text: 'first' }] }], + modelFacingUserTurnCount: 1, + }; const originalConversation = { id: 'session-1', title: 'Existing session', diff --git a/packages/vscode-ide-companion/src/webview/handlers/SessionMessageHandler.ts b/packages/vscode-ide-companion/src/webview/handlers/SessionMessageHandler.ts index c60b866abe..d0d30b1c8e 100644 --- a/packages/vscode-ide-companion/src/webview/handlers/SessionMessageHandler.ts +++ b/packages/vscode-ide-companion/src/webview/handlers/SessionMessageHandler.ts @@ -6,7 +6,10 @@ import * as vscode from 'vscode'; import { BaseMessageHandler } from './BaseMessageHandler.js'; -import type { ChatMessage } from '../../services/qwenAgentManager.js'; +import type { + AcpHistorySnapshot, + ChatMessage, +} from '../../services/qwenAgentManager.js'; import type { Conversation } from '../../services/conversationStore.js'; import type { ImageAttachment } from '../../utils/imageSupport.js'; import type { ApprovalModeValue } from '../../types/approvalModeValueTypes.js'; @@ -594,7 +597,7 @@ export class SessionMessageHandler extends BaseMessageHandler { let editRestoreSnapshot: Conversation | null = null; let editStoreMutationApplied = false; let editAcpMutationApplied = false; - let editAcpHistorySnapshot: unknown[] | null = null; + let editAcpHistorySnapshot: AcpHistorySnapshot | null = null; if (editTargetTurnIndex !== undefined) { if (!Number.isInteger(editTargetTurnIndex) || editTargetTurnIndex < 0) {