"""Tests for upstream key uniquification and sync cleanup.""" import json from datetime import datetime, timezone import pytest from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import StaticPool from app.database import Base from app.models.upstream import Upstream from app.models.website import Website # noqa: F401 — registers table for FK refs from app.models.upstream_key import UpstreamGeneratedKey @pytest.fixture() def db_session(): engine = create_engine( "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) Base.metadata.create_all(bind=engine) TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) db = TestingSessionLocal() try: yield db finally: db.close() Base.metadata.drop_all(bind=engine) def test_duplicate_cleanup_keeps_latest_only(): """同一 upstream_id + group_id + key_name 的多条记录只保留最新一条。 使用独立 engine + 全 raw SQL,模拟迁移前的数据库状态。 """ engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) with engine.begin() as conn: conn.execute(text(""" CREATE TABLE upstreams ( id INTEGER PRIMARY KEY AUTOINCREMENT, name VARCHAR(255) NOT NULL, base_url VARCHAR(512) NOT NULL, api_prefix VARCHAR(64) DEFAULT '', auth_type VARCHAR(32), auth_config_json TEXT DEFAULT '{}', groups_endpoint VARCHAR(256), rate_endpoint VARCHAR(256), enabled BOOLEAN DEFAULT 1, check_interval_seconds INTEGER DEFAULT 600, timeout_seconds INTEGER DEFAULT 30, last_status VARCHAR(32) DEFAULT 'unknown', last_checked_at DATETIME, last_error TEXT, consecutive_failures INTEGER DEFAULT 0, balance FLOAT, balance_updated_at DATETIME, balance_endpoint VARCHAR(256) DEFAULT '', balance_response_path VARCHAR(256) DEFAULT '', balance_divisor FLOAT DEFAULT 1.0, updated_at DATETIME, created_at DATETIME ) """)) conn.execute(text(""" CREATE TABLE upstream_generated_keys ( id INTEGER PRIMARY KEY AUTOINCREMENT, upstream_id INTEGER NOT NULL, group_id VARCHAR(255) NOT NULL, group_name VARCHAR(255) DEFAULT '', key_id VARCHAR(255), key_name VARCHAR(255) NOT NULL, key_value TEXT NOT NULL, masked_key VARCHAR(255) DEFAULT '', raw_json TEXT DEFAULT '{}', status VARCHAR(32) DEFAULT 'created', error TEXT, imported_website_id INTEGER, imported_account_id VARCHAR(255), imported_at DATETIME, created_at DATETIME, updated_at DATETIME ) """)) TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) db = TestingSessionLocal() try: now = datetime.now(timezone.utc) db.execute(text(""" INSERT INTO upstreams (name, base_url, api_prefix, auth_type, auth_config_json, groups_endpoint, rate_endpoint, created_at, updated_at) VALUES (:n, :b, :p, :a, :j, :g, :r, :ca, :ua) """), {"n": "Test", "b": "http://local", "p": "/api/v1", "a": "bearer", "j": "{}", "g": "/groups", "r": "/rates", "ca": now, "ua": now}) db.commit() uid = db.execute(text("SELECT id FROM upstreams LIMIT 1")).scalar() # 插入 3 条重复记录 for kv, ca in [ ("old-key", datetime(2025, 1, 1, tzinfo=timezone.utc)), ("middle-key", datetime(2025, 6, 1, tzinfo=timezone.utc)), ("newest-key", datetime(2025, 12, 1, tzinfo=timezone.utc)), ]: db.execute(text(""" INSERT INTO upstream_generated_keys (upstream_id, group_id, group_name, key_name, key_value, masked_key, raw_json, status, created_at, updated_at) VALUES (:uid, :gid, :gn, :kn, :kv, :mk, :rj, :st, :ca, :ca) """), { "uid": uid, "gid": "vip", "gn": "VIP", "kn": "SmartUp-Test-VIP", "kv": kv, "mk": "", "rj": "{}", "st": "created", "ca": ca, }) db.commit() # 清理:同一组合只保留最新一条(id 最大) db.execute(text(""" DELETE FROM upstream_generated_keys WHERE id NOT IN ( SELECT MAX(id) FROM upstream_generated_keys GROUP BY upstream_id, group_id, key_name ) """)) db.commit() remaining = db.execute(text("SELECT key_value FROM upstream_generated_keys")).fetchall() assert len(remaining) == 1, f"expected 1 after cleanup, got {len(remaining)}" assert remaining[0][0] == "newest-key" finally: db.close() def test_migration_backfills_managed_prefix_and_deduplicates(): """迁移逻辑应回填历史 SmartUp 记录的 managed_prefix 并清理重复。 使用独立 engine(不创建唯一约束),模拟迁移前状态。 """ from sqlalchemy import text as _text engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) with engine.begin() as conn: conn.execute(_text(""" CREATE TABLE upstreams ( id INTEGER PRIMARY KEY AUTOINCREMENT, name VARCHAR(255) NOT NULL, base_url VARCHAR(512) NOT NULL, api_prefix VARCHAR(64) DEFAULT '', auth_type VARCHAR(32), auth_config_json TEXT DEFAULT '{}', groups_endpoint VARCHAR(256), rate_endpoint VARCHAR(256), enabled BOOLEAN DEFAULT 1, check_interval_seconds INTEGER DEFAULT 600, timeout_seconds INTEGER DEFAULT 30, last_status VARCHAR(32) DEFAULT 'unknown', last_checked_at DATETIME, last_error TEXT, consecutive_failures INTEGER DEFAULT 0, balance FLOAT, balance_updated_at DATETIME, balance_endpoint VARCHAR(256) DEFAULT '', balance_response_path VARCHAR(256) DEFAULT '', balance_divisor FLOAT DEFAULT 1.0, updated_at DATETIME, created_at DATETIME ) """)) conn.execute(_text(""" CREATE TABLE upstream_generated_keys ( id INTEGER PRIMARY KEY AUTOINCREMENT, upstream_id INTEGER NOT NULL, group_id VARCHAR(255) NOT NULL, group_name VARCHAR(255) DEFAULT '', key_id VARCHAR(255), key_name VARCHAR(255) NOT NULL, key_value TEXT NOT NULL, masked_key VARCHAR(255) DEFAULT '', raw_json TEXT DEFAULT '{}', managed_prefix VARCHAR(64), status VARCHAR(32) DEFAULT 'created', error TEXT, imported_website_id INTEGER, imported_account_id VARCHAR(255), imported_at DATETIME, created_at DATETIME, updated_at DATETIME ) """)) TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) db = TestingSessionLocal() try: now = datetime.now(timezone.utc) db.execute(_text(""" INSERT INTO upstreams (name, base_url, api_prefix, auth_type, auth_config_json, groups_endpoint, rate_endpoint, created_at, updated_at) VALUES (:n, :b, :p, :a, :j, :g, :r, :ca, :ua) """), {"n": "Old", "b": "http://local", "p": "/api/v1", "a": "bearer", "j": "{}", "g": "/groups", "r": "/rates", "ca": now, "ua": now}) db.commit() uid = db.execute(_text("SELECT id FROM upstreams LIMIT 1")).scalar() # 插入两条重复记录(无 managed_prefix,key_name 以 SmartUp 开头) for kv in ("sk-old", "sk-new"): db.execute(_text(""" INSERT INTO upstream_generated_keys (upstream_id, group_id, group_name, key_name, key_value, masked_key, raw_json, status, created_at, updated_at) VALUES (:uid, :gid, :gn, :kn, :kv, '', '{}', 'created', :ca, :ca) """), {"uid": uid, "gid": "vip", "gn": "VIP", "kn": "SmartUp-Old-vip", "kv": kv, "ca": now}) db.commit() # 执行迁移逻辑(与 database.py 中的 SQL 一致) conn = db.connection() conn.execute(_text( "UPDATE upstream_generated_keys SET managed_prefix = 'SmartUp' " "WHERE managed_prefix IS NULL AND key_name LIKE 'SmartUp-%'" )) to_delete = conn.execute(_text(""" SELECT id FROM upstream_generated_keys WHERE managed_prefix IS NOT NULL AND id NOT IN ( SELECT MAX(id) FROM upstream_generated_keys WHERE managed_prefix IS NOT NULL GROUP BY upstream_id, group_id, managed_prefix ) """)).fetchall() for (row_id,) in to_delete: conn.execute(_text("DELETE FROM upstream_generated_keys WHERE id = :id"), {"id": row_id}) db.commit() remaining = db.execute(_text("SELECT key_value, managed_prefix FROM upstream_generated_keys")).fetchall() assert len(remaining) == 1, f"expected 1 after migration, got {len(remaining)}" assert remaining[0][0] == "sk-new" # 保留最新一条 assert remaining[0][1] == "SmartUp" # 已回填 finally: db.close() def test_ensure_group_key_reuses_old_record(db_session, monkeypatch): """_ensure_group_key 应复用 managed_prefix IS NULL 的旧记录,不新建。""" from app.routers.upstreams import _ensure_group_key from app.models.upstream_key import UpstreamGeneratedKey from app.services.upstream_client import UpstreamClient from app.schemas.upstream import GenerateKeysByGroupsRequest upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1", auth_type="bearer", auth_config_json="{}", groups_endpoint="/groups", rate_endpoint="/rates") db_session.add(upstream) db_session.commit() db_session.refresh(upstream) # 插入一条旧记录(无 managed_prefix) db_session.add(UpstreamGeneratedKey( upstream_id=upstream.id, group_id="vip", group_name="VIP", key_name="SmartUp-Test-vip", key_value="sk-old", managed_prefix=None, key_id="remote-999", )) db_session.commit() # 构造 mock client class MockClient: def find_smartup_group_key(self, gid, name, prefix): return None def create_api_key(self, name, group_id, **kw): return {"id": "new-remote", "key": "sk-new-value", "masked_key": "sk-****-lue"} group = {"id": "vip", "name": "VIP", "rate_multiplier": 1} body = GenerateKeysByGroupsRequest( group_ids=["vip"], name_prefix="SmartUp", quota=0, endpoint="/keys", ) result = _ensure_group_key(db_session, MockClient(), upstream, group, "SmartUp", body) assert result.status == "created" rows = db_session.query(UpstreamGeneratedKey).filter( UpstreamGeneratedKey.upstream_id == upstream.id, UpstreamGeneratedKey.group_id == "vip", ).all() assert len(rows) == 1, f"expected 1 record, got {len(rows)}" assert rows[0].managed_prefix == "SmartUp" assert rows[0].key_value == "sk-new-value" def test_sync_removes_remote_key_when_list_empty(db_session, monkeypatch): """同步函数在远端返回空列表时应删除本地 key_id 对应的记录。""" from app.services import scheduler as sched_mod from app.models.upstream_key import UpstreamGeneratedKey from app.services.upstream_client import UpstreamClient upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1", auth_type="bearer", auth_config_json="{}", groups_endpoint="/groups", rate_endpoint="/rates") db_session.add(upstream) db_session.commit() db_session.refresh(upstream) db_session.add(UpstreamGeneratedKey( upstream_id=upstream.id, group_id="vip", group_name="VIP", key_name="SmartUp-Test-vip", key_value="sk-vip", managed_prefix="SmartUp", key_id="remote-key-id", )) db_session.commit() # mock list_api_keys 返回空列表(查询成功但无 Key) monkeypatch.setattr(UpstreamClient, "list_api_keys", lambda self, **kw: []) monkeypatch.setattr(UpstreamClient, "login", lambda self: None) monkeypatch.setattr(UpstreamClient, "close", lambda self: None) monkeypatch.setattr(UpstreamClient, "__enter__", lambda self: self) monkeypatch.setattr(UpstreamClient, "__exit__", lambda self, *a: None) # 让 _sync_upstream_keys 使用 db_session 的 bind 引擎 monkeypatch.setattr(sched_mod, "SessionLocal", lambda: db_session) # 阻止 finally 中的 db.close() 影响测试会话 original_close = db_session.close monkeypatch.setattr(db_session, "close", lambda: None) snapshot = { "upstream_id": upstream.id, "groups": {"vip": {"group_id": "vip", "rate": "1"}}, "captured_at": datetime.now(timezone.utc).isoformat(), } captured_at = datetime.now(timezone.utc) sched_mod._sync_upstream_keys(upstream.id, snapshot, captured_at) monkeypatch.setattr(db_session, "close", original_close) remaining = db_session.query(UpstreamGeneratedKey).all() assert len(remaining) == 0, f"expected 0 after sync with empty remote, got {len(remaining)}" def test_migration_function_integration(monkeypatch): """直接调用 _migrate_upstream_generated_keys() 验证列新增和索引创建。""" from app.database import _migrate_upstream_generated_keys, engine as real_engine from sqlalchemy import text as _text # 使用独立 engine,避免影响真实数据库 test_engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) monkeypatch.setattr("app.database.engine", test_engine) # 建表(不含 managed_prefix 列,模拟旧版 schema) with test_engine.begin() as conn: conn.execute(_text(""" CREATE TABLE upstream_generated_keys ( id INTEGER PRIMARY KEY AUTOINCREMENT, upstream_id INTEGER NOT NULL, group_id VARCHAR(255) NOT NULL, group_name VARCHAR(255) DEFAULT '', key_id VARCHAR(255), key_name VARCHAR(255) NOT NULL, key_value TEXT NOT NULL, masked_key VARCHAR(255) DEFAULT '', raw_json TEXT DEFAULT '{}', status VARCHAR(32) DEFAULT 'created', error TEXT, imported_website_id INTEGER, imported_account_id VARCHAR(255), imported_at DATETIME, created_at DATETIME ) """)) conn.execute(_text(""" INSERT INTO upstream_generated_keys (upstream_id, group_id, group_name, key_name, key_value, masked_key, raw_json, status, created_at) VALUES (1, 'vip', 'VIP', 'SmartUp-Old-vip', 'sk-val', '', '{}', 'created', datetime('now')) """)) # 调用迁移函数入口 _migrate_upstream_generated_keys() # 验证 managed_prefix 列已存在且被填充 inspector = __import__('sqlalchemy', fromlist=['']).inspect(test_engine) cols = {c["name"] for c in inspector.get_columns("upstream_generated_keys")} assert "managed_prefix" in cols, "managed_prefix column should exist after migration" with test_engine.connect() as conn: row = conn.execute(_text("SELECT managed_prefix, key_value FROM upstream_generated_keys LIMIT 1")).fetchone() assert row[0] == "SmartUp", f"expected SmartUp, got {row[0]}" assert row[1] == "sk-val" # 验证唯一索引已创建 indexes = inspector.get_indexes("upstream_generated_keys") index_names = {ix["name"] for ix in indexes} assert "uq_upstream_group_managed" in index_names, "partial unique index should exist" monkeypatch.undo() def test_create_twice_only_one_record(db_session): """同一上游同一分组连续调用两次 ensure,本地只保留一条记录。""" from app.models.upstream_key import UpstreamGeneratedKey upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1", auth_type="bearer", auth_config_json="{}", groups_endpoint="/groups", rate_endpoint="/rates") db_session.add(upstream) db_session.commit() db_session.refresh(upstream) # 模拟第一次创建 db_session.add(UpstreamGeneratedKey( upstream_id=upstream.id, group_id="vip", group_name="VIP", key_name="SmartUp-Test-VIP", key_value="sk-first", status="created", )) db_session.commit() # 模拟第二次调用 upsert(用同一个 key_name 且 status=exists) existing = db_session.query(UpstreamGeneratedKey).filter( UpstreamGeneratedKey.upstream_id == upstream.id, UpstreamGeneratedKey.group_id == "vip", UpstreamGeneratedKey.key_name == "SmartUp-Test-VIP", ).first() if existing: existing.status = "exists" existing.updated_at = datetime.now(timezone.utc) else: db_session.add(UpstreamGeneratedKey( upstream_id=upstream.id, group_id="vip", group_name="VIP", key_name="SmartUp-Test-VIP", key_value="sk-second", status="exists", )) db_session.commit() rows = db_session.query(UpstreamGeneratedKey).filter( UpstreamGeneratedKey.upstream_id == upstream.id, UpstreamGeneratedKey.group_id == "vip", ).all() assert len(rows) == 1 assert rows[0].status == "exists" assert rows[0].key_value == "sk-first" # 更新的是原记录,不是新建 def test_sync_removes_gone_group(db_session): """分组不在最新快照中时,本地对应 Key 记录应被删除。""" upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1", auth_type="bearer", auth_config_json="{}", groups_endpoint="/groups", rate_endpoint="/rates") db_session.add(upstream) db_session.commit() db_session.refresh(upstream) db_session.add_all([ UpstreamGeneratedKey( upstream_id=upstream.id, group_id="vip", group_name="VIP", key_name="SmartUp-Test-VIP", key_value="sk-vip", ), UpstreamGeneratedKey( upstream_id=upstream.id, group_id="free", group_name="Free", key_name="SmartUp-Test-Free", key_value="sk-free", ), ]) db_session.commit() # 快照中只有 vip,没有 free active_group_ids = {"vip"} for row in db_session.query(UpstreamGeneratedKey).filter( UpstreamGeneratedKey.upstream_id == upstream.id).all(): if row.group_id not in active_group_ids: db_session.delete(row) db_session.commit() remaining = db_session.query(UpstreamGeneratedKey).all() assert len(remaining) == 1 assert remaining[0].group_id == "vip" def test_sync_removes_deleted_remote_key(db_session): """远端 Key 被删除后,本地对应记录应被删除。""" from app.models.upstream_key import UpstreamGeneratedKey upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1", auth_type="bearer", auth_config_json="{}", groups_endpoint="/groups", rate_endpoint="/rates") db_session.add(upstream) db_session.commit() db_session.refresh(upstream) db_session.add(UpstreamGeneratedKey( upstream_id=upstream.id, group_id="vip", group_name="VIP", key_name="SmartUp-Test-VIP", key_value="sk-vip", key_id="remote-123", )) db_session.commit() # 模拟远端返回的活跃 key_ids 中没有 remote-123 remote_key_ids = {"remote-456", "remote-789"} for row in db_session.query(UpstreamGeneratedKey).filter( UpstreamGeneratedKey.upstream_id == upstream.id).all(): if row.key_id and row.key_id not in remote_key_ids: db_session.delete(row) db_session.commit() remaining = db_session.query(UpstreamGeneratedKey).all() assert len(remaining) == 0