From c5778bb3e75b5fd8e907665a5a200ae98fa60c58 Mon Sep 17 00:00:00 2001 From: liumangmang Date: Fri, 29 May 2026 16:00:43 +0800 Subject: [PATCH] feat: persist browser sessions and update admin workflows --- backend/app/database.py | 4 + backend/app/models/upstream.py | 3 + backend/app/routers/upstreams.py | 2 + backend/app/routers/websites.py | 46 ++- backend/app/schemas/upstream.py | 3 + backend/app/schemas/website.py | 1 + .../app/services/browser_session_service.py | 109 ++++++ backend/app/services/scheduler.py | 31 ++ backend/app/services/webhook_service.py | 38 +++ backend/app/services/website_client.py | 6 + backend/app/utils/dingtalk.py | 19 ++ backend/test_browser_session_service.py | 293 ++++++++++++++++- frontend/src/api/index.ts | 3 + frontend/src/components/AppLayout.vue | 31 +- frontend/src/router/index.ts | 4 +- frontend/src/views/NotificationLogs.vue | 105 ++++-- frontend/src/views/Upstreams.vue | 309 ++++++------------ frontend/src/views/Webhooks.vue | 62 ++-- frontend/src/views/Websites.vue | 129 +++----- 19 files changed, 829 insertions(+), 369 deletions(-) diff --git a/backend/app/database.py b/backend/app/database.py index 0d10afd..ae18562 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -126,6 +126,10 @@ def _migrate_upstreams(): conn.execute(text("ALTER TABLE upstreams ADD COLUMN balance_response_path VARCHAR(256) NOT NULL DEFAULT ''")) if "balance_divisor" not in columns: conn.execute(text("ALTER TABLE upstreams ADD COLUMN balance_divisor FLOAT NOT NULL DEFAULT 1.0")) + if "balance_alert_threshold" not in columns: + conn.execute(text("ALTER TABLE upstreams ADD COLUMN balance_alert_threshold FLOAT")) + if "balance_alert_notified" not in columns: + conn.execute(text("ALTER TABLE upstreams ADD COLUMN balance_alert_notified BOOLEAN NOT NULL DEFAULT 0")) def _migrate_upstream_generated_keys(): diff --git a/backend/app/models/upstream.py b/backend/app/models/upstream.py index d41130a..2ea70fc 100644 --- a/backend/app/models/upstream.py +++ b/backend/app/models/upstream.py @@ -32,6 +32,9 @@ class Upstream(Base): balance_endpoint: Mapped[str] = mapped_column(String(256), default="") balance_response_path: Mapped[str] = mapped_column(String(256), default="") balance_divisor: Mapped[float] = mapped_column(Float, default=1.0) + # Balance alert + balance_alert_threshold: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + balance_alert_notified: Mapped[bool] = mapped_column(Boolean, default=False) created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc)) updated_at: Mapped[datetime] = mapped_column( DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc) diff --git a/backend/app/routers/upstreams.py b/backend/app/routers/upstreams.py index a866ebe..3ad3c6f 100644 --- a/backend/app/routers/upstreams.py +++ b/backend/app/routers/upstreams.py @@ -112,6 +112,7 @@ def _to_response(u: Upstream) -> UpstreamResponse: balance_endpoint=u.balance_endpoint or "", balance_response_path=u.balance_response_path or "", balance_divisor=u.balance_divisor or 1.0, + balance_alert_threshold=u.balance_alert_threshold, created_at=u.created_at, updated_at=u.updated_at, ) @@ -352,6 +353,7 @@ def create_upstream( balance_endpoint=body.balance_endpoint, balance_response_path=body.balance_response_path, balance_divisor=body.balance_divisor, + balance_alert_threshold=body.balance_alert_threshold, ) db.add(u) db.commit() diff --git a/backend/app/routers/websites.py b/backend/app/routers/websites.py index 1b352b7..fc91dcf 100644 --- a/backend/app/routers/websites.py +++ b/backend/app/routers/websites.py @@ -169,6 +169,28 @@ def _numeric_group_id(value: str | None) -> int | None: return None +def _build_rate_priority_map(db: Session, upstream_ids: set[int]) -> dict[str, int]: + """根据上游分组倍率构建 group_id → priority 映射。 + + 遍历所有涉及的上游的最新快照,收集分组的倍率,按倍率升序排列后赋值 priority。 + 倍率最低的 priority=1,次低的 priority=2,以此类推。相同倍率的分组共享同一 priority。 + """ + group_rates: dict[str, float] = {} + for uid in upstream_ids: + groups = _latest_upstream_groups(db, uid) + for g in groups: + gid = _source_group_id(g) + rate = _source_group_rate(g) + if gid: + # 同一 group_id 在同个 upstream 内是唯一的;跨 upstream 的相同 group_id + # 如果倍率不同则以最后遇到的为准(实际很少冲突) + group_rates[gid] = rate + # 按倍率排序分配 priority + unique_rates = sorted(set(group_rates.values())) + rate_to_priority = {rate: idx + 1 for idx, rate in enumerate(unique_rates)} + return {gid: rate_to_priority[rate] for gid, rate in group_rates.items()} + + @router.get("/api/websites", response_model=List[WebsiteResponse]) def list_websites(db: Session = Depends(get_db), _=Depends(get_current_user)): return [_website_response(row) for row in db.query(Website).order_by(Website.id).all()] @@ -496,6 +518,16 @@ def import_upstream_keys_as_accounts( if _u: upstream_base_url = _u.base_url + # 按倍率自动分配优先级 + rate_priority_map: dict[str, int] = {} + if body.auto_priority_by_rate: + upstream_ids = {row.upstream_id for row in rows} + try: + rate_priority_map = _build_rate_priority_map(db, upstream_ids) + except HTTPException: + # 没有快照时忽略,后续 fallback 到 body.priority + pass + with _client(website) as c: for row in rows: # 先确定平台(失败项也需要记录) @@ -512,6 +544,16 @@ def import_upstream_keys_as_accounts( old_account_id = row.imported_account_id exists = c.account_exists(row.imported_account_id) if exists is True: + # 自动更新已有账号的 priority(分步导入时全局倍率排序可能已变) + new_priority = rate_priority_map.get(row.group_id) if body.auto_priority_by_rate else None + priority_msg = "已导入过,已跳过" + if new_priority is not None: + try: + c.update_account(old_account_id, {"priority": new_priority}) + priority_msg = f"已导入过,优先级已更新为 {new_priority}" + except Exception as exc: + logger.warning("update priority failed account=%s: %s", old_account_id, exc) + priority_msg = f"已导入过,优先级更新失败: {exc}" items.append(ImportAccountItem( upstream_key_id=row.id, source_group_id=row.group_id, @@ -522,7 +564,7 @@ def import_upstream_keys_as_accounts( platform=platform, upstream_base_url=upstream_base_url, status="exists", - message="已导入过,已跳过", + message=priority_msg, )) continue elif exists is False: @@ -574,7 +616,7 @@ def import_upstream_keys_as_accounts( "group_ids": group_ids, "rate_multiplier": 1, "concurrency": body.concurrency, - "priority": body.priority, + "priority": rate_priority_map.get(row.group_id, body.priority) if body.auto_priority_by_rate else body.priority, "notes": f"Imported by SmartUp from upstream key #{row.id}", } try: diff --git a/backend/app/schemas/upstream.py b/backend/app/schemas/upstream.py index 059e75b..b25978c 100644 --- a/backend/app/schemas/upstream.py +++ b/backend/app/schemas/upstream.py @@ -32,6 +32,7 @@ class UpstreamCreate(BaseModel): balance_endpoint: str = "" balance_response_path: str = "" balance_divisor: float = 1.0 + balance_alert_threshold: Optional[float] = None class UpstreamUpdate(BaseModel): @@ -48,6 +49,7 @@ class UpstreamUpdate(BaseModel): balance_endpoint: Optional[str] = None balance_response_path: Optional[str] = None balance_divisor: Optional[float] = None + balance_alert_threshold: Optional[float] = None class UpstreamResponse(BaseModel): @@ -70,6 +72,7 @@ class UpstreamResponse(BaseModel): balance_endpoint: str = "" balance_response_path: str = "" balance_divisor: float = 1.0 + balance_alert_threshold: Optional[float] = None created_at: datetime updated_at: datetime diff --git a/backend/app/schemas/website.py b/backend/app/schemas/website.py index efbf87e..cb18387 100644 --- a/backend/app/schemas/website.py +++ b/backend/app/schemas/website.py @@ -157,6 +157,7 @@ class ImportAccountsRequest(BaseModel): platform_mode: str = "auto" # "auto" | "manual" concurrency: int = Field(default=10, ge=1) priority: int = Field(default=1, ge=0) + auto_priority_by_rate: bool = True class ImportAccountItem(BaseModel): diff --git a/backend/app/services/browser_session_service.py b/backend/app/services/browser_session_service.py index 3f3b121..8c3cbd6 100644 --- a/backend/app/services/browser_session_service.py +++ b/backend/app/services/browser_session_service.py @@ -34,6 +34,7 @@ class BrowserSession: lock: asyncio.Lock cdp_session: Any = None captured_headers: list[dict] = None # auth headers from CDP + last_saved_state_at: float = 0.0 class BrowserSessionService: @@ -92,12 +93,15 @@ class BrowserSessionService: self._profiles.pop(profile_key, None) # Idle cleanup: close stale sessions before spawning new ones await self._evict_idle_sessions() + context = await self._playwright.chromium.launch_persistent_context( str(self._profile_dir(profile_key)), headless=get_settings().browser_headless, viewport={"width": width, "height": height}, + color_scheme="dark", args=["--no-sandbox", "--disable-dev-shm-usage"], ) + await self._restore_session_state(context, profile_key) # Grant clipboard access for the page origin try: parsed = urlparse(url) @@ -137,6 +141,11 @@ class BrowserSessionService: self._touch(session_id) async with session.lock: self._ensure_open(session) + if session.profile_key and not session.profile_key.startswith("auth-capture-"): + now = time.monotonic() + if now - session.last_saved_state_at > 10.0: + await self._save_session_state(session) + session.last_saved_state_at = now return await session.page.screenshot(type="jpeg", quality=65, full_page=False) async def event( @@ -188,6 +197,12 @@ class BrowserSessionService: await page.set_viewport_size({"width": width, "height": height}) else: raise ValueError("Unsupported browser event") + if session.profile_key and not session.profile_key.startswith("auth-capture-"): + now = time.monotonic() + if now - session.last_saved_state_at > 5.0: + await self._save_session_state(session) + session.last_saved_state_at = now + if not include_state: return None return await self._session_state(session) @@ -242,6 +257,15 @@ class BrowserSessionService: session = self._discard_session(session_id) if not session: return + + # 在完全关闭 context 前,强制将最新的状态落盘保存 + if session.profile_key and not session.profile_key.startswith("auth-capture-"): + try: + if not session.page.is_closed(): + await self._save_session_state(session) + except Exception as exc: + logger.debug("failed to save state during close: %s", exc) + # Detach CDP session if active if session.cdp_session: try: @@ -261,6 +285,7 @@ class BrowserSessionService: except Exception: pass + async def shutdown(self) -> None: # Cancel the background eviction loop if self._evict_task is not None and not self._evict_task.done(): @@ -524,6 +549,9 @@ class BrowserSessionService: profile.mkdir(parents=True, exist_ok=True) return profile + def _cookies_path(self, profile_key: str) -> Path: + return self._profile_dir(profile_key) / "session-cookies.json" + def _profile_key(self, custom_page_id: int, url: str) -> str: parsed = urlparse(url) origin = f"{parsed.scheme}-{parsed.netloc}".lower() @@ -553,6 +581,7 @@ class BrowserSessionService: str(self._profile_dir(profile_key)), headless=get_settings().browser_headless, viewport={"width": width, "height": height}, + color_scheme="dark", args=["--no-sandbox", "--disable-dev-shm-usage"], ) # Grant clipboard access for the page origin @@ -613,5 +642,85 @@ class BrowserSessionService: except Exception as exc: logger.debug("CDP capture not available: %s", exc) + async def _save_session_state(self, session: BrowserSession) -> None: + if not session.profile_key or session.profile_key.startswith("auth-capture-"): + return + try: + state = await session.context.storage_state() + cookies_path = self._cookies_path(session.profile_key) + import json + import tempfile + import os + # Ensure parent directories exist + cookies_path.parent.mkdir(parents=True, exist_ok=True) + temp_fd, temp_path = tempfile.mkstemp(dir=str(cookies_path.parent)) + try: + with os.fdopen(temp_fd, 'w', encoding='utf-8') as f: + json.dump(state, f, ensure_ascii=False, indent=2) + os.replace(temp_path, cookies_path) + except Exception: + try: + os.unlink(temp_path) + except Exception: + pass + raise + except Exception as exc: + logger.debug("failed to save session state for %s: %s", session.profile_key, exc) + + async def _restore_session_state(self, context: Any, profile_key: str) -> None: + if profile_key.startswith("auth-capture-"): + return + cookies_path = self._cookies_path(profile_key) + if not cookies_path.exists() or cookies_path.stat().st_size == 0: + return + try: + import json + import time + with open(cookies_path, 'r', encoding='utf-8') as f: + state = json.load(f) + cookies = state.get("cookies", []) + if cookies: + now = time.time() + valid_cookies = [] + for c in cookies: + expires = c.get("expires") + if expires is not None and expires > 0 and expires <= now: + continue + if expires is not None and expires <= 0: + c.pop("expires", None) + valid_cookies.append(c) + if valid_cookies: + await context.add_cookies(valid_cookies) + logger.info("restored %d cookies for profile %s", len(valid_cookies), profile_key) + + # 还原 LocalStorage + origins = state.get("origins", []) + if origins: + origins_json = json.dumps(origins) + init_script = f""" + (() => {{ + try {{ + const origins = {origins_json}; + const currentOrigin = window.location.origin; + const target = origins.find(o => o.origin === currentOrigin); + if (target && target.localStorage) {{ + for (const item of target.localStorage) {{ + try {{ + window.localStorage.setItem(item.name, item.value); + }} catch (e) {{ + console.error('Failed to restore localStorage key', item.name, e); + }} + }} + }} + }} catch (err) {{ + console.error('LocalStorage restore initialization script failed', err); + }} + }})(); + """ + await context.add_init_script(init_script) + logger.info("registered LocalStorage init script for profile %s (origins: %d)", profile_key, len(origins)) + except Exception as exc: + logger.warning("failed to restore cookies/state for profile %s: %s", profile_key, exc) + browser_sessions = BrowserSessionService() diff --git a/backend/app/services/scheduler.py b/backend/app/services/scheduler.py index 8fa97c6..70b1873 100644 --- a/backend/app/services/scheduler.py +++ b/backend/app/services/scheduler.py @@ -46,6 +46,7 @@ def _check_upstream(upstream_id: int) -> None: auth_config = json.loads(upstream.auth_config_json or "{}") was_unhealthy = upstream.last_status == "unhealthy" + balance_alert_triggered = False snapshot = None changes = None @@ -76,6 +77,14 @@ def _check_upstream(upstream_id: int) -> None: if balance is not None: upstream.balance = balance upstream.balance_updated_at = datetime.now(timezone.utc) + # ── 余额告警阈值检查 ── + threshold = upstream.balance_alert_threshold + if threshold is not None and threshold > 0: + if balance < threshold and not upstream.balance_alert_notified: + upstream.balance_alert_notified = True + balance_alert_triggered = True + elif balance >= threshold and upstream.balance_alert_notified: + upstream.balance_alert_notified = False except Exception as exc: # failure path upstream.consecutive_failures = (upstream.consecutive_failures or 0) + 1 @@ -152,6 +161,12 @@ def _check_upstream(upstream_id: int) -> None: _notify_rate_changed(upstream_id, upstream.name, upstream.base_url, changes) _sync_website_bindings(upstream_id, changes) + if balance_alert_triggered: + _notify_balance_low( + upstream_id, upstream.name, upstream.base_url, + upstream.balance, upstream.balance_alert_threshold, + ) + def _notify_status( upstream_id: int, @@ -184,6 +199,22 @@ def _notify_rate_changed( db.close() +def _notify_balance_low( + upstream_id: int, + upstream_name: str, + base_url: str, + balance: float, + threshold: float, +) -> None: + db = SessionLocal() + try: + webhook_service.send_balance_low(db, upstream_id, upstream_name, base_url, balance, threshold) + except Exception: + logger.exception("balance low webhook failed for upstream %s", upstream_name) + finally: + db.close() + + def _sync_upstream_keys(upstream_id: int, snapshot: dict[str, Any], captured_at: datetime) -> None: """上游检测成功后同步 SmartUp Key 状态(远端删除/分组删除)。""" db = SessionLocal() diff --git a/backend/app/services/webhook_service.py b/backend/app/services/webhook_service.py index 87c52ed..42ecacd 100644 --- a/backend/app/services/webhook_service.py +++ b/backend/app/services/webhook_service.py @@ -15,6 +15,7 @@ from app.utils.dingtalk import ( format_dingtalk_rate_changed, format_dingtalk_website_rate_changed, format_dingtalk_status, + format_dingtalk_balance_low, ) @@ -185,6 +186,43 @@ def send_status_event( _log(db, wh, event, generic_payload, "failed", str(exc)) +def send_balance_low( + db: Session, + upstream_id: int, + upstream_name: str, + base_url: str, + balance: float, + threshold: float, +) -> None: + webhooks = ( + db.query(WebhookConfig) + .filter(WebhookConfig.enabled == True) + .all() + ) + event = "upstream_balance_low" + changed_at = _now_iso() + generic_payload = { + "event": event, + "upstream": {"id": upstream_id, "name": upstream_name, "base_url": base_url}, + "balance": balance, + "threshold": threshold, + "changed_at": changed_at, + } + for wh in webhooks: + events = json.loads(wh.events_json or "[]") + if event not in events: + continue + try: + if wh.type == "dingtalk": + msg = format_dingtalk_balance_low(upstream_name, balance, threshold, changed_at) + resp_text = _send_dingtalk(wh.url, wh.secret, msg) + else: + resp_text = _send_generic(wh.url, generic_payload) + _log(db, wh, event, generic_payload, "success", resp_text) + except Exception as exc: + _log(db, wh, event, generic_payload, "failed", str(exc)) + + def send_test_notification(db: Session, webhook: WebhookConfig) -> tuple[bool, str]: payload = { "event": "test", diff --git a/backend/app/services/website_client.py b/backend/app/services/website_client.py index a75f5b2..9c9cc28 100644 --- a/backend/app/services/website_client.py +++ b/backend/app/services/website_client.py @@ -223,6 +223,12 @@ class Sub2ApiWebsiteClient: data = _unwrap_data(resp) return data if isinstance(data, dict) else {"value": data} + def update_account(self, account_id: str, body: dict[str, Any], endpoint: str = "/accounts") -> dict[str, Any]: + """更新远端账号(仅传入需要变更的字段)。""" + resp = self._request("PUT", f"{endpoint}/{account_id}", body) + data = _unwrap_data(resp) + return data if isinstance(data, dict) else {"value": data} + @staticmethod def _unwrap_list(value: dict) -> list | None: """递归展开嵌套的列表包装:data.items、data.data、items、accounts 等。""" diff --git a/backend/app/utils/dingtalk.py b/backend/app/utils/dingtalk.py index 8a0b867..7890f73 100644 --- a/backend/app/utils/dingtalk.py +++ b/backend/app/utils/dingtalk.py @@ -64,6 +64,25 @@ def format_dingtalk_website_rate_changed( } +def format_dingtalk_balance_low( + upstream_name: str, balance: float, threshold: float, changed_at: str +) -> dict[str, Any]: + lines = [ + f"### ⚠️ {upstream_name} 余额不足", + "", + f"- **当前余额**:{balance:.2f}", + f"- **告警阈值**:{threshold:.2f}", + f"- **时间**:{changed_at}", + ] + return { + "msgtype": "markdown", + "markdown": { + "title": f"{upstream_name} 余额不足", + "text": "\n".join(lines), + }, + } + + def format_dingtalk_status(upstream_name: str, event: str, changed_at: str, error: str = "") -> dict[str, Any]: emoji = "🔴" if event == "upstream_unhealthy" else "🟢" label = "服务异常" if event == "upstream_unhealthy" else "服务恢复" diff --git a/backend/test_browser_session_service.py b/backend/test_browser_session_service.py index c95eef9..9cb1616 100644 --- a/backend/test_browser_session_service.py +++ b/backend/test_browser_session_service.py @@ -1,5 +1,6 @@ import asyncio - +from pathlib import Path +from app.config import get_settings from app.routers.auth_capture import _sanitize_candidate from app.services.browser_session_service import BrowserSessionService @@ -127,3 +128,293 @@ def test_sanitize_candidate_strips_secret_fields_but_keeps_metadata(): "cookie_name": "session", "domain": "example.test", } + + +def test_cookies_path_mapping(): + import tempfile + import shutil + service = BrowserSessionService() + temp_dir = tempfile.mkdtemp() + original_dir = get_settings().browser_profiles_dir + get_settings().browser_profiles_dir = temp_dir + try: + profile_key = "test-profile-123" + expected_path = Path(service._cookies_path(profile_key)) + assert expected_path.name == "session-cookies.json" + assert expected_path.parent == Path(service._profile_dir(profile_key)) + finally: + get_settings().browser_profiles_dir = original_dir + shutil.rmtree(temp_dir, ignore_errors=True) + + +def test_screenshot_throttled_save(): + import tempfile + import shutil + service = BrowserSessionService() + + # 准备临时目录并 mock settings.browser_profiles_dir + temp_dir = tempfile.mkdtemp() + original_dir = get_settings().browser_profiles_dir + get_settings().browser_profiles_dir = temp_dir + + try: + from unittest.mock import AsyncMock, MagicMock + from app.services.browser_session_service import BrowserSession + + # Mock Context & Page + fake_context = AsyncMock() + fake_page = MagicMock() + fake_page.is_closed = MagicMock(return_value=False) + fake_page.screenshot = AsyncMock(return_value=b"screenshot-bytes") + + session = BrowserSession( + id="session123", + custom_page_id=1, + profile_key="page-1-test", + context=fake_context, + page=fake_page, + lock=asyncio.Lock(), + last_saved_state_at=0.0 + ) + service._sessions[session.id] = session + + # 第一次调用 screenshot: 触发存储 + res1 = run(service.screenshot(session.id)) + assert res1 == b"screenshot-bytes" + assert fake_context.storage_state.call_count == 1 + + # 记录第一次保存后的时间戳 + first_save_time = session.last_saved_state_at + assert first_save_time > 0 + + # 第二次立即调用 screenshot: 应该因为限流 10s 被跳过,不增加 call_count + res2 = run(service.screenshot(session.id)) + assert res2 == b"screenshot-bytes" + assert fake_context.storage_state.call_count == 1 + + # 模拟 11 秒后(防抖时间已过)再度截图 + session.last_saved_state_at = first_save_time - 11.0 + res3 = run(service.screenshot(session.id)) + assert res3 == b"screenshot-bytes" + assert fake_context.storage_state.call_count == 2 + + # 测试临时 auth-capture 会话不触发任何 state 存储 + ephemeral_session = BrowserSession( + id="session-eph", + custom_page_id=0, + profile_key="auth-capture-xyz", + context=fake_context, + page=fake_page, + lock=asyncio.Lock(), + last_saved_state_at=0.0 + ) + service._sessions[ephemeral_session.id] = ephemeral_session + run(service.screenshot(ephemeral_session.id)) + # 它的 call_count 依然是 2,没有增加 + assert fake_context.storage_state.call_count == 2 + + finally: + get_settings().browser_profiles_dir = original_dir + shutil.rmtree(temp_dir, ignore_errors=True) + + +def test_close_saves_state_and_cleans_up(): + import tempfile + import shutil + service = BrowserSessionService() + temp_dir = tempfile.mkdtemp() + original_dir = get_settings().browser_profiles_dir + get_settings().browser_profiles_dir = temp_dir + + try: + from unittest.mock import AsyncMock, MagicMock + from app.services.browser_session_service import BrowserSession + + fake_context = AsyncMock() + fake_page = MagicMock() + fake_page.is_closed = MagicMock(return_value=False) + + import time + session = BrowserSession( + id="session_close", + custom_page_id=2, + profile_key="page-2-test", + context=fake_context, + page=fake_page, + lock=asyncio.Lock(), + last_saved_state_at=time.monotonic() # 此时在限流内 + ) + service._sessions[session.id] = session + + # 即使在限流时间内,close 也必须强制保存 + run(service.close(session.id)) + assert fake_context.storage_state.call_count == 1 + assert fake_context.close.call_count == 1 + + # 测试 ephemeral 会话在 close 时不应该保存 state,并且其 profile_dir 应当被清理,导致 cookies json 不复存在 + eph_context = AsyncMock() + eph_page = MagicMock() + eph_page.is_closed = MagicMock(return_value=False) + eph_session = BrowserSession( + id="session_eph_close", + custom_page_id=0, + profile_key="auth-capture-abc", + context=eph_context, + page=eph_page, + lock=asyncio.Lock(), + last_saved_state_at=0.0 + ) + service._sessions[eph_session.id] = eph_session + + # 先手动创建一个假 session-cookies.json + eph_cookies_path = service._cookies_path(eph_session.profile_key) + eph_cookies_path.write_text("{}") + assert eph_cookies_path.exists() + + run(service.close(eph_session.id)) + # ephemeral close 时不保存,所以 call_count 依然是 0 + assert eph_context.storage_state.call_count == 0 + assert eph_context.close.call_count == 1 + # 但对应的 profile 目录已被删除,文件自然不复存在 + assert not eph_cookies_path.exists() + + finally: + get_settings().browser_profiles_dir = original_dir + shutil.rmtree(temp_dir, ignore_errors=True) + + +def test_restore_session_state_decoding_and_inject(): + import json + import tempfile + import shutil + import time + service = BrowserSessionService() + temp_dir = tempfile.mkdtemp() + original_dir = get_settings().browser_profiles_dir + get_settings().browser_profiles_dir = temp_dir + + try: + from unittest.mock import AsyncMock + fake_context = AsyncMock() + profile_key = "test-restore-profile" + + # 准备假 cookies.json,包含 cookies 和 origins/localStorage + cookies_path = service._cookies_path(profile_key) + + now = time.time() + fake_state = { + "cookies": [ + { + "name": "valid_persistent", + "value": "123", + "domain": "example.test", + "path": "/", + "expires": now + 3600 # 未过期 + }, + { + "name": "expired_cookie", + "value": "456", + "domain": "example.test", + "path": "/", + "expires": now - 3600 # 已过期 + }, + { + "name": "session_cookie", + "value": "789", + "domain": "example.test", + "path": "/", + "expires": -1 # session cookie,应被保留并剔除 expires 字段 + } + ], + "origins": [ + { + "origin": "https://example.test", + "localStorage": [ + { + "name": "theme", + "value": "dark" + } + ] + } + ] + } + with open(cookies_path, "w", encoding='utf-8') as f: + json.dump(fake_state, f) + + # 运行还原方法 + run(service._restore_session_state(fake_context, profile_key)) + + # 检查是否成功调用 add_cookies + assert fake_context.add_cookies.call_count == 1 + + # 检查过滤后的 cookies 内容 + injected_cookies = fake_context.add_cookies.call_args[0][0] + assert len(injected_cookies) == 2 + + names = [c["name"] for c in injected_cookies] + assert "valid_persistent" in names + assert "session_cookie" in names + assert "expired_cookie" not in names + + # 校验 session_cookie 是否成功移除了 expires + session_c = next(c for c in injected_cookies if c["name"] == "session_cookie") + assert "expires" not in session_c + + # 检查是否成功调用了 add_init_script (用于还原 LocalStorage) + assert fake_context.add_init_script.call_count == 1 + init_script = fake_context.add_init_script.call_args[0][0] + assert "window.localStorage.setItem" in init_script + assert "theme" in init_script + assert "dark" in init_script + + finally: + get_settings().browser_profiles_dir = original_dir + shutil.rmtree(temp_dir, ignore_errors=True) + + +def test_websocket_event_saves_state(): + import tempfile + import shutil + service = BrowserSessionService() + temp_dir = tempfile.mkdtemp() + original_dir = get_settings().browser_profiles_dir + get_settings().browser_profiles_dir = temp_dir + + try: + from unittest.mock import AsyncMock, MagicMock + from app.services.browser_session_service import BrowserSession + + fake_context = AsyncMock() + fake_page = MagicMock() + fake_page.is_closed = MagicMock(return_value=False) + fake_page.mouse = MagicMock() + fake_page.mouse.click = AsyncMock() + + session = BrowserSession( + id="session_ws", + custom_page_id=3, + profile_key="page-3-ws-test", + context=fake_context, + page=fake_page, + lock=asyncio.Lock(), + last_saved_state_at=0.0 + ) + service._sessions[session.id] = session + + # 即使 include_state=False,也应当在 5 秒节流到期后保存状态 + run(service.event( + session_id=session.id, + event_type="click", + payload={"x": 10.0, "y": 20.0}, + include_state=False + )) + + # storage_state 应该被调用,说明保存成功触发了 + assert fake_context.storage_state.call_count == 1 + assert session.last_saved_state_at > 0 + + finally: + get_settings().browser_profiles_dir = original_dir + shutil.rmtree(temp_dir, ignore_errors=True) + + diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts index 2a850ab..9dc21df 100644 --- a/frontend/src/api/index.ts +++ b/frontend/src/api/index.ts @@ -88,6 +88,7 @@ export interface UpstreamData { balance_endpoint: string balance_response_path: string balance_divisor: number + balance_alert_threshold: number | null created_at: string updated_at: string } @@ -106,6 +107,7 @@ export interface UpstreamForm { balance_endpoint: string balance_response_path: string balance_divisor: number + balance_alert_threshold: number | null } export interface GeneratedUpstreamKey { @@ -284,6 +286,7 @@ export const websitesApi = { platform_mode?: string concurrency?: number priority?: number + auto_priority_by_rate?: boolean }) => api.post<{ success: boolean; message: string; items: ImportAccountItem[] }>(`/api/websites/${id}/accounts/import-upstream-keys`, data), listBindings: () => api.get('/api/group-bindings'), createBinding: (data: GroupBindingForm) => api.post('/api/group-bindings', data), diff --git a/frontend/src/components/AppLayout.vue b/frontend/src/components/AppLayout.vue index 65fa327..f39aac6 100644 --- a/frontend/src/components/AppLayout.vue +++ b/frontend/src/components/AppLayout.vue @@ -19,13 +19,6 @@