feat: sync account priorities after rate changes
This commit is contained in:
@@ -512,13 +512,14 @@ def check_now(uid: int, db: Session = Depends(get_db), _=Depends(get_current_use
|
||||
|
||||
if was_unhealthy:
|
||||
webhook_service.send_status_event(db, u.id, u.name, u.base_url, "upstream_recovered")
|
||||
# 先同步 Key 状态(标记 orphaned),再执行优先级同步(避免未标记的 key 参与计算)
|
||||
from app.services.scheduler import _sync_upstream_keys as _synck
|
||||
_synck(uid, snapshot, new_row.captured_at)
|
||||
|
||||
if changes:
|
||||
webhook_service.send_rate_changed(db, u.id, u.name, u.base_url, changes)
|
||||
website_sync.sync_affected_bindings(db, u.id, changes)
|
||||
|
||||
# 同步 SmartUp Key 状态(使用实际快照入库时间,与定时任务一致)
|
||||
from app.services.scheduler import _sync_upstream_keys as _synck
|
||||
_synck(uid, snapshot, new_row.captured_at)
|
||||
website_sync.sync_account_priorities_for_upstream(db, u.id)
|
||||
|
||||
msg = f"检测成功,{len(groups)} 个分组"
|
||||
if changes:
|
||||
|
||||
@@ -32,7 +32,7 @@ from app.schemas.website import (
|
||||
WebsiteUpdate,
|
||||
)
|
||||
from app.services.website_client import Sub2ApiWebsiteClient
|
||||
from app.services.website_sync import binding_sources, sync_binding
|
||||
from app.services.website_sync import binding_sources, sync_binding, build_rate_priority_map
|
||||
from app.utils.auth import get_current_user
|
||||
|
||||
router = APIRouter(tags=["websites"])
|
||||
@@ -171,24 +171,10 @@ def _numeric_group_id(value: str | None) -> int | None:
|
||||
|
||||
def _build_rate_priority_map(db: Session, upstream_ids: set[int]) -> dict[str, int]:
|
||||
"""根据上游分组倍率构建 group_id → priority 映射。
|
||||
|
||||
遍历所有涉及的上游的最新快照,收集分组的倍率,按倍率升序排列后赋值 priority。
|
||||
倍率最低的 priority=1,次低的 priority=2,以此类推。相同倍率的分组共享同一 priority。
|
||||
|
||||
委托给 website_sync.build_rate_priority_map 避免逻辑重复。
|
||||
"""
|
||||
group_rates: dict[str, float] = {}
|
||||
for uid in upstream_ids:
|
||||
groups = _latest_upstream_groups(db, uid)
|
||||
for g in groups:
|
||||
gid = _source_group_id(g)
|
||||
rate = _source_group_rate(g)
|
||||
if gid:
|
||||
# 同一 group_id 在同个 upstream 内是唯一的;跨 upstream 的相同 group_id
|
||||
# 如果倍率不同则以最后遇到的为准(实际很少冲突)
|
||||
group_rates[gid] = rate
|
||||
# 按倍率排序分配 priority
|
||||
unique_rates = sorted(set(group_rates.values()))
|
||||
rate_to_priority = {rate: idx + 1 for idx, rate in enumerate(unique_rates)}
|
||||
return {gid: rate_to_priority[rate] for gid, rate in group_rates.items()}
|
||||
return build_rate_priority_map(db, upstream_ids)
|
||||
|
||||
|
||||
@router.get("/api/websites", response_model=List[WebsiteResponse])
|
||||
@@ -545,7 +531,7 @@ def import_upstream_keys_as_accounts(
|
||||
exists = c.account_exists(row.imported_account_id)
|
||||
if exists is True:
|
||||
# 自动更新已有账号的 priority(分步导入时全局倍率排序可能已变)
|
||||
new_priority = rate_priority_map.get(row.group_id) if body.auto_priority_by_rate else None
|
||||
new_priority = rate_priority_map.get(f"{row.upstream_id}:{row.group_id}") if body.auto_priority_by_rate else None
|
||||
priority_msg = "已导入过,已跳过"
|
||||
if new_priority is not None:
|
||||
try:
|
||||
@@ -616,7 +602,7 @@ def import_upstream_keys_as_accounts(
|
||||
"group_ids": group_ids,
|
||||
"rate_multiplier": 1,
|
||||
"concurrency": body.concurrency,
|
||||
"priority": rate_priority_map.get(row.group_id, body.priority) if body.auto_priority_by_rate else body.priority,
|
||||
"priority": rate_priority_map.get(f"{row.upstream_id}:{row.group_id}", body.priority) if body.auto_priority_by_rate else body.priority,
|
||||
"notes": f"Imported by SmartUp from upstream key #{row.id}",
|
||||
}
|
||||
try:
|
||||
|
||||
@@ -160,6 +160,7 @@ def _check_upstream(upstream_id: int) -> None:
|
||||
if changes:
|
||||
_notify_rate_changed(upstream_id, upstream.name, upstream.base_url, changes)
|
||||
_sync_website_bindings(upstream_id, changes)
|
||||
_sync_account_priorities(upstream_id)
|
||||
|
||||
if balance_alert_triggered:
|
||||
_notify_balance_low(
|
||||
@@ -284,6 +285,17 @@ def _sync_upstream_keys(upstream_id: int, snapshot: dict[str, Any], captured_at:
|
||||
db.close()
|
||||
|
||||
|
||||
def _sync_account_priorities(upstream_id: int) -> None:
|
||||
"""倍率变更后自动更新已导入下游账号的 priority。"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
website_sync.sync_account_priorities_for_upstream(db, upstream_id)
|
||||
except Exception:
|
||||
logger.exception("account priority sync failed for upstream %s", upstream_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _sync_website_bindings(upstream_id: int, changes: list[dict[str, Any]]) -> None:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
|
||||
@@ -16,6 +16,7 @@ from app.utils.dingtalk import (
|
||||
format_dingtalk_website_rate_changed,
|
||||
format_dingtalk_status,
|
||||
format_dingtalk_balance_low,
|
||||
format_dingtalk_priority_changed,
|
||||
)
|
||||
|
||||
|
||||
@@ -223,6 +224,47 @@ def send_balance_low(
|
||||
_log(db, wh, event, generic_payload, "failed", str(exc))
|
||||
|
||||
|
||||
def send_account_priority_changed(
|
||||
db: Session,
|
||||
website_id: int,
|
||||
website_name: str,
|
||||
upstream_id: int,
|
||||
upstream_name: str,
|
||||
updates: list[dict],
|
||||
) -> None:
|
||||
webhooks = (
|
||||
db.query(WebhookConfig)
|
||||
.filter(WebhookConfig.enabled == True)
|
||||
.all()
|
||||
)
|
||||
event = "account_priority_changed"
|
||||
changed_at = _now_iso()
|
||||
success = sum(1 for u in updates if u.get("status") == "success")
|
||||
failed = sum(1 for u in updates if u.get("status") == "failed")
|
||||
skipped = sum(1 for u in updates if u.get("status") == "skipped")
|
||||
generic_payload = {
|
||||
"event": event,
|
||||
"website": {"id": website_id, "name": website_name},
|
||||
"upstream": {"id": upstream_id, "name": upstream_name},
|
||||
"changed_at": changed_at,
|
||||
"updates": updates,
|
||||
"summary": {"total": len(updates), "success": success, "failed": failed, "skipped": skipped},
|
||||
}
|
||||
for wh in webhooks:
|
||||
events = json.loads(wh.events_json or "[]")
|
||||
if event not in events:
|
||||
continue
|
||||
try:
|
||||
if wh.type == "dingtalk":
|
||||
msg = format_dingtalk_priority_changed(website_name, upstream_name, changed_at, updates)
|
||||
resp_text = _send_dingtalk(wh.url, wh.secret, msg)
|
||||
else:
|
||||
resp_text = _send_generic(wh.url, generic_payload)
|
||||
_log(db, wh, event, generic_payload, "success", resp_text)
|
||||
except Exception as exc:
|
||||
_log(db, wh, event, generic_payload, "failed", str(exc))
|
||||
|
||||
|
||||
def send_test_notification(db: Session, webhook: WebhookConfig) -> tuple[bool, str]:
|
||||
payload = {
|
||||
"event": "test",
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.snapshot import UpstreamRateSnapshot
|
||||
from app.models.upstream import Upstream
|
||||
from app.models.upstream_key import UpstreamGeneratedKey
|
||||
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
|
||||
from app.services.website_client import Sub2ApiWebsiteClient, WebsiteError, calculate_target_rate, decimal_string
|
||||
from app.services import webhook_service
|
||||
@@ -171,6 +172,251 @@ def sync_binding(db: Session, binding: WebsiteGroupBinding, write: bool = True)
|
||||
return _log(db, binding, website, source_rates, "success", message, old_rate, target_rate)
|
||||
|
||||
|
||||
def _snapshot_group_rate(group: dict) -> float:
|
||||
"""从快照分组数据中提取倍率(兼容多个字段名)。"""
|
||||
raw = group.get("rate") or group.get("default_rate") or group.get("rate_multiplier") or 1
|
||||
try:
|
||||
return float(raw)
|
||||
except (TypeError, ValueError):
|
||||
return 1.0
|
||||
|
||||
|
||||
def build_rate_priority_map(db: Session, upstream_ids: set[int]) -> dict[str, int]:
|
||||
"""根据上游分组倍率构建 f"{upstream_id}:{group_id}" → priority 映射。
|
||||
|
||||
使用 (upstream_id, group_id) 复合键避免不同上游的同名分组互相覆盖。
|
||||
遍历所有涉及的上游的最新快照,收集分组的倍率,按倍率升序排列后赋值 priority。
|
||||
倍率最低的 priority=1,次低的 priority=2,以此类推。相同倍率的分组共享同一 priority。
|
||||
"""
|
||||
group_rates: dict[str, float] = {}
|
||||
for uid in upstream_ids:
|
||||
groups = latest_rate_map(db, uid)
|
||||
for gid, g in groups.items():
|
||||
if not isinstance(g, dict):
|
||||
continue
|
||||
rate = _snapshot_group_rate(g)
|
||||
key = f"{uid}:{gid}"
|
||||
group_rates[key] = rate
|
||||
unique_rates = sorted(set(group_rates.values()))
|
||||
rate_to_priority = {rate: idx + 1 for idx, rate in enumerate(unique_rates)}
|
||||
return {key: rate_to_priority[rate] for key, rate in group_rates.items()}
|
||||
|
||||
|
||||
def _priority_result(row, new_priority: int | None, status: str, message: str) -> dict:
|
||||
"""构建统一的优先级同步结果 dict。"""
|
||||
return {
|
||||
"account_id": row.imported_account_id,
|
||||
"group_id": row.group_id,
|
||||
"upstream_id": row.upstream_id,
|
||||
"old_priority": None,
|
||||
"new_priority": new_priority,
|
||||
"status": status,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
|
||||
def _write_priority_sync_log_with_map(
|
||||
db: Session, wid: int, upstream_name: str,
|
||||
results: list[dict], priority_map: dict[str, int],
|
||||
) -> None:
|
||||
"""写入 priority_sync 日志,同时保存账号明细和 priority_map 快照。
|
||||
|
||||
source_rates_json 格式:[{"_meta": "priority_map", "data": {...}}, {"account_id": ..., ...}, ...]
|
||||
兼容 WebsiteSyncLogResponse.source_rates: list[dict] 类型约束。
|
||||
"""
|
||||
log_results: list[dict] = [
|
||||
{"_meta": "priority_map", "data": dict(priority_map)},
|
||||
]
|
||||
log_results.extend(results)
|
||||
success = sum(1 for r in results if r["status"] == "success")
|
||||
failed = sum(1 for r in results if r["status"] == "failed")
|
||||
skipped = sum(1 for r in results if r["status"] == "skipped")
|
||||
parts = []
|
||||
if success:
|
||||
parts.append(f"{success} 个更新成功")
|
||||
if failed:
|
||||
parts.append(f"{failed} 个失败")
|
||||
if skipped:
|
||||
parts.append(f"{skipped} 个跳过")
|
||||
log = WebsiteSyncLog(
|
||||
website_id=wid,
|
||||
binding_id=None,
|
||||
target_group_id="",
|
||||
target_group_name="",
|
||||
algorithm="priority_sync",
|
||||
percent=0,
|
||||
source_rates_json=json.dumps(log_results, ensure_ascii=False, default=str),
|
||||
old_rate=None,
|
||||
new_rate=None,
|
||||
status="failed" if failed else "success",
|
||||
message=f"优先级同步(上游={upstream_name}):{'、'.join(parts)} / 共 {len(results)} 个",
|
||||
)
|
||||
db.add(log)
|
||||
db.commit()
|
||||
|
||||
|
||||
def _try_send_priority_webhook(
|
||||
db: Session, wid: int, website_name: str,
|
||||
upstream_id: int, upstream_name: str,
|
||||
updates: list[dict],
|
||||
) -> None:
|
||||
"""发送 account_priority_changed webhook,失败不抛异常。"""
|
||||
if not updates:
|
||||
return
|
||||
# 如果没传入名称,尝试从 DB 查
|
||||
resolved_name = website_name
|
||||
if not resolved_name:
|
||||
row = db.query(Website.name).filter(Website.id == wid).first()
|
||||
if row:
|
||||
resolved_name = row[0]
|
||||
else:
|
||||
resolved_name = f"网站#{wid}"
|
||||
try:
|
||||
webhook_service.send_account_priority_changed(
|
||||
db,
|
||||
website_id=wid,
|
||||
website_name=resolved_name,
|
||||
upstream_id=upstream_id,
|
||||
upstream_name=upstream_name,
|
||||
updates=updates,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("account_priority_changed webhook failed for website %s: %s", wid, exc)
|
||||
|
||||
|
||||
def sync_account_priorities_for_upstream(db: Session, upstream_id: int) -> list[dict]:
|
||||
"""上游倍率变化后,自动更新已导入下游账号的 priority。
|
||||
|
||||
查询该上游下所有已导入(非 orphaned)的 Key,按目标网站分组后重新计算全局优先级,
|
||||
并通过 update_account API 推送到下游网站。返回详细结果列表。
|
||||
|
||||
同时写入 WebsiteSyncLog 持久化审计日志,并通过 webhook 发送通知。
|
||||
"""
|
||||
from app.services.website_client import Sub2ApiWebsiteClient as Client
|
||||
|
||||
key_rows = (
|
||||
db.query(UpstreamGeneratedKey)
|
||||
.filter(
|
||||
UpstreamGeneratedKey.upstream_id == upstream_id,
|
||||
UpstreamGeneratedKey.imported_website_id.isnot(None),
|
||||
UpstreamGeneratedKey.imported_account_id.isnot(None),
|
||||
UpstreamGeneratedKey.status != "orphaned",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if not key_rows:
|
||||
return []
|
||||
|
||||
upstream_name = db.query(Upstream.name).filter(Upstream.id == upstream_id).scalar() or f"#{upstream_id}"
|
||||
|
||||
# 按 imported_website_id 分组
|
||||
website_groups: dict[int, list[UpstreamGeneratedKey]] = {}
|
||||
for row in key_rows:
|
||||
wid = row.imported_website_id
|
||||
if wid not in website_groups:
|
||||
website_groups[wid] = []
|
||||
website_groups[wid].append(row)
|
||||
|
||||
all_results: list[dict] = []
|
||||
|
||||
for wid, rows in website_groups.items():
|
||||
website = db.query(Website).filter(Website.id == wid).first()
|
||||
if not website or not website.enabled:
|
||||
logger.info("skip account priority sync: website %s not found or disabled", wid)
|
||||
site_results = []
|
||||
for row in rows:
|
||||
r = _priority_result(row, None, "failed", "网站不可用")
|
||||
site_results.append(r)
|
||||
all_results.append(r)
|
||||
_write_priority_sync_log_with_map(db, wid, upstream_name, site_results, {})
|
||||
_try_send_priority_webhook(db, wid, "", upstream_id, upstream_name, site_results)
|
||||
continue
|
||||
|
||||
# 查询该网站所有已导入 Key(跨上游),实现全局优先级排序
|
||||
all_website_keys = (
|
||||
db.query(UpstreamGeneratedKey)
|
||||
.filter(
|
||||
UpstreamGeneratedKey.imported_website_id == wid,
|
||||
UpstreamGeneratedKey.imported_account_id.isnot(None),
|
||||
UpstreamGeneratedKey.status != "orphaned",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
all_upstream_ids = {k.upstream_id for k in all_website_keys}
|
||||
try:
|
||||
priority_map = build_rate_priority_map(db, all_upstream_ids)
|
||||
except Exception as exc:
|
||||
logger.warning("build_rate_priority_map failed for website %s: %s", wid, exc)
|
||||
site_results = []
|
||||
for row in all_website_keys:
|
||||
r = _priority_result(row, None, "failed", f"构建优先级映射失败: {exc}")
|
||||
site_results.append(r)
|
||||
all_results.append(r)
|
||||
_write_priority_sync_log_with_map(db, wid, upstream_name, site_results, {})
|
||||
_try_send_priority_webhook(db, wid, "", upstream_id, upstream_name, site_results)
|
||||
continue
|
||||
|
||||
if not priority_map:
|
||||
logger.info("skip account priority sync for website %s: empty priority map", wid)
|
||||
site_results = []
|
||||
for row in all_website_keys:
|
||||
r = _priority_result(row, None, "skipped", "无上游倍率数据")
|
||||
site_results.append(r)
|
||||
all_results.append(r)
|
||||
_write_priority_sync_log_with_map(db, wid, upstream_name, site_results, {})
|
||||
_try_send_priority_webhook(db, wid, "", upstream_id, upstream_name, site_results)
|
||||
continue
|
||||
|
||||
site_results: list[dict] = []
|
||||
try:
|
||||
with Client(
|
||||
base_url=website.base_url,
|
||||
api_prefix=website.api_prefix,
|
||||
auth_type=website.auth_type,
|
||||
auth_config=json.loads(website.auth_config_json or "{}"),
|
||||
timeout=float(website.timeout_seconds),
|
||||
) as client:
|
||||
for row in all_website_keys:
|
||||
account_id = row.imported_account_id
|
||||
if not account_id:
|
||||
continue
|
||||
new_priority = priority_map.get(f"{row.upstream_id}:{row.group_id}")
|
||||
if new_priority is None:
|
||||
site_results.append(
|
||||
_priority_result(row, None, "skipped", "无倍率数据,跳过")
|
||||
)
|
||||
continue
|
||||
try:
|
||||
client.update_account(account_id, {"priority": new_priority})
|
||||
logger.info(
|
||||
"updated priority for account %s (website=%s, upstream=%s, group=%s): %s",
|
||||
account_id, wid, row.upstream_id, row.group_id, new_priority,
|
||||
)
|
||||
site_results.append(
|
||||
_priority_result(row, new_priority, "success", f"优先级已更新为 {new_priority}")
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"failed to update priority for account %s (website=%s): %s",
|
||||
account_id, wid, exc,
|
||||
)
|
||||
site_results.append(
|
||||
_priority_result(row, new_priority, "failed", str(exc))
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("failed to connect website %s for account priority sync: %s", wid, exc)
|
||||
for row in all_website_keys:
|
||||
site_results.append(
|
||||
_priority_result(row, None, "failed", f"连接网站失败: {exc}")
|
||||
)
|
||||
|
||||
all_results.extend(site_results)
|
||||
_write_priority_sync_log_with_map(db, wid, upstream_name, site_results, priority_map)
|
||||
_try_send_priority_webhook(db, wid, website.name, upstream_id, upstream_name, site_results)
|
||||
|
||||
return all_results
|
||||
|
||||
|
||||
def sync_affected_bindings(db: Session, upstream_id: int, changes: list[dict[str, Any]]) -> None:
|
||||
for binding in get_affected_bindings(db, changes, upstream_id):
|
||||
try:
|
||||
|
||||
@@ -64,6 +64,35 @@ def format_dingtalk_website_rate_changed(
|
||||
}
|
||||
|
||||
|
||||
def format_dingtalk_priority_changed(
|
||||
website_name: str, upstream_name: str, changed_at: str,
|
||||
updates: list[dict],
|
||||
) -> dict[str, Any]:
|
||||
success = sum(1 for u in updates if u.get("status") == "success")
|
||||
failed = sum(1 for u in updates if u.get("status") == "failed")
|
||||
skipped = sum(1 for u in updates if u.get("status") == "skipped")
|
||||
lines = [
|
||||
f"### 🔄 {website_name} 账号优先级变更",
|
||||
"",
|
||||
f"- **触发上游**:{upstream_name}",
|
||||
f"- **时间**:{changed_at}",
|
||||
f"- **摘要**:{success} 更新 / {failed} 失败 / {skipped} 跳过",
|
||||
"",
|
||||
]
|
||||
for u in updates:
|
||||
emoji = {"success": "✅", "failed": "❌", "skipped": "⏭️"}.get(u.get("status", ""), "➖")
|
||||
gid = u.get("group_id", "?")
|
||||
priority = u.get("new_priority", "—")
|
||||
lines.append(f"{emoji} `{gid}` → priority={priority}")
|
||||
return {
|
||||
"msgtype": "markdown",
|
||||
"markdown": {
|
||||
"title": f"{website_name} 账号优先级变更",
|
||||
"text": "\n".join(lines),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def format_dingtalk_balance_low(
|
||||
upstream_name: str, balance: float, threshold: float, changed_at: str
|
||||
) -> dict[str, Any]:
|
||||
|
||||
@@ -0,0 +1,221 @@
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.database import Base
|
||||
from app.models.upstream import Upstream
|
||||
from app.models.website import Website, WebsiteSyncLog
|
||||
from app.models.upstream_key import UpstreamGeneratedKey
|
||||
from app.models.snapshot import UpstreamRateSnapshot
|
||||
from app.services.website_sync import (
|
||||
build_rate_priority_map,
|
||||
sync_account_priorities_for_upstream
|
||||
)
|
||||
from app.services.website_client import Sub2ApiWebsiteClient
|
||||
|
||||
@pytest.fixture()
|
||||
def db_session():
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
db = TestingSessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
def test_priority_sync_cross_upstream_group(db_session):
|
||||
# Setup 2 upstreams
|
||||
u1 = Upstream(name="U1", base_url="http://u1")
|
||||
u2 = Upstream(name="U2", base_url="http://u2")
|
||||
db_session.add_all([u1, u2])
|
||||
db_session.commit()
|
||||
db_session.refresh(u1)
|
||||
db_session.refresh(u2)
|
||||
|
||||
# Setup snapshots for both with same group ID "VIP" but different rates
|
||||
s1 = UpstreamRateSnapshot(
|
||||
upstream_id=u1.id,
|
||||
snapshot_json=json.dumps({"groups": {"VIP": {"group_name": "VIP", "rate": 1.0}}}),
|
||||
captured_at=datetime.now(timezone.utc)
|
||||
)
|
||||
s2 = UpstreamRateSnapshot(
|
||||
upstream_id=u2.id,
|
||||
snapshot_json=json.dumps({"groups": {"VIP": {"group_name": "VIP", "rate": 2.0}}}),
|
||||
captured_at=datetime.now(timezone.utc)
|
||||
)
|
||||
db_session.add_all([s1, s2])
|
||||
db_session.commit()
|
||||
|
||||
priority_map = build_rate_priority_map(db_session, {u1.id, u2.id})
|
||||
|
||||
assert priority_map[f"{u1.id}:VIP"] == 1
|
||||
assert priority_map[f"{u2.id}:VIP"] == 2
|
||||
assert len(priority_map) == 2
|
||||
|
||||
def test_priority_sync_full_website_update(db_session, monkeypatch):
|
||||
# Setup website and upstreams
|
||||
w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}", timeout_seconds=30)
|
||||
u1 = Upstream(name="U1", base_url="http://u1")
|
||||
u2 = Upstream(name="U2", base_url="http://u2")
|
||||
db_session.add_all([w, u1, u2])
|
||||
db_session.commit()
|
||||
db_session.refresh(w)
|
||||
db_session.refresh(u1)
|
||||
db_session.refresh(u2)
|
||||
|
||||
# Setup snapshots
|
||||
db_session.add(UpstreamRateSnapshot(
|
||||
upstream_id=u1.id,
|
||||
snapshot_json=json.dumps({"groups": {"G1": {"rate": 1.0}}}),
|
||||
captured_at=datetime.now(timezone.utc)
|
||||
))
|
||||
db_session.add(UpstreamRateSnapshot(
|
||||
upstream_id=u2.id,
|
||||
snapshot_json=json.dumps({"groups": {"G2": {"rate": 2.0}}}),
|
||||
captured_at=datetime.now(timezone.utc)
|
||||
))
|
||||
db_session.commit()
|
||||
|
||||
# Setup keys imported to website
|
||||
k1 = UpstreamGeneratedKey(upstream_id=u1.id, group_id="G1", key_name="K1", key_value="V1",
|
||||
imported_website_id=w.id, imported_account_id="A1")
|
||||
k2 = UpstreamGeneratedKey(upstream_id=u2.id, group_id="G2", key_name="K2", key_value="V2",
|
||||
imported_website_id=w.id, imported_account_id="A2")
|
||||
db_session.add_all([k1, k2])
|
||||
db_session.commit()
|
||||
|
||||
# Mock Sub2ApiWebsiteClient
|
||||
update_calls = []
|
||||
class MockClient:
|
||||
def __init__(self, **kwargs): pass
|
||||
def __enter__(self): return self
|
||||
def __exit__(self, *args): pass
|
||||
def update_account(self, account_id, data):
|
||||
update_calls.append((account_id, data))
|
||||
|
||||
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", MockClient)
|
||||
monkeypatch.setattr("app.services.website_client.Sub2ApiWebsiteClient", MockClient)
|
||||
|
||||
# Trigger sync for U1
|
||||
sync_account_priorities_for_upstream(db_session, u1.id)
|
||||
|
||||
# Verify BOTH A1 and A2 were updated because they belong to the same website
|
||||
assert len(update_calls) == 2
|
||||
account_ids = {c[0] for c in update_calls}
|
||||
assert account_ids == {"A1", "A2"}
|
||||
|
||||
# Priority check: G1(1.0) -> 1, G2(2.0) -> 2
|
||||
for aid, data in update_calls:
|
||||
if aid == "A1": assert data["priority"] == 1
|
||||
if aid == "A2": assert data["priority"] == 2
|
||||
|
||||
def test_priority_sync_log_structure(db_session, monkeypatch):
|
||||
w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}", timeout_seconds=30)
|
||||
u1 = Upstream(name="U1", base_url="http://u1")
|
||||
db_session.add_all([w, u1])
|
||||
db_session.commit()
|
||||
db_session.refresh(w)
|
||||
db_session.refresh(u1)
|
||||
|
||||
db_session.add(UpstreamRateSnapshot(
|
||||
upstream_id=u1.id,
|
||||
snapshot_json=json.dumps({"groups": {"G1": {"rate": 1.0}}}),
|
||||
captured_at=datetime.now(timezone.utc)
|
||||
))
|
||||
db_session.add(UpstreamGeneratedKey(upstream_id=u1.id, group_id="G1", key_name="K1", key_value="V1",
|
||||
imported_website_id=w.id, imported_account_id="A1"))
|
||||
db_session.commit()
|
||||
|
||||
class MockClient:
|
||||
def __init__(self, **kwargs): pass
|
||||
def __enter__(self): return self
|
||||
def __exit__(self, *args): pass
|
||||
def update_account(self, account_id, data): pass
|
||||
|
||||
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", MockClient)
|
||||
monkeypatch.setattr("app.services.website_client.Sub2ApiWebsiteClient", MockClient)
|
||||
|
||||
sync_account_priorities_for_upstream(db_session, u1.id)
|
||||
|
||||
log = db_session.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == w.id).first()
|
||||
assert log is not None
|
||||
assert log.algorithm == "priority_sync"
|
||||
|
||||
data = json.loads(log.source_rates_json)
|
||||
# The first item should be the priority map metadata
|
||||
assert data[0]["_meta"] == "priority_map"
|
||||
assert f"{u1.id}:G1" in data[0]["data"]
|
||||
# The second item should be the account result
|
||||
assert data[1]["account_id"] == "A1"
|
||||
|
||||
def test_import_auto_priority_by_rate(db_session, monkeypatch):
|
||||
from app.routers.websites import import_upstream_keys_as_accounts
|
||||
from app.schemas.website import ImportAccountsRequest
|
||||
|
||||
w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}",
|
||||
groups_endpoint="/groups", group_update_endpoint="/groups/{id}", timeout_seconds=30)
|
||||
u1 = Upstream(name="U1", base_url="http://u1")
|
||||
u2 = Upstream(name="U2", base_url="http://u2")
|
||||
db_session.add_all([w, u1, u2])
|
||||
db_session.commit()
|
||||
db_session.refresh(w)
|
||||
db_session.refresh(u1)
|
||||
db_session.refresh(u2)
|
||||
|
||||
db_session.add(UpstreamRateSnapshot(
|
||||
upstream_id=u1.id,
|
||||
snapshot_json=json.dumps({"groups": {"G1": {"rate": 2.0}}}),
|
||||
captured_at=datetime.now(timezone.utc)
|
||||
))
|
||||
db_session.add(UpstreamRateSnapshot(
|
||||
upstream_id=u2.id,
|
||||
snapshot_json=json.dumps({"groups": {"G2": {"rate": 1.0}}}),
|
||||
captured_at=datetime.now(timezone.utc)
|
||||
))
|
||||
|
||||
k1 = UpstreamGeneratedKey(upstream_id=u1.id, group_id="G1", group_name="G1", key_name="K1", key_value="V1")
|
||||
k2 = UpstreamGeneratedKey(upstream_id=u2.id, group_id="G2", group_name="G2", key_name="K2", key_value="V2")
|
||||
db_session.add_all([k1, k2])
|
||||
db_session.commit()
|
||||
|
||||
created_accounts = []
|
||||
class MockClient:
|
||||
def __init__(self, **kwargs): pass
|
||||
def __enter__(self): return self
|
||||
def __exit__(self, *args): pass
|
||||
def create_account(self, body):
|
||||
created_accounts.append(body)
|
||||
return {"id": f"remote-{len(created_accounts)}", "name": body["name"]}
|
||||
def extract_id(self, data): return data["id"]
|
||||
def account_exists(self, aid): return False
|
||||
|
||||
monkeypatch.setattr("app.routers.websites._client", lambda website: MockClient())
|
||||
monkeypatch.setattr("app.services.website_client.Sub2ApiWebsiteClient", MockClient)
|
||||
|
||||
req = ImportAccountsRequest(
|
||||
upstream_key_ids=[k1.id, k2.id],
|
||||
target_group_map={},
|
||||
auto_priority_by_rate=True,
|
||||
priority=10,
|
||||
account_name_prefix="test",
|
||||
default_platform="openai"
|
||||
)
|
||||
|
||||
import_upstream_keys_as_accounts(w.id, req, db_session)
|
||||
|
||||
assert len(created_accounts) == 2
|
||||
# G2 has rate 1.0 -> priority 1
|
||||
# G1 has rate 2.0 -> priority 2
|
||||
p1 = next(a["priority"] for a in created_accounts if "G1" in a["name"])
|
||||
p2 = next(a["priority"] for a in created_accounts if "G2" in a["name"])
|
||||
assert p2 == 1
|
||||
assert p1 == 2
|
||||
Reference in New Issue
Block a user