Files
SmartUp/backend/test_upstream_key_sync.py
T
2026-06-03 17:03:11 +08:00

826 lines
34 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Tests for upstream key uniquification and sync cleanup."""
import json
from datetime import datetime, timezone
import pytest
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.models.upstream import Upstream
from app.models.website import Website # noqa: F401 — registers table for FK refs
from app.models.upstream_key import UpstreamGeneratedKey
@pytest.fixture()
def db_session():
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
Base.metadata.drop_all(bind=engine)
def test_duplicate_cleanup_keeps_latest_only():
"""同一 upstream_id + group_id + key_name 的多条记录只保留最新一条。
使用独立 engine + 全 raw SQL,模拟迁移前的数据库状态。
"""
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
with engine.begin() as conn:
conn.execute(text("""
CREATE TABLE upstreams (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name VARCHAR(255) NOT NULL,
base_url VARCHAR(512) NOT NULL,
api_prefix VARCHAR(64) DEFAULT '',
auth_type VARCHAR(32),
auth_config_json TEXT DEFAULT '{}',
groups_endpoint VARCHAR(256),
rate_endpoint VARCHAR(256),
enabled BOOLEAN DEFAULT 1,
check_interval_seconds INTEGER DEFAULT 600,
timeout_seconds INTEGER DEFAULT 30,
last_status VARCHAR(32) DEFAULT 'unknown',
last_checked_at DATETIME,
last_error TEXT,
consecutive_failures INTEGER DEFAULT 0,
balance FLOAT,
balance_updated_at DATETIME,
balance_endpoint VARCHAR(256) DEFAULT '',
balance_response_path VARCHAR(256) DEFAULT '',
balance_divisor FLOAT DEFAULT 1.0,
updated_at DATETIME,
created_at DATETIME
)
"""))
conn.execute(text("""
CREATE TABLE upstream_generated_keys (
id INTEGER PRIMARY KEY AUTOINCREMENT,
upstream_id INTEGER NOT NULL,
group_id VARCHAR(255) NOT NULL,
group_name VARCHAR(255) DEFAULT '',
key_id VARCHAR(255),
key_name VARCHAR(255) NOT NULL,
key_value TEXT NOT NULL,
masked_key VARCHAR(255) DEFAULT '',
raw_json TEXT DEFAULT '{}',
status VARCHAR(32) DEFAULT 'created',
error TEXT,
imported_website_id INTEGER,
imported_account_id VARCHAR(255),
imported_at DATETIME,
created_at DATETIME,
updated_at DATETIME
)
"""))
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = TestingSessionLocal()
try:
now = datetime.now(timezone.utc)
db.execute(text("""
INSERT INTO upstreams (name, base_url, api_prefix, auth_type, auth_config_json, groups_endpoint, rate_endpoint, created_at, updated_at)
VALUES (:n, :b, :p, :a, :j, :g, :r, :ca, :ua)
"""), {"n": "Test", "b": "http://local", "p": "/api/v1", "a": "bearer",
"j": "{}", "g": "/groups", "r": "/rates", "ca": now, "ua": now})
db.commit()
uid = db.execute(text("SELECT id FROM upstreams LIMIT 1")).scalar()
# 插入 3 条重复记录
for kv, ca in [
("old-key", datetime(2025, 1, 1, tzinfo=timezone.utc)),
("middle-key", datetime(2025, 6, 1, tzinfo=timezone.utc)),
("newest-key", datetime(2025, 12, 1, tzinfo=timezone.utc)),
]:
db.execute(text("""
INSERT INTO upstream_generated_keys
(upstream_id, group_id, group_name, key_name, key_value, masked_key, raw_json, status, created_at, updated_at)
VALUES (:uid, :gid, :gn, :kn, :kv, :mk, :rj, :st, :ca, :ca)
"""), {
"uid": uid, "gid": "vip", "gn": "VIP",
"kn": "SmartUp-Test-VIP", "kv": kv,
"mk": "", "rj": "{}", "st": "created", "ca": ca,
})
db.commit()
# 清理:同一组合只保留最新一条(id 最大)
db.execute(text("""
DELETE FROM upstream_generated_keys
WHERE id NOT IN (
SELECT MAX(id) FROM upstream_generated_keys
GROUP BY upstream_id, group_id, key_name
)
"""))
db.commit()
remaining = db.execute(text("SELECT key_value FROM upstream_generated_keys")).fetchall()
assert len(remaining) == 1, f"expected 1 after cleanup, got {len(remaining)}"
assert remaining[0][0] == "newest-key"
finally:
db.close()
def test_migration_backfills_managed_prefix_and_deduplicates():
"""迁移逻辑应回填历史 SmartUp 记录的 managed_prefix 并清理重复。
使用独立 engine(不创建唯一约束),模拟迁移前状态。
"""
from sqlalchemy import text as _text
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
with engine.begin() as conn:
conn.execute(_text("""
CREATE TABLE upstreams (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name VARCHAR(255) NOT NULL, base_url VARCHAR(512) NOT NULL,
api_prefix VARCHAR(64) DEFAULT '', auth_type VARCHAR(32),
auth_config_json TEXT DEFAULT '{}', groups_endpoint VARCHAR(256),
rate_endpoint VARCHAR(256), enabled BOOLEAN DEFAULT 1,
check_interval_seconds INTEGER DEFAULT 600, timeout_seconds INTEGER DEFAULT 30,
last_status VARCHAR(32) DEFAULT 'unknown', last_checked_at DATETIME,
last_error TEXT, consecutive_failures INTEGER DEFAULT 0,
balance FLOAT, balance_updated_at DATETIME,
balance_endpoint VARCHAR(256) DEFAULT '', balance_response_path VARCHAR(256) DEFAULT '',
balance_divisor FLOAT DEFAULT 1.0, updated_at DATETIME, created_at DATETIME
)
"""))
conn.execute(_text("""
CREATE TABLE upstream_generated_keys (
id INTEGER PRIMARY KEY AUTOINCREMENT,
upstream_id INTEGER NOT NULL, group_id VARCHAR(255) NOT NULL,
group_name VARCHAR(255) DEFAULT '', key_id VARCHAR(255),
key_name VARCHAR(255) NOT NULL, key_value TEXT NOT NULL,
masked_key VARCHAR(255) DEFAULT '', raw_json TEXT DEFAULT '{}',
managed_prefix VARCHAR(64), status VARCHAR(32) DEFAULT 'created',
error TEXT, imported_website_id INTEGER, imported_account_id VARCHAR(255),
imported_at DATETIME, created_at DATETIME, updated_at DATETIME
)
"""))
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = TestingSessionLocal()
try:
now = datetime.now(timezone.utc)
db.execute(_text("""
INSERT INTO upstreams (name, base_url, api_prefix, auth_type, auth_config_json, groups_endpoint, rate_endpoint, created_at, updated_at)
VALUES (:n, :b, :p, :a, :j, :g, :r, :ca, :ua)
"""), {"n": "Old", "b": "http://local", "p": "/api/v1", "a": "bearer",
"j": "{}", "g": "/groups", "r": "/rates", "ca": now, "ua": now})
db.commit()
uid = db.execute(_text("SELECT id FROM upstreams LIMIT 1")).scalar()
# 插入两条重复记录(无 managed_prefixkey_name 以 SmartUp 开头)
for kv in ("sk-old", "sk-new"):
db.execute(_text("""
INSERT INTO upstream_generated_keys
(upstream_id, group_id, group_name, key_name, key_value, masked_key, raw_json, status, created_at, updated_at)
VALUES (:uid, :gid, :gn, :kn, :kv, '', '{}', 'created', :ca, :ca)
"""), {"uid": uid, "gid": "vip", "gn": "VIP",
"kn": "SmartUp-Old-vip", "kv": kv, "ca": now})
db.commit()
# 执行迁移逻辑(与 database.py 中的 SQL 一致)
conn = db.connection()
conn.execute(_text(
"UPDATE upstream_generated_keys SET managed_prefix = 'SmartUp' "
"WHERE managed_prefix IS NULL AND key_name LIKE 'SmartUp-%'"
))
to_delete = conn.execute(_text("""
SELECT id FROM upstream_generated_keys
WHERE managed_prefix IS NOT NULL
AND id NOT IN (
SELECT MAX(id) FROM upstream_generated_keys
WHERE managed_prefix IS NOT NULL
GROUP BY upstream_id, group_id, managed_prefix
)
""")).fetchall()
for (row_id,) in to_delete:
conn.execute(_text("DELETE FROM upstream_generated_keys WHERE id = :id"), {"id": row_id})
db.commit()
remaining = db.execute(_text("SELECT key_value, managed_prefix FROM upstream_generated_keys")).fetchall()
assert len(remaining) == 1, f"expected 1 after migration, got {len(remaining)}"
assert remaining[0][0] == "sk-new" # 保留最新一条
assert remaining[0][1] == "SmartUp" # 已回填
finally:
db.close()
def test_ensure_group_key_reuses_old_record(db_session, monkeypatch):
"""_ensure_group_key 应复用 managed_prefix IS NULL 的旧记录,不新建。"""
from app.routers.upstreams import _ensure_group_key
from app.models.upstream_key import UpstreamGeneratedKey
from app.services.upstream_client import UpstreamClient
from app.schemas.upstream import GenerateKeysByGroupsRequest
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
auth_type="bearer", auth_config_json="{}",
groups_endpoint="/groups", rate_endpoint="/rates")
db_session.add(upstream)
db_session.commit()
db_session.refresh(upstream)
# 插入一条旧记录(无 managed_prefix
db_session.add(UpstreamGeneratedKey(
upstream_id=upstream.id, group_id="vip", group_name="VIP",
key_name="SmartUp-Test-vip", key_value="sk-old",
managed_prefix=None, key_id="remote-999",
))
db_session.commit()
# 构造 mock client
class MockClient:
def find_smartup_group_key(self, gid, name, prefix):
return None
def create_api_key(self, name, group_id, **kw):
return {"id": "new-remote", "key": "sk-new-value", "masked_key": "sk-****-lue"}
group = {"id": "vip", "name": "VIP", "rate_multiplier": 1}
body = GenerateKeysByGroupsRequest(
group_ids=["vip"], name_prefix="SmartUp",
quota=0, endpoint="/keys",
)
result = _ensure_group_key(db_session, MockClient(), upstream, group, "SmartUp", body)
assert result.status == "created"
rows = db_session.query(UpstreamGeneratedKey).filter(
UpstreamGeneratedKey.upstream_id == upstream.id,
UpstreamGeneratedKey.group_id == "vip",
).all()
assert len(rows) == 1, f"expected 1 record, got {len(rows)}"
assert rows[0].managed_prefix == "SmartUp"
assert rows[0].key_value == "sk-new-value"
def test_ensure_group_key_backfills_plaintext_from_remote_existing_key(db_session):
"""远端已存在的 SmartUp Key 如果列表接口返回明文,应补写到本地 key_value。"""
from app.routers.upstreams import _ensure_group_key
from app.models.upstream_key import UpstreamGeneratedKey
from app.schemas.upstream import GenerateKeysByGroupsRequest
from app.services.upstream_client import mask_secret
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
auth_type="bearer", auth_config_json="{}",
groups_endpoint="/groups", rate_endpoint="/rates")
db_session.add(upstream)
db_session.commit()
db_session.refresh(upstream)
db_session.add(UpstreamGeneratedKey(
upstream_id=upstream.id,
group_id="vip",
group_name="VIP",
key_name="SmartUp-Test-vip",
key_value="",
masked_key="sk-old-masked",
key_id="remote-123",
managed_prefix="SmartUp",
))
db_session.commit()
class MockClient:
def find_smartup_group_key(self, gid, name, prefix):
return {
"id": "remote-123",
"name": "SmartUp-Test-vip",
"key": "sk-remote-plain-value-1234567890abcdef",
"masked_key": "sk-re************cdef",
}
def create_api_key(self, *args, **kwargs):
raise AssertionError("create_api_key should not be called when remote key exists")
group = {"id": "vip", "name": "VIP", "rate_multiplier": 1}
body = GenerateKeysByGroupsRequest(
group_ids=["vip"],
name_prefix="SmartUp",
quota=0,
endpoint="/keys",
)
result = _ensure_group_key(db_session, MockClient(), upstream, group, "SmartUp", body)
assert result.status == "exists"
assert result.has_key_value is True
row = db_session.query(UpstreamGeneratedKey).filter(
UpstreamGeneratedKey.upstream_id == upstream.id,
UpstreamGeneratedKey.group_id == "vip",
).one()
assert row.key_value == "sk-remote-plain-value-1234567890abcdef"
assert row.masked_key == mask_secret(row.key_value)
assert row.status == "exists"
def test_sync_removes_remote_key_when_list_empty(db_session, monkeypatch):
"""同步函数在远端返回空列表时应删除本地 key_id 对应的记录。"""
from app.services import scheduler as sched_mod
from app.models.upstream_key import UpstreamGeneratedKey
from app.services.upstream_client import UpstreamClient
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
auth_type="bearer", auth_config_json="{}",
groups_endpoint="/groups", rate_endpoint="/rates")
db_session.add(upstream)
db_session.commit()
db_session.refresh(upstream)
db_session.add(UpstreamGeneratedKey(
upstream_id=upstream.id, group_id="vip", group_name="VIP",
key_name="SmartUp-Test-vip", key_value="sk-vip",
managed_prefix="SmartUp", key_id="remote-key-id",
))
db_session.commit()
# mock list_api_keys 返回空列表(查询成功但无 Key)
monkeypatch.setattr(UpstreamClient, "list_api_keys", lambda self, **kw: [])
monkeypatch.setattr(UpstreamClient, "login", lambda self: None)
monkeypatch.setattr(UpstreamClient, "close", lambda self: None)
monkeypatch.setattr(UpstreamClient, "__enter__", lambda self: self)
monkeypatch.setattr(UpstreamClient, "__exit__", lambda self, *a: None)
# 让 _sync_upstream_keys 使用 db_session 的 bind 引擎
monkeypatch.setattr(sched_mod, "SessionLocal",
lambda: db_session)
# 阻止 finally 中的 db.close() 影响测试会话
original_close = db_session.close
monkeypatch.setattr(db_session, "close", lambda: None)
snapshot = {
"upstream_id": upstream.id,
"groups": {"vip": {"group_id": "vip", "rate": "1"}},
"captured_at": datetime.now(timezone.utc).isoformat(),
}
captured_at = datetime.now(timezone.utc)
sched_mod._sync_upstream_keys(upstream.id, snapshot, captured_at)
monkeypatch.setattr(db_session, "close", original_close)
remaining = db_session.query(UpstreamGeneratedKey).all()
assert len(remaining) == 0, f"expected 0 after sync with empty remote, got {len(remaining)}"
def test_sync_marks_imported_key_orphaned_when_remote_key_missing(db_session, monkeypatch):
"""已导入账号管理的 Key 远端消失时保留本地行,避免丢失目标账号关联。"""
from app.services import scheduler as sched_mod
from app.models.upstream_key import UpstreamGeneratedKey
from app.services.upstream_client import UpstreamClient
website = Website(
name="Target",
site_type="sub2api",
base_url="http://target.local",
api_prefix="/api/v1/admin",
auth_type="api_key",
auth_config_json="{}",
groups_endpoint="/groups",
group_update_endpoint="/groups/{id}",
)
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
auth_type="bearer", auth_config_json="{}",
groups_endpoint="/groups", rate_endpoint="/rates")
db_session.add_all([website, upstream])
db_session.commit()
db_session.refresh(website)
db_session.refresh(upstream)
db_session.add(UpstreamGeneratedKey(
upstream_id=upstream.id,
group_id="vip",
group_name="VIP",
key_name="SmartUp-Test-vip",
key_value="sk-vip",
managed_prefix="SmartUp",
key_id="remote-key-id",
imported_website_id=website.id,
imported_account_id="account-101",
))
db_session.commit()
monkeypatch.setattr(UpstreamClient, "list_api_keys", lambda self, **kw: [])
monkeypatch.setattr(UpstreamClient, "login", lambda self: None)
monkeypatch.setattr(UpstreamClient, "close", lambda self: None)
monkeypatch.setattr(UpstreamClient, "__enter__", lambda self: self)
monkeypatch.setattr(UpstreamClient, "__exit__", lambda self, *a: None)
monkeypatch.setattr(sched_mod, "SessionLocal", lambda: db_session)
original_close = db_session.close
monkeypatch.setattr(db_session, "close", lambda: None)
snapshot = {
"upstream_id": upstream.id,
"groups": {"vip": {"group_id": "vip", "rate": "1"}},
"captured_at": datetime.now(timezone.utc).isoformat(),
}
captured_at = datetime.now(timezone.utc)
sched_mod._sync_upstream_keys(upstream.id, snapshot, captured_at)
monkeypatch.setattr(db_session, "close", original_close)
remaining = db_session.query(UpstreamGeneratedKey).all()
assert len(remaining) == 1
row = remaining[0]
assert row.status == "orphaned"
assert row.imported_website_id == website.id
assert row.imported_account_id == "account-101"
assert row.error == "远端 Key 已不存在"
def test_migration_function_integration(monkeypatch):
"""直接调用 _migrate_upstream_generated_keys() 验证列新增和索引创建。"""
from app.database import _migrate_upstream_generated_keys, engine as real_engine
from sqlalchemy import text as _text
# 使用独立 engine,避免影响真实数据库
test_engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
monkeypatch.setattr("app.database.engine", test_engine)
# 建表(不含 managed_prefix 列,模拟旧版 schema
with test_engine.begin() as conn:
conn.execute(_text("""
CREATE TABLE upstream_generated_keys (
id INTEGER PRIMARY KEY AUTOINCREMENT,
upstream_id INTEGER NOT NULL, group_id VARCHAR(255) NOT NULL,
group_name VARCHAR(255) DEFAULT '', key_id VARCHAR(255),
key_name VARCHAR(255) NOT NULL, key_value TEXT NOT NULL,
masked_key VARCHAR(255) DEFAULT '', raw_json TEXT DEFAULT '{}',
status VARCHAR(32) DEFAULT 'created', error TEXT,
imported_website_id INTEGER, imported_account_id VARCHAR(255),
imported_at DATETIME, created_at DATETIME
)
"""))
conn.execute(_text("""
INSERT INTO upstream_generated_keys
(upstream_id, group_id, group_name, key_name, key_value, masked_key, raw_json, status, created_at)
VALUES (1, 'vip', 'VIP', 'SmartUp-Old-vip', 'sk-val', '', '{}', 'created', datetime('now'))
"""))
# 调用迁移函数入口
_migrate_upstream_generated_keys()
# 验证 managed_prefix 列已存在且被填充
inspector = __import__('sqlalchemy', fromlist=['']).inspect(test_engine)
cols = {c["name"] for c in inspector.get_columns("upstream_generated_keys")}
assert "managed_prefix" in cols, "managed_prefix column should exist after migration"
with test_engine.connect() as conn:
row = conn.execute(_text("SELECT managed_prefix, key_value FROM upstream_generated_keys LIMIT 1")).fetchone()
assert row[0] == "SmartUp", f"expected SmartUp, got {row[0]}"
assert row[1] == "sk-val"
# 验证唯一索引已创建
indexes = inspector.get_indexes("upstream_generated_keys")
index_names = {ix["name"] for ix in indexes}
assert "uq_upstream_group_managed" in index_names, "partial unique index should exist"
monkeypatch.undo()
def test_create_twice_only_one_record(db_session):
"""同一上游同一分组连续调用两次 ensure,本地只保留一条记录。"""
from app.models.upstream_key import UpstreamGeneratedKey
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
auth_type="bearer", auth_config_json="{}",
groups_endpoint="/groups", rate_endpoint="/rates")
db_session.add(upstream)
db_session.commit()
db_session.refresh(upstream)
# 模拟第一次创建
db_session.add(UpstreamGeneratedKey(
upstream_id=upstream.id, group_id="vip", group_name="VIP",
key_name="SmartUp-Test-VIP", key_value="sk-first",
status="created",
))
db_session.commit()
# 模拟第二次调用 upsert(用同一个 key_name 且 status=exists
existing = db_session.query(UpstreamGeneratedKey).filter(
UpstreamGeneratedKey.upstream_id == upstream.id,
UpstreamGeneratedKey.group_id == "vip",
UpstreamGeneratedKey.key_name == "SmartUp-Test-VIP",
).first()
if existing:
existing.status = "exists"
existing.updated_at = datetime.now(timezone.utc)
else:
db_session.add(UpstreamGeneratedKey(
upstream_id=upstream.id, group_id="vip", group_name="VIP",
key_name="SmartUp-Test-VIP", key_value="sk-second",
status="exists",
))
db_session.commit()
rows = db_session.query(UpstreamGeneratedKey).filter(
UpstreamGeneratedKey.upstream_id == upstream.id,
UpstreamGeneratedKey.group_id == "vip",
).all()
assert len(rows) == 1
assert rows[0].status == "exists"
assert rows[0].key_value == "sk-first" # 更新的是原记录,不是新建
def test_sync_removes_gone_group(db_session):
"""分组不在最新快照中时,本地对应 Key 记录应被删除。"""
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
auth_type="bearer", auth_config_json="{}",
groups_endpoint="/groups", rate_endpoint="/rates")
db_session.add(upstream)
db_session.commit()
db_session.refresh(upstream)
db_session.add_all([
UpstreamGeneratedKey(
upstream_id=upstream.id, group_id="vip", group_name="VIP",
key_name="SmartUp-Test-VIP", key_value="sk-vip",
),
UpstreamGeneratedKey(
upstream_id=upstream.id, group_id="free", group_name="Free",
key_name="SmartUp-Test-Free", key_value="sk-free",
),
])
db_session.commit()
# 快照中只有 vip,没有 free
active_group_ids = {"vip"}
for row in db_session.query(UpstreamGeneratedKey).filter(
UpstreamGeneratedKey.upstream_id == upstream.id).all():
if row.group_id not in active_group_ids:
db_session.delete(row)
db_session.commit()
remaining = db_session.query(UpstreamGeneratedKey).all()
assert len(remaining) == 1
assert remaining[0].group_id == "vip"
def test_sync_removes_deleted_remote_key(db_session):
"""远端 Key 被删除后,本地对应记录应被删除。"""
from app.models.upstream_key import UpstreamGeneratedKey
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
auth_type="bearer", auth_config_json="{}",
groups_endpoint="/groups", rate_endpoint="/rates")
db_session.add(upstream)
db_session.commit()
db_session.refresh(upstream)
db_session.add(UpstreamGeneratedKey(
upstream_id=upstream.id, group_id="vip", group_name="VIP",
key_name="SmartUp-Test-VIP", key_value="sk-vip",
key_id="remote-123",
))
db_session.commit()
# 模拟远端返回的活跃 key_ids 中没有 remote-123
remote_key_ids = {"remote-456", "remote-789"}
for row in db_session.query(UpstreamGeneratedKey).filter(
UpstreamGeneratedKey.upstream_id == upstream.id).all():
if row.key_id and row.key_id not in remote_key_ids:
db_session.delete(row)
db_session.commit()
remaining = db_session.query(UpstreamGeneratedKey).all()
assert len(remaining) == 0
def test_new_api_create_token_fetches_plaintext_key(monkeypatch):
"""New-API 创建 token 后需按 id 再取一次明文 key。"""
from app.services.upstream_client import UpstreamClient
client = UpstreamClient(
base_url="http://newapi.local",
api_prefix="",
auth_type="cookie",
auth_config={"cookie_string": "session=abc", "new_api_user": "7"},
)
created_bodies = []
def fake_request(method, path, body=None, auth=True):
if method == "GET" and path == "/api/status":
return {"success": True, "data": {"quota_per_unit": 500000}}
if method == "POST" and path == "/api/token/":
created_bodies.append(body)
return {"success": True, "message": ""}
if method == "POST" and path == "/api/token/123/key":
return {"success": True, "data": {"key": "new-api-plain-key"}}
raise AssertionError(f"unexpected request {method} {path}")
monkeypatch.setattr(client, "_request", fake_request)
monkeypatch.setattr(
client,
"_list_new_api_tokens",
lambda search="", group_id=None: [{"id": 123, "name": search, "group": group_id, "key": "new-****-key"}],
)
result = client.create_api_key(
"SmartUp-1-VIP-vip",
"vip",
quota=2,
expires_in_days=3,
endpoint="/api/token",
)
assert result["id"] == "123"
assert result["key"] == "new-api-plain-key"
assert created_bodies[0]["group"] == "vip"
assert created_bodies[0]["remain_quota"] == 1000000
assert created_bodies[0]["unlimited_quota"] is False
assert created_bodies[0]["expired_time"] > 0
def test_new_api_list_tokens_uses_full_list_and_fetches_plaintext_when_search_misses():
"""New-API search 可能不匹配前缀;应拉完整 token 列表并按 id 回填明文。"""
from app.services.upstream_client import UpstreamClient
client = UpstreamClient(
base_url="http://newapi.local",
api_prefix="",
auth_type="cookie",
auth_config={"cookie_string": "session=abc", "new_api_user": "7"},
)
class FakeResponse:
def __init__(self, payload):
self._payload = payload
self.cookies = {}
self.headers = {"content-type": "application/json"}
self.content = b"{}"
self.text = "{}"
def raise_for_status(self):
return None
def json(self):
return self._payload
class FakeHttpClient:
def request(self, method, url, **kwargs):
path = url.replace("http://newapi.local", "")
params = kwargs.get("params") or {}
if method == "GET" and path == "/api/token/search":
assert params["keyword"] == "SmartUp"
return FakeResponse({
"success": True,
"data": {"page": 1, "page_size": 100, "total": 0, "items": []},
})
if method == "GET" and path == "/api/token/":
return FakeResponse({
"success": True,
"data": {
"page": 1,
"page_size": 100,
"total": 4,
"items": [
{"id": 447, "name": "SmartUp-4-gptpro-gpt pro", "group": "gpt pro", "key": "sk-XE2o********WWh"},
{"id": 446, "name": "SmartUp-4-gptplus-gpt plus", "group": "gpt plus", "key": "sk-JRi1********rtum"},
{"id": 445, "name": "SmartUp-4-claude特价kiro-claude 特价kiro", "group": "claude 特价kiro", "key": "sk-Aldb********08W2"},
{"id": 56, "name": "plus", "group": "gpt plus", "key": "sk-20cB********pEfE"},
],
},
})
if method == "POST" and path == "/api/token/447/key":
return FakeResponse({"success": True, "data": {"key": "sk-gptpro-plain"}})
if method == "POST" and path == "/api/token/446/key":
return FakeResponse({"success": True, "data": {"key": "sk-gptplus-plain"}})
if method == "POST" and path == "/api/token/445/key":
return FakeResponse({"success": True, "data": {"key": "sk-claude-plain"}})
raise AssertionError(f"unexpected request {method} {path} {params}")
client._client = FakeHttpClient()
rows = client.list_api_keys(search="SmartUp", status="active")
assert [row["id"] for row in rows] == [447, 446, 445]
assert [row["group_id"] for row in rows] == ["gpt pro", "gpt plus", "claude 特价kiro"]
assert [row["key"] for row in rows] == ["sk-gptpro-plain", "sk-gptplus-plain", "sk-claude-plain"]
def test_generated_keys_persists_new_api_tokens_with_plaintext(db_session, monkeypatch):
"""generated-keys 应把 New-API 远端 token 回填成本地可导入记录。"""
from app.routers import upstreams as upstreams_router
from app.models.upstream_key import UpstreamGeneratedKey
upstream = Upstream(
name="NewAPI",
base_url="http://newapi.local",
api_prefix="",
auth_type="cookie",
auth_config_json=json.dumps({"cookie_string": "session=abc", "new_api_user": "7"}),
groups_endpoint="/api/user/self/groups",
rate_endpoint="/api/user/self/groups",
)
db_session.add(upstream)
db_session.commit()
db_session.refresh(upstream)
class FakeClient:
def __init__(self, **kwargs):
self.kwargs = kwargs
def __enter__(self):
return self
def __exit__(self, *args):
return None
def login(self):
return None
def list_api_keys(self, search="", status="active"):
assert search == "SmartUp"
return [
{
"id": 447,
"name": "SmartUp-4-gptpro-gpt pro",
"group": "gpt pro",
"group_id": "gpt pro",
"key": "sk-gptpro-plain",
"masked_key": "sk-g********lain",
}
]
monkeypatch.setattr(upstreams_router, "UpstreamClient", FakeClient)
response = upstreams_router.list_generated_keys(upstream.id, db_session, object())
assert len(response) == 1
assert response[0].has_key_value is True
assert response[0].id is not None
assert response[0].key_name == "SmartUp-4-gptpro-gpt pro"
row = db_session.query(UpstreamGeneratedKey).filter(
UpstreamGeneratedKey.upstream_id == upstream.id,
UpstreamGeneratedKey.key_id == "447",
).one()
assert row.group_id == "gpt pro"
assert row.group_name == "gpt pro"
assert row.key_value == "sk-gptpro-plain"
assert row.managed_prefix == "SmartUp"
def test_generate_keys_allows_new_api_user_upstream(db_session, monkeypatch):
"""New-API 普通账号上游应允许按分组生成 token。"""
from app.routers import upstreams as upstreams_router
from app.schemas.upstream import GenerateKeysByGroupsRequest
upstream = Upstream(
name="NewAPI",
base_url="http://newapi.local",
api_prefix="",
auth_type="cookie",
auth_config_json=json.dumps({"cookie_string": "session=abc", "new_api_user": "7"}),
groups_endpoint="/api/user/self/groups",
rate_endpoint="/api/user/self/groups",
)
db_session.add(upstream)
db_session.commit()
db_session.refresh(upstream)
monkeypatch.setattr(upstreams_router.website_sync, "reconcile_upstream_keys_full", lambda db, uid: True)
class FakeClient:
def __init__(self, **kwargs):
self.kwargs = kwargs
def __enter__(self):
return self
def __exit__(self, *args):
return None
def login(self):
return None
def get_available_groups(self, endpoint):
assert endpoint == "/api/user/self/groups"
return [{"id": "vip", "name": "VIP"}]
def find_smartup_group_key(self, gid, expected_name, prefix):
return None
def create_api_key(self, name, group_id, **kwargs):
assert kwargs["endpoint"] == "/api/token"
return {"id": "123", "key": "new-api-plain-key", "masked_key": "new-****-key", "raw": {"id": 123}}
monkeypatch.setattr(upstreams_router, "UpstreamClient", FakeClient)
response = upstreams_router.generate_keys_by_groups(
upstream.id,
GenerateKeysByGroupsRequest(group_ids=["vip"], endpoint="/api/token"),
db_session,
object(),
)
assert response.success is True
assert response.items[0].status == "created"
assert response.items[0].key_value == "new-api-plain-key"