feat: add one-click sync for website group bindings
This commit is contained in:
@@ -30,7 +30,9 @@ from app.schemas.website import (
|
||||
WebsiteResponse,
|
||||
WebsiteSyncLogResponse,
|
||||
WebsiteUpdate,
|
||||
WebsiteBatchSyncResponse,
|
||||
)
|
||||
|
||||
from app.services.website_client import Sub2ApiWebsiteClient
|
||||
from app.services.website_sync import binding_sources, sync_binding, build_rate_priority_map
|
||||
from app.utils.auth import get_current_user
|
||||
@@ -669,6 +671,69 @@ def list_bindings(db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||||
return [_binding_response(db, row) for row in rows]
|
||||
|
||||
|
||||
@router.post("/api/websites/{wid}/group-bindings/sync-now", response_model=WebsiteBatchSyncResponse)
|
||||
def sync_website_group_bindings(
|
||||
wid: int,
|
||||
db: Session = Depends(get_db),
|
||||
_=Depends(get_current_user),
|
||||
):
|
||||
"""一键同步网站下所有分组绑定。"""
|
||||
website = db.query(Website).filter(Website.id == wid).first()
|
||||
if not website:
|
||||
raise HTTPException(404, "website not found")
|
||||
|
||||
bindings = db.query(WebsiteGroupBinding).filter(WebsiteGroupBinding.website_id == wid).order_by(WebsiteGroupBinding.id.asc()).all()
|
||||
if not bindings:
|
||||
return WebsiteBatchSyncResponse(
|
||||
total=0, success=0, failed=0, skipped=0,
|
||||
message="暂无绑定可同步",
|
||||
logs=[],
|
||||
)
|
||||
|
||||
results: list[WebsiteSyncLog] = []
|
||||
for binding in bindings:
|
||||
try:
|
||||
log = sync_binding(db, binding, write=True)
|
||||
results.append(log)
|
||||
except Exception as exc:
|
||||
logger.exception("batch sync failed for binding %s", binding.id)
|
||||
# Create and persist a synthetic failed log if sync_binding crashed before creating one
|
||||
synthetic_log = WebsiteSyncLog(
|
||||
website_id=wid,
|
||||
binding_id=binding.id,
|
||||
target_group_id=binding.target_group_id,
|
||||
target_group_name=binding.target_group_name,
|
||||
algorithm=binding.algorithm,
|
||||
percent=binding.percent,
|
||||
source_rates_json="[]",
|
||||
status="failed",
|
||||
message=str(exc),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db.add(synthetic_log)
|
||||
db.commit()
|
||||
db.refresh(synthetic_log)
|
||||
results.append(synthetic_log)
|
||||
|
||||
success_count = sum(1 for r in results if r.status == "success")
|
||||
failed_count = sum(1 for r in results if r.status == "failed")
|
||||
skipped_count = sum(1 for r in results if r.status == "skipped")
|
||||
|
||||
msg_parts = []
|
||||
if success_count: msg_parts.append(f"成功 {success_count}")
|
||||
if failed_count: msg_parts.append(f"失败 {failed_count}")
|
||||
if skipped_count: msg_parts.append(f"跳过 {skipped_count}")
|
||||
|
||||
return WebsiteBatchSyncResponse(
|
||||
total=len(bindings),
|
||||
success=success_count,
|
||||
failed=failed_count,
|
||||
skipped=skipped_count,
|
||||
message=f"同步完成:{'、'.join(msg_parts)}" if msg_parts else "同步完成",
|
||||
logs=[_log_response(r) for r in results],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/group-bindings", response_model=BindingResponse, status_code=201)
|
||||
def create_binding(body: BindingCreate, db: Session = Depends(get_db), _=Depends(get_current_user)):
|
||||
website = db.query(Website).filter(Website.id == body.website_id).first()
|
||||
|
||||
@@ -178,3 +178,12 @@ class ImportAccountsResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
items: list[ImportAccountItem]
|
||||
|
||||
|
||||
class WebsiteBatchSyncResponse(BaseModel):
|
||||
total: int
|
||||
success: int
|
||||
failed: int
|
||||
skipped: int
|
||||
message: str
|
||||
logs: list[WebsiteSyncLogResponse]
|
||||
|
||||
@@ -0,0 +1,186 @@
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.database import Base
|
||||
# Import all models to ensure they are registered with Base.metadata
|
||||
from app.models import admin_user, upstream, snapshot, webhook_config, notification_log, custom_page, website, revoked_token, upstream_key
|
||||
from app.models.website import Website, WebsiteGroupBinding, WebsiteSyncLog
|
||||
from app.models.upstream import Upstream
|
||||
from app.models.snapshot import UpstreamRateSnapshot
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
|
||||
@pytest.fixture()
|
||||
def db_session():
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
db = TestingSessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
def test_sync_website_group_bindings_success(db_session, monkeypatch):
|
||||
from app.routers.websites import sync_website_group_bindings
|
||||
|
||||
# Setup website
|
||||
w = Website(name="W1", base_url="http://w1", enabled=True, auto_sync_enabled=True, auth_config_json="{}", timeout_seconds=30)
|
||||
db_session.add(w)
|
||||
db_session.commit()
|
||||
db_session.refresh(w)
|
||||
|
||||
# Setup upstream and snapshot
|
||||
u = Upstream(name="U1", base_url="http://u1")
|
||||
db_session.add(u)
|
||||
db_session.commit()
|
||||
db_session.refresh(u)
|
||||
|
||||
db_session.add(UpstreamRateSnapshot(
|
||||
upstream_id=u.id,
|
||||
snapshot_json=json.dumps({"groups": {"G1": {"rate": 1.0}, "G2": {"rate": 2.0}}}),
|
||||
captured_at=datetime.now(timezone.utc)
|
||||
))
|
||||
db_session.commit()
|
||||
|
||||
# Setup 2 bindings
|
||||
b1 = WebsiteGroupBinding(website_id=w.id, target_group_id="TG1", target_group_name="TG1",
|
||||
source_groups_json=json.dumps([{"upstream_id": u.id, "group_id": "G1"}]),
|
||||
algorithm="max_plus_percent", percent=10, enabled=True)
|
||||
b2 = WebsiteGroupBinding(website_id=w.id, target_group_id="TG2", target_group_name="TG2",
|
||||
source_groups_json=json.dumps([{"upstream_id": u.id, "group_id": "G2"}]),
|
||||
algorithm="max_plus_percent", percent=10, enabled=True)
|
||||
db_session.add_all([b1, b2])
|
||||
db_session.commit()
|
||||
|
||||
# Mock Website client
|
||||
class MockClient:
|
||||
def __init__(self, **kwargs): pass
|
||||
def __enter__(self): return self
|
||||
def __exit__(self, *args): pass
|
||||
def get_groups(self, endpoint): return [{"id": "TG1", "rate_multiplier": 1.0}, {"id": "TG2", "rate_multiplier": 2.0}]
|
||||
def update_group_rate(self, endpoint, gid, rate): pass
|
||||
|
||||
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", MockClient)
|
||||
|
||||
# Call batch sync
|
||||
resp = sync_website_group_bindings(w.id, db_session)
|
||||
|
||||
assert resp.total == 2
|
||||
assert resp.success == 2
|
||||
assert "成功 2" in resp.message
|
||||
assert len(resp.logs) == 2
|
||||
|
||||
# Verify logs in DB
|
||||
logs = db_session.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == w.id).all()
|
||||
assert len(logs) == 2
|
||||
|
||||
def test_sync_website_group_bindings_partial_failure(db_session, monkeypatch):
|
||||
from app.routers.websites import sync_website_group_bindings
|
||||
|
||||
w = Website(name="W1", base_url="http://w1", enabled=True, auto_sync_enabled=True, auth_config_json="{}", timeout_seconds=30)
|
||||
db_session.add(w)
|
||||
db_session.commit()
|
||||
db_session.refresh(w)
|
||||
|
||||
u = Upstream(name="U1", base_url="http://u1")
|
||||
db_session.add(u)
|
||||
db_session.commit()
|
||||
db_session.refresh(u)
|
||||
|
||||
db_session.add(UpstreamRateSnapshot(
|
||||
upstream_id=u.id,
|
||||
snapshot_json=json.dumps({"groups": {"G1": {"rate": 1.0}, "G2": {"rate": 2.0}}}),
|
||||
captured_at=datetime.now(timezone.utc)
|
||||
))
|
||||
db_session.commit()
|
||||
|
||||
b1 = WebsiteGroupBinding(website_id=w.id, target_group_id="TG1",
|
||||
source_groups_json=json.dumps([{"upstream_id": u.id, "group_id": "G1"}]),
|
||||
algorithm="max_plus_percent", percent=10, enabled=True)
|
||||
b2 = WebsiteGroupBinding(website_id=w.id, target_group_id="TG2",
|
||||
source_groups_json=json.dumps([{"upstream_id": u.id, "group_id": "G2"}]),
|
||||
algorithm="max_plus_percent", percent=10, enabled=True)
|
||||
db_session.add_all([b1, b2])
|
||||
db_session.commit()
|
||||
|
||||
class MockClient:
|
||||
def __init__(self, **kwargs): pass
|
||||
def __enter__(self): return self
|
||||
def __exit__(self, *args): pass
|
||||
def get_groups(self, endpoint): return [] # Will cause sync to fail or skip
|
||||
def update_group_rate(self, endpoint, gid, rate):
|
||||
if gid == "TG2": raise Exception("Simulated error")
|
||||
|
||||
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", MockClient)
|
||||
|
||||
# The sync_binding implementation returns a log with status="failed" on exception during write back
|
||||
|
||||
resp = sync_website_group_bindings(w.id, db_session)
|
||||
|
||||
# Depending on how sync_binding handles missing target group, it might be "success" (with message) or "failed"
|
||||
# Actually if get_groups is empty, sync_binding might fail to find old_rate but it continues to update_group_rate
|
||||
# If update_group_rate fails, it returns a failed log.
|
||||
|
||||
assert resp.total == 2
|
||||
# TG1: update_group_rate succeeds -> success
|
||||
# TG2: update_group_rate fails -> failed
|
||||
assert resp.success == 1
|
||||
assert resp.failed == 1
|
||||
|
||||
def test_sync_website_group_bindings_no_bindings(db_session):
|
||||
from app.routers.websites import sync_website_group_bindings
|
||||
w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}", timeout_seconds=30)
|
||||
db_session.add(w)
|
||||
db_session.commit()
|
||||
db_session.refresh(w)
|
||||
|
||||
resp = sync_website_group_bindings(w.id, db_session)
|
||||
assert resp.total == 0
|
||||
assert resp.message == "暂无绑定可同步"
|
||||
|
||||
def test_sync_website_group_bindings_not_found(db_session):
|
||||
from app.routers.websites import sync_website_group_bindings
|
||||
from fastapi import HTTPException
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
sync_website_group_bindings(999, db_session)
|
||||
assert exc.value.status_code == 404
|
||||
|
||||
def test_sync_website_group_bindings_exception_persists_log(db_session, monkeypatch):
|
||||
from app.routers.websites import sync_website_group_bindings
|
||||
|
||||
w = Website(name="W1", base_url="http://w1", enabled=True, auth_config_json="{}", timeout_seconds=30)
|
||||
db_session.add(w)
|
||||
db_session.commit()
|
||||
|
||||
b1 = WebsiteGroupBinding(website_id=w.id, target_group_id="TG1",
|
||||
source_groups_json="[]",
|
||||
algorithm="max_plus_percent", percent=10, enabled=True)
|
||||
db_session.add(b1)
|
||||
db_session.commit()
|
||||
|
||||
# Mock sync_binding to raise an exception
|
||||
def mock_sync_binding(db, binding, write):
|
||||
raise Exception("Fatal crash")
|
||||
|
||||
monkeypatch.setattr("app.routers.websites.sync_binding", mock_sync_binding)
|
||||
|
||||
resp = sync_website_group_bindings(w.id, db_session)
|
||||
|
||||
assert resp.total == 1
|
||||
assert resp.failed == 1
|
||||
assert resp.logs[0].status == "failed"
|
||||
assert "Fatal crash" in resp.logs[0].message
|
||||
|
||||
# Verify it was PERSISTED in the database
|
||||
log_in_db = db_session.query(WebsiteSyncLog).filter(WebsiteSyncLog.website_id == w.id).first()
|
||||
assert log_in_db is not None
|
||||
assert log_in_db.status == "failed"
|
||||
assert "Fatal crash" in log_in_db.message
|
||||
@@ -28,6 +28,7 @@ def db_session():
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
from app.models import admin_user, upstream, snapshot, webhook_config, notification_log, custom_page, website, revoked_token, upstream_key
|
||||
Base.metadata.create_all(bind=engine)
|
||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
db = TestingSessionLocal()
|
||||
@@ -116,6 +117,10 @@ def test_create_binding_runs_initial_sync(monkeypatch, client, db_session):
|
||||
class FakeClient:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
|
||||
def get_groups(self, endpoint):
|
||||
return [{"id": "target", "name": "Target group", "rate_multiplier": "1"}]
|
||||
@@ -204,6 +209,10 @@ def test_update_binding_runs_sync_after_save(monkeypatch, client, db_session):
|
||||
class FakeClient:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
|
||||
def get_groups(self, endpoint):
|
||||
return [{"id": "target", "name": "Target group", "rate_multiplier": "1"}]
|
||||
@@ -254,6 +263,10 @@ def test_update_binding_skips_write_when_disabled(monkeypatch, client, db_sessio
|
||||
class FakeClient:
|
||||
def __init__(self, **kwargs):
|
||||
raise AssertionError("should not write when binding is disabled")
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(websites_router, "Sub2ApiWebsiteClient", FakeClient)
|
||||
monkeypatch.setattr("app.services.website_sync.Sub2ApiWebsiteClient", FakeClient)
|
||||
@@ -289,6 +302,10 @@ def test_create_binding_notifies_when_website_rate_changes(monkeypatch, client,
|
||||
class FakeClient:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
|
||||
def get_groups(self, endpoint):
|
||||
return [{"id": "target", "name": "Target group", "rate_multiplier": "1"}]
|
||||
@@ -333,6 +350,10 @@ def test_create_binding_does_not_notify_when_website_rate_unchanged(monkeypatch,
|
||||
class FakeClient:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
|
||||
def get_groups(self, endpoint):
|
||||
return [{"id": "target", "name": "Target group", "rate_multiplier": "2.2"}]
|
||||
|
||||
@@ -267,6 +267,15 @@ export interface ImportAccountItem {
|
||||
raw: Record<string, any>
|
||||
}
|
||||
|
||||
export interface WebsiteBatchSyncResponse {
|
||||
total: number
|
||||
success: number
|
||||
failed: number
|
||||
skipped: number
|
||||
message: string
|
||||
logs: WebsiteSyncLog[]
|
||||
}
|
||||
|
||||
export const websitesApi = {
|
||||
list: () => api.get<WebsiteData[]>('/api/websites'),
|
||||
create: (data: WebsiteForm) => api.post<WebsiteData>('/api/websites', data),
|
||||
@@ -293,6 +302,7 @@ export const websitesApi = {
|
||||
updateBinding: (id: number, data: Partial<GroupBindingForm>) => api.put<GroupBindingData>(`/api/group-bindings/${id}`, data),
|
||||
deleteBinding: (id: number) => api.delete(`/api/group-bindings/${id}`),
|
||||
syncNow: (id: number) => api.post<WebsiteSyncLog>(`/api/group-bindings/${id}/sync-now`),
|
||||
syncWebsiteBindings: (id: number) => api.post<WebsiteBatchSyncResponse>(`/api/websites/${id}/group-bindings/sync-now`),
|
||||
logs: (params?: { website_id?: number; binding_id?: number; limit?: number; offset?: number }) =>
|
||||
api.get<WebsiteSyncLog[]>('/api/website-sync-logs', { params }),
|
||||
}
|
||||
|
||||
@@ -101,6 +101,16 @@
|
||||
>
|
||||
<el-option v-for="site in websites" :key="site.id" :label="site.name" :value="site.id" />
|
||||
</el-select>
|
||||
<el-button
|
||||
size="small"
|
||||
text
|
||||
:loading="batchSyncLoading"
|
||||
:disabled="!selectedWebsite || selectedWebsiteBindings.length === 0"
|
||||
@click="syncAllWebsiteBindings"
|
||||
title="同步当前网站的所有分组绑定"
|
||||
>
|
||||
一键同步
|
||||
</el-button>
|
||||
<el-button size="small" text :disabled="websites.length === 0" @click="openBindingCreate(selectedWebsite || websites[0])">新增绑定</el-button>
|
||||
</div>
|
||||
<div class="binding-list" v-loading="bindingLoading">
|
||||
@@ -469,6 +479,7 @@ const websiteLoading = ref(false)
|
||||
const groupsLoading = ref(false)
|
||||
const bindingLoading = ref(false)
|
||||
const logLoading = ref(false)
|
||||
const batchSyncLoading = ref(false)
|
||||
const importingGroups = ref(false)
|
||||
const importingAccounts = ref(false)
|
||||
const generatedKeyLoading = ref(false)
|
||||
@@ -924,16 +935,6 @@ async function saveBinding() {
|
||||
}
|
||||
}
|
||||
|
||||
async function toggleBinding(row: GroupBindingData) {
|
||||
try {
|
||||
await websitesApi.updateBinding(row.id, { enabled: row.enabled })
|
||||
ElMessage.success(row.enabled ? '已启用绑定' : '已停用绑定')
|
||||
} catch {
|
||||
row.enabled = !row.enabled
|
||||
ElMessage.error('操作失败')
|
||||
}
|
||||
}
|
||||
|
||||
async function syncBinding(row: GroupBindingData & { _syncing?: boolean }) {
|
||||
row._syncing = true
|
||||
try {
|
||||
@@ -946,9 +947,35 @@ async function syncBinding(row: GroupBindingData & { _syncing?: boolean }) {
|
||||
}
|
||||
}
|
||||
|
||||
async function syncAllWebsiteBindings() {
|
||||
if (!selectedWebsite.value) return
|
||||
batchSyncLoading.value = true
|
||||
try {
|
||||
const res = await websitesApi.syncWebsiteBindings(selectedWebsite.value.id)
|
||||
ElMessage.success(res.data.message)
|
||||
// 刷新相关数据
|
||||
await Promise.all([loadLogs(), loadWebsiteGroups(), loadBindings()])
|
||||
} catch (e: any) {
|
||||
ElMessage.error(e.response?.data?.detail || '批量同步失败')
|
||||
} finally {
|
||||
batchSyncLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function toggleBinding(row: GroupBindingData) {
|
||||
try {
|
||||
await websitesApi.updateBinding(row.id, { enabled: row.enabled })
|
||||
ElMessage.success(row.enabled ? '已启用' : '已禁用')
|
||||
await loadLogs()
|
||||
} catch (e: any) {
|
||||
row.enabled = !row.enabled
|
||||
ElMessage.error(e.response?.data?.detail || '更新失败')
|
||||
}
|
||||
}
|
||||
|
||||
async function deleteBinding(row: GroupBindingData) {
|
||||
try {
|
||||
await ElMessageBox.confirm('确认删除该绑定?', '删除确认', { type: 'warning' })
|
||||
await ElMessageBox.confirm('确认删除该绑定规则?', '删除确认', { type: 'warning' })
|
||||
await websitesApi.deleteBinding(row.id)
|
||||
ElMessage.success('已删除')
|
||||
await loadBindings()
|
||||
|
||||
Reference in New Issue
Block a user