feat: persist browser sessions and update admin workflows
This commit is contained in:
@@ -126,6 +126,10 @@ def _migrate_upstreams():
|
||||
conn.execute(text("ALTER TABLE upstreams ADD COLUMN balance_response_path VARCHAR(256) NOT NULL DEFAULT ''"))
|
||||
if "balance_divisor" not in columns:
|
||||
conn.execute(text("ALTER TABLE upstreams ADD COLUMN balance_divisor FLOAT NOT NULL DEFAULT 1.0"))
|
||||
if "balance_alert_threshold" not in columns:
|
||||
conn.execute(text("ALTER TABLE upstreams ADD COLUMN balance_alert_threshold FLOAT"))
|
||||
if "balance_alert_notified" not in columns:
|
||||
conn.execute(text("ALTER TABLE upstreams ADD COLUMN balance_alert_notified BOOLEAN NOT NULL DEFAULT 0"))
|
||||
|
||||
|
||||
def _migrate_upstream_generated_keys():
|
||||
|
||||
@@ -32,6 +32,9 @@ class Upstream(Base):
|
||||
balance_endpoint: Mapped[str] = mapped_column(String(256), default="")
|
||||
balance_response_path: Mapped[str] = mapped_column(String(256), default="")
|
||||
balance_divisor: Mapped[float] = mapped_column(Float, default=1.0)
|
||||
# Balance alert
|
||||
balance_alert_threshold: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
balance_alert_notified: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)
|
||||
|
||||
@@ -112,6 +112,7 @@ def _to_response(u: Upstream) -> UpstreamResponse:
|
||||
balance_endpoint=u.balance_endpoint or "",
|
||||
balance_response_path=u.balance_response_path or "",
|
||||
balance_divisor=u.balance_divisor or 1.0,
|
||||
balance_alert_threshold=u.balance_alert_threshold,
|
||||
created_at=u.created_at,
|
||||
updated_at=u.updated_at,
|
||||
)
|
||||
@@ -352,6 +353,7 @@ def create_upstream(
|
||||
balance_endpoint=body.balance_endpoint,
|
||||
balance_response_path=body.balance_response_path,
|
||||
balance_divisor=body.balance_divisor,
|
||||
balance_alert_threshold=body.balance_alert_threshold,
|
||||
)
|
||||
db.add(u)
|
||||
db.commit()
|
||||
|
||||
@@ -169,6 +169,28 @@ def _numeric_group_id(value: str | None) -> int | None:
|
||||
return None
|
||||
|
||||
|
||||
def _build_rate_priority_map(db: Session, upstream_ids: set[int]) -> dict[str, int]:
|
||||
"""根据上游分组倍率构建 group_id → priority 映射。
|
||||
|
||||
遍历所有涉及的上游的最新快照,收集分组的倍率,按倍率升序排列后赋值 priority。
|
||||
倍率最低的 priority=1,次低的 priority=2,以此类推。相同倍率的分组共享同一 priority。
|
||||
"""
|
||||
group_rates: dict[str, float] = {}
|
||||
for uid in upstream_ids:
|
||||
groups = _latest_upstream_groups(db, uid)
|
||||
for g in groups:
|
||||
gid = _source_group_id(g)
|
||||
rate = _source_group_rate(g)
|
||||
if gid:
|
||||
# 同一 group_id 在同个 upstream 内是唯一的;跨 upstream 的相同 group_id
|
||||
# 如果倍率不同则以最后遇到的为准(实际很少冲突)
|
||||
group_rates[gid] = rate
|
||||
# 按倍率排序分配 priority
|
||||
unique_rates = sorted(set(group_rates.values()))
|
||||
rate_to_priority = {rate: idx + 1 for idx, rate in enumerate(unique_rates)}
|
||||
return {gid: rate_to_priority[rate] for gid, rate in group_rates.items()}
|
||||
|
||||
|
||||
@router.get("/api/websites", response_model=List[WebsiteResponse])
|
||||
def list_websites(db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||||
return [_website_response(row) for row in db.query(Website).order_by(Website.id).all()]
|
||||
@@ -496,6 +518,16 @@ def import_upstream_keys_as_accounts(
|
||||
if _u:
|
||||
upstream_base_url = _u.base_url
|
||||
|
||||
# 按倍率自动分配优先级
|
||||
rate_priority_map: dict[str, int] = {}
|
||||
if body.auto_priority_by_rate:
|
||||
upstream_ids = {row.upstream_id for row in rows}
|
||||
try:
|
||||
rate_priority_map = _build_rate_priority_map(db, upstream_ids)
|
||||
except HTTPException:
|
||||
# 没有快照时忽略,后续 fallback 到 body.priority
|
||||
pass
|
||||
|
||||
with _client(website) as c:
|
||||
for row in rows:
|
||||
# 先确定平台(失败项也需要记录)
|
||||
@@ -512,6 +544,16 @@ def import_upstream_keys_as_accounts(
|
||||
old_account_id = row.imported_account_id
|
||||
exists = c.account_exists(row.imported_account_id)
|
||||
if exists is True:
|
||||
# 自动更新已有账号的 priority(分步导入时全局倍率排序可能已变)
|
||||
new_priority = rate_priority_map.get(row.group_id) if body.auto_priority_by_rate else None
|
||||
priority_msg = "已导入过,已跳过"
|
||||
if new_priority is not None:
|
||||
try:
|
||||
c.update_account(old_account_id, {"priority": new_priority})
|
||||
priority_msg = f"已导入过,优先级已更新为 {new_priority}"
|
||||
except Exception as exc:
|
||||
logger.warning("update priority failed account=%s: %s", old_account_id, exc)
|
||||
priority_msg = f"已导入过,优先级更新失败: {exc}"
|
||||
items.append(ImportAccountItem(
|
||||
upstream_key_id=row.id,
|
||||
source_group_id=row.group_id,
|
||||
@@ -522,7 +564,7 @@ def import_upstream_keys_as_accounts(
|
||||
platform=platform,
|
||||
upstream_base_url=upstream_base_url,
|
||||
status="exists",
|
||||
message="已导入过,已跳过",
|
||||
message=priority_msg,
|
||||
))
|
||||
continue
|
||||
elif exists is False:
|
||||
@@ -574,7 +616,7 @@ def import_upstream_keys_as_accounts(
|
||||
"group_ids": group_ids,
|
||||
"rate_multiplier": 1,
|
||||
"concurrency": body.concurrency,
|
||||
"priority": body.priority,
|
||||
"priority": rate_priority_map.get(row.group_id, body.priority) if body.auto_priority_by_rate else body.priority,
|
||||
"notes": f"Imported by SmartUp from upstream key #{row.id}",
|
||||
}
|
||||
try:
|
||||
|
||||
@@ -32,6 +32,7 @@ class UpstreamCreate(BaseModel):
|
||||
balance_endpoint: str = ""
|
||||
balance_response_path: str = ""
|
||||
balance_divisor: float = 1.0
|
||||
balance_alert_threshold: Optional[float] = None
|
||||
|
||||
|
||||
class UpstreamUpdate(BaseModel):
|
||||
@@ -48,6 +49,7 @@ class UpstreamUpdate(BaseModel):
|
||||
balance_endpoint: Optional[str] = None
|
||||
balance_response_path: Optional[str] = None
|
||||
balance_divisor: Optional[float] = None
|
||||
balance_alert_threshold: Optional[float] = None
|
||||
|
||||
|
||||
class UpstreamResponse(BaseModel):
|
||||
@@ -70,6 +72,7 @@ class UpstreamResponse(BaseModel):
|
||||
balance_endpoint: str = ""
|
||||
balance_response_path: str = ""
|
||||
balance_divisor: float = 1.0
|
||||
balance_alert_threshold: Optional[float] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
@@ -157,6 +157,7 @@ class ImportAccountsRequest(BaseModel):
|
||||
platform_mode: str = "auto" # "auto" | "manual"
|
||||
concurrency: int = Field(default=10, ge=1)
|
||||
priority: int = Field(default=1, ge=0)
|
||||
auto_priority_by_rate: bool = True
|
||||
|
||||
|
||||
class ImportAccountItem(BaseModel):
|
||||
|
||||
@@ -34,6 +34,7 @@ class BrowserSession:
|
||||
lock: asyncio.Lock
|
||||
cdp_session: Any = None
|
||||
captured_headers: list[dict] = None # auth headers from CDP
|
||||
last_saved_state_at: float = 0.0
|
||||
|
||||
|
||||
class BrowserSessionService:
|
||||
@@ -92,12 +93,15 @@ class BrowserSessionService:
|
||||
self._profiles.pop(profile_key, None)
|
||||
# Idle cleanup: close stale sessions before spawning new ones
|
||||
await self._evict_idle_sessions()
|
||||
|
||||
context = await self._playwright.chromium.launch_persistent_context(
|
||||
str(self._profile_dir(profile_key)),
|
||||
headless=get_settings().browser_headless,
|
||||
viewport={"width": width, "height": height},
|
||||
color_scheme="dark",
|
||||
args=["--no-sandbox", "--disable-dev-shm-usage"],
|
||||
)
|
||||
await self._restore_session_state(context, profile_key)
|
||||
# Grant clipboard access for the page origin
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
@@ -137,6 +141,11 @@ class BrowserSessionService:
|
||||
self._touch(session_id)
|
||||
async with session.lock:
|
||||
self._ensure_open(session)
|
||||
if session.profile_key and not session.profile_key.startswith("auth-capture-"):
|
||||
now = time.monotonic()
|
||||
if now - session.last_saved_state_at > 10.0:
|
||||
await self._save_session_state(session)
|
||||
session.last_saved_state_at = now
|
||||
return await session.page.screenshot(type="jpeg", quality=65, full_page=False)
|
||||
|
||||
async def event(
|
||||
@@ -188,6 +197,12 @@ class BrowserSessionService:
|
||||
await page.set_viewport_size({"width": width, "height": height})
|
||||
else:
|
||||
raise ValueError("Unsupported browser event")
|
||||
if session.profile_key and not session.profile_key.startswith("auth-capture-"):
|
||||
now = time.monotonic()
|
||||
if now - session.last_saved_state_at > 5.0:
|
||||
await self._save_session_state(session)
|
||||
session.last_saved_state_at = now
|
||||
|
||||
if not include_state:
|
||||
return None
|
||||
return await self._session_state(session)
|
||||
@@ -242,6 +257,15 @@ class BrowserSessionService:
|
||||
session = self._discard_session(session_id)
|
||||
if not session:
|
||||
return
|
||||
|
||||
# 在完全关闭 context 前,强制将最新的状态落盘保存
|
||||
if session.profile_key and not session.profile_key.startswith("auth-capture-"):
|
||||
try:
|
||||
if not session.page.is_closed():
|
||||
await self._save_session_state(session)
|
||||
except Exception as exc:
|
||||
logger.debug("failed to save state during close: %s", exc)
|
||||
|
||||
# Detach CDP session if active
|
||||
if session.cdp_session:
|
||||
try:
|
||||
@@ -261,6 +285,7 @@ class BrowserSessionService:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
# Cancel the background eviction loop
|
||||
if self._evict_task is not None and not self._evict_task.done():
|
||||
@@ -524,6 +549,9 @@ class BrowserSessionService:
|
||||
profile.mkdir(parents=True, exist_ok=True)
|
||||
return profile
|
||||
|
||||
def _cookies_path(self, profile_key: str) -> Path:
|
||||
return self._profile_dir(profile_key) / "session-cookies.json"
|
||||
|
||||
def _profile_key(self, custom_page_id: int, url: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
origin = f"{parsed.scheme}-{parsed.netloc}".lower()
|
||||
@@ -553,6 +581,7 @@ class BrowserSessionService:
|
||||
str(self._profile_dir(profile_key)),
|
||||
headless=get_settings().browser_headless,
|
||||
viewport={"width": width, "height": height},
|
||||
color_scheme="dark",
|
||||
args=["--no-sandbox", "--disable-dev-shm-usage"],
|
||||
)
|
||||
# Grant clipboard access for the page origin
|
||||
@@ -613,5 +642,85 @@ class BrowserSessionService:
|
||||
except Exception as exc:
|
||||
logger.debug("CDP capture not available: %s", exc)
|
||||
|
||||
async def _save_session_state(self, session: BrowserSession) -> None:
|
||||
if not session.profile_key or session.profile_key.startswith("auth-capture-"):
|
||||
return
|
||||
try:
|
||||
state = await session.context.storage_state()
|
||||
cookies_path = self._cookies_path(session.profile_key)
|
||||
import json
|
||||
import tempfile
|
||||
import os
|
||||
# Ensure parent directories exist
|
||||
cookies_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
temp_fd, temp_path = tempfile.mkstemp(dir=str(cookies_path.parent))
|
||||
try:
|
||||
with os.fdopen(temp_fd, 'w', encoding='utf-8') as f:
|
||||
json.dump(state, f, ensure_ascii=False, indent=2)
|
||||
os.replace(temp_path, cookies_path)
|
||||
except Exception:
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.debug("failed to save session state for %s: %s", session.profile_key, exc)
|
||||
|
||||
async def _restore_session_state(self, context: Any, profile_key: str) -> None:
|
||||
if profile_key.startswith("auth-capture-"):
|
||||
return
|
||||
cookies_path = self._cookies_path(profile_key)
|
||||
if not cookies_path.exists() or cookies_path.stat().st_size == 0:
|
||||
return
|
||||
try:
|
||||
import json
|
||||
import time
|
||||
with open(cookies_path, 'r', encoding='utf-8') as f:
|
||||
state = json.load(f)
|
||||
cookies = state.get("cookies", [])
|
||||
if cookies:
|
||||
now = time.time()
|
||||
valid_cookies = []
|
||||
for c in cookies:
|
||||
expires = c.get("expires")
|
||||
if expires is not None and expires > 0 and expires <= now:
|
||||
continue
|
||||
if expires is not None and expires <= 0:
|
||||
c.pop("expires", None)
|
||||
valid_cookies.append(c)
|
||||
if valid_cookies:
|
||||
await context.add_cookies(valid_cookies)
|
||||
logger.info("restored %d cookies for profile %s", len(valid_cookies), profile_key)
|
||||
|
||||
# 还原 LocalStorage
|
||||
origins = state.get("origins", [])
|
||||
if origins:
|
||||
origins_json = json.dumps(origins)
|
||||
init_script = f"""
|
||||
(() => {{
|
||||
try {{
|
||||
const origins = {origins_json};
|
||||
const currentOrigin = window.location.origin;
|
||||
const target = origins.find(o => o.origin === currentOrigin);
|
||||
if (target && target.localStorage) {{
|
||||
for (const item of target.localStorage) {{
|
||||
try {{
|
||||
window.localStorage.setItem(item.name, item.value);
|
||||
}} catch (e) {{
|
||||
console.error('Failed to restore localStorage key', item.name, e);
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}} catch (err) {{
|
||||
console.error('LocalStorage restore initialization script failed', err);
|
||||
}}
|
||||
}})();
|
||||
"""
|
||||
await context.add_init_script(init_script)
|
||||
logger.info("registered LocalStorage init script for profile %s (origins: %d)", profile_key, len(origins))
|
||||
except Exception as exc:
|
||||
logger.warning("failed to restore cookies/state for profile %s: %s", profile_key, exc)
|
||||
|
||||
|
||||
browser_sessions = BrowserSessionService()
|
||||
|
||||
@@ -46,6 +46,7 @@ def _check_upstream(upstream_id: int) -> None:
|
||||
|
||||
auth_config = json.loads(upstream.auth_config_json or "{}")
|
||||
was_unhealthy = upstream.last_status == "unhealthy"
|
||||
balance_alert_triggered = False
|
||||
snapshot = None
|
||||
changes = None
|
||||
|
||||
@@ -76,6 +77,14 @@ def _check_upstream(upstream_id: int) -> None:
|
||||
if balance is not None:
|
||||
upstream.balance = balance
|
||||
upstream.balance_updated_at = datetime.now(timezone.utc)
|
||||
# ── 余额告警阈值检查 ──
|
||||
threshold = upstream.balance_alert_threshold
|
||||
if threshold is not None and threshold > 0:
|
||||
if balance < threshold and not upstream.balance_alert_notified:
|
||||
upstream.balance_alert_notified = True
|
||||
balance_alert_triggered = True
|
||||
elif balance >= threshold and upstream.balance_alert_notified:
|
||||
upstream.balance_alert_notified = False
|
||||
except Exception as exc:
|
||||
# failure path
|
||||
upstream.consecutive_failures = (upstream.consecutive_failures or 0) + 1
|
||||
@@ -152,6 +161,12 @@ def _check_upstream(upstream_id: int) -> None:
|
||||
_notify_rate_changed(upstream_id, upstream.name, upstream.base_url, changes)
|
||||
_sync_website_bindings(upstream_id, changes)
|
||||
|
||||
if balance_alert_triggered:
|
||||
_notify_balance_low(
|
||||
upstream_id, upstream.name, upstream.base_url,
|
||||
upstream.balance, upstream.balance_alert_threshold,
|
||||
)
|
||||
|
||||
|
||||
def _notify_status(
|
||||
upstream_id: int,
|
||||
@@ -184,6 +199,22 @@ def _notify_rate_changed(
|
||||
db.close()
|
||||
|
||||
|
||||
def _notify_balance_low(
|
||||
upstream_id: int,
|
||||
upstream_name: str,
|
||||
base_url: str,
|
||||
balance: float,
|
||||
threshold: float,
|
||||
) -> None:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
webhook_service.send_balance_low(db, upstream_id, upstream_name, base_url, balance, threshold)
|
||||
except Exception:
|
||||
logger.exception("balance low webhook failed for upstream %s", upstream_name)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _sync_upstream_keys(upstream_id: int, snapshot: dict[str, Any], captured_at: datetime) -> None:
|
||||
"""上游检测成功后同步 SmartUp Key 状态(远端删除/分组删除)。"""
|
||||
db = SessionLocal()
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.utils.dingtalk import (
|
||||
format_dingtalk_rate_changed,
|
||||
format_dingtalk_website_rate_changed,
|
||||
format_dingtalk_status,
|
||||
format_dingtalk_balance_low,
|
||||
)
|
||||
|
||||
|
||||
@@ -185,6 +186,43 @@ def send_status_event(
|
||||
_log(db, wh, event, generic_payload, "failed", str(exc))
|
||||
|
||||
|
||||
def send_balance_low(
|
||||
db: Session,
|
||||
upstream_id: int,
|
||||
upstream_name: str,
|
||||
base_url: str,
|
||||
balance: float,
|
||||
threshold: float,
|
||||
) -> None:
|
||||
webhooks = (
|
||||
db.query(WebhookConfig)
|
||||
.filter(WebhookConfig.enabled == True)
|
||||
.all()
|
||||
)
|
||||
event = "upstream_balance_low"
|
||||
changed_at = _now_iso()
|
||||
generic_payload = {
|
||||
"event": event,
|
||||
"upstream": {"id": upstream_id, "name": upstream_name, "base_url": base_url},
|
||||
"balance": balance,
|
||||
"threshold": threshold,
|
||||
"changed_at": changed_at,
|
||||
}
|
||||
for wh in webhooks:
|
||||
events = json.loads(wh.events_json or "[]")
|
||||
if event not in events:
|
||||
continue
|
||||
try:
|
||||
if wh.type == "dingtalk":
|
||||
msg = format_dingtalk_balance_low(upstream_name, balance, threshold, changed_at)
|
||||
resp_text = _send_dingtalk(wh.url, wh.secret, msg)
|
||||
else:
|
||||
resp_text = _send_generic(wh.url, generic_payload)
|
||||
_log(db, wh, event, generic_payload, "success", resp_text)
|
||||
except Exception as exc:
|
||||
_log(db, wh, event, generic_payload, "failed", str(exc))
|
||||
|
||||
|
||||
def send_test_notification(db: Session, webhook: WebhookConfig) -> tuple[bool, str]:
|
||||
payload = {
|
||||
"event": "test",
|
||||
|
||||
@@ -223,6 +223,12 @@ class Sub2ApiWebsiteClient:
|
||||
data = _unwrap_data(resp)
|
||||
return data if isinstance(data, dict) else {"value": data}
|
||||
|
||||
def update_account(self, account_id: str, body: dict[str, Any], endpoint: str = "/accounts") -> dict[str, Any]:
|
||||
"""更新远端账号(仅传入需要变更的字段)。"""
|
||||
resp = self._request("PUT", f"{endpoint}/{account_id}", body)
|
||||
data = _unwrap_data(resp)
|
||||
return data if isinstance(data, dict) else {"value": data}
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_list(value: dict) -> list | None:
|
||||
"""递归展开嵌套的列表包装:data.items、data.data、items、accounts 等。"""
|
||||
|
||||
@@ -64,6 +64,25 @@ def format_dingtalk_website_rate_changed(
|
||||
}
|
||||
|
||||
|
||||
def format_dingtalk_balance_low(
|
||||
upstream_name: str, balance: float, threshold: float, changed_at: str
|
||||
) -> dict[str, Any]:
|
||||
lines = [
|
||||
f"### ⚠️ {upstream_name} 余额不足",
|
||||
"",
|
||||
f"- **当前余额**:{balance:.2f}",
|
||||
f"- **告警阈值**:{threshold:.2f}",
|
||||
f"- **时间**:{changed_at}",
|
||||
]
|
||||
return {
|
||||
"msgtype": "markdown",
|
||||
"markdown": {
|
||||
"title": f"{upstream_name} 余额不足",
|
||||
"text": "\n".join(lines),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def format_dingtalk_status(upstream_name: str, event: str, changed_at: str, error: str = "") -> dict[str, Any]:
|
||||
emoji = "🔴" if event == "upstream_unhealthy" else "🟢"
|
||||
label = "服务异常" if event == "upstream_unhealthy" else "服务恢复"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
|
||||
from pathlib import Path
|
||||
from app.config import get_settings
|
||||
from app.routers.auth_capture import _sanitize_candidate
|
||||
from app.services.browser_session_service import BrowserSessionService
|
||||
|
||||
@@ -127,3 +128,293 @@ def test_sanitize_candidate_strips_secret_fields_but_keeps_metadata():
|
||||
"cookie_name": "session",
|
||||
"domain": "example.test",
|
||||
}
|
||||
|
||||
|
||||
def test_cookies_path_mapping():
|
||||
import tempfile
|
||||
import shutil
|
||||
service = BrowserSessionService()
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
original_dir = get_settings().browser_profiles_dir
|
||||
get_settings().browser_profiles_dir = temp_dir
|
||||
try:
|
||||
profile_key = "test-profile-123"
|
||||
expected_path = Path(service._cookies_path(profile_key))
|
||||
assert expected_path.name == "session-cookies.json"
|
||||
assert expected_path.parent == Path(service._profile_dir(profile_key))
|
||||
finally:
|
||||
get_settings().browser_profiles_dir = original_dir
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def test_screenshot_throttled_save():
|
||||
import tempfile
|
||||
import shutil
|
||||
service = BrowserSessionService()
|
||||
|
||||
# 准备临时目录并 mock settings.browser_profiles_dir
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
original_dir = get_settings().browser_profiles_dir
|
||||
get_settings().browser_profiles_dir = temp_dir
|
||||
|
||||
try:
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from app.services.browser_session_service import BrowserSession
|
||||
|
||||
# Mock Context & Page
|
||||
fake_context = AsyncMock()
|
||||
fake_page = MagicMock()
|
||||
fake_page.is_closed = MagicMock(return_value=False)
|
||||
fake_page.screenshot = AsyncMock(return_value=b"screenshot-bytes")
|
||||
|
||||
session = BrowserSession(
|
||||
id="session123",
|
||||
custom_page_id=1,
|
||||
profile_key="page-1-test",
|
||||
context=fake_context,
|
||||
page=fake_page,
|
||||
lock=asyncio.Lock(),
|
||||
last_saved_state_at=0.0
|
||||
)
|
||||
service._sessions[session.id] = session
|
||||
|
||||
# 第一次调用 screenshot: 触发存储
|
||||
res1 = run(service.screenshot(session.id))
|
||||
assert res1 == b"screenshot-bytes"
|
||||
assert fake_context.storage_state.call_count == 1
|
||||
|
||||
# 记录第一次保存后的时间戳
|
||||
first_save_time = session.last_saved_state_at
|
||||
assert first_save_time > 0
|
||||
|
||||
# 第二次立即调用 screenshot: 应该因为限流 10s 被跳过,不增加 call_count
|
||||
res2 = run(service.screenshot(session.id))
|
||||
assert res2 == b"screenshot-bytes"
|
||||
assert fake_context.storage_state.call_count == 1
|
||||
|
||||
# 模拟 11 秒后(防抖时间已过)再度截图
|
||||
session.last_saved_state_at = first_save_time - 11.0
|
||||
res3 = run(service.screenshot(session.id))
|
||||
assert res3 == b"screenshot-bytes"
|
||||
assert fake_context.storage_state.call_count == 2
|
||||
|
||||
# 测试临时 auth-capture 会话不触发任何 state 存储
|
||||
ephemeral_session = BrowserSession(
|
||||
id="session-eph",
|
||||
custom_page_id=0,
|
||||
profile_key="auth-capture-xyz",
|
||||
context=fake_context,
|
||||
page=fake_page,
|
||||
lock=asyncio.Lock(),
|
||||
last_saved_state_at=0.0
|
||||
)
|
||||
service._sessions[ephemeral_session.id] = ephemeral_session
|
||||
run(service.screenshot(ephemeral_session.id))
|
||||
# 它的 call_count 依然是 2,没有增加
|
||||
assert fake_context.storage_state.call_count == 2
|
||||
|
||||
finally:
|
||||
get_settings().browser_profiles_dir = original_dir
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def test_close_saves_state_and_cleans_up():
|
||||
import tempfile
|
||||
import shutil
|
||||
service = BrowserSessionService()
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
original_dir = get_settings().browser_profiles_dir
|
||||
get_settings().browser_profiles_dir = temp_dir
|
||||
|
||||
try:
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from app.services.browser_session_service import BrowserSession
|
||||
|
||||
fake_context = AsyncMock()
|
||||
fake_page = MagicMock()
|
||||
fake_page.is_closed = MagicMock(return_value=False)
|
||||
|
||||
import time
|
||||
session = BrowserSession(
|
||||
id="session_close",
|
||||
custom_page_id=2,
|
||||
profile_key="page-2-test",
|
||||
context=fake_context,
|
||||
page=fake_page,
|
||||
lock=asyncio.Lock(),
|
||||
last_saved_state_at=time.monotonic() # 此时在限流内
|
||||
)
|
||||
service._sessions[session.id] = session
|
||||
|
||||
# 即使在限流时间内,close 也必须强制保存
|
||||
run(service.close(session.id))
|
||||
assert fake_context.storage_state.call_count == 1
|
||||
assert fake_context.close.call_count == 1
|
||||
|
||||
# 测试 ephemeral 会话在 close 时不应该保存 state,并且其 profile_dir 应当被清理,导致 cookies json 不复存在
|
||||
eph_context = AsyncMock()
|
||||
eph_page = MagicMock()
|
||||
eph_page.is_closed = MagicMock(return_value=False)
|
||||
eph_session = BrowserSession(
|
||||
id="session_eph_close",
|
||||
custom_page_id=0,
|
||||
profile_key="auth-capture-abc",
|
||||
context=eph_context,
|
||||
page=eph_page,
|
||||
lock=asyncio.Lock(),
|
||||
last_saved_state_at=0.0
|
||||
)
|
||||
service._sessions[eph_session.id] = eph_session
|
||||
|
||||
# 先手动创建一个假 session-cookies.json
|
||||
eph_cookies_path = service._cookies_path(eph_session.profile_key)
|
||||
eph_cookies_path.write_text("{}")
|
||||
assert eph_cookies_path.exists()
|
||||
|
||||
run(service.close(eph_session.id))
|
||||
# ephemeral close 时不保存,所以 call_count 依然是 0
|
||||
assert eph_context.storage_state.call_count == 0
|
||||
assert eph_context.close.call_count == 1
|
||||
# 但对应的 profile 目录已被删除,文件自然不复存在
|
||||
assert not eph_cookies_path.exists()
|
||||
|
||||
finally:
|
||||
get_settings().browser_profiles_dir = original_dir
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def test_restore_session_state_decoding_and_inject():
|
||||
import json
|
||||
import tempfile
|
||||
import shutil
|
||||
import time
|
||||
service = BrowserSessionService()
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
original_dir = get_settings().browser_profiles_dir
|
||||
get_settings().browser_profiles_dir = temp_dir
|
||||
|
||||
try:
|
||||
from unittest.mock import AsyncMock
|
||||
fake_context = AsyncMock()
|
||||
profile_key = "test-restore-profile"
|
||||
|
||||
# 准备假 cookies.json,包含 cookies 和 origins/localStorage
|
||||
cookies_path = service._cookies_path(profile_key)
|
||||
|
||||
now = time.time()
|
||||
fake_state = {
|
||||
"cookies": [
|
||||
{
|
||||
"name": "valid_persistent",
|
||||
"value": "123",
|
||||
"domain": "example.test",
|
||||
"path": "/",
|
||||
"expires": now + 3600 # 未过期
|
||||
},
|
||||
{
|
||||
"name": "expired_cookie",
|
||||
"value": "456",
|
||||
"domain": "example.test",
|
||||
"path": "/",
|
||||
"expires": now - 3600 # 已过期
|
||||
},
|
||||
{
|
||||
"name": "session_cookie",
|
||||
"value": "789",
|
||||
"domain": "example.test",
|
||||
"path": "/",
|
||||
"expires": -1 # session cookie,应被保留并剔除 expires 字段
|
||||
}
|
||||
],
|
||||
"origins": [
|
||||
{
|
||||
"origin": "https://example.test",
|
||||
"localStorage": [
|
||||
{
|
||||
"name": "theme",
|
||||
"value": "dark"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
with open(cookies_path, "w", encoding='utf-8') as f:
|
||||
json.dump(fake_state, f)
|
||||
|
||||
# 运行还原方法
|
||||
run(service._restore_session_state(fake_context, profile_key))
|
||||
|
||||
# 检查是否成功调用 add_cookies
|
||||
assert fake_context.add_cookies.call_count == 1
|
||||
|
||||
# 检查过滤后的 cookies 内容
|
||||
injected_cookies = fake_context.add_cookies.call_args[0][0]
|
||||
assert len(injected_cookies) == 2
|
||||
|
||||
names = [c["name"] for c in injected_cookies]
|
||||
assert "valid_persistent" in names
|
||||
assert "session_cookie" in names
|
||||
assert "expired_cookie" not in names
|
||||
|
||||
# 校验 session_cookie 是否成功移除了 expires
|
||||
session_c = next(c for c in injected_cookies if c["name"] == "session_cookie")
|
||||
assert "expires" not in session_c
|
||||
|
||||
# 检查是否成功调用了 add_init_script (用于还原 LocalStorage)
|
||||
assert fake_context.add_init_script.call_count == 1
|
||||
init_script = fake_context.add_init_script.call_args[0][0]
|
||||
assert "window.localStorage.setItem" in init_script
|
||||
assert "theme" in init_script
|
||||
assert "dark" in init_script
|
||||
|
||||
finally:
|
||||
get_settings().browser_profiles_dir = original_dir
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def test_websocket_event_saves_state():
|
||||
import tempfile
|
||||
import shutil
|
||||
service = BrowserSessionService()
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
original_dir = get_settings().browser_profiles_dir
|
||||
get_settings().browser_profiles_dir = temp_dir
|
||||
|
||||
try:
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from app.services.browser_session_service import BrowserSession
|
||||
|
||||
fake_context = AsyncMock()
|
||||
fake_page = MagicMock()
|
||||
fake_page.is_closed = MagicMock(return_value=False)
|
||||
fake_page.mouse = MagicMock()
|
||||
fake_page.mouse.click = AsyncMock()
|
||||
|
||||
session = BrowserSession(
|
||||
id="session_ws",
|
||||
custom_page_id=3,
|
||||
profile_key="page-3-ws-test",
|
||||
context=fake_context,
|
||||
page=fake_page,
|
||||
lock=asyncio.Lock(),
|
||||
last_saved_state_at=0.0
|
||||
)
|
||||
service._sessions[session.id] = session
|
||||
|
||||
# 即使 include_state=False,也应当在 5 秒节流到期后保存状态
|
||||
run(service.event(
|
||||
session_id=session.id,
|
||||
event_type="click",
|
||||
payload={"x": 10.0, "y": 20.0},
|
||||
include_state=False
|
||||
))
|
||||
|
||||
# storage_state 应该被调用,说明保存成功触发了
|
||||
assert fake_context.storage_state.call_count == 1
|
||||
assert session.last_saved_state_at > 0
|
||||
|
||||
finally:
|
||||
get_settings().browser_profiles_dir = original_dir
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user