Files
SmartUp/backend/test_upstream_key_sync.py
T
liumangmang 6044b00685 feat: 上游 Key 唯一化、分组导入跳过、账号导入平台识别&远端校验&base_url 注入
- 上游 Key 命名改为 {prefix}-{upstream.id}-{safe_group_name}-{group_id}
- 唯一约束 (upstream_id, group_id, managed_prefix) 加 managed_prefix 列
- 上游检测成功时同步 Key 状态,远端已删/分组已删自动清理
- 重复分组导入跳过,目标网站已存在同名分组返回 exists
- 账号导入平台自动识别(auto/manual 模式)
- 全选可导入 Key 按钮 + 目标分组自动匹配
- 导入幂等:已导入过的 Key 校验远端账号,不存在则重建
- 新增同步接口 POST /sync-imported-upstream-keys
- account_exists() 通过拉取账号列表判断,避免 404 误判
- credentials.base_url 注入来源上游地址,避免 401
- 前端导入弹窗自动同步+刷新按钮+并发/优先级设置
- 新增 12 个测试覆盖同步、幂等、远端删除、校验失败路径
2026-05-21 01:16:39 +08:00

470 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Tests for upstream key uniquification and sync cleanup."""
import json
from datetime import datetime, timezone
import pytest
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.models.upstream import Upstream
from app.models.website import Website # noqa: F401 — registers table for FK refs
from app.models.upstream_key import UpstreamGeneratedKey
@pytest.fixture()
def db_session():
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
Base.metadata.drop_all(bind=engine)
def test_duplicate_cleanup_keeps_latest_only():
"""同一 upstream_id + group_id + key_name 的多条记录只保留最新一条。
使用独立 engine + 全 raw SQL,模拟迁移前的数据库状态。
"""
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
with engine.begin() as conn:
conn.execute(text("""
CREATE TABLE upstreams (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name VARCHAR(255) NOT NULL,
base_url VARCHAR(512) NOT NULL,
api_prefix VARCHAR(64) DEFAULT '',
auth_type VARCHAR(32),
auth_config_json TEXT DEFAULT '{}',
groups_endpoint VARCHAR(256),
rate_endpoint VARCHAR(256),
enabled BOOLEAN DEFAULT 1,
check_interval_seconds INTEGER DEFAULT 600,
timeout_seconds INTEGER DEFAULT 30,
last_status VARCHAR(32) DEFAULT 'unknown',
last_checked_at DATETIME,
last_error TEXT,
consecutive_failures INTEGER DEFAULT 0,
balance FLOAT,
balance_updated_at DATETIME,
balance_endpoint VARCHAR(256) DEFAULT '',
balance_response_path VARCHAR(256) DEFAULT '',
balance_divisor FLOAT DEFAULT 1.0,
updated_at DATETIME,
created_at DATETIME
)
"""))
conn.execute(text("""
CREATE TABLE upstream_generated_keys (
id INTEGER PRIMARY KEY AUTOINCREMENT,
upstream_id INTEGER NOT NULL,
group_id VARCHAR(255) NOT NULL,
group_name VARCHAR(255) DEFAULT '',
key_id VARCHAR(255),
key_name VARCHAR(255) NOT NULL,
key_value TEXT NOT NULL,
masked_key VARCHAR(255) DEFAULT '',
raw_json TEXT DEFAULT '{}',
status VARCHAR(32) DEFAULT 'created',
error TEXT,
imported_website_id INTEGER,
imported_account_id VARCHAR(255),
imported_at DATETIME,
created_at DATETIME,
updated_at DATETIME
)
"""))
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = TestingSessionLocal()
try:
now = datetime.now(timezone.utc)
db.execute(text("""
INSERT INTO upstreams (name, base_url, api_prefix, auth_type, auth_config_json, groups_endpoint, rate_endpoint, created_at, updated_at)
VALUES (:n, :b, :p, :a, :j, :g, :r, :ca, :ua)
"""), {"n": "Test", "b": "http://local", "p": "/api/v1", "a": "bearer",
"j": "{}", "g": "/groups", "r": "/rates", "ca": now, "ua": now})
db.commit()
uid = db.execute(text("SELECT id FROM upstreams LIMIT 1")).scalar()
# 插入 3 条重复记录
for kv, ca in [
("old-key", datetime(2025, 1, 1, tzinfo=timezone.utc)),
("middle-key", datetime(2025, 6, 1, tzinfo=timezone.utc)),
("newest-key", datetime(2025, 12, 1, tzinfo=timezone.utc)),
]:
db.execute(text("""
INSERT INTO upstream_generated_keys
(upstream_id, group_id, group_name, key_name, key_value, masked_key, raw_json, status, created_at, updated_at)
VALUES (:uid, :gid, :gn, :kn, :kv, :mk, :rj, :st, :ca, :ca)
"""), {
"uid": uid, "gid": "vip", "gn": "VIP",
"kn": "SmartUp-Test-VIP", "kv": kv,
"mk": "", "rj": "{}", "st": "created", "ca": ca,
})
db.commit()
# 清理:同一组合只保留最新一条(id 最大)
db.execute(text("""
DELETE FROM upstream_generated_keys
WHERE id NOT IN (
SELECT MAX(id) FROM upstream_generated_keys
GROUP BY upstream_id, group_id, key_name
)
"""))
db.commit()
remaining = db.execute(text("SELECT key_value FROM upstream_generated_keys")).fetchall()
assert len(remaining) == 1, f"expected 1 after cleanup, got {len(remaining)}"
assert remaining[0][0] == "newest-key"
finally:
db.close()
def test_migration_backfills_managed_prefix_and_deduplicates():
"""迁移逻辑应回填历史 SmartUp 记录的 managed_prefix 并清理重复。
使用独立 engine(不创建唯一约束),模拟迁移前状态。
"""
from sqlalchemy import text as _text
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
with engine.begin() as conn:
conn.execute(_text("""
CREATE TABLE upstreams (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name VARCHAR(255) NOT NULL, base_url VARCHAR(512) NOT NULL,
api_prefix VARCHAR(64) DEFAULT '', auth_type VARCHAR(32),
auth_config_json TEXT DEFAULT '{}', groups_endpoint VARCHAR(256),
rate_endpoint VARCHAR(256), enabled BOOLEAN DEFAULT 1,
check_interval_seconds INTEGER DEFAULT 600, timeout_seconds INTEGER DEFAULT 30,
last_status VARCHAR(32) DEFAULT 'unknown', last_checked_at DATETIME,
last_error TEXT, consecutive_failures INTEGER DEFAULT 0,
balance FLOAT, balance_updated_at DATETIME,
balance_endpoint VARCHAR(256) DEFAULT '', balance_response_path VARCHAR(256) DEFAULT '',
balance_divisor FLOAT DEFAULT 1.0, updated_at DATETIME, created_at DATETIME
)
"""))
conn.execute(_text("""
CREATE TABLE upstream_generated_keys (
id INTEGER PRIMARY KEY AUTOINCREMENT,
upstream_id INTEGER NOT NULL, group_id VARCHAR(255) NOT NULL,
group_name VARCHAR(255) DEFAULT '', key_id VARCHAR(255),
key_name VARCHAR(255) NOT NULL, key_value TEXT NOT NULL,
masked_key VARCHAR(255) DEFAULT '', raw_json TEXT DEFAULT '{}',
managed_prefix VARCHAR(64), status VARCHAR(32) DEFAULT 'created',
error TEXT, imported_website_id INTEGER, imported_account_id VARCHAR(255),
imported_at DATETIME, created_at DATETIME, updated_at DATETIME
)
"""))
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = TestingSessionLocal()
try:
now = datetime.now(timezone.utc)
db.execute(_text("""
INSERT INTO upstreams (name, base_url, api_prefix, auth_type, auth_config_json, groups_endpoint, rate_endpoint, created_at, updated_at)
VALUES (:n, :b, :p, :a, :j, :g, :r, :ca, :ua)
"""), {"n": "Old", "b": "http://local", "p": "/api/v1", "a": "bearer",
"j": "{}", "g": "/groups", "r": "/rates", "ca": now, "ua": now})
db.commit()
uid = db.execute(_text("SELECT id FROM upstreams LIMIT 1")).scalar()
# 插入两条重复记录(无 managed_prefixkey_name 以 SmartUp 开头)
for kv in ("sk-old", "sk-new"):
db.execute(_text("""
INSERT INTO upstream_generated_keys
(upstream_id, group_id, group_name, key_name, key_value, masked_key, raw_json, status, created_at, updated_at)
VALUES (:uid, :gid, :gn, :kn, :kv, '', '{}', 'created', :ca, :ca)
"""), {"uid": uid, "gid": "vip", "gn": "VIP",
"kn": "SmartUp-Old-vip", "kv": kv, "ca": now})
db.commit()
# 执行迁移逻辑(与 database.py 中的 SQL 一致)
conn = db.connection()
conn.execute(_text(
"UPDATE upstream_generated_keys SET managed_prefix = 'SmartUp' "
"WHERE managed_prefix IS NULL AND key_name LIKE 'SmartUp-%'"
))
to_delete = conn.execute(_text("""
SELECT id FROM upstream_generated_keys
WHERE managed_prefix IS NOT NULL
AND id NOT IN (
SELECT MAX(id) FROM upstream_generated_keys
WHERE managed_prefix IS NOT NULL
GROUP BY upstream_id, group_id, managed_prefix
)
""")).fetchall()
for (row_id,) in to_delete:
conn.execute(_text("DELETE FROM upstream_generated_keys WHERE id = :id"), {"id": row_id})
db.commit()
remaining = db.execute(_text("SELECT key_value, managed_prefix FROM upstream_generated_keys")).fetchall()
assert len(remaining) == 1, f"expected 1 after migration, got {len(remaining)}"
assert remaining[0][0] == "sk-new" # 保留最新一条
assert remaining[0][1] == "SmartUp" # 已回填
finally:
db.close()
def test_ensure_group_key_reuses_old_record(db_session, monkeypatch):
"""_ensure_group_key 应复用 managed_prefix IS NULL 的旧记录,不新建。"""
from app.routers.upstreams import _ensure_group_key
from app.models.upstream_key import UpstreamGeneratedKey
from app.services.upstream_client import UpstreamClient
from app.schemas.upstream import GenerateKeysByGroupsRequest
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
auth_type="bearer", auth_config_json="{}",
groups_endpoint="/groups", rate_endpoint="/rates")
db_session.add(upstream)
db_session.commit()
db_session.refresh(upstream)
# 插入一条旧记录(无 managed_prefix
db_session.add(UpstreamGeneratedKey(
upstream_id=upstream.id, group_id="vip", group_name="VIP",
key_name="SmartUp-Test-vip", key_value="sk-old",
managed_prefix=None, key_id="remote-999",
))
db_session.commit()
# 构造 mock client
class MockClient:
def find_smartup_group_key(self, gid, name, prefix):
return None
def create_api_key(self, name, group_id, **kw):
return {"id": "new-remote", "key": "sk-new-value", "masked_key": "sk-****-lue"}
group = {"id": "vip", "name": "VIP", "rate_multiplier": 1}
body = GenerateKeysByGroupsRequest(
group_ids=["vip"], name_prefix="SmartUp",
quota=0, endpoint="/keys",
)
result = _ensure_group_key(db_session, MockClient(), upstream, group, "SmartUp", body)
assert result.status == "created"
rows = db_session.query(UpstreamGeneratedKey).filter(
UpstreamGeneratedKey.upstream_id == upstream.id,
UpstreamGeneratedKey.group_id == "vip",
).all()
assert len(rows) == 1, f"expected 1 record, got {len(rows)}"
assert rows[0].managed_prefix == "SmartUp"
assert rows[0].key_value == "sk-new-value"
def test_sync_removes_remote_key_when_list_empty(db_session, monkeypatch):
"""同步函数在远端返回空列表时应删除本地 key_id 对应的记录。"""
from app.services import scheduler as sched_mod
from app.models.upstream_key import UpstreamGeneratedKey
from app.services.upstream_client import UpstreamClient
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
auth_type="bearer", auth_config_json="{}",
groups_endpoint="/groups", rate_endpoint="/rates")
db_session.add(upstream)
db_session.commit()
db_session.refresh(upstream)
db_session.add(UpstreamGeneratedKey(
upstream_id=upstream.id, group_id="vip", group_name="VIP",
key_name="SmartUp-Test-vip", key_value="sk-vip",
managed_prefix="SmartUp", key_id="remote-key-id",
))
db_session.commit()
# mock list_api_keys 返回空列表(查询成功但无 Key)
monkeypatch.setattr(UpstreamClient, "list_api_keys", lambda self, **kw: [])
monkeypatch.setattr(UpstreamClient, "login", lambda self: None)
monkeypatch.setattr(UpstreamClient, "close", lambda self: None)
monkeypatch.setattr(UpstreamClient, "__enter__", lambda self: self)
monkeypatch.setattr(UpstreamClient, "__exit__", lambda self, *a: None)
# 让 _sync_upstream_keys 使用 db_session 的 bind 引擎
monkeypatch.setattr(sched_mod, "SessionLocal",
lambda: db_session)
# 阻止 finally 中的 db.close() 影响测试会话
original_close = db_session.close
monkeypatch.setattr(db_session, "close", lambda: None)
snapshot = {
"upstream_id": upstream.id,
"groups": {"vip": {"group_id": "vip", "rate": "1"}},
"captured_at": datetime.now(timezone.utc).isoformat(),
}
captured_at = datetime.now(timezone.utc)
sched_mod._sync_upstream_keys(upstream.id, snapshot, captured_at)
monkeypatch.setattr(db_session, "close", original_close)
remaining = db_session.query(UpstreamGeneratedKey).all()
assert len(remaining) == 0, f"expected 0 after sync with empty remote, got {len(remaining)}"
def test_migration_function_integration(monkeypatch):
"""直接调用 _migrate_upstream_generated_keys() 验证列新增和索引创建。"""
from app.database import _migrate_upstream_generated_keys, engine as real_engine
from sqlalchemy import text as _text
# 使用独立 engine,避免影响真实数据库
test_engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
monkeypatch.setattr("app.database.engine", test_engine)
# 建表(不含 managed_prefix 列,模拟旧版 schema
with test_engine.begin() as conn:
conn.execute(_text("""
CREATE TABLE upstream_generated_keys (
id INTEGER PRIMARY KEY AUTOINCREMENT,
upstream_id INTEGER NOT NULL, group_id VARCHAR(255) NOT NULL,
group_name VARCHAR(255) DEFAULT '', key_id VARCHAR(255),
key_name VARCHAR(255) NOT NULL, key_value TEXT NOT NULL,
masked_key VARCHAR(255) DEFAULT '', raw_json TEXT DEFAULT '{}',
status VARCHAR(32) DEFAULT 'created', error TEXT,
imported_website_id INTEGER, imported_account_id VARCHAR(255),
imported_at DATETIME, created_at DATETIME
)
"""))
conn.execute(_text("""
INSERT INTO upstream_generated_keys
(upstream_id, group_id, group_name, key_name, key_value, masked_key, raw_json, status, created_at)
VALUES (1, 'vip', 'VIP', 'SmartUp-Old-vip', 'sk-val', '', '{}', 'created', datetime('now'))
"""))
# 调用迁移函数入口
_migrate_upstream_generated_keys()
# 验证 managed_prefix 列已存在且被填充
inspector = __import__('sqlalchemy', fromlist=['']).inspect(test_engine)
cols = {c["name"] for c in inspector.get_columns("upstream_generated_keys")}
assert "managed_prefix" in cols, "managed_prefix column should exist after migration"
with test_engine.connect() as conn:
row = conn.execute(_text("SELECT managed_prefix, key_value FROM upstream_generated_keys LIMIT 1")).fetchone()
assert row[0] == "SmartUp", f"expected SmartUp, got {row[0]}"
assert row[1] == "sk-val"
# 验证唯一索引已创建
indexes = inspector.get_indexes("upstream_generated_keys")
index_names = {ix["name"] for ix in indexes}
assert "uq_upstream_group_managed" in index_names, "partial unique index should exist"
monkeypatch.undo()
def test_create_twice_only_one_record(db_session):
"""同一上游同一分组连续调用两次 ensure,本地只保留一条记录。"""
from app.models.upstream_key import UpstreamGeneratedKey
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
auth_type="bearer", auth_config_json="{}",
groups_endpoint="/groups", rate_endpoint="/rates")
db_session.add(upstream)
db_session.commit()
db_session.refresh(upstream)
# 模拟第一次创建
db_session.add(UpstreamGeneratedKey(
upstream_id=upstream.id, group_id="vip", group_name="VIP",
key_name="SmartUp-Test-VIP", key_value="sk-first",
status="created",
))
db_session.commit()
# 模拟第二次调用 upsert(用同一个 key_name 且 status=exists
existing = db_session.query(UpstreamGeneratedKey).filter(
UpstreamGeneratedKey.upstream_id == upstream.id,
UpstreamGeneratedKey.group_id == "vip",
UpstreamGeneratedKey.key_name == "SmartUp-Test-VIP",
).first()
if existing:
existing.status = "exists"
existing.updated_at = datetime.now(timezone.utc)
else:
db_session.add(UpstreamGeneratedKey(
upstream_id=upstream.id, group_id="vip", group_name="VIP",
key_name="SmartUp-Test-VIP", key_value="sk-second",
status="exists",
))
db_session.commit()
rows = db_session.query(UpstreamGeneratedKey).filter(
UpstreamGeneratedKey.upstream_id == upstream.id,
UpstreamGeneratedKey.group_id == "vip",
).all()
assert len(rows) == 1
assert rows[0].status == "exists"
assert rows[0].key_value == "sk-first" # 更新的是原记录,不是新建
def test_sync_removes_gone_group(db_session):
"""分组不在最新快照中时,本地对应 Key 记录应被删除。"""
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
auth_type="bearer", auth_config_json="{}",
groups_endpoint="/groups", rate_endpoint="/rates")
db_session.add(upstream)
db_session.commit()
db_session.refresh(upstream)
db_session.add_all([
UpstreamGeneratedKey(
upstream_id=upstream.id, group_id="vip", group_name="VIP",
key_name="SmartUp-Test-VIP", key_value="sk-vip",
),
UpstreamGeneratedKey(
upstream_id=upstream.id, group_id="free", group_name="Free",
key_name="SmartUp-Test-Free", key_value="sk-free",
),
])
db_session.commit()
# 快照中只有 vip,没有 free
active_group_ids = {"vip"}
for row in db_session.query(UpstreamGeneratedKey).filter(
UpstreamGeneratedKey.upstream_id == upstream.id).all():
if row.group_id not in active_group_ids:
db_session.delete(row)
db_session.commit()
remaining = db_session.query(UpstreamGeneratedKey).all()
assert len(remaining) == 1
assert remaining[0].group_id == "vip"
def test_sync_removes_deleted_remote_key(db_session):
"""远端 Key 被删除后,本地对应记录应被删除。"""
from app.models.upstream_key import UpstreamGeneratedKey
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
auth_type="bearer", auth_config_json="{}",
groups_endpoint="/groups", rate_endpoint="/rates")
db_session.add(upstream)
db_session.commit()
db_session.refresh(upstream)
db_session.add(UpstreamGeneratedKey(
upstream_id=upstream.id, group_id="vip", group_name="VIP",
key_name="SmartUp-Test-VIP", key_value="sk-vip",
key_id="remote-123",
))
db_session.commit()
# 模拟远端返回的活跃 key_ids 中没有 remote-123
remote_key_ids = {"remote-456", "remote-789"}
for row in db_session.query(UpstreamGeneratedKey).filter(
UpstreamGeneratedKey.upstream_id == upstream.id).all():
if row.key_id and row.key_id not in remote_key_ids:
db_session.delete(row)
db_session.commit()
remaining = db_session.query(UpstreamGeneratedKey).all()
assert len(remaining) == 0