From 3227268ce2a8943965345d20e302919a94b90b44 Mon Sep 17 00:00:00 2001 From: Joshua Horton Date: Tue, 19 May 2026 13:09:12 -0500 Subject: [PATCH] refactor(web): expose suggestion-root parameters for use in unit tests Build-bot: skip build:web Test-bot: skip --- .../worker-thread/src/main/predict-helpers.ts | 83 ++-- ...ine-tokenized-correction-sequence.tests.ts | 360 ++++++++++++++++++ 2 files changed, 420 insertions(+), 23 deletions(-) create mode 100644 web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-tokenized-correction-sequence.tests.ts diff --git a/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts b/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts index ebd62c09dfb..1839f2e7639 100644 --- a/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts +++ b/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts @@ -13,7 +13,6 @@ import { ContextTransition } from './correction/context-transition.js'; import { ExecutionTimer } from './correction/execution-timer.js'; import ModelCompositor from './model-compositor.js'; import { EDIT_DISTANCE_COST_SCALE, getBestTokenMatches } from './correction/distance-modeler.js'; -import { TokenResultMapping } from './correction/token-result-mapping.js'; const searchForProperty = defaultWordbreaker.searchForProperty; @@ -28,6 +27,7 @@ import Reversion = LexicalModelTypes.Reversion; import Suggestion = LexicalModelTypes.Suggestion; import SuggestionTag = LexicalModelTypes.SuggestionTag; import Transform = LexicalModelTypes.Transform; +import { TokenResult } from './correction/tokenization-corrector.js'; /* * The functions in this file exist to provide unit-testable stateless components for the @@ -390,24 +390,55 @@ export function determineSuggestionRange( } } +/** + * Specifies the core, preprocessed data necessary for generating predictions, + * regardless of model type. + */ +export interface PredictionParameters { + /** + * The portion of context that should remain unchanged by generated suggestions + */ + rootContext: Context, + + /** + * A tokenization of the corrected part of the context, usable to generate + * suggestions. + * + * Note that each correction will be applied iteratively to the rootContext. + * That is, when suggesting based on the correction at index 1, the + * "unchanged" (root) context used for that suggestion will include the + * changes from the entry at index 0 (or possibly, a suggestion derived from it). + */ + tokenizedCorrection: ProbabilityMass[], + + /** + * A closure to be applied to the generated suggestion's metadata. + * @param entry + * @returns + */ + applyInPost: (entry: CorrectionPredictionTuple) => void +} + /** * This function takes in metadata about generated corrections (for models that - * implement Traversals) and uses that to construct predictions based upon those - * corrections. - * @param transition Context-transition data underlying the tokenization that led to the correction - * @param tokenization The tokenization from which the correction was generated. - * @param match The generated correction itself - the correction string and its cost - * @param costFactor A multiplicative factor used to adjust the cost when building prediction probabilities. + * implement Traversals) and uses that to produce the corresponding parameters + * to use for generating suggestions. + * @param transition Context-transition data underlying the tokenization that + * led to the correction + * @param tokenization The tokenization from which the correction was + * generated. + * @param match The generated correction itself - the correction string + * and its cost + * @param costFactor A multiplicative factor used to adjust the cost when + * building prediction probabilities. * @returns */ -export function buildAndMapPredictions( +export function determineTokenizedCorrectionSequence( transition: ContextTransition, tokenization: ContextTokenization, - match: Readonly, + match: Readonly, costFactor: number -): CorrectionPredictionTuple[] { - const model = transition.final.model; - +): PredictionParameters { const applicationTarget = transition.base.displayTokenization; const { tokensToRemove, tokensToPredict } = determineSuggestionRange(applicationTarget, tokenization); @@ -418,7 +449,10 @@ export function buildAndMapPredictions( const correctionTransform: Transform = { insert: match.matchString, // insert correction string deleteLeft: 0, - id: transition.transitionId // The correction should always be based on the most recent external transform/transcription ID. + } + + if(transition.transitionId) { + correctionTransform.id = transition.transitionId // The correction should always be based on the most recent external transform/transcription ID. } const rootCost = match.totalCost; @@ -427,15 +461,16 @@ export function buildAndMapPredictions( p: Math.exp(-rootCost * costFactor) }; - const predictions = predictFromCorrectionSequence(model, [predictionRoot], rootContext); - predictions.forEach((entry) => { - entry.preservationTransform = tokenization.taillessTrueKeystroke; - // // Will need an extra lookup layer if the suggestion is generated from within a cluster. - // entry.baseTokenization = transition.final.tokenizationSourceMap.get(tokenization); - entry.prediction.sample.transform.deleteLeft = deleteLeft; - }); - - return predictions; + return { + rootContext, + tokenizedCorrection: [predictionRoot], + applyInPost: (entry: CorrectionPredictionTuple) => { + entry.preservationTransform = tokenization.taillessTrueKeystroke; + // // Will need an extra lookup layer if the suggestion is generated from within a cluster. + // entry.baseTokenization = transition.final.tokenizationSourceMap.get(tokenization); + entry.prediction.sample.transform.deleteLeft = deleteLeft; + } + }; } /** @@ -549,7 +584,9 @@ export async function correctAndEnumerate( */ const costFactor = (tokenization.tail.inputCount <= 1) ? ModelCompositor.SINGLE_CHAR_KEY_PROB_EXPONENT : 1; - const predictions = buildAndMapPredictions(transition, tokenization, match, costFactor); + const predictionPrep = determineTokenizedCorrectionSequence(transition, tokenization, match, costFactor); + const predictions = predictFromCorrectionSequence(lexicalModel, predictionPrep.tokenizedCorrection, predictionPrep.rootContext); + predictions.forEach((p) => predictionPrep.applyInPost(p)); // Only set 'best correction' cost when a correction ACTUALLY YIELDS predictions. if(predictions.length > 0 && bestCorrectionCost === undefined) { diff --git a/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-tokenized-correction-sequence.tests.ts b/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-tokenized-correction-sequence.tests.ts new file mode 100644 index 00000000000..a8182d890d2 --- /dev/null +++ b/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-tokenized-correction-sequence.tests.ts @@ -0,0 +1,360 @@ +/* + * Keyman is copyright (C) SIL Global. MIT License. + * + * Created by jahorton on 2026-05-19 + * + * This file tests the prediction helper-method responsible for preparing + * corrections for multi-token prediction for our standard models, all of which + * utilize LexiconTraversals and the context-tokenization-caching subsystem. + */ + +import { assert } from 'chai'; + +import { LexicalModelTypes } from "@keymanapp/common-types"; +import * as wordBreakers from '@keymanapp/models-wordbreakers'; +import { jsonFixture } from '@keymanapp/common-test-resources/model-helpers.mjs'; +import { KMWString } from '@keymanapp/web-utils'; + +import { determineTokenizedCorrectionSequence, models, ContextState, ContextToken, ContextTokenization, CorrectionPredictionTuple } from "@keymanapp/lm-worker/test-index"; + +import Context = LexicalModelTypes.Context; +import ProbabilityMass = LexicalModelTypes.ProbabilityMass; +import Transform = LexicalModelTypes.Transform; +import TrieModel = models.TrieModel; + +const testModel = new TrieModel( + jsonFixture('models/tries/english-1000'), { + wordBreaker: wordBreakers.default, + } +); + +describe('determineTokenizedCorrectionSequence', () => { + it(`properly analyzes common-case token-extension - adding a letter to an existing word`, () => { + const context: Context = { + left: 'the quick brown f', + right: '', + startOfBuffer: true, + endOfBuffer: true + }; + + const trueInput: ProbabilityMass = { + sample: { + insert: 'o', + deleteLeft: 0 + }, + p: .5 + }; + + const state = new ContextState(context, testModel); + const transition = state.analyzeTransition(context, [trueInput]); + + + const results = determineTokenizedCorrectionSequence( + transition, + transition.final.displayTokenization, { + matchString: 'fo', + inputSamplingCost: -Math.log(trueInput.p), + knownCost: 0, + totalCost: -Math.log(trueInput.p) + }, + 1 + ); + + assert.deepEqual({...results.rootContext, casingForm: results.rootContext.casingForm}, { + casingForm: undefined, + left: 'the quick brown ', + right: '', + startOfBuffer: true, + endOfBuffer: true + }); + + assert.deepEqual(results.tokenizedCorrection, [ + { + sample: { + insert: 'fo', + deleteLeft: 0 + }, + p: trueInput.p + } + ]); + }); + + it(`properly analyzes common-case whitespace - ending a token and adding a new one`, () => { + const context: Context = { + left: 'the quick brown', + right: '', + startOfBuffer: true, + endOfBuffer: true + }; + + const trueInput: ProbabilityMass = { + sample: { + insert: ' ', + deleteLeft: 0 + }, + p: .5 + }; + + const state = new ContextState(context, testModel); + const transition = state.analyzeTransition(context, [trueInput]); + + + const results = determineTokenizedCorrectionSequence( + transition, + transition.final.displayTokenization, { + matchString: ' ', + inputSamplingCost: -Math.log(trueInput.p), + knownCost: 0, + totalCost: -Math.log(trueInput.p) + }, + 1 + ); + + assert.deepEqual({...results.rootContext, casingForm: results.rootContext.casingForm}, { + casingForm: undefined, + left: 'the quick brown', + right: '', + startOfBuffer: true, + endOfBuffer: true + }); + + assert.deepEqual(results.tokenizedCorrection, [ + { + sample: { + insert: ' ', + deleteLeft: 0 + }, + p: trueInput.p + } + ]); + }); + + it(`properly analyzes common-case word-start - beginning a new token`, () => { + const context: Context = { + left: 'the quick brown ', + right: '', + startOfBuffer: true, + endOfBuffer: true + }; + + const trueInput: ProbabilityMass = { + sample: { + insert: 'f', + deleteLeft: 0 + }, + p: .5 + }; + + const state = new ContextState(context, testModel); + const transition = state.analyzeTransition(context, [trueInput]); + + + const results = determineTokenizedCorrectionSequence( + transition, + transition.final.displayTokenization, { + matchString: 'f', + inputSamplingCost: -Math.log(trueInput.p), + knownCost: 0, + totalCost: -Math.log(trueInput.p) + }, + 1 + ); + + assert.deepEqual({...results.rootContext, casingForm: results.rootContext.casingForm}, { + casingForm: undefined, + left: 'the quick brown ', + right: '', + startOfBuffer: true, + endOfBuffer: true + }); + + assert.deepEqual(results.tokenizedCorrection, [ + { + sample: { + insert: 'f', + deleteLeft: 0 + }, + p: trueInput.p + } + ]); + }); + + it(`properly analyzes post-merge case`, () => { + let context: Context = { + left: 'the quick brown fox ', + right: '', + startOfBuffer: true, + endOfBuffer: true + }; + + const trueInput: ProbabilityMass = { + sample: { + insert: 't', + deleteLeft: 0 + }, + p: .5 + }; + + const constructingState = new ContextState(context, testModel); + const tokens = constructingState.displayTokenization.tokens; + tokens.push(ContextToken.fromRawText(testModel, 'can')); + tokens.push(ContextToken.fromRawText(testModel, '\'')); + + context = models.applyTransform({insert: 'can\'', deleteLeft: 0}, context); + + const state = new ContextState(context, testModel, new ContextTokenization(tokens)); + const transition = state.analyzeTransition(context, [trueInput]); + + const results = determineTokenizedCorrectionSequence( + transition, + transition.final.displayTokenization, { + matchString: 'can\'t', + inputSamplingCost: -Math.log(trueInput.p), + knownCost: 0, + totalCost: -Math.log(trueInput.p) + }, + 1 + ); + + assert.deepEqual({...results.rootContext, casingForm: results.rootContext.casingForm}, { + casingForm: undefined, + left: 'the quick brown fox ', + right: '', + startOfBuffer: true, + endOfBuffer: true + }); + + assert.deepEqual(results.tokenizedCorrection, [ + { + sample: { + insert: 'can\'t', + deleteLeft: 0 + }, + p: trueInput.p + } + ]); + }); + + it(`properly analyzes post-split case`, () => { + const context: Context = { + left: 'the quick brown fox can\'', + right: '', + startOfBuffer: true, + endOfBuffer: true + }; + + const trueInput: ProbabilityMass = { + sample: { + insert: ' ', + deleteLeft: 0 + }, + p: .5 + }; + + const state = new ContextState(context, testModel); + assert.equal(state.displayTokenization.tail.exampleInput, 'can\''); + const transition = state.analyzeTransition(context, [trueInput]); + + const results = determineTokenizedCorrectionSequence( + transition, + transition.final.displayTokenization, { + matchString: ' ', + inputSamplingCost: -Math.log(trueInput.p), + knownCost: 0, + totalCost: -Math.log(trueInput.p) + }, + 1 + ); + + assert.deepEqual({...results.rootContext, casingForm: results.rootContext.casingForm}, { + casingForm: undefined, + left: 'the quick brown fox can\'', + right: '', + startOfBuffer: true, + endOfBuffer: true + }); + + assert.deepEqual(results.tokenizedCorrection, [ + { + sample: { + insert: ' ', + deleteLeft: 0 + }, + p: trueInput.p + } + ]); + }); + + it(`properly analyzes conplex transition - multi-token replacement`, () => { + const context: Context = { + left: 'the quick brown f', + right: '', + startOfBuffer: true, + endOfBuffer: true + }; + + const trueInput: ProbabilityMass = { + sample: { + insert: 'fast red d', + deleteLeft: 'quick brown f'.length + }, + p: .5 + }; + + const state = new ContextState(context, testModel); + const transition = state.analyzeTransition(context, [trueInput]); + + const results = determineTokenizedCorrectionSequence( + transition, + transition.final.displayTokenization, { + matchString: 'd', + inputSamplingCost: -Math.log(trueInput.p), + knownCost: 0, + totalCost: -Math.log(trueInput.p) + }, + 1 + ); + + // Large-scale deletions will receive enhanced handling soon. But, for now, it's + // deleted by the `preservationTransform`, not here. + assert.deepEqual({...results.rootContext, casingForm: results.rootContext.casingForm}, { + casingForm: undefined, + left: 'the quick brown f', + right: '', + startOfBuffer: true, + endOfBuffer: true + }); + + assert.deepEqual(results.tokenizedCorrection, [ + { + sample: { + insert: 'd', + deleteLeft: 0 + }, + p: trueInput.p + } + ]); + + const dummiedTuple: CorrectionPredictionTuple = { + prediction: { + sample: { + transform: { insert: 'dog', deleteLeft: 0 }, + displayAs: 'dog' + }, + p: .25 + }, + correction: { + sample: 'd', + p: trueInput.p + }, + totalProb: .25 * trueInput.p + }; + + results.applyInPost(dummiedTuple); + + assert.deepEqual(dummiedTuple.preservationTransform, { + insert: trueInput.sample.insert.substring(0, KMWString.length(trueInput.sample.insert) - 1), // remove the 'd'. + deleteLeft: trueInput.sample.deleteLeft - 1 + }); + }); +}); \ No newline at end of file