diff --git a/invokeai/app/invocations/flux2_denoise.py b/invokeai/app/invocations/flux2_denoise.py index 1b5ea372d68..d4239e41420 100644 --- a/invokeai/app/invocations/flux2_denoise.py +++ b/invokeai/app/invocations/flux2_denoise.py @@ -53,8 +53,8 @@ "flux2_denoise", title="FLUX2 Denoise", tags=["image", "flux", "flux2", "klein", "denoise"], - category="latents", - version="1.4.0", + category="image", + version="1.5.0", classification=Classification.Prototype, ) class Flux2DenoiseInvocation(BaseInvocation): @@ -101,6 +101,14 @@ class Flux2DenoiseInvocation(BaseInvocation): description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.", input=Input.Connection, ) + guidance: float = InputField( + default=4.0, + ge=0, + le=20, + description="Guidance strength for distilled guidance-embedding models. " + "Inert for all current FLUX.2 Klein variants (their guidance_embeds weights are absent/zero); " + "kept for node-graph compatibility and future guidance-embedded models.", + ) cfg_scale: float = InputField( default=1.0, description=FieldDescriptions.cfg_scale, @@ -467,6 +475,7 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor: txt_ids=txt_ids, timesteps=timesteps, step_callback=self._build_step_callback(context), + guidance=self.guidance, cfg_scale=cfg_scale_list, neg_txt=neg_txt, neg_txt_ids=neg_txt_ids, diff --git a/invokeai/app/invocations/flux2_klein_model_loader.py b/invokeai/app/invocations/flux2_klein_model_loader.py index f39e7688f3e..2091fd380d7 100644 --- a/invokeai/app/invocations/flux2_klein_model_loader.py +++ b/invokeai/app/invocations/flux2_klein_model_loader.py @@ -207,9 +207,9 @@ def _validate_qwen3_encoder_variant(self, context: InvocationContext, main_confi flux2_variant = main_config.variant # Validate the variants match - # Klein4B requires Qwen3_4B, Klein9B/Klein9BBase requires Qwen3_8B + # Klein4B/Klein4BBase requires Qwen3_4B, Klein9B/Klein9BBase requires Qwen3_8B expected_qwen3_variant = None - if flux2_variant == Flux2VariantType.Klein4B: + if flux2_variant in (Flux2VariantType.Klein4B, Flux2VariantType.Klein4BBase): expected_qwen3_variant = Qwen3VariantType.Qwen3_4B elif flux2_variant in (Flux2VariantType.Klein9B, Flux2VariantType.Klein9BBase): expected_qwen3_variant = Qwen3VariantType.Qwen3_8B diff --git a/invokeai/backend/flux/util.py b/invokeai/backend/flux/util.py index da6590c7573..81b10a913ac 100644 --- a/invokeai/backend/flux/util.py +++ b/invokeai/backend/flux/util.py @@ -133,7 +133,24 @@ def get_flux_ae_params() -> AutoEncoderParams: axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, - guidance_embed=True, + guidance_embed=False, + ), + # Flux2 Klein 4B Base is the undistilled foundation model. It shares the same + # architecture as Klein 4B (distilled) and reports guidance_embeds=False in its + # HF transformer config - classical CFG (external negative pass) is the guidance mechanism. + Flux2VariantType.Klein4BBase: FluxParams( + in_channels=64, + vec_in_dim=2560, # Qwen3-4B hidden size (used for pooled output) + context_in_dim=7680, # 3 layers * 2560 = 7680 for Qwen3-4B + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, ), # Flux2 Klein 9B uses Qwen3 8B text encoder with stacked embeddings from layers [9, 18, 27] # The context_in_dim is 3 * hidden_size of Qwen3 (3 * 4096 = 12288) @@ -149,7 +166,24 @@ def get_flux_ae_params() -> AutoEncoderParams: axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, - guidance_embed=True, + guidance_embed=False, + ), + # Flux2 Klein 9B Base is the undistilled foundation model. It shares the same + # architecture as Klein 9B (distilled) and reports guidance_embeds=False in its + # HF transformer config - the guidance scalar is inert for all Klein variants. + Flux2VariantType.Klein9BBase: FluxParams( + in_channels=64, + vec_in_dim=4096, # Qwen3-8B hidden size (used for pooled output) + context_in_dim=12288, # 3 layers * 4096 = 12288 for Qwen3-8B + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, ), } diff --git a/invokeai/backend/flux2/denoise.py b/invokeai/backend/flux2/denoise.py index 7b5bd6194e0..b4438094f7b 100644 --- a/invokeai/backend/flux2/denoise.py +++ b/invokeai/backend/flux2/denoise.py @@ -26,6 +26,7 @@ def denoise( # sampling parameters timesteps: list[float], step_callback: Callable[[PipelineIntermediateState], None], + guidance: float, cfg_scale: list[float], # Negative conditioning for CFG neg_txt: torch.Tensor | None = None, @@ -45,7 +46,10 @@ def denoise( This is a simplified denoise function for FLUX.2 Klein models that uses the diffusers Flux2Transformer2DModel interface. - Note: FLUX.2 Klein has guidance_embeds=False, so no guidance parameter is used. + All current FLUX.2 Klein variants (4B, 4B Base, 9B, 9B Base) have guidance_embeds=False + in their HF transformer config (or absent/zeroed projection weights), so the guidance + value is passed but effectively ignored by the model. The argument is retained for + node-graph compatibility and future variants that may ship trained guidance projections. CFG is applied externally using negative conditioning when cfg_scale != 1.0. Args: @@ -56,6 +60,8 @@ def denoise( txt_ids: Text position IDs tensor. timesteps: List of timesteps for denoising schedule (linear sigmas from 1.0 to 1/n). step_callback: Callback function for progress updates. + guidance: Guidance strength. Inert for all current FLUX.2 Klein variants + (their guidance_embeds projection weights are absent/zero). cfg_scale: List of CFG scale values per step. neg_txt: Negative text embeddings for CFG (optional). neg_txt_ids: Negative text position IDs (optional). @@ -76,9 +82,10 @@ def denoise( img = torch.cat([img, img_cond_seq], dim=1) img_ids = torch.cat([img_ids, img_cond_seq_ids], dim=1) - # Klein has guidance_embeds=False, but the transformer forward() still requires a guidance tensor - # We pass a dummy value (1.0) since it won't affect the output when guidance_embeds=False - guidance = torch.full((img.shape[0],), 1.0, device=img.device, dtype=img.dtype) + # The transformer forward() requires a guidance tensor even when guidance_embeds=False, + # because the Flux2TimestepGuidanceEmbeddings forward signature takes it unconditionally. + # All current Klein variants have guidance_embeds=False, so the value is ignored internally. + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) # Use scheduler if provided use_scheduler = scheduler is not None @@ -121,7 +128,7 @@ def denoise( timestep=t_vec, img_ids=img_ids, txt_ids=txt_ids, - guidance=guidance, + guidance=guidance_vec, return_dict=False, ) @@ -141,7 +148,7 @@ def denoise( timestep=t_vec, img_ids=img_ids, txt_ids=neg_txt_ids if neg_txt_ids is not None else txt_ids, - guidance=guidance, + guidance=guidance_vec, return_dict=False, ) @@ -222,7 +229,7 @@ def denoise( timestep=t_vec, img_ids=img_ids, txt_ids=txt_ids, - guidance=guidance, + guidance=guidance_vec, return_dict=False, ) @@ -242,7 +249,7 @@ def denoise( timestep=t_vec, img_ids=img_ids, txt_ids=neg_txt_ids if neg_txt_ids is not None else txt_ids, - guidance=guidance, + guidance=guidance_vec, return_dict=False, ) diff --git a/invokeai/backend/model_manager/configs/main.py b/invokeai/backend/model_manager/configs/main.py index a2f008f41ed..da5bc5eed36 100644 --- a/invokeai/backend/model_manager/configs/main.py +++ b/invokeai/backend/model_manager/configs/main.py @@ -81,8 +81,8 @@ def from_base( return cls(steps=35, cfg_scale=4.5, width=1024, height=1024) case BaseModelType.Flux2: # Different defaults based on variant - if variant == Flux2VariantType.Klein9BBase: - # Undistilled base model needs more steps + if variant in (Flux2VariantType.Klein4BBase, Flux2VariantType.Klein9BBase): + # Undistilled base models need more steps return cls(steps=28, cfg_scale=1.0, width=1024, height=1024) else: # Distilled models (Klein 4B, Klein 9B) use fewer steps @@ -389,6 +389,7 @@ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | N # Default to Klein9B - callers use filename heuristics to detect Klein9BBase return Flux2VariantType.Klein9B elif context_in_dim == KLEIN_4B_CONTEXT_DIM: + # Default to Klein4B - callers use filename heuristics to detect Klein4BBase return Flux2VariantType.Klein4B elif context_in_dim > 4096: # Unknown FLUX.2 variant, default to 4B @@ -573,10 +574,12 @@ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType: if variant is None: raise NotAMatchError("unable to determine FLUX.2 model variant from state dict") - # Klein 9B Base and Klein 9B have identical architectures. - # Use filename heuristic to detect the Base (undistilled) variant. + # Base (undistilled) and distilled variants share identical architectures. + # Use filename heuristic to detect the Base variant. if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name): return Flux2VariantType.Klein9BBase + if variant == Flux2VariantType.Klein4B and _filename_suggests_base(mod.name): + return Flux2VariantType.Klein4BBase return variant @@ -745,10 +748,12 @@ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType: if variant is None: raise NotAMatchError("unable to determine FLUX.2 model variant from state dict") - # Klein 9B Base and Klein 9B have identical architectures. - # Use filename heuristic to detect the Base (undistilled) variant. + # Base (undistilled) and distilled variants share identical architectures. + # Use filename heuristic to detect the Base variant. if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name): return Flux2VariantType.Klein9BBase + if variant == Flux2VariantType.Klein4B and _filename_suggests_base(mod.name): + return Flux2VariantType.Klein4BBase return variant @@ -856,11 +861,10 @@ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType: """Determine the FLUX.2 variant from the transformer config. FLUX.2 Klein uses Qwen3 text encoder with larger joint_attention_dim: - - Klein 4B: joint_attention_dim = 7680 (3×Qwen3-4B hidden size) + - Klein 4B/4B Base: joint_attention_dim = 7680 (3×Qwen3-4B hidden size) - Klein 9B/9B Base: joint_attention_dim = 12288 (3×Qwen3-8B hidden size) - Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures - and both have guidance_embeds=False. We use a filename heuristic to detect Base models. + Distilled and Base variants share identical architectures. We use a filename heuristic to detect Base models. """ KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560 KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096 @@ -875,6 +879,8 @@ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType: return Flux2VariantType.Klein9BBase return Flux2VariantType.Klein9B elif joint_attention_dim == KLEIN_4B_CONTEXT_DIM: + if _filename_suggests_base(mod.name): + return Flux2VariantType.Klein4BBase return Flux2VariantType.Klein4B elif joint_attention_dim > 4096: # Unknown FLUX.2 variant, default to 4B diff --git a/invokeai/backend/model_manager/taxonomy.py b/invokeai/backend/model_manager/taxonomy.py index a141d43cf42..8f1fb00b5b7 100644 --- a/invokeai/backend/model_manager/taxonomy.py +++ b/invokeai/backend/model_manager/taxonomy.py @@ -132,7 +132,10 @@ class Flux2VariantType(str, Enum): """FLUX.2 model variants.""" Klein4B = "klein_4b" - """Flux2 Klein 4B variant using Qwen3 4B text encoder.""" + """Flux2 Klein 4B variant using Qwen3 4B text encoder (distilled).""" + + Klein4BBase = "klein_4b_base" + """Flux2 Klein 4B Base variant - undistilled foundation model using Qwen3 4B text encoder.""" Klein9B = "klein_9b" """Flux2 Klein 9B variant using Qwen3 8B text encoder (distilled).""" diff --git a/invokeai/frontend/web/openapi.json b/invokeai/frontend/web/openapi.json index 19e5a3a68e9..6bb5aa662f0 100644 --- a/invokeai/frontend/web/openapi.json +++ b/invokeai/frontend/web/openapi.json @@ -19924,7 +19924,7 @@ }, "Flux2VariantType": { "type": "string", - "enum": ["klein_4b", "klein_9b", "klein_9b_base"], + "enum": ["klein_4b", "klein_4b_base", "klein_9b", "klein_9b_base"], "title": "Flux2VariantType", "description": "FLUX.2 model variants." }, diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index e123d0ebd06..105ad3dfd67 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -63,6 +63,8 @@ export const ImageMetadataActions = memo((props: Props) => { + + ); diff --git a/invokeai/frontend/web/src/features/metadata/parsing.test.tsx b/invokeai/frontend/web/src/features/metadata/parsing.test.tsx new file mode 100644 index 00000000000..bb295303273 --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/parsing.test.tsx @@ -0,0 +1,174 @@ +import type { AppStore } from 'app/store/store'; +import type * as paramsSliceModule from 'features/controlLayers/store/paramsSlice'; +import { ImageMetadataHandlers } from 'features/metadata/parsing'; +import type * as modelsApiModule from 'services/api/endpoints/models'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +// --------------------------------------------------------------------------- +// Module mocks +// +// We are testing only the *gating* logic of the model-related metadata +// handlers (`VAEModel`, `KleinVAEModel`, `KleinQwen3EncoderModel`). The actual +// model lookup goes through `parseModelIdentifier`, which dispatches RTK +// Query thunks. We stub the models endpoint so that any lookup resolves to a +// canned model identifier — the parse step then succeeds and the assertions +// inside each handler become observable. +// --------------------------------------------------------------------------- + +let currentBase: string | null = 'flux2'; + +vi.mock('features/controlLayers/store/paramsSlice', async (importOriginal) => { + const mod = await importOriginal(); + return { ...mod, selectBase: () => currentBase }; +}); + +const fakeModel = (type: 'vae' | 'qwen3_encoder', base: string) => ({ + key: `${type}-key`, + hash: 'hash', + name: `Some ${type}`, + base, + type, +}); + +let nextResolved: ReturnType = fakeModel('vae', 'flux2'); + +vi.mock('services/api/endpoints/models', async (importOriginal) => { + const mod = await importOriginal(); + return { + ...mod, + modelsApi: { + ...mod.modelsApi, + endpoints: { + ...mod.modelsApi.endpoints, + getModelConfig: { initiate: (key: string) => ({ type: 'rtkq/initiate', key }) }, + }, + }, + }; +}); + +const makeStore = (): AppStore => + ({ + dispatch: vi.fn(() => ({ + unwrap: () => Promise.resolve(nextResolved), + })), + getState: () => ({}), + }) as unknown as AppStore; + +beforeEach(() => { + currentBase = 'flux2'; + nextResolved = fakeModel('vae', 'flux2'); +}); + +describe('ImageMetadataHandlers — Klein recall gating', () => { + describe('KleinVAEModel', () => { + it('parses metadata.vae when the current main model is FLUX.2 Klein', async () => { + currentBase = 'flux2'; + nextResolved = fakeModel('vae', 'flux2'); + const store = makeStore(); + + const parsed = await ImageMetadataHandlers.KleinVAEModel.parse({ vae: nextResolved }, store); + + expect(parsed.key).toBe('vae-key'); + expect(parsed.type).toBe('vae'); + }); + + it('rejects parsing when the current main model is not FLUX.2 Klein', async () => { + currentBase = 'sdxl'; + nextResolved = fakeModel('vae', 'flux2'); + const store = makeStore(); + + await expect(ImageMetadataHandlers.KleinVAEModel.parse({ vae: nextResolved }, store)).rejects.toThrow(); + }); + }); + + describe('KleinQwen3EncoderModel', () => { + it('parses metadata.qwen3_encoder when the current main model is FLUX.2 Klein', async () => { + currentBase = 'flux2'; + nextResolved = fakeModel('qwen3_encoder', 'flux2'); + const store = makeStore(); + + const parsed = await ImageMetadataHandlers.KleinQwen3EncoderModel.parse({ qwen3_encoder: nextResolved }, store); + + expect(parsed.key).toBe('qwen3_encoder-key'); + expect(parsed.type).toBe('qwen3_encoder'); + }); + + it('rejects parsing when the current main model is not FLUX.2 Klein', async () => { + currentBase = 'sdxl'; + nextResolved = fakeModel('qwen3_encoder', 'flux2'); + const store = makeStore(); + + await expect( + ImageMetadataHandlers.KleinQwen3EncoderModel.parse({ qwen3_encoder: nextResolved }, store) + ).rejects.toThrow(); + }); + }); + + describe('VAEModel (generic)', () => { + // The generic VAEModel handler must NOT also fire for FLUX.2 / Z-Image + // images, otherwise the metadata viewer renders duplicate VAE rows next + // to the dedicated KleinVAEModel / ZImageVAEModel handlers. + it.each(['flux2', 'z-image'])('rejects parsing when current base is %s', async (base) => { + currentBase = base; + nextResolved = fakeModel('vae', base); + const store = makeStore(); + + await expect(ImageMetadataHandlers.VAEModel.parse({ vae: nextResolved }, store)).rejects.toThrow(); + }); + + it('parses successfully for non-Klein, non-Z-Image bases', async () => { + currentBase = 'sdxl'; + nextResolved = fakeModel('vae', 'sdxl'); + const store = makeStore(); + + const parsed = await ImageMetadataHandlers.VAEModel.parse({ vae: nextResolved }, store); + expect(parsed.key).toBe('vae-key'); + }); + }); + + describe('Guidance (legacy FLUX.2 gating)', () => { + // Prior to the Klein guidance cleanup, FLUX.2 images wrote a `guidance` + // field into metadata. The guidance scalar is inert for all current Klein + // variants, so legacy values must not be recalled into the shared guidance + // state — otherwise they leak back into FLUX.1 when the user switches + // models. + it('rejects parsing when the image was generated with a FLUX.2 model', async () => { + const store = makeStore(); + + await expect( + Promise.resolve().then(() => + ImageMetadataHandlers.Guidance.parse( + { + model: { key: 'k', hash: 'h', name: 'Klein 9B Base', base: 'flux2', type: 'main' }, + guidance: 3.5, + }, + store + ) + ) + ).rejects.toThrow(); + }); + + it('parses successfully for FLUX.1 metadata', async () => { + const store = makeStore(); + + const parsed = await ImageMetadataHandlers.Guidance.parse( + { + model: { key: 'k', hash: 'h', name: 'FLUX Dev', base: 'flux', type: 'main' }, + guidance: 3.5, + }, + store + ); + + expect(parsed).toBe(3.5); + }); + + it('parses successfully when no model metadata is present', async () => { + // Metadata without a model field should still parse (back-compat for + // images where only scalar params were saved). + const store = makeStore(); + + const parsed = await ImageMetadataHandlers.Guidance.parse({ guidance: 3.5 }, store); + expect(parsed).toBe(3.5); + }); + }); +}); diff --git a/invokeai/frontend/web/src/features/metadata/parsing.tsx b/invokeai/frontend/web/src/features/metadata/parsing.tsx index 24b643da319..cf55f378106 100644 --- a/invokeai/frontend/web/src/features/metadata/parsing.tsx +++ b/invokeai/frontend/web/src/features/metadata/parsing.tsx @@ -379,6 +379,15 @@ const Guidance: SingleMetadataHandler = { [SingleMetadataKey]: true, type: 'Guidance', parse: (metadata, _store) => { + // Legacy FLUX.2 images may still carry a `guidance` field, but guidance_embeds + // is inert for all current Klein variants. Reject parsing for FLUX.2 metadata + // so the handler is skipped on both display and recall - avoids leaking a stale + // value into the shared guidance param (which is still used by FLUX.1). + const rawModel = getProperty(metadata, 'model'); + const modelBase = (rawModel as { base?: unknown } | undefined)?.base; + if (modelBase === 'flux2') { + throw new Error('Guidance is not used for FLUX.2 Klein models.'); + } const raw = getProperty(metadata, 'guidance'); const parsed = zParameterGuidance.parse(raw); return Promise.resolve(parsed); @@ -957,6 +966,9 @@ const VAEModel: SingleMetadataHandler = { const parsed = await parseModelIdentifier(raw, store, 'vae'); assert(parsed.type === 'vae'); assert(isCompatibleWithMainModel(parsed, store)); + // Z-Image and FLUX.2 Klein have dedicated VAE handlers; avoid rendering a duplicate row. + const base = selectBase(store.getState()); + assert(base !== 'z-image' && base !== 'flux2', 'VAEModel handler does not apply to Z-Image or FLUX.2 Klein'); return Promise.resolve(parsed); }, recall: (value, store) => { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/models.ts b/invokeai/frontend/web/src/features/modelManagerV2/models.ts index 1c0d2b20c7c..20f44850014 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/models.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/models.ts @@ -235,6 +235,7 @@ export const MODEL_VARIANT_TO_LONG_NAME: Record = { dev_fill: 'FLUX Dev - Fill', schnell: 'FLUX Schnell', klein_4b: 'FLUX.2 Klein 4B', + klein_4b_base: 'FLUX.2 Klein 4B Base', klein_9b: 'FLUX.2 Klein 9B', klein_9b_base: 'FLUX.2 Klein 9B Base', turbo: 'Z-Image Turbo', diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index d1aa0523a43..ba527df0354 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -158,7 +158,7 @@ export const zSubModelType = z.enum([ export const zClipVariantType = z.enum(['large', 'gigantic']); export const zModelVariantType = z.enum(['normal', 'inpaint', 'depth']); export const zFluxVariantType = z.enum(['dev', 'dev_fill', 'schnell']); -export const zFlux2VariantType = z.enum(['klein_4b', 'klein_9b', 'klein_9b_base']); +export const zFlux2VariantType = z.enum(['klein_4b', 'klein_4b_base', 'klein_9b', 'klein_9b_base']); export const zZImageVariantType = z.enum(['turbo', 'zbase']); const zQwenImageVariantType = z.enum(['generate', 'edit']); export const zQwen3VariantType = z.enum(['qwen3_4b', 'qwen3_8b', 'qwen3_06b']); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts index 7f01becc3df..5b9f3d0a468 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts @@ -1,4 +1,15 @@ -import { afterEach, describe, expect, it, vi } from 'vitest'; +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +// --------------------------------------------------------------------------- +// Module mocks +// +// `buildFLUXGraph` pulls in a large slice of the app: redux selectors, every +// `add*` helper, validators, the canvas manager, etc. The function itself only +// orchestrates these; the units under test here are the orchestration bits +// (variant-gated guidance, scheduler propagation, metadata persistence, and +// qwen3_source_model auto-detection for GGUF Klein). So we stub out every +// collaborator and assert against the resulting `Graph` object. +// --------------------------------------------------------------------------- vi.mock('app/logging/logger', () => ({ logger: () => ({ @@ -71,39 +82,53 @@ const diffusers9BSourceModelFixture = { variant: 'klein_9b', }; -// --- Mutable state --- +const makeFlux2Model = (variant: string) => ({ + key: `flux2-${variant}`, + hash: 'hash', + name: `FLUX.2 Klein ${variant}`, + base: 'flux2', + type: 'main', + format: 'diffusers', + variant, +}); + +// --- Mutable state shared by all tests --- -let model: Record = { ...flux2DiffusersModel }; -let kleinVaeModel: Record | null = null; -let kleinQwen3EncoderModel: Record | null = null; +let currentModel: Record | null = null; +let currentKleinVae: Record | null = null; +let currentKleinQwen3: Record | null = null; let diffusersModels: Record[] = []; +const mockParams = { + guidance: 3.5, + steps: 28, + fluxScheduler: 'euler' as const, + fluxDypePreset: 'off' as const, + fluxDypeScale: 1, + fluxDypeExponent: 1, + fluxVAE: null, + t5EncoderModel: null, + clipEmbedModel: null, +}; + vi.mock('features/controlLayers/store/paramsSlice', () => ({ - selectMainModelConfig: vi.fn(() => model), - selectParamsSlice: vi.fn(() => ({ - guidance: 4, - steps: 20, - fluxScheduler: 'euler', - fluxDypePreset: 'off', - fluxDypeScale: 2.0, - fluxDypeExponent: 2.0, - fluxVAE: null, - t5EncoderModel: null, - clipEmbedModel: null, - })), - selectKleinVaeModel: vi.fn(() => kleinVaeModel), - selectKleinQwen3EncoderModel: vi.fn(() => kleinQwen3EncoderModel), + selectMainModelConfig: vi.fn(() => currentModel), + selectParamsSlice: vi.fn(() => mockParams), + selectKleinVaeModel: vi.fn(() => currentKleinVae), + selectKleinQwen3EncoderModel: vi.fn(() => currentKleinQwen3), })); vi.mock('features/controlLayers/store/refImagesSlice', () => ({ - selectRefImagesSlice: vi.fn(() => ({ - entities: [], - })), + selectRefImagesSlice: vi.fn(() => ({ entities: [] })), })); vi.mock('features/controlLayers/store/selectors', () => ({ + selectCanvasSlice: vi.fn(() => ({ + bbox: { rect: { x: 0, y: 0, width: 1024, height: 1024 } }, + controlLayers: { entities: [] }, + regionalGuidance: { entities: [] }, + })), selectCanvasMetadata: vi.fn(() => ({})), - selectCanvasSlice: vi.fn(() => ({})), })); vi.mock('features/controlLayers/store/types', () => ({ @@ -115,69 +140,48 @@ vi.mock('features/controlLayers/store/validators', () => ({ getGlobalReferenceImageWarnings: vi.fn(() => []), })); -vi.mock('features/nodes/util/graph/generation/addFlux2KleinLoRAs', () => ({ - addFlux2KleinLoRAs: vi.fn(), -})); - -vi.mock('features/nodes/util/graph/generation/addFLUXFill', () => ({ - addFLUXFill: vi.fn(), +vi.mock('features/ui/store/uiSelectors', () => ({ + selectActiveTab: vi.fn(() => 'generate'), })); -vi.mock('features/nodes/util/graph/generation/addFLUXLoRAs', () => ({ - addFLUXLoRAs: vi.fn(), +vi.mock('features/nodes/util/graph/graphBuilderUtils', () => ({ + selectCanvasOutputFields: vi.fn(() => ({})), + selectPresetModifiedPrompts: vi.fn(() => ({ + positive: 'a prompt', + negative: '', + })), })); -vi.mock('features/nodes/util/graph/generation/addFLUXRedux', () => ({ - addFLUXReduxes: vi.fn(), +// Helper add* functions: the tests care about the FLUX.2 orchestration path +// (metadata, denoise inputs, loader inputs). The actual node graphs produced +// by these helpers are irrelevant here. +vi.mock('features/nodes/util/graph/generation/addTextToImage', () => ({ + addTextToImage: vi.fn(({ l2i }) => l2i), })); - vi.mock('features/nodes/util/graph/generation/addImageToImage', () => ({ addImageToImage: vi.fn(), })); - -vi.mock('features/nodes/util/graph/generation/addInpaint', () => ({ - addInpaint: vi.fn(), -})); - +vi.mock('features/nodes/util/graph/generation/addInpaint', () => ({ addInpaint: vi.fn() })); +vi.mock('features/nodes/util/graph/generation/addOutpaint', () => ({ addOutpaint: vi.fn() })); vi.mock('features/nodes/util/graph/generation/addNSFWChecker', () => ({ addNSFWChecker: vi.fn((_g, node) => node), })); - -vi.mock('features/nodes/util/graph/generation/addOutpaint', () => ({ - addOutpaint: vi.fn(), -})); - -vi.mock('features/nodes/util/graph/generation/addRegions', () => ({ - addRegions: vi.fn(), -})); - -vi.mock('features/nodes/util/graph/generation/addTextToImage', () => ({ - addTextToImage: vi.fn(({ l2i }) => l2i), -})); - vi.mock('features/nodes/util/graph/generation/addWatermarker', () => ({ addWatermarker: vi.fn((_g, node) => node), })); - +vi.mock('features/nodes/util/graph/generation/addRegions', () => ({ addRegions: vi.fn(() => []) })); +vi.mock('features/nodes/util/graph/generation/addFLUXLoRAs', () => ({ addFLUXLoRAs: vi.fn() })); +vi.mock('features/nodes/util/graph/generation/addFlux2KleinLoRAs', () => ({ addFlux2KleinLoRAs: vi.fn() })); +vi.mock('features/nodes/util/graph/generation/addFLUXFill', () => ({ addFLUXFill: vi.fn() })); +vi.mock('features/nodes/util/graph/generation/addFLUXRedux', () => ({ + addFLUXReduxes: vi.fn(() => ({ addedFLUXReduxes: 0 })), +})); vi.mock('features/nodes/util/graph/generation/addControlAdapters', () => ({ + addControlNets: vi.fn(() => Promise.resolve({ addedControlNets: 0 })), addControlLoRA: vi.fn(), - addControlNets: vi.fn(), })); - vi.mock('features/nodes/util/graph/generation/addIPAdapters', () => ({ - addIPAdapters: vi.fn(), -})); - -vi.mock('features/nodes/util/graph/graphBuilderUtils', () => ({ - selectCanvasOutputFields: vi.fn(() => ({})), - selectPresetModifiedPrompts: vi.fn(() => ({ - positive: 'a prompt', - negative: '', - })), -})); - -vi.mock('features/ui/store/uiSelectors', () => ({ - selectActiveTab: vi.fn(() => 'generation'), + addIPAdapters: vi.fn(() => ({ addedIPAdapters: 0 })), })); vi.mock('services/api/hooks/modelsByType', () => ({ @@ -192,20 +196,36 @@ vi.mock('services/api/types', async () => { }; }); -import { buildFLUXGraph } from './buildFLUXGraph'; +import type { GraphBuilderArg } from 'features/nodes/util/graph/types'; +import type { Invocation } from 'services/api/types'; -const buildGraphArg = () => ({ - generationMode: 'txt2img' as const, - manager: null, - state: { - system: { - shouldUseNSFWChecker: false, - shouldUseWatermarker: false, +import { buildFLUXGraph } from './buildFLUXGraph'; +import type { Graph } from './Graph'; + +// --------------------------------------------------------------------------- +// Test harness +// --------------------------------------------------------------------------- + +const buildGraphArg = (): GraphBuilderArg => + ({ + generationMode: 'txt2img', + manager: null, + state: { + system: { + shouldUseNSFWChecker: false, + shouldUseWatermarker: false, + }, }, - } as never, -}); + }) as unknown as GraphBuilderArg; + +const findFlux2Denoise = (g: Graph): Invocation<'flux2_denoise'> | undefined => { + const nodes = (g as unknown as { _graph: { nodes: Record } })._graph.nodes; + return Object.values(nodes).find((n) => n.type === 'flux2_denoise') as Invocation<'flux2_denoise'> | undefined; +}; + +const getMetadata = (g: Graph): Record => + (g as unknown as { getMetadataNode: () => Record }).getMetadataNode(); -/** Find the flux2_klein_model_loader node in the graph. */ const getLoaderNode = async () => { const { g } = await buildFLUXGraph(buildGraphArg()); const graph = g.getGraph(); @@ -213,24 +233,85 @@ const getLoaderNode = async () => { return loaderEntry?.[1] as Record | undefined; }; -describe('buildFLUXGraph – FLUX.2 Klein qwen3_source_model', () => { - afterEach(() => { - nextId = 0; - model = { ...flux2DiffusersModel }; - kleinVaeModel = null; - kleinQwen3EncoderModel = null; - diffusersModels = []; +const resetState = () => { + nextId = 0; + currentModel = null; + currentKleinVae = null; + currentKleinQwen3 = null; + diffusersModels = []; +}; + +beforeEach(resetState); +afterEach(resetState); + +describe('buildFLUXGraph (FLUX.2 Klein)', () => { + describe('guidance gating', () => { + // guidance_embeds is inert for all current FLUX.2 Klein variants (weights are + // absent or zeroed), so the linear UI does not expose it and the graph builder + // must not write it into the denoise node or metadata. + it.each(['klein_9b_base', 'klein_9b', 'klein_4b_base', 'klein_4b'])( + 'omits guidance from metadata and denoise for variant %s', + async (variant) => { + currentModel = makeFlux2Model(variant); + + const { g } = await buildFLUXGraph(buildGraphArg()); + + const metadata = getMetadata(g); + expect(metadata.guidance).toBeUndefined(); + + const denoise = findFlux2Denoise(g); + expect(denoise).toBeDefined(); + expect(denoise?.guidance).toBeUndefined(); + } + ); }); + describe('scheduler persistence', () => { + it('writes the FLUX scheduler into metadata and the denoise node for FLUX.2', async () => { + currentModel = makeFlux2Model('klein_9b_base'); + + const { g } = await buildFLUXGraph(buildGraphArg()); + + expect(getMetadata(g).scheduler).toBe(mockParams.fluxScheduler); + expect(findFlux2Denoise(g)?.scheduler).toBe(mockParams.fluxScheduler); + }); + }); + + describe('Klein VAE / Qwen3 metadata', () => { + it('persists separately selected Klein VAE and Qwen3 encoder into metadata', async () => { + currentModel = makeFlux2Model('klein_9b_base'); + currentKleinVae = { key: 'vae-1', hash: 'h', name: 'Klein VAE', base: 'flux2', type: 'vae' }; + currentKleinQwen3 = { key: 'q3-1', hash: 'h', name: 'Qwen3', base: 'flux2', type: 'qwen3_encoder' }; + + const { g } = await buildFLUXGraph(buildGraphArg()); + + const metadata = getMetadata(g); + expect(metadata.vae).toEqual(currentKleinVae); + expect(metadata.qwen3_encoder).toEqual(currentKleinQwen3); + }); + + it('omits vae / qwen3_encoder when none are selected', async () => { + currentModel = makeFlux2Model('klein_9b_base'); + + const { g } = await buildFLUXGraph(buildGraphArg()); + + const metadata = getMetadata(g); + expect(metadata.vae).toBeUndefined(); + expect(metadata.qwen3_encoder).toBeUndefined(); + }); + }); +}); + +describe('buildFLUXGraph – FLUX.2 Klein qwen3_source_model', () => { it('does not set qwen3_source_model when main model is diffusers', async () => { - model = { ...flux2DiffusersModel }; + currentModel = { ...flux2DiffusersModel }; const loader = await getLoaderNode(); expect(loader).toBeDefined(); expect(loader!.qwen3_source_model).toBeUndefined(); }); it('sets qwen3_source_model when main model is GGUF and a diffusers model is available', async () => { - model = { ...flux2GGUFModel }; + currentModel = { ...flux2GGUFModel }; diffusersModels = [diffusersSourceModelFixture]; const loader = await getLoaderNode(); @@ -245,9 +326,9 @@ describe('buildFLUXGraph – FLUX.2 Klein qwen3_source_model', () => { }); it('does not set qwen3_source_model when main model is GGUF but standalone VAE and Qwen3 are both selected', async () => { - model = { ...flux2GGUFModel }; - kleinVaeModel = kleinVaeModelFixture; - kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture; + currentModel = { ...flux2GGUFModel }; + currentKleinVae = kleinVaeModelFixture; + currentKleinQwen3 = kleinQwen3EncoderModelFixture; diffusersModels = [diffusersSourceModelFixture]; const loader = await getLoaderNode(); @@ -256,7 +337,7 @@ describe('buildFLUXGraph – FLUX.2 Klein qwen3_source_model', () => { }); it('does not set qwen3_source_model when main model is GGUF and no diffusers model is available', async () => { - model = { ...flux2GGUFModel }; + currentModel = { ...flux2GGUFModel }; diffusersModels = []; const loader = await getLoaderNode(); @@ -265,9 +346,9 @@ describe('buildFLUXGraph – FLUX.2 Klein qwen3_source_model', () => { }); it('sets qwen3_source_model when only VAE is selected but Qwen3 is missing', async () => { - model = { ...flux2GGUFModel }; - kleinVaeModel = kleinVaeModelFixture; - kleinQwen3EncoderModel = null; + currentModel = { ...flux2GGUFModel }; + currentKleinVae = kleinVaeModelFixture; + currentKleinQwen3 = null; diffusersModels = [diffusersSourceModelFixture]; const loader = await getLoaderNode(); @@ -276,9 +357,9 @@ describe('buildFLUXGraph – FLUX.2 Klein qwen3_source_model', () => { }); it('sets qwen3_source_model when only Qwen3 is selected but VAE is missing', async () => { - model = { ...flux2GGUFModel }; - kleinVaeModel = null; - kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture; + currentModel = { ...flux2GGUFModel }; + currentKleinVae = null; + currentKleinQwen3 = kleinQwen3EncoderModelFixture; diffusersModels = [diffusersSourceModelFixture]; const loader = await getLoaderNode(); @@ -287,9 +368,9 @@ describe('buildFLUXGraph – FLUX.2 Klein qwen3_source_model', () => { }); it('passes standalone vae_model and qwen3_encoder_model when selected', async () => { - model = { ...flux2DiffusersModel }; - kleinVaeModel = kleinVaeModelFixture; - kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture; + currentModel = { ...flux2DiffusersModel }; + currentKleinVae = kleinVaeModelFixture; + currentKleinQwen3 = kleinQwen3EncoderModelFixture; const loader = await getLoaderNode(); expect(loader).toBeDefined(); @@ -300,43 +381,38 @@ describe('buildFLUXGraph – FLUX.2 Klein qwen3_source_model', () => { describe('variant matching', () => { it('selects a variant-matching diffusers model when multiple are available', async () => { - model = { ...flux2GGUF9BModel }; + currentModel = { ...flux2GGUF9BModel }; diffusersModels = [diffusersSourceModelFixture, diffusers9BSourceModelFixture]; const loader = await getLoaderNode(); expect(loader).toBeDefined(); - // Should pick the 9B diffusers model, not the 4B expect(loader!.qwen3_source_model).toEqual(expect.objectContaining({ key: diffusers9BSourceModelFixture.key })); }); it('falls back to any diffusers model for VAE when standalone Qwen3 is selected but no variant match', async () => { - model = { ...flux2GGUF9BModel }; - kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture; - // Only 4B diffusers available, no 9B — but Qwen3 is already provided standalone + currentModel = { ...flux2GGUF9BModel }; + currentKleinQwen3 = kleinQwen3EncoderModelFixture; diffusersModels = [diffusersSourceModelFixture]; const loader = await getLoaderNode(); expect(loader).toBeDefined(); - // Should use the 4B diffusers model just for VAE extraction expect(loader!.qwen3_source_model).toEqual(expect.objectContaining({ key: diffusersSourceModelFixture.key })); }); it('does not set qwen3_source_model when GGUF 9B with only 4B diffusers available and no standalone Qwen3', async () => { - model = { ...flux2GGUF9BModel }; - kleinQwen3EncoderModel = null; - // Only 4B diffusers available — wrong variant for Qwen3, no standalone Qwen3 selected + currentModel = { ...flux2GGUF9BModel }; + currentKleinQwen3 = null; diffusersModels = [diffusersSourceModelFixture]; const loader = await getLoaderNode(); expect(loader).toBeDefined(); - // Should NOT use the 4B diffusers since it has the wrong Qwen3 encoder expect(loader!.qwen3_source_model).toBeUndefined(); }); }); describe('graph structure', () => { it('uses flux2_klein_model_loader for flux2 models', async () => { - model = { ...flux2DiffusersModel }; + currentModel = { ...flux2DiffusersModel }; const { g } = await buildFLUXGraph(buildGraphArg()); const graph = g.getGraph(); const nodeIds = Object.keys(graph.nodes); @@ -344,7 +420,7 @@ describe('buildFLUXGraph – FLUX.2 Klein qwen3_source_model', () => { }); it('uses flux2_vae_decode for flux2 models', async () => { - model = { ...flux2DiffusersModel }; + currentModel = { ...flux2DiffusersModel }; const { g } = await buildFLUXGraph(buildGraphArg()); const graph = g.getGraph(); const nodeIds = Object.keys(graph.nodes); @@ -352,7 +428,7 @@ describe('buildFLUXGraph – FLUX.2 Klein qwen3_source_model', () => { }); it('uses flux2_klein_text_encoder for flux2 models', async () => { - model = { ...flux2DiffusersModel }; + currentModel = { ...flux2DiffusersModel }; const { g } = await buildFLUXGraph(buildGraphArg()); const graph = g.getGraph(); const nodeIds = Object.keys(graph.nodes); @@ -360,7 +436,7 @@ describe('buildFLUXGraph – FLUX.2 Klein qwen3_source_model', () => { }); it('uses flux2_denoise for flux2 models', async () => { - model = { ...flux2DiffusersModel }; + currentModel = { ...flux2DiffusersModel }; const { g } = await buildFLUXGraph(buildGraphArg()); const graph = g.getGraph(); const nodeTypes = Object.values(graph.nodes).map((n) => n.type); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index 407c921421f..dafcd9310ec 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -180,6 +180,7 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise = { model: Graph.getModelMetadataField(model), steps, + scheduler: fluxScheduler, }; if (kleinVaeModel) { flux2Metadata.vae = kleinVaeModel; diff --git a/invokeai/frontend/web/src/features/parameters/util/flux2Klein.ts b/invokeai/frontend/web/src/features/parameters/util/flux2Klein.ts index b9508a4f82d..8b9363895a0 100644 --- a/invokeai/frontend/web/src/features/parameters/util/flux2Klein.ts +++ b/invokeai/frontend/web/src/features/parameters/util/flux2Klein.ts @@ -6,6 +6,7 @@ */ export const KLEIN_TO_QWEN3_VARIANT_MAP: Record = { klein_4b: 'qwen3_4b', + klein_4b_base: 'qwen3_4b', klein_9b: 'qwen3_8b', klein_9b_base: 'qwen3_8b', }; diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx index 7e07c9f5648..220008a38b0 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx @@ -105,12 +105,12 @@ export const GenerationSettingsAccordion = memo(() => { !isZImage && !isQwenImage && !isAnima && } - {!isExternal && isFLUX && } + {!isExternal && (isFLUX || isFlux2) && } {!isExternal && isZImage && } {!isExternal && isAnima && } {modelSupportsSteps && } {isExternal && modelSupportsGuidance && } - {!isExternal && (isFLUX || isFlux2) && modelConfig && !isFluxFillMainModelModelConfig(modelConfig) && ( + {!isExternal && isFLUX && modelConfig && !isFluxFillMainModelModelConfig(modelConfig) && ( )} {!isExternal && !isFLUX && !isFlux2 && } diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index cd184c0d155..dde4ad0485e 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -9901,6 +9901,12 @@ export type components = { * @default null */ negative_text_conditioning?: components["schemas"]["FluxConditioningField"] | null; + /** + * Guidance + * @description Guidance strength for distilled guidance-embedding models. Inert for all current FLUX.2 Klein variants (their guidance_embeds weights are absent/zero); kept for node-graph compatibility and future guidance-embedded models. + * @default 4 + */ + guidance?: number; /** * CFG Scale * @description Classifier-Free Guidance scale @@ -10331,7 +10337,7 @@ export type components = { * @description FLUX.2 model variants. * @enum {string} */ - Flux2VariantType: "klein_4b" | "klein_9b" | "klein_9b_base"; + Flux2VariantType: "klein_4b" | "klein_4b_base" | "klein_9b" | "klein_9b_base"; /** * FluxConditioningCollectionOutput * @description Base class for nodes that output a collection of conditioning tensors