fix: remove stale _decimal_str ref, add context manager to HTTP clients
- UpstreamClient & Sub2ApiWebsiteClient: add __enter__/__exit__ - Convert all call sites to `with Client(...) as c:` pattern - Remove unused `upstream_name`/`upstream_base_url` locals in scheduler - Fix stale _decimal_str→decimal_string in _rate_from_group
This commit is contained in:
@@ -146,29 +146,29 @@ def test_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current
|
||||
if not u:
|
||||
raise HTTPException(404, "upstream not found")
|
||||
auth_config = json.loads(u.auth_config_json or "{}")
|
||||
client = UpstreamClient(
|
||||
with UpstreamClient(
|
||||
base_url=u.base_url,
|
||||
api_prefix=u.api_prefix,
|
||||
auth_type=u.auth_type,
|
||||
auth_config=auth_config,
|
||||
timeout=float(u.timeout_seconds),
|
||||
)
|
||||
try:
|
||||
client.login()
|
||||
groups = client.get_available_groups(u.groups_endpoint)
|
||||
u.last_status = "healthy"
|
||||
u.last_error = None
|
||||
u.last_checked_at = datetime.now(timezone.utc)
|
||||
u.consecutive_failures = 0
|
||||
db.commit()
|
||||
return TestResult(success=True, message=f"连接成功,获取到 {len(groups)} 个分组")
|
||||
except Exception as exc:
|
||||
u.last_status = "unhealthy"
|
||||
u.last_error = str(exc)
|
||||
u.last_checked_at = datetime.now(timezone.utc)
|
||||
u.consecutive_failures = (u.consecutive_failures or 0) + 1
|
||||
db.commit()
|
||||
return TestResult(success=False, message="连接失败", detail=str(exc))
|
||||
) as client:
|
||||
try:
|
||||
client.login()
|
||||
groups = client.get_available_groups(u.groups_endpoint)
|
||||
u.last_status = "healthy"
|
||||
u.last_error = None
|
||||
u.last_checked_at = datetime.now(timezone.utc)
|
||||
u.consecutive_failures = 0
|
||||
db.commit()
|
||||
return TestResult(success=True, message=f"连接成功,获取到 {len(groups)} 个分组")
|
||||
except Exception as exc:
|
||||
u.last_status = "unhealthy"
|
||||
u.last_error = str(exc)
|
||||
u.last_checked_at = datetime.now(timezone.utc)
|
||||
u.consecutive_failures = (u.consecutive_failures or 0) + 1
|
||||
db.commit()
|
||||
return TestResult(success=False, message="连接失败", detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/{uid}/check-now", response_model=TestResult)
|
||||
@@ -177,24 +177,24 @@ def check_now(uid: int, db: Session = Depends(get_db), _=Depends(get_current_use
|
||||
if not u:
|
||||
raise HTTPException(404, "upstream not found")
|
||||
auth_config = json.loads(u.auth_config_json or "{}")
|
||||
client = UpstreamClient(
|
||||
with UpstreamClient(
|
||||
base_url=u.base_url,
|
||||
api_prefix=u.api_prefix,
|
||||
auth_type=u.auth_type,
|
||||
auth_config=auth_config,
|
||||
timeout=float(u.timeout_seconds),
|
||||
)
|
||||
try:
|
||||
client.login()
|
||||
groups = client.get_available_groups(u.groups_endpoint)
|
||||
raw_rates = client.get_group_rates(u.rate_endpoint)
|
||||
snapshot = build_snapshot(u.id, u.base_url, u.api_prefix, groups, raw_rates)
|
||||
except Exception as exc:
|
||||
u.consecutive_failures = (u.consecutive_failures or 0) + 1
|
||||
u.last_error = str(exc)
|
||||
u.last_checked_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
return TestResult(success=False, message="检测失败", detail=str(exc))
|
||||
) as client:
|
||||
try:
|
||||
client.login()
|
||||
groups = client.get_available_groups(u.groups_endpoint)
|
||||
raw_rates = client.get_group_rates(u.rate_endpoint)
|
||||
snapshot = build_snapshot(u.id, u.base_url, u.api_prefix, groups, raw_rates)
|
||||
except Exception as exc:
|
||||
u.consecutive_failures = (u.consecutive_failures or 0) + 1
|
||||
u.last_error = str(exc)
|
||||
u.last_checked_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
return TestResult(success=False, message="检测失败", detail=str(exc))
|
||||
|
||||
prev_row = (
|
||||
db.query(UpstreamRateSnapshot)
|
||||
|
||||
@@ -188,7 +188,8 @@ def test_website(wid: int, db: Session = Depends(get_db), _=Depends(get_current_
|
||||
if not row:
|
||||
raise HTTPException(404, "website not found")
|
||||
try:
|
||||
groups = _client(row).get_groups(row.groups_endpoint)
|
||||
with _client(row) as c:
|
||||
groups = c.get_groups(row.groups_endpoint)
|
||||
row.last_status = "healthy"
|
||||
row.last_error = None
|
||||
row.last_checked_at = datetime.now(timezone.utc)
|
||||
@@ -208,7 +209,8 @@ def list_website_groups(wid: int, db: Session = Depends(get_db), _=Depends(get_c
|
||||
if not row:
|
||||
raise HTTPException(404, "website not found")
|
||||
try:
|
||||
return _client(row).get_groups(row.groups_endpoint)
|
||||
with _client(row) as c:
|
||||
return c.get_groups(row.groups_endpoint)
|
||||
except Exception as exc:
|
||||
raise HTTPException(502, str(exc))
|
||||
|
||||
|
||||
@@ -37,7 +37,6 @@ def _check_upstream(upstream_id: int) -> None:
|
||||
settings = get_settings()
|
||||
# ── Phase 1: upstream check + DB write ──────────────────────────
|
||||
db: Session = SessionLocal()
|
||||
client = None
|
||||
try:
|
||||
upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first()
|
||||
if not upstream or not upstream.enabled:
|
||||
@@ -45,46 +44,45 @@ def _check_upstream(upstream_id: int) -> None:
|
||||
return
|
||||
|
||||
auth_config = json.loads(upstream.auth_config_json or "{}")
|
||||
client = UpstreamClient(
|
||||
was_unhealthy = upstream.last_status == "unhealthy"
|
||||
snapshot = None
|
||||
changes = None
|
||||
|
||||
with UpstreamClient(
|
||||
base_url=upstream.base_url,
|
||||
api_prefix=upstream.api_prefix,
|
||||
auth_type=upstream.auth_type,
|
||||
auth_config=auth_config,
|
||||
timeout=float(upstream.timeout_seconds),
|
||||
)
|
||||
) as client:
|
||||
try:
|
||||
client.login()
|
||||
groups = client.get_available_groups(upstream.groups_endpoint)
|
||||
raw_rates = client.get_group_rates(upstream.rate_endpoint)
|
||||
snapshot = build_snapshot(
|
||||
upstream.id, upstream.base_url, upstream.api_prefix, groups, raw_rates
|
||||
)
|
||||
except Exception as exc:
|
||||
# failure path
|
||||
upstream.consecutive_failures = (upstream.consecutive_failures or 0) + 1
|
||||
upstream.last_error = str(exc)
|
||||
upstream.last_checked_at = datetime.now(timezone.utc)
|
||||
threshold = settings.unhealthy_threshold
|
||||
became_unhealthy = (
|
||||
upstream.consecutive_failures >= threshold
|
||||
and upstream.last_status != "unhealthy"
|
||||
)
|
||||
if became_unhealthy:
|
||||
upstream.last_status = "unhealthy"
|
||||
db.commit()
|
||||
logger.warning("upstream %s check failed: %s", upstream.name, exc)
|
||||
# Phase 2: notify unhealthy in a fresh session
|
||||
if became_unhealthy:
|
||||
_notify_status(upstream.id, upstream.name, upstream.base_url,
|
||||
"upstream_unhealthy", str(exc))
|
||||
return
|
||||
|
||||
was_unhealthy = upstream.last_status == "unhealthy"
|
||||
snapshot = None
|
||||
changes = None
|
||||
|
||||
try:
|
||||
client.login()
|
||||
groups = client.get_available_groups(upstream.groups_endpoint)
|
||||
raw_rates = client.get_group_rates(upstream.rate_endpoint)
|
||||
snapshot = build_snapshot(
|
||||
upstream.id, upstream.base_url, upstream.api_prefix, groups, raw_rates
|
||||
)
|
||||
except Exception as exc:
|
||||
# failure path
|
||||
upstream.consecutive_failures = (upstream.consecutive_failures or 0) + 1
|
||||
upstream.last_error = str(exc)
|
||||
upstream.last_checked_at = datetime.now(timezone.utc)
|
||||
threshold = settings.unhealthy_threshold
|
||||
became_unhealthy = (
|
||||
upstream.consecutive_failures >= threshold
|
||||
and upstream.last_status != "unhealthy"
|
||||
)
|
||||
if became_unhealthy:
|
||||
upstream.last_status = "unhealthy"
|
||||
db.commit()
|
||||
logger.warning("upstream %s check failed: %s", upstream.name, exc)
|
||||
# Phase 2: notify unhealthy in a fresh session
|
||||
if became_unhealthy:
|
||||
_notify_status(upstream.id, upstream.name, upstream.base_url,
|
||||
"upstream_unhealthy", str(exc))
|
||||
return
|
||||
|
||||
# success path
|
||||
# success path (client auto-closed by `with`)
|
||||
prev_snapshot_row = (
|
||||
db.query(UpstreamRateSnapshot)
|
||||
.filter(UpstreamRateSnapshot.upstream_id == upstream_id)
|
||||
@@ -116,7 +114,6 @@ def _check_upstream(upstream_id: int) -> None:
|
||||
)
|
||||
|
||||
finally:
|
||||
client.close()
|
||||
db.close()
|
||||
|
||||
# ── Phase 2: notifications (independent sessions) ──────────────
|
||||
|
||||
@@ -83,7 +83,7 @@ def _rate_from_group(group: dict[str, Any]) -> str:
|
||||
"effective_rate_multiplier", "effectiveRateMultiplier",
|
||||
"rate_multiplier", "rateMultiplier",
|
||||
):
|
||||
r = _decimal_str(group.get(key))
|
||||
r = decimal_string(group.get(key))
|
||||
if r:
|
||||
return r
|
||||
return ""
|
||||
@@ -214,6 +214,12 @@ class UpstreamClient:
|
||||
def close(self) -> None:
|
||||
self._client.close()
|
||||
|
||||
def __enter__(self) -> UpstreamClient:
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def _url(self, path: str) -> str:
|
||||
prefix = f"/{self.api_prefix}" if self.api_prefix else ""
|
||||
return f"{self.base_url}{prefix}/{path.lstrip('/')}"
|
||||
|
||||
@@ -105,6 +105,12 @@ class Sub2ApiWebsiteClient:
|
||||
def close(self) -> None:
|
||||
self._client.close()
|
||||
|
||||
def __enter__(self) -> Sub2ApiWebsiteClient:
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def _url(self, path: str) -> str:
|
||||
prefix = f"/{self.api_prefix}" if self.api_prefix else ""
|
||||
return f"{self.base_url}{prefix}/{path.lstrip('/')}"
|
||||
|
||||
@@ -118,11 +118,11 @@ def sync_binding(db: Session, binding: WebsiteGroupBinding, write: bool = True)
|
||||
old_rate = None
|
||||
if write and website.enabled and website.auto_sync_enabled and binding.enabled:
|
||||
try:
|
||||
client = _client_for(website)
|
||||
groups = client.get_groups(website.groups_endpoint)
|
||||
target = next((item for item in groups if item.get("id") == binding.target_group_id), None)
|
||||
old_rate = target.get("rate_multiplier") if target else None
|
||||
client.update_group_rate(website.group_update_endpoint, binding.target_group_id, target_rate)
|
||||
with _client_for(website) as client:
|
||||
groups = client.get_groups(website.groups_endpoint)
|
||||
target = next((item for item in groups if item.get("id") == binding.target_group_id), None)
|
||||
old_rate = target.get("rate_multiplier") if target else None
|
||||
client.update_group_rate(website.group_update_endpoint, binding.target_group_id, target_rate)
|
||||
website.last_status = "healthy"
|
||||
website.last_error = None
|
||||
except Exception as exc:
|
||||
|
||||
+18
-18
@@ -6,31 +6,31 @@ from app.services.upstream_client import UpstreamClient
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
def main():
|
||||
client = UpstreamClient(
|
||||
with UpstreamClient(
|
||||
base_url="http://170.106.100.210:55555",
|
||||
api_prefix="",
|
||||
auth_type="bearer",
|
||||
auth_config={"token": ""}, # We don't have token, but /api/group/ in some new-api may be open, or fail with 401
|
||||
timeout=10.0,
|
||||
)
|
||||
try:
|
||||
groups = client.get_available_groups("/api/group/")
|
||||
print("Groups:", groups)
|
||||
except Exception as e:
|
||||
print("Groups Error:", e)
|
||||
) as client:
|
||||
try:
|
||||
groups = client.get_available_groups("/api/group/")
|
||||
print("Groups:", groups)
|
||||
except Exception as e:
|
||||
print("Groups Error:", e)
|
||||
|
||||
try:
|
||||
rates = client.get_group_rates("/api/option/?key=GroupRatio")
|
||||
print("Rates:", rates)
|
||||
except Exception as e:
|
||||
print("Rates Error:", e)
|
||||
try:
|
||||
rates = client.get_group_rates("/api/option/?key=GroupRatio")
|
||||
print("Rates:", rates)
|
||||
except Exception as e:
|
||||
print("Rates Error:", e)
|
||||
|
||||
try:
|
||||
from app.services.upstream_client import _extract_rates_map, _unwrap_list
|
||||
print("Unwrapped Groups:", _unwrap_list(groups))
|
||||
print("Extracted Rates:", _extract_rates_map(rates))
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
from app.services.upstream_client import _extract_rates_map, _unwrap_list
|
||||
print("Unwrapped Groups:", _unwrap_list(groups))
|
||||
print("Extracted Rates:", _extract_rates_map(rates))
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user