fix: reuse upstream keys for account import
This commit is contained in:
@@ -5,7 +5,7 @@ import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import List
|
||||
from typing import Any, List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,7 +23,7 @@ from app.schemas.upstream import (
|
||||
GeneratedUpstreamKeyResponse,
|
||||
UpstreamCreate, UpstreamUpdate, UpstreamResponse, SnapshotResponse, TestResult
|
||||
)
|
||||
from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot, mask_secret
|
||||
from app.services.upstream_client import UpstreamClient, UpstreamError, build_snapshot, mask_secret, _extract_key_value
|
||||
from app.services.snapshot_service import diff_snapshots
|
||||
from app.services import scheduler as sched_svc
|
||||
from app.services import webhook_service
|
||||
@@ -65,6 +65,7 @@ def _key_response(row: UpstreamGeneratedKey, include_value: bool = False) -> Gen
|
||||
imported_at=row.imported_at,
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
has_key_value=bool(row.key_value),
|
||||
)
|
||||
|
||||
|
||||
@@ -78,6 +79,17 @@ def _mask_auth_config(auth_type: str, cfg: dict) -> dict:
|
||||
return masked
|
||||
|
||||
|
||||
def _extract_plaintext_key(payload: dict[str, Any] | None) -> str:
|
||||
if not isinstance(payload, dict):
|
||||
return ""
|
||||
key_value = _extract_key_value(payload)
|
||||
if not key_value:
|
||||
return ""
|
||||
if "*" in key_value:
|
||||
return ""
|
||||
return key_value
|
||||
|
||||
|
||||
def _to_response(u: Upstream) -> UpstreamResponse:
|
||||
cfg = json.loads(u.auth_config_json or "{}")
|
||||
return UpstreamResponse(
|
||||
@@ -164,9 +176,21 @@ def _ensure_group_key(
|
||||
except Exception:
|
||||
existing = None
|
||||
if existing:
|
||||
key_id = str(existing.get("id") or "")
|
||||
key_value = _extract_plaintext_key(existing)
|
||||
masked = mask_secret(key_value) if key_value else (existing.get("masked_key") or existing.get("key") or "")
|
||||
row.key_id = key_id or row.key_id
|
||||
if key_value:
|
||||
row.key_value = key_value
|
||||
row.masked_key = masked
|
||||
elif masked:
|
||||
row.masked_key = str(masked)
|
||||
row.raw_json = json.dumps(existing, ensure_ascii=False)
|
||||
row.status = "exists"
|
||||
row.updated_at = datetime.now(timezone.utc)
|
||||
db.add(row)
|
||||
db.commit()
|
||||
db.refresh(row)
|
||||
return _key_response(row, include_value=False)
|
||||
# 远端不存在,需要重新创建
|
||||
row.status = "replaced"
|
||||
@@ -175,10 +199,16 @@ def _ensure_group_key(
|
||||
existing = client.find_smartup_group_key(gid, stable_name, prefix)
|
||||
if existing:
|
||||
key_id = str(existing.get("id") or "")
|
||||
masked = existing.get("masked_key") or existing.get("key") or ""
|
||||
key_value = _extract_plaintext_key(existing)
|
||||
masked = mask_secret(key_value) if key_value else (existing.get("masked_key") or existing.get("key") or "")
|
||||
if row:
|
||||
row.key_id = key_id or row.key_id
|
||||
row.masked_key = masked or row.masked_key
|
||||
if key_value:
|
||||
row.key_value = key_value
|
||||
row.masked_key = masked
|
||||
elif masked:
|
||||
row.masked_key = str(masked)
|
||||
row.raw_json = json.dumps(existing, ensure_ascii=False)
|
||||
row.status = "exists"
|
||||
row.updated_at = datetime.now(timezone.utc)
|
||||
else:
|
||||
@@ -188,13 +218,13 @@ def _ensure_group_key(
|
||||
group_name=gname,
|
||||
key_id=key_id or None,
|
||||
key_name=stable_name,
|
||||
key_value="",
|
||||
key_value=key_value or "",
|
||||
masked_key=masked,
|
||||
raw_json=json.dumps(existing, ensure_ascii=False),
|
||||
managed_prefix=prefix,
|
||||
status="exists",
|
||||
)
|
||||
db.add(row)
|
||||
db.add(row)
|
||||
db.commit()
|
||||
db.refresh(row)
|
||||
return _key_response(row, include_value=False)
|
||||
|
||||
@@ -118,6 +118,7 @@ class GeneratedUpstreamKeyResponse(BaseModel):
|
||||
imported_at: Optional[datetime] = None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
has_key_value: bool = False
|
||||
|
||||
|
||||
class GenerateKeysByGroupsResponse(BaseModel):
|
||||
|
||||
@@ -223,13 +223,25 @@ def _sync_upstream_keys(upstream_id: int, snapshot: dict[str, Any], captured_at:
|
||||
for row in key_rows:
|
||||
# 1. 分组已不在当前快照中 → 删除本地记录
|
||||
if row.group_id not in active_group_ids:
|
||||
db.delete(row)
|
||||
logger.info("removed key %s (group %s no longer in snapshot)", row.id, row.group_id)
|
||||
if row.imported_website_id and row.imported_account_id:
|
||||
row.status = "orphaned"
|
||||
row.error = "来源分组已不存在"
|
||||
row.updated_at = captured_at
|
||||
logger.info("marked key %s orphaned (group %s no longer in snapshot)", row.id, row.group_id)
|
||||
else:
|
||||
db.delete(row)
|
||||
logger.info("removed key %s (group %s no longer in snapshot)", row.id, row.group_id)
|
||||
continue
|
||||
# 2. 远端查询成功但 key_id 不在列表中 → 删除本地记录
|
||||
if row.key_id and remote_key_ids is not None and row.key_id not in remote_key_ids:
|
||||
db.delete(row)
|
||||
logger.info("removed key %s (key_id %s gone from remote)", row.id, row.key_id)
|
||||
if row.imported_website_id and row.imported_account_id:
|
||||
row.status = "orphaned"
|
||||
row.error = "远端 Key 已不存在"
|
||||
row.updated_at = captured_at
|
||||
logger.info("marked key %s orphaned (key_id %s gone from remote)", row.id, row.key_id)
|
||||
else:
|
||||
db.delete(row)
|
||||
logger.info("removed key %s (key_id %s gone from remote)", row.id, row.key_id)
|
||||
continue
|
||||
# 3. 更新同步时间戳(仅当查询成功且 Key 仍在远端时)
|
||||
if remote_key_ids is not None and row.key_id in remote_key_ids:
|
||||
|
||||
Reference in New Issue
Block a user