diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 087aa625bd..4c4f8ab58d 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -722,6 +722,13 @@ async def delete_platform_session(self, session_id: str) -> None: """Delete a Platform session by its ID.""" ... + @abc.abstractmethod + async def migrate_user_webchat_data( + self, old_username: str, new_username: str + ) -> None: + """Migrate all webchat user data when username is changed.""" + ... + # ==== # ChatUI Project Management # ==== diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index fd6668c0c7..e745b9579e 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -1616,6 +1616,47 @@ async def delete_platform_session(self, session_id: str) -> None: ), ) + async def migrate_user_webchat_data( + self, old_username: str, new_username: str + ) -> None: + """Migrate all webchat user data when username is changed.""" + old_fragment = f"!{old_username}!" + new_fragment = f"!{new_username}!" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + await session.execute( + update(PlatformSession) + .where(col(PlatformSession.creator) == old_username) + .values(creator=new_username) + ) + await session.execute( + update(ChatUIProject) + .where(col(ChatUIProject.creator) == old_username) + .values(creator=new_username) + ) + await session.execute( + update(ConversationV2) + .where(col(ConversationV2.user_id).like("webchat%")) + .values( + user_id=func.replace( + ConversationV2.user_id, old_fragment, new_fragment + ) + ) + ) + await session.execute( + update(Preference) + .where( + col(Preference.scope) == "umo", + col(Preference.scope_id).like("webchat%"), + ) + .values( + scope_id=func.replace( + Preference.scope_id, old_fragment, new_fragment + ) + ) + ) + # ==== # ChatUI Project Management # ==== diff --git a/astrbot/dashboard/routes/auth.py b/astrbot/dashboard/routes/auth.py index eac5f65b0b..ef7d559492 100644 --- a/astrbot/dashboard/routes/auth.py +++ b/astrbot/dashboard/routes/auth.py @@ -6,13 +6,15 @@ from astrbot import logger from astrbot.core import DEMO_MODE +from astrbot.core.db import BaseDatabase from .route import Response, Route, RouteContext class AuthRoute(Route): - def __init__(self, context: RouteContext) -> None: + def __init__(self, context: RouteContext, db: BaseDatabase) -> None: super().__init__(context) + self.db = db self.routes = { "/auth/login": ("POST", self.login), "/auth/account/edit": ("POST", self.edit_account), @@ -72,9 +74,15 @@ async def edit_account(self): if confirm_pwd != new_pwd: return Response().error("两次输入的新密码不一致").__dict__ self.config["dashboard"]["password"] = new_pwd + + old_username = self.config["dashboard"]["username"] if new_username: self.config["dashboard"]["username"] = new_username + # Migrate webchat user data before saving config to keep them in sync. + if new_username and new_username != old_username: + await self.db.migrate_user_webchat_data(old_username, new_username) + self.config.save_config() return Response().ok(None, "修改成功").__dict__ diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index cbb7296bd0..2a3f90f479 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -112,7 +112,7 @@ def __init__( self.cr = ConfigRoute(self.context, core_lifecycle) self.lr = LogRoute(self.context, core_lifecycle.log_broker) self.sfr = StaticFileRoute(self.context) - self.ar = AuthRoute(self.context) + self.ar = AuthRoute(self.context, db) self.api_key_route = ApiKeyRoute(self.context, db) self.chat_route = ChatRoute(self.context, db, core_lifecycle) self.open_api_route = OpenApiRoute(