fix: reuse upstream keys for account import
This commit is contained in:
@@ -5,7 +5,7 @@ import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import List
|
||||
from typing import Any, List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,7 +23,7 @@ from app.schemas.upstream import (
|
||||
GeneratedUpstreamKeyResponse,
|
||||
UpstreamCreate, UpstreamUpdate, UpstreamResponse, SnapshotResponse, TestResult
|
||||
)
|
||||
from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot, mask_secret
|
||||
from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot, mask_secret, _extract_key_value
|
||||
from app.services.snapshot_service import diff_snapshots
|
||||
from app.services import scheduler as sched_svc
|
||||
from app.services import webhook_service
|
||||
@@ -65,6 +65,7 @@ def _key_response(row: UpstreamGeneratedKey, include_value: bool = False) -> Gen
|
||||
imported_at=row.imported_at,
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
has_key_value=bool(row.key_value),
|
||||
)
|
||||
|
||||
|
||||
@@ -78,6 +79,17 @@ def _mask_auth_config(auth_type: str, cfg: dict) -> dict:
|
||||
return masked
|
||||
|
||||
|
||||
def _extract_plaintext_key(payload: dict[str, Any] | None) -> str:
|
||||
if not isinstance(payload, dict):
|
||||
return ""
|
||||
key_value = _extract_key_value(payload)
|
||||
if not key_value:
|
||||
return ""
|
||||
if "*" in key_value:
|
||||
return ""
|
||||
return key_value
|
||||
|
||||
|
||||
def _to_response(u: Upstream) -> UpstreamResponse:
|
||||
cfg = json.loads(u.auth_config_json or "{}")
|
||||
return UpstreamResponse(
|
||||
@@ -164,9 +176,21 @@ def _ensure_group_key(
|
||||
except Exception:
|
||||
existing = None
|
||||
if existing:
|
||||
key_id = str(existing.get("id") or "")
|
||||
key_value = _extract_plaintext_key(existing)
|
||||
masked = mask_secret(key_value) if key_value else (existing.get("masked_key") or existing.get("key") or "")
|
||||
row.key_id = key_id or row.key_id
|
||||
if key_value:
|
||||
row.key_value = key_value
|
||||
row.masked_key = masked
|
||||
elif masked:
|
||||
row.masked_key = str(masked)
|
||||
row.raw_json = json.dumps(existing, ensure_ascii=False)
|
||||
row.status = "exists"
|
||||
row.updated_at = datetime.now(timezone.utc)
|
||||
db.add(row)
|
||||
db.commit()
|
||||
db.refresh(row)
|
||||
return _key_response(row, include_value=False)
|
||||
# 远端不存在,需要重新创建
|
||||
row.status = "replaced"
|
||||
@@ -175,10 +199,16 @@ def _ensure_group_key(
|
||||
existing = client.find_smartup_group_key(gid, stable_name, prefix)
|
||||
if existing:
|
||||
key_id = str(existing.get("id") or "")
|
||||
masked = existing.get("masked_key") or existing.get("key") or ""
|
||||
key_value = _extract_plaintext_key(existing)
|
||||
masked = mask_secret(key_value) if key_value else (existing.get("masked_key") or existing.get("key") or "")
|
||||
if row:
|
||||
row.key_id = key_id or row.key_id
|
||||
row.masked_key = masked or row.masked_key
|
||||
if key_value:
|
||||
row.key_value = key_value
|
||||
row.masked_key = masked
|
||||
elif masked:
|
||||
row.masked_key = str(masked)
|
||||
row.raw_json = json.dumps(existing, ensure_ascii=False)
|
||||
row.status = "exists"
|
||||
row.updated_at = datetime.now(timezone.utc)
|
||||
else:
|
||||
@@ -188,13 +218,13 @@ def _ensure_group_key(
|
||||
group_name=gname,
|
||||
key_id=key_id or None,
|
||||
key_name=stable_name,
|
||||
key_value="",
|
||||
key_value=key_value or "",
|
||||
masked_key=masked,
|
||||
raw_json=json.dumps(existing, ensure_ascii=False),
|
||||
managed_prefix=prefix,
|
||||
status="exists",
|
||||
)
|
||||
db.add(row)
|
||||
db.add(row)
|
||||
db.commit()
|
||||
db.refresh(row)
|
||||
return _key_response(row, include_value=False)
|
||||
|
||||
@@ -118,6 +118,7 @@ class GeneratedUpstreamKeyResponse(BaseModel):
|
||||
imported_at: Optional[datetime] = None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
has_key_value: bool = False
|
||||
|
||||
|
||||
class GenerateKeysByGroupsResponse(BaseModel):
|
||||
|
||||
@@ -223,13 +223,25 @@ def _sync_upstream_keys(upstream_id: int, snapshot: dict[str, Any], captured_at:
|
||||
for row in key_rows:
|
||||
# 1. 分组已不在当前快照中 → 删除本地记录
|
||||
if row.group_id not in active_group_ids:
|
||||
db.delete(row)
|
||||
logger.info("removed key %s (group %s no longer in snapshot)", row.id, row.group_id)
|
||||
if row.imported_website_id and row.imported_account_id:
|
||||
row.status = "orphaned"
|
||||
row.error = "来源分组已不存在"
|
||||
row.updated_at = captured_at
|
||||
logger.info("marked key %s orphaned (group %s no longer in snapshot)", row.id, row.group_id)
|
||||
else:
|
||||
db.delete(row)
|
||||
logger.info("removed key %s (group %s no longer in snapshot)", row.id, row.group_id)
|
||||
continue
|
||||
# 2. 远端查询成功但 key_id 不在列表中 → 删除本地记录
|
||||
if row.key_id and remote_key_ids is not None and row.key_id not in remote_key_ids:
|
||||
db.delete(row)
|
||||
logger.info("removed key %s (key_id %s gone from remote)", row.id, row.key_id)
|
||||
if row.imported_website_id and row.imported_account_id:
|
||||
row.status = "orphaned"
|
||||
row.error = "远端 Key 已不存在"
|
||||
row.updated_at = captured_at
|
||||
logger.info("marked key %s orphaned (key_id %s gone from remote)", row.id, row.key_id)
|
||||
else:
|
||||
db.delete(row)
|
||||
logger.info("removed key %s (key_id %s gone from remote)", row.id, row.key_id)
|
||||
continue
|
||||
# 3. 更新同步时间戳(仅当查询成功且 Key 仍在远端时)
|
||||
if remote_key_ids is not None and row.key_id in remote_key_ids:
|
||||
|
||||
@@ -262,6 +262,64 @@ def test_ensure_group_key_reuses_old_record(db_session, monkeypatch):
|
||||
assert rows[0].key_value == "sk-new-value"
|
||||
|
||||
|
||||
def test_ensure_group_key_backfills_plaintext_from_remote_existing_key(db_session):
|
||||
"""远端已存在的 SmartUp Key 如果列表接口返回明文,应补写到本地 key_value。"""
|
||||
from app.routers.upstreams import _ensure_group_key
|
||||
from app.models.upstream_key import UpstreamGeneratedKey
|
||||
from app.schemas.upstream import GenerateKeysByGroupsRequest
|
||||
from app.services.upstream_client import mask_secret
|
||||
|
||||
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
|
||||
auth_type="bearer", auth_config_json="{}",
|
||||
groups_endpoint="/groups", rate_endpoint="/rates")
|
||||
db_session.add(upstream)
|
||||
db_session.commit()
|
||||
db_session.refresh(upstream)
|
||||
|
||||
db_session.add(UpstreamGeneratedKey(
|
||||
upstream_id=upstream.id,
|
||||
group_id="vip",
|
||||
group_name="VIP",
|
||||
key_name="SmartUp-Test-vip",
|
||||
key_value="",
|
||||
masked_key="sk-old-masked",
|
||||
key_id="remote-123",
|
||||
managed_prefix="SmartUp",
|
||||
))
|
||||
db_session.commit()
|
||||
|
||||
class MockClient:
|
||||
def find_smartup_group_key(self, gid, name, prefix):
|
||||
return {
|
||||
"id": "remote-123",
|
||||
"name": "SmartUp-Test-vip",
|
||||
"key": "sk-remote-plain-value-1234567890abcdef",
|
||||
"masked_key": "sk-re************cdef",
|
||||
}
|
||||
|
||||
def create_api_key(self, *args, **kwargs):
|
||||
raise AssertionError("create_api_key should not be called when remote key exists")
|
||||
|
||||
group = {"id": "vip", "name": "VIP", "rate_multiplier": 1}
|
||||
body = GenerateKeysByGroupsRequest(
|
||||
group_ids=["vip"],
|
||||
name_prefix="SmartUp",
|
||||
quota=0,
|
||||
endpoint="/keys",
|
||||
)
|
||||
result = _ensure_group_key(db_session, MockClient(), upstream, group, "SmartUp", body)
|
||||
|
||||
assert result.status == "exists"
|
||||
assert result.has_key_value is True
|
||||
row = db_session.query(UpstreamGeneratedKey).filter(
|
||||
UpstreamGeneratedKey.upstream_id == upstream.id,
|
||||
UpstreamGeneratedKey.group_id == "vip",
|
||||
).one()
|
||||
assert row.key_value == "sk-remote-plain-value-1234567890abcdef"
|
||||
assert row.masked_key == mask_secret(row.key_value)
|
||||
assert row.status == "exists"
|
||||
|
||||
|
||||
def test_sync_removes_remote_key_when_list_empty(db_session, monkeypatch):
|
||||
"""同步函数在远端返回空列表时应删除本地 key_id 对应的记录。"""
|
||||
from app.services import scheduler as sched_mod
|
||||
@@ -310,6 +368,71 @@ def test_sync_removes_remote_key_when_list_empty(db_session, monkeypatch):
|
||||
assert len(remaining) == 0, f"expected 0 after sync with empty remote, got {len(remaining)}"
|
||||
|
||||
|
||||
def test_sync_marks_imported_key_orphaned_when_remote_key_missing(db_session, monkeypatch):
|
||||
"""已导入账号管理的 Key 远端消失时保留本地行,避免丢失目标账号关联。"""
|
||||
from app.services import scheduler as sched_mod
|
||||
from app.models.upstream_key import UpstreamGeneratedKey
|
||||
from app.services.upstream_client import UpstreamClient
|
||||
|
||||
website = Website(
|
||||
name="Target",
|
||||
site_type="sub2api",
|
||||
base_url="http://target.local",
|
||||
api_prefix="/api/v1/admin",
|
||||
auth_type="api_key",
|
||||
auth_config_json="{}",
|
||||
groups_endpoint="/groups",
|
||||
group_update_endpoint="/groups/{id}",
|
||||
)
|
||||
upstream = Upstream(name="Test", base_url="http://local", api_prefix="/api/v1",
|
||||
auth_type="bearer", auth_config_json="{}",
|
||||
groups_endpoint="/groups", rate_endpoint="/rates")
|
||||
db_session.add_all([website, upstream])
|
||||
db_session.commit()
|
||||
db_session.refresh(website)
|
||||
db_session.refresh(upstream)
|
||||
|
||||
db_session.add(UpstreamGeneratedKey(
|
||||
upstream_id=upstream.id,
|
||||
group_id="vip",
|
||||
group_name="VIP",
|
||||
key_name="SmartUp-Test-vip",
|
||||
key_value="sk-vip",
|
||||
managed_prefix="SmartUp",
|
||||
key_id="remote-key-id",
|
||||
imported_website_id=website.id,
|
||||
imported_account_id="account-101",
|
||||
))
|
||||
db_session.commit()
|
||||
|
||||
monkeypatch.setattr(UpstreamClient, "list_api_keys", lambda self, **kw: [])
|
||||
monkeypatch.setattr(UpstreamClient, "login", lambda self: None)
|
||||
monkeypatch.setattr(UpstreamClient, "close", lambda self: None)
|
||||
monkeypatch.setattr(UpstreamClient, "__enter__", lambda self: self)
|
||||
monkeypatch.setattr(UpstreamClient, "__exit__", lambda self, *a: None)
|
||||
monkeypatch.setattr(sched_mod, "SessionLocal", lambda: db_session)
|
||||
original_close = db_session.close
|
||||
monkeypatch.setattr(db_session, "close", lambda: None)
|
||||
|
||||
snapshot = {
|
||||
"upstream_id": upstream.id,
|
||||
"groups": {"vip": {"group_id": "vip", "rate": "1"}},
|
||||
"captured_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
captured_at = datetime.now(timezone.utc)
|
||||
sched_mod._sync_upstream_keys(upstream.id, snapshot, captured_at)
|
||||
|
||||
monkeypatch.setattr(db_session, "close", original_close)
|
||||
remaining = db_session.query(UpstreamGeneratedKey).all()
|
||||
assert len(remaining) == 1
|
||||
row = remaining[0]
|
||||
assert row.status == "orphaned"
|
||||
assert row.imported_website_id == website.id
|
||||
assert row.imported_account_id == "account-101"
|
||||
assert row.error == "远端 Key 已不存在"
|
||||
|
||||
|
||||
def test_migration_function_integration(monkeypatch):
|
||||
"""直接调用 _migrate_upstream_generated_keys() 验证列新增和索引创建。"""
|
||||
from app.database import _migrate_upstream_generated_keys, engine as real_engine
|
||||
|
||||
Reference in New Issue
Block a user