diff --git a/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx b/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx index 2bf8d8643c7..5a4abdd4d28 100644 --- a/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx +++ b/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx @@ -4,13 +4,14 @@ import { useStore } from '@nanostores/react'; import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetQueueStatusQuery } from 'services/api/endpoints/queue'; -import { $isConnected, $lastProgressEvent } from 'services/events/stores'; +import { $isConnected, $lastProgressEvent, $loadingModelsCount } from 'services/events/stores'; const ProgressBar = (props: ProgressProps) => { const { t } = useTranslation(); const { data: queueStatus } = useGetQueueStatusQuery(); const isConnected = useStore($isConnected); const lastProgressEvent = useStore($lastProgressEvent); + const loadingModelsCount = useStore($loadingModelsCount); const value = useMemo(() => { if (!lastProgressEvent) { return 0; @@ -23,6 +24,10 @@ const ProgressBar = (props: ProgressProps) => { return false; } + if (loadingModelsCount > 0) { + return true; + } + if (!queueStatus?.queue.in_progress) { return false; } @@ -40,7 +45,7 @@ const ProgressBar = (props: ProgressProps) => { } return false; - }, [isConnected, lastProgressEvent, queueStatus?.queue.in_progress]); + }, [isConnected, lastProgressEvent, queueStatus?.queue.in_progress, loadingModelsCount]); return ( { log.debug('Connect error'); setIsConnected(false); $lastProgressEvent.set(null); + $loadingModelsCount.set(0); if (error && error.message) { const data: string | undefined = (error as unknown as { data: string | undefined }).data; if (data === 'ERR_UNAUTHENTICATED') { @@ -95,6 +97,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis socket.on('disconnect', () => { log.debug('Disconnected'); $lastProgressEvent.set(null); + $loadingModelsCount.set(0); setIsConnected(false); }); @@ -183,6 +186,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis const message = `Model load started: ${name} (${extras.join(', ')})`; log.debug({ data }, message); + $loadingModelsCount.set($loadingModelsCount.get() + 1); }); socket.on('model_load_complete', (data) => { @@ -197,6 +201,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis const message = `Model load complete: ${name} (${extras.join(', ')})`; log.debug({ data }, message); + $loadingModelsCount.set(Math.max(0, $loadingModelsCount.get() - 1)); }); socket.on('download_started', (data) => { diff --git a/invokeai/frontend/web/src/services/events/stores.ts b/invokeai/frontend/web/src/services/events/stores.ts index 720ba920cf2..180f4a3a636 100644 --- a/invokeai/frontend/web/src/services/events/stores.ts +++ b/invokeai/frontend/web/src/services/events/stores.ts @@ -6,6 +6,7 @@ import type { AppSocket } from 'services/events/types'; export const $socket = atom(null); export const $isConnected = atom(false); export const $lastProgressEvent = atom(null); +export const $loadingModelsCount = atom(0); export const $lastProgressMessage = computed($lastProgressEvent, (val) => { if (!val) {