696 lines
28 KiB
Python
696 lines
28 KiB
Python
"""Tests for upstream key uniquification and sync cleanup."""
|
||
import json
|
||
from datetime import datetime, timezone
|
||
|
||
import pytest
|
||
from sqlalchemy import create_engine, text
|
||
from sqlalchemy.orm import sessionmaker
|
||
from sqlalchemy.pool import StaticPool
|
||
|
||
from app.database import Base
|
||
from app.models.upstream import Upstream
|
||
from app.models.website import Website # noqa: F401 — registers table for FK refs
|
||
from app.models.upstream_key import UpstreamGeneratedKey
|
||
|
||
|
||
@pytest.fixture()
|
||
def db_session():
|
||
engine = create_engine(
|
||
"sqlite://",
|
||
connect_args={"check_same_thread": False},
|
||
poolclass=StaticPool,
|
||
)
|
||
Base.metadata.create_all(bind=engine)
|
||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||
db = TestingSessionLocal()
|
||
try:
|
||
yield db
|
||
finally:
|
||
db.close()
|
||
Base.metadata.drop_all(bind=engine)
|
||
|
||
|
||
def test_duplicate_cleanup_keeps_latest_only():
|
||
"""同一 upstream_id + group_id + key_name 的多条记录只保留最新一条。
|
||
|
||
使用独立 engine + 全 raw SQL,模拟迁移前的数据库状态。
|
||
"""
|
||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||
with engine.begin() as conn:
|
||
conn.execute(text("""
|
||
CREATE TABLE upstreams (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
name VARCHAR(255) NOT NULL,
|
||
base_url VARCHAR(512) NOT NULL,
|
||
api_prefix VARCHAR(64) DEFAULT '',
|
||
auth_type VARCHAR(32),
|
||
auth_config_json TEXT DEFAULT '{}',
|
||
groups_endpoint VARCHAR(256),
|
||
rate_endpoint VARCHAR(256),
|
||
enabled BOOLEAN DEFAULT 1,
|
||
check_interval_seconds INTEGER DEFAULT 600,
|
||
timeout_seconds INTEGER DEFAULT 30,
|
||
last_status VARCHAR(32) DEFAULT 'unknown',
|
||
last_checked_at DATETIME,
|
||
last_error TEXT,
|
||
consecutive_failures INTEGER DEFAULT 0,
|
||
balance FLOAT,
|
||
balance_updated_at DATETIME,
|
||
balance_endpoint VARCHAR(256) DEFAULT '',
|
||
balance_response_path VARCHAR(256) DEFAULT '',
|
||
balance_divisor FLOAT DEFAULT 1.0,
|
||
updated_at DATETIME,
|
||
created_at DATETIME
|
||
)
|
||
"""))
|
||
conn.execute(text("""
|
||
CREATE TABLE upstream_generated_keys (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
upstream_id INTEGER NOT NULL,
|
||
group_id VARCHAR(255) NOT NULL,
|
||
group_name VARCHAR(255) DEFAULT '',
|
||
key_id VARCHAR(255),
|
||
key_name VARCHAR(255) NOT NULL,
|
||
key_value TEXT NOT NULL,
|
||
masked_key VARCHAR(255) DEFAULT '',
|
||
raw_json TEXT DEFAULT '{}',
|
||
status VARCHAR(32) DEFAULT 'created',
|
||
error TEXT,
|
||
imported_website_id INTEGER,
|
||
imported_account_id VARCHAR(255),
|
||
imported_at DATETIME,
|
||
created_at DATETIME,
|
||
updated_at DATETIME
|
||
)
|
||
"""))
|
||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||
db = TestingSessionLocal()
|
||
try:
|
||
now = datetime.now(timezone.utc)
|
||
db.execute(text("""
|
||
INSERT INTO upstreams (name, base_url, api_prefix, auth_type, auth_config_json, groups_endpoint, rate_endpoint, created_at, updated_at)
|
||
VALUES (:n, :b, :p, :a, :j, :g, :r, :ca, :ua)
|
||
"""), {"n": "Test", "b": "http://local", "p": "/api/v1", "a": "bearer",
|
||
"j": "{}", "g": "/groups", "r": "/rates", "ca": now, "ua": now})
|
||
db.commit()
|
||
uid = db.execute(text("SELECT id FROM upstreams LIMIT 1")).scalar()
|
||
|
||
# 插入 3 条重复记录
|
||
for kv, ca in [
|
||
("old-key", datetime(2025, 1, 1, tzinfo=timezone.utc)),
|
||
("middle-key", datetime(2025, 6, 1, tzinfo=timezone.utc)),
|
||
("newest-key", datetime(2025, 12, 1, tzinfo=timezone.utc)),
|
||
]:
|
||
db.execute(text("""
|
||
INSERT INTO upstream_generated_keys
|
||
(upstream_id, group_id, group_name, key_name, key_value, masked_key, raw_json, status, created_at, updated_at)
|
||
VALUES (:uid, :gid, :gn, :kn, :kv, :mk, :rj, :st, :ca, :ca)
|
||
"""), {
|
||
"uid": uid, "gid": "vip", "gn": "VIP",
|
||
"kn": "SmartUp-Test-VIP", "kv": kv,
|
||
"mk": "", "rj": "{}", "st": "created", "ca": ca,
|
||
})
|
||
db.commit()
|
||
|
||
# 清理:同一组合只保留最新一条(id 最大)
|
||
db.execute(text("""
|
||
DELETE FROM upstream_generated_keys
|
||
WHERE id NOT IN (
|
||
SELECT MAX(id) FROM upstream_generated_keys
|
||
GROUP BY upstream_id, group_id, key_name
|
||
)
|
||
"""))
|
||
db.commit()
|
||
|
||
remaining = db.execute(text("SELECT key_value FROM upstream_generated_keys")).fetchall()
|
||
assert len(remaining) == 1, f"expected 1 after cleanup, got {len(remaining)}"
|
||
assert remaining[0][0] == "newest-key"
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
def test_migration_backfills_managed_prefix_and_deduplicates():
|
||
"""迁移逻辑应回填历史 SmartUp 记录的 managed_prefix 并清理重复。
|
||
|
||
使用独立 engine(不创建唯一约束),模拟迁移前状态。
|
||
"""
|
||
from sqlalchemy import text as _text
|
||
|
||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||
with engine.begin() as conn:
|
||
conn.execute(_text("""
|
||
CREATE TABLE upstreams (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
name VARCHAR(255) NOT NULL, base_url VARCHAR(512) NOT NULL,
|
||
api_prefix VARCHAR(64) DEFAULT '', auth_type VARCHAR(32),
|
||
auth_config_json TEXT DEFAULT '{}', groups_endpoint VARCHAR(256),
|
||
rate_endpoint VARCHAR(256), enabled BOOLEAN DEFAULT 1,
|
||
check_interval_seconds INTEGER DEFAULT 600, timeout_seconds INTEGER DEFAULT 30,
|
||
last_status VARCHAR(32) DEFAULT 'unknown', last_checked_at DATETIME,
|
||
last_error TEXT, consecutive_failures INTEGER DEFAULT 0,
|
||
balance FLOAT, balance_updated_at DATETIME,
|
||
balance_endpoint VARCHAR(256) DEFAULT '', balance_response_path VARCHAR(256) DEFAULT '',
|
||
balance_divisor FLOAT DEFAULT 1.0, updated_at DATETIME, created_at DATETIME
|
||
)
|
||
"""))
|
||
conn.execute(_text("""
|
||
CREATE TABLE upstream_generated_keys (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
upstream_id INTEGER NOT NULL, group_id VARCHAR(255) NOT NULL,
|
||
group_name VARCHAR(255) DEFAULT '', key_id VARCHAR(255),
|
||
key_name VARCHAR(255) NOT NULL, key_value TEXT NOT NULL,
|
||
masked_key VARCHAR(255) DEFAULT '', raw_json TEXT DEFAULT '{}',
|
||
managed_prefix VARCHAR(64), status VARCHAR(32) DEFAULT 'created',
|
||
error TEXT, imported_website_id INTEGER, imported_account_id VARCHAR(255),
|
||
imported_at DATETIME, created_at DATETIME, updated_at DATETIME
|
||
)
|
||
"""))
|
||
|
||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||
db = TestingSessionLocal()
|
||
try:
|
||
now = datetime.now(timezone.utc)
|
||
db.execute(_text("""
|
||
INSERT INTO upstreams (name, base_url, api_prefix, auth_type, auth_config_json, groups_endpoint, rate_endpoint, created_at, updated_at)
|
||
VALUES (:n, :b, :p, :a, :j, :g, :r, :ca, :ua)
|
||
"""), {"n": "Old", "b": "http://local", "p": "/api/v1", "a": "bearer",
|
||
"j": "{}", "g": "/groups", "r": "/rates", "ca": now, "ua": now})
|
||
db.commit()
|
||
uid = db.execute(_text("SELECT id FROM upstreams LIMIT 1")).scalar()
|
||
|
||
# 插入两条重复记录(无 managed_prefix,key_name 以 SmartUp 开头)
|
||
for kv in ("sk-old", "sk-new"):
|
||
db.execute(_text("""
|
||
INSERT INTO upstream_generated_keys
|
||
(upstream_id, group_id, group_name, key_name, key_value, masked_key, raw_json, status, created_at, updated_at)
|
||
VALUES (:uid, :gid, :gn, :kn, :kv, '', '{}', 'created', :ca, :ca)
|
||
"""), {"uid": uid, "gid": "vip", "gn": "VIP",
|
||
"kn": "SmartUp-Old-vip", "kv": kv, "ca": now})
|
||
db.commit()
|
||
|
||
# 执行迁移逻辑(与 database.py 中的 SQL 一致)
|
||
conn = db.connection()
|
||
conn.execute(_text(
|
||
"UPDATE upstream_generated_keys SET managed_prefix = 'SmartUp' "
|
||
"WHERE managed_prefix IS NULL AND key_name LIKE 'SmartUp-%'"
|
||
))
|
||
to_delete = conn.execute(_text("""
|
||
SELECT id FROM upstream_generated_keys
|
||
WHERE managed_prefix IS NOT NULL
|
||
AND id NOT IN (
|
||
SELECT MAX(id) FROM upstream_generated_keys
|
||
WHERE managed_prefix IS NOT NULL
|
||
GROUP BY upstream_id, group_id, managed_prefix
|
||
)
|
||
""")).fetchall()
|
||
for (row_id,) in to_delete:
|
||
conn.execute(_text("DELETE FROM upstream_generated_keys WHERE id = :id"), {"id": row_id})
|
||
db.commit()
|
||
|
||
remaining = db.execute(_text("SELECT key_value, managed_prefix FROM upstream_generated_keys")).fetchall()
|
||
assert len(remaining) == 1, f"expected 1 after migration, got {len(remaining)}"
|
||
assert remaining[0][0] == "sk-new" # 保留最新一条
|
||
assert remaining[0][1] == "SmartUp" # 已回填
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
def test_ensure_group_key_reuses_old_record(db_session, monkeypatch):
|
||
"""_ensure_group_key 应复用 managed_prefix IS NULL 的旧记录,不新建。"""
|
||
from app.routers.upstreams import _ensure_group_key
|
||
from app.models.upstream_key import UpstreamGeneratedKey
|
||
from app.services.upstream_client import UpstreamClient
|
||
from app.schemas.upstream import GenerateKeysByGroupsRequest
|
||
|
||
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
|
||
auth_type="bearer", auth_config_json="{}",
|
||
groups_endpoint="/groups", rate_endpoint="/rates")
|
||
db_session.add(upstream)
|
||
db_session.commit()
|
||
db_session.refresh(upstream)
|
||
|
||
# 插入一条旧记录(无 managed_prefix)
|
||
db_session.add(UpstreamGeneratedKey(
|
||
upstream_id=upstream.id, group_id="vip", group_name="VIP",
|
||
key_name="SmartUp-Test-vip", key_value="sk-old",
|
||
managed_prefix=None, key_id="remote-999",
|
||
))
|
||
db_session.commit()
|
||
|
||
# 构造 mock client
|
||
class MockClient:
|
||
def find_smartup_group_key(self, gid, name, prefix):
|
||
return None
|
||
def create_api_key(self, name, group_id, **kw):
|
||
return {"id": "new-remote", "key": "sk-new-value", "masked_key": "sk-****-lue"}
|
||
|
||
group = {"id": "vip", "name": "VIP", "rate_multiplier": 1}
|
||
body = GenerateKeysByGroupsRequest(
|
||
group_ids=["vip"], name_prefix="SmartUp",
|
||
quota=0, endpoint="/keys",
|
||
)
|
||
result = _ensure_group_key(db_session, MockClient(), upstream, group, "SmartUp", body)
|
||
|
||
assert result.status == "created"
|
||
|
||
rows = db_session.query(UpstreamGeneratedKey).filter(
|
||
UpstreamGeneratedKey.upstream_id == upstream.id,
|
||
UpstreamGeneratedKey.group_id == "vip",
|
||
).all()
|
||
assert len(rows) == 1, f"expected 1 record, got {len(rows)}"
|
||
assert rows[0].managed_prefix == "SmartUp"
|
||
assert rows[0].key_value == "sk-new-value"
|
||
|
||
|
||
def test_ensure_group_key_backfills_plaintext_from_remote_existing_key(db_session):
|
||
"""远端已存在的 SmartUp Key 如果列表接口返回明文,应补写到本地 key_value。"""
|
||
from app.routers.upstreams import _ensure_group_key
|
||
from app.models.upstream_key import UpstreamGeneratedKey
|
||
from app.schemas.upstream import GenerateKeysByGroupsRequest
|
||
from app.services.upstream_client import mask_secret
|
||
|
||
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
|
||
auth_type="bearer", auth_config_json="{}",
|
||
groups_endpoint="/groups", rate_endpoint="/rates")
|
||
db_session.add(upstream)
|
||
db_session.commit()
|
||
db_session.refresh(upstream)
|
||
|
||
db_session.add(UpstreamGeneratedKey(
|
||
upstream_id=upstream.id,
|
||
group_id="vip",
|
||
group_name="VIP",
|
||
key_name="SmartUp-Test-vip",
|
||
key_value="",
|
||
masked_key="sk-old-masked",
|
||
key_id="remote-123",
|
||
managed_prefix="SmartUp",
|
||
))
|
||
db_session.commit()
|
||
|
||
class MockClient:
|
||
def find_smartup_group_key(self, gid, name, prefix):
|
||
return {
|
||
"id": "remote-123",
|
||
"name": "SmartUp-Test-vip",
|
||
"key": "sk-remote-plain-value-1234567890abcdef",
|
||
"masked_key": "sk-re************cdef",
|
||
}
|
||
|
||
def create_api_key(self, *args, **kwargs):
|
||
raise AssertionError("create_api_key should not be called when remote key exists")
|
||
|
||
group = {"id": "vip", "name": "VIP", "rate_multiplier": 1}
|
||
body = GenerateKeysByGroupsRequest(
|
||
group_ids=["vip"],
|
||
name_prefix="SmartUp",
|
||
quota=0,
|
||
endpoint="/keys",
|
||
)
|
||
result = _ensure_group_key(db_session, MockClient(), upstream, group, "SmartUp", body)
|
||
|
||
assert result.status == "exists"
|
||
assert result.has_key_value is True
|
||
row = db_session.query(UpstreamGeneratedKey).filter(
|
||
UpstreamGeneratedKey.upstream_id == upstream.id,
|
||
UpstreamGeneratedKey.group_id == "vip",
|
||
).one()
|
||
assert row.key_value == "sk-remote-plain-value-1234567890abcdef"
|
||
assert row.masked_key == mask_secret(row.key_value)
|
||
assert row.status == "exists"
|
||
|
||
|
||
def test_sync_removes_remote_key_when_list_empty(db_session, monkeypatch):
|
||
"""同步函数在远端返回空列表时应删除本地 key_id 对应的记录。"""
|
||
from app.services import scheduler as sched_mod
|
||
from app.models.upstream_key import UpstreamGeneratedKey
|
||
from app.services.upstream_client import UpstreamClient
|
||
|
||
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
|
||
auth_type="bearer", auth_config_json="{}",
|
||
groups_endpoint="/groups", rate_endpoint="/rates")
|
||
db_session.add(upstream)
|
||
db_session.commit()
|
||
db_session.refresh(upstream)
|
||
|
||
db_session.add(UpstreamGeneratedKey(
|
||
upstream_id=upstream.id, group_id="vip", group_name="VIP",
|
||
key_name="SmartUp-Test-vip", key_value="sk-vip",
|
||
managed_prefix="SmartUp", key_id="remote-key-id",
|
||
))
|
||
db_session.commit()
|
||
|
||
# mock list_api_keys 返回空列表(查询成功但无 Key)
|
||
monkeypatch.setattr(UpstreamClient, "list_api_keys", lambda self, **kw: [])
|
||
monkeypatch.setattr(UpstreamClient, "login", lambda self: None)
|
||
monkeypatch.setattr(UpstreamClient, "close", lambda self: None)
|
||
monkeypatch.setattr(UpstreamClient, "__enter__", lambda self: self)
|
||
monkeypatch.setattr(UpstreamClient, "__exit__", lambda self, *a: None)
|
||
|
||
# 让 _sync_upstream_keys 使用 db_session 的 bind 引擎
|
||
monkeypatch.setattr(sched_mod, "SessionLocal",
|
||
lambda: db_session)
|
||
# 阻止 finally 中的 db.close() 影响测试会话
|
||
original_close = db_session.close
|
||
monkeypatch.setattr(db_session, "close", lambda: None)
|
||
|
||
snapshot = {
|
||
"upstream_id": upstream.id,
|
||
"groups": {"vip": {"group_id": "vip", "rate": "1"}},
|
||
"captured_at": datetime.now(timezone.utc).isoformat(),
|
||
}
|
||
|
||
captured_at = datetime.now(timezone.utc)
|
||
sched_mod._sync_upstream_keys(upstream.id, snapshot, captured_at)
|
||
|
||
monkeypatch.setattr(db_session, "close", original_close)
|
||
remaining = db_session.query(UpstreamGeneratedKey).all()
|
||
assert len(remaining) == 0, f"expected 0 after sync with empty remote, got {len(remaining)}"
|
||
|
||
|
||
def test_sync_marks_imported_key_orphaned_when_remote_key_missing(db_session, monkeypatch):
|
||
"""已导入账号管理的 Key 远端消失时保留本地行,避免丢失目标账号关联。"""
|
||
from app.services import scheduler as sched_mod
|
||
from app.models.upstream_key import UpstreamGeneratedKey
|
||
from app.services.upstream_client import UpstreamClient
|
||
|
||
website = Website(
|
||
name="Target",
|
||
site_type="sub2api",
|
||
base_url="http://target.local",
|
||
api_prefix="/api/v1/admin",
|
||
auth_type="api_key",
|
||
auth_config_json="{}",
|
||
groups_endpoint="/groups",
|
||
group_update_endpoint="/groups/{id}",
|
||
)
|
||
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
|
||
auth_type="bearer", auth_config_json="{}",
|
||
groups_endpoint="/groups", rate_endpoint="/rates")
|
||
db_session.add_all([website, upstream])
|
||
db_session.commit()
|
||
db_session.refresh(website)
|
||
db_session.refresh(upstream)
|
||
|
||
db_session.add(UpstreamGeneratedKey(
|
||
upstream_id=upstream.id,
|
||
group_id="vip",
|
||
group_name="VIP",
|
||
key_name="SmartUp-Test-vip",
|
||
key_value="sk-vip",
|
||
managed_prefix="SmartUp",
|
||
key_id="remote-key-id",
|
||
imported_website_id=website.id,
|
||
imported_account_id="account-101",
|
||
))
|
||
db_session.commit()
|
||
|
||
monkeypatch.setattr(UpstreamClient, "list_api_keys", lambda self, **kw: [])
|
||
monkeypatch.setattr(UpstreamClient, "login", lambda self: None)
|
||
monkeypatch.setattr(UpstreamClient, "close", lambda self: None)
|
||
monkeypatch.setattr(UpstreamClient, "__enter__", lambda self: self)
|
||
monkeypatch.setattr(UpstreamClient, "__exit__", lambda self, *a: None)
|
||
monkeypatch.setattr(sched_mod, "SessionLocal", lambda: db_session)
|
||
original_close = db_session.close
|
||
monkeypatch.setattr(db_session, "close", lambda: None)
|
||
|
||
snapshot = {
|
||
"upstream_id": upstream.id,
|
||
"groups": {"vip": {"group_id": "vip", "rate": "1"}},
|
||
"captured_at": datetime.now(timezone.utc).isoformat(),
|
||
}
|
||
|
||
captured_at = datetime.now(timezone.utc)
|
||
sched_mod._sync_upstream_keys(upstream.id, snapshot, captured_at)
|
||
|
||
monkeypatch.setattr(db_session, "close", original_close)
|
||
remaining = db_session.query(UpstreamGeneratedKey).all()
|
||
assert len(remaining) == 1
|
||
row = remaining[0]
|
||
assert row.status == "orphaned"
|
||
assert row.imported_website_id == website.id
|
||
assert row.imported_account_id == "account-101"
|
||
assert row.error == "远端 Key 已不存在"
|
||
|
||
|
||
def test_migration_function_integration(monkeypatch):
|
||
"""直接调用 _migrate_upstream_generated_keys() 验证列新增和索引创建。"""
|
||
from app.database import _migrate_upstream_generated_keys, engine as real_engine
|
||
from sqlalchemy import text as _text
|
||
|
||
# 使用独立 engine,避免影响真实数据库
|
||
test_engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||
monkeypatch.setattr("app.database.engine", test_engine)
|
||
|
||
# 建表(不含 managed_prefix 列,模拟旧版 schema)
|
||
with test_engine.begin() as conn:
|
||
conn.execute(_text("""
|
||
CREATE TABLE upstream_generated_keys (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
upstream_id INTEGER NOT NULL, group_id VARCHAR(255) NOT NULL,
|
||
group_name VARCHAR(255) DEFAULT '', key_id VARCHAR(255),
|
||
key_name VARCHAR(255) NOT NULL, key_value TEXT NOT NULL,
|
||
masked_key VARCHAR(255) DEFAULT '', raw_json TEXT DEFAULT '{}',
|
||
status VARCHAR(32) DEFAULT 'created', error TEXT,
|
||
imported_website_id INTEGER, imported_account_id VARCHAR(255),
|
||
imported_at DATETIME, created_at DATETIME
|
||
)
|
||
"""))
|
||
conn.execute(_text("""
|
||
INSERT INTO upstream_generated_keys
|
||
(upstream_id, group_id, group_name, key_name, key_value, masked_key, raw_json, status, created_at)
|
||
VALUES (1, 'vip', 'VIP', 'SmartUp-Old-vip', 'sk-val', '', '{}', 'created', datetime('now'))
|
||
"""))
|
||
|
||
# 调用迁移函数入口
|
||
_migrate_upstream_generated_keys()
|
||
|
||
# 验证 managed_prefix 列已存在且被填充
|
||
inspector = __import__('sqlalchemy', fromlist=['']).inspect(test_engine)
|
||
cols = {c["name"] for c in inspector.get_columns("upstream_generated_keys")}
|
||
assert "managed_prefix" in cols, "managed_prefix column should exist after migration"
|
||
|
||
with test_engine.connect() as conn:
|
||
row = conn.execute(_text("SELECT managed_prefix, key_value FROM upstream_generated_keys LIMIT 1")).fetchone()
|
||
assert row[0] == "SmartUp", f"expected SmartUp, got {row[0]}"
|
||
assert row[1] == "sk-val"
|
||
|
||
# 验证唯一索引已创建
|
||
indexes = inspector.get_indexes("upstream_generated_keys")
|
||
index_names = {ix["name"] for ix in indexes}
|
||
assert "uq_upstream_group_managed" in index_names, "partial unique index should exist"
|
||
|
||
monkeypatch.undo()
|
||
|
||
|
||
def test_create_twice_only_one_record(db_session):
|
||
"""同一上游同一分组连续调用两次 ensure,本地只保留一条记录。"""
|
||
from app.models.upstream_key import UpstreamGeneratedKey
|
||
|
||
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
|
||
auth_type="bearer", auth_config_json="{}",
|
||
groups_endpoint="/groups", rate_endpoint="/rates")
|
||
db_session.add(upstream)
|
||
db_session.commit()
|
||
db_session.refresh(upstream)
|
||
|
||
# 模拟第一次创建
|
||
db_session.add(UpstreamGeneratedKey(
|
||
upstream_id=upstream.id, group_id="vip", group_name="VIP",
|
||
key_name="SmartUp-Test-VIP", key_value="sk-first",
|
||
status="created",
|
||
))
|
||
db_session.commit()
|
||
|
||
# 模拟第二次调用 upsert(用同一个 key_name 且 status=exists)
|
||
existing = db_session.query(UpstreamGeneratedKey).filter(
|
||
UpstreamGeneratedKey.upstream_id == upstream.id,
|
||
UpstreamGeneratedKey.group_id == "vip",
|
||
UpstreamGeneratedKey.key_name == "SmartUp-Test-VIP",
|
||
).first()
|
||
if existing:
|
||
existing.status = "exists"
|
||
existing.updated_at = datetime.now(timezone.utc)
|
||
else:
|
||
db_session.add(UpstreamGeneratedKey(
|
||
upstream_id=upstream.id, group_id="vip", group_name="VIP",
|
||
key_name="SmartUp-Test-VIP", key_value="sk-second",
|
||
status="exists",
|
||
))
|
||
db_session.commit()
|
||
|
||
rows = db_session.query(UpstreamGeneratedKey).filter(
|
||
UpstreamGeneratedKey.upstream_id == upstream.id,
|
||
UpstreamGeneratedKey.group_id == "vip",
|
||
).all()
|
||
assert len(rows) == 1
|
||
assert rows[0].status == "exists"
|
||
assert rows[0].key_value == "sk-first" # 更新的是原记录,不是新建
|
||
|
||
|
||
def test_sync_removes_gone_group(db_session):
|
||
"""分组不在最新快照中时,本地对应 Key 记录应被删除。"""
|
||
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
|
||
auth_type="bearer", auth_config_json="{}",
|
||
groups_endpoint="/groups", rate_endpoint="/rates")
|
||
db_session.add(upstream)
|
||
db_session.commit()
|
||
db_session.refresh(upstream)
|
||
|
||
db_session.add_all([
|
||
UpstreamGeneratedKey(
|
||
upstream_id=upstream.id, group_id="vip", group_name="VIP",
|
||
key_name="SmartUp-Test-VIP", key_value="sk-vip",
|
||
),
|
||
UpstreamGeneratedKey(
|
||
upstream_id=upstream.id, group_id="free", group_name="Free",
|
||
key_name="SmartUp-Test-Free", key_value="sk-free",
|
||
),
|
||
])
|
||
db_session.commit()
|
||
|
||
# 快照中只有 vip,没有 free
|
||
active_group_ids = {"vip"}
|
||
for row in db_session.query(UpstreamGeneratedKey).filter(
|
||
UpstreamGeneratedKey.upstream_id == upstream.id).all():
|
||
if row.group_id not in active_group_ids:
|
||
db_session.delete(row)
|
||
db_session.commit()
|
||
|
||
remaining = db_session.query(UpstreamGeneratedKey).all()
|
||
assert len(remaining) == 1
|
||
assert remaining[0].group_id == "vip"
|
||
|
||
|
||
def test_sync_removes_deleted_remote_key(db_session):
|
||
"""远端 Key 被删除后,本地对应记录应被删除。"""
|
||
from app.models.upstream_key import UpstreamGeneratedKey
|
||
|
||
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
|
||
auth_type="bearer", auth_config_json="{}",
|
||
groups_endpoint="/groups", rate_endpoint="/rates")
|
||
db_session.add(upstream)
|
||
db_session.commit()
|
||
db_session.refresh(upstream)
|
||
|
||
db_session.add(UpstreamGeneratedKey(
|
||
upstream_id=upstream.id, group_id="vip", group_name="VIP",
|
||
key_name="SmartUp-Test-VIP", key_value="sk-vip",
|
||
key_id="remote-123",
|
||
))
|
||
db_session.commit()
|
||
|
||
# 模拟远端返回的活跃 key_ids 中没有 remote-123
|
||
remote_key_ids = {"remote-456", "remote-789"}
|
||
for row in db_session.query(UpstreamGeneratedKey).filter(
|
||
UpstreamGeneratedKey.upstream_id == upstream.id).all():
|
||
if row.key_id and row.key_id not in remote_key_ids:
|
||
db_session.delete(row)
|
||
db_session.commit()
|
||
|
||
remaining = db_session.query(UpstreamGeneratedKey).all()
|
||
assert len(remaining) == 0
|
||
|
||
|
||
def test_new_api_create_token_fetches_plaintext_key(monkeypatch):
|
||
"""New-API 创建 token 后需按 id 再取一次明文 key。"""
|
||
from app.services.upstream_client import UpstreamClient
|
||
|
||
client = UpstreamClient(
|
||
base_url="http://newapi.local",
|
||
api_prefix="",
|
||
auth_type="cookie",
|
||
auth_config={"cookie_string": "session=abc", "new_api_user": "7"},
|
||
)
|
||
created_bodies = []
|
||
|
||
def fake_request(method, path, body=None, auth=True):
|
||
if method == "GET" and path == "/api/status":
|
||
return {"success": True, "data": {"quota_per_unit": 500000}}
|
||
if method == "POST" and path == "/api/token/":
|
||
created_bodies.append(body)
|
||
return {"success": True, "message": ""}
|
||
if method == "POST" and path == "/api/token/123/key":
|
||
return {"success": True, "data": {"key": "new-api-plain-key"}}
|
||
raise AssertionError(f"unexpected request {method} {path}")
|
||
|
||
monkeypatch.setattr(client, "_request", fake_request)
|
||
monkeypatch.setattr(
|
||
client,
|
||
"_list_new_api_tokens",
|
||
lambda search="", group_id=None: [{"id": 123, "name": search, "group": group_id, "key": "new-****-key"}],
|
||
)
|
||
|
||
result = client.create_api_key(
|
||
"SmartUp-1-VIP-vip",
|
||
"vip",
|
||
quota=2,
|
||
expires_in_days=3,
|
||
endpoint="/api/token",
|
||
)
|
||
|
||
assert result["id"] == "123"
|
||
assert result["key"] == "new-api-plain-key"
|
||
assert created_bodies[0]["group"] == "vip"
|
||
assert created_bodies[0]["remain_quota"] == 1000000
|
||
assert created_bodies[0]["unlimited_quota"] is False
|
||
assert created_bodies[0]["expired_time"] > 0
|
||
|
||
|
||
def test_generate_keys_allows_new_api_user_upstream(db_session, monkeypatch):
|
||
"""New-API 普通账号上游应允许按分组生成 token。"""
|
||
from app.routers import upstreams as upstreams_router
|
||
from app.schemas.upstream import GenerateKeysByGroupsRequest
|
||
|
||
upstream = Upstream(
|
||
name="NewAPI",
|
||
base_url="http://newapi.local",
|
||
api_prefix="",
|
||
auth_type="cookie",
|
||
auth_config_json=json.dumps({"cookie_string": "session=abc", "new_api_user": "7"}),
|
||
groups_endpoint="/api/user/self/groups",
|
||
rate_endpoint="/api/user/self/groups",
|
||
)
|
||
db_session.add(upstream)
|
||
db_session.commit()
|
||
db_session.refresh(upstream)
|
||
|
||
monkeypatch.setattr(upstreams_router.website_sync, "reconcile_upstream_keys_full", lambda db, uid: True)
|
||
|
||
class FakeClient:
|
||
def __init__(self, **kwargs):
|
||
self.kwargs = kwargs
|
||
|
||
def __enter__(self):
|
||
return self
|
||
|
||
def __exit__(self, *args):
|
||
return None
|
||
|
||
def login(self):
|
||
return None
|
||
|
||
def get_available_groups(self, endpoint):
|
||
assert endpoint == "/api/user/self/groups"
|
||
return [{"id": "vip", "name": "VIP"}]
|
||
|
||
def find_smartup_group_key(self, gid, expected_name, prefix):
|
||
return None
|
||
|
||
def create_api_key(self, name, group_id, **kwargs):
|
||
assert kwargs["endpoint"] == "/api/token"
|
||
return {"id": "123", "key": "new-api-plain-key", "masked_key": "new-****-key", "raw": {"id": 123}}
|
||
|
||
monkeypatch.setattr(upstreams_router, "UpstreamClient", FakeClient)
|
||
|
||
response = upstreams_router.generate_keys_by_groups(
|
||
upstream.id,
|
||
GenerateKeysByGroupsRequest(group_ids=["vip"], endpoint="/api/token"),
|
||
db_session,
|
||
object(),
|
||
)
|
||
|
||
assert response.success is True
|
||
assert response.items[0].status == "created"
|
||
assert response.items[0].key_value == "new-api-plain-key"
|