fix: remove stale _decimal_str ref, add context manager to HTTP clients

- UpstreamClient & Sub2ApiWebsiteClient: add __enter__/__exit__
- Convert all call sites to `with Client(...) as c:` pattern
- Remove unused `upstream_name`/`upstream_base_url` locals in scheduler
- Fix stale _decimal_str→decimal_string in _rate_from_group
This commit is contained in:
SmartUp Developer
2026-05-17 11:29:51 +08:00
parent 8a6ed249be
commit 2934473770
7 changed files with 104 additions and 93 deletions
+31 -31
View File
@@ -146,29 +146,29 @@ def test_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current
if not u:
raise HTTPException(404, "upstream not found")
auth_config = json.loads(u.auth_config_json or "{}")
client = UpstreamClient(
with UpstreamClient(
base_url=u.base_url,
api_prefix=u.api_prefix,
auth_type=u.auth_type,
auth_config=auth_config,
timeout=float(u.timeout_seconds),
)
try:
client.login()
groups = client.get_available_groups(u.groups_endpoint)
u.last_status = "healthy"
u.last_error = None
u.last_checked_at = datetime.now(timezone.utc)
u.consecutive_failures = 0
db.commit()
return TestResult(success=True, message=f"连接成功,获取到 {len(groups)} 个分组")
except Exception as exc:
u.last_status = "unhealthy"
u.last_error = str(exc)
u.last_checked_at = datetime.now(timezone.utc)
u.consecutive_failures = (u.consecutive_failures or 0) + 1
db.commit()
return TestResult(success=False, message="连接失败", detail=str(exc))
) as client:
try:
client.login()
groups = client.get_available_groups(u.groups_endpoint)
u.last_status = "healthy"
u.last_error = None
u.last_checked_at = datetime.now(timezone.utc)
u.consecutive_failures = 0
db.commit()
return TestResult(success=True, message=f"连接成功,获取到 {len(groups)} 个分组")
except Exception as exc:
u.last_status = "unhealthy"
u.last_error = str(exc)
u.last_checked_at = datetime.now(timezone.utc)
u.consecutive_failures = (u.consecutive_failures or 0) + 1
db.commit()
return TestResult(success=False, message="连接失败", detail=str(exc))
@router.post("/{uid}/check-now", response_model=TestResult)
@@ -177,24 +177,24 @@ def check_now(uid: int, db: Session = Depends(get_db), _=Depends(get_current_use
if not u:
raise HTTPException(404, "upstream not found")
auth_config = json.loads(u.auth_config_json or "{}")
client = UpstreamClient(
with UpstreamClient(
base_url=u.base_url,
api_prefix=u.api_prefix,
auth_type=u.auth_type,
auth_config=auth_config,
timeout=float(u.timeout_seconds),
)
try:
client.login()
groups = client.get_available_groups(u.groups_endpoint)
raw_rates = client.get_group_rates(u.rate_endpoint)
snapshot = build_snapshot(u.id, u.base_url, u.api_prefix, groups, raw_rates)
except Exception as exc:
u.consecutive_failures = (u.consecutive_failures or 0) + 1
u.last_error = str(exc)
u.last_checked_at = datetime.now(timezone.utc)
db.commit()
return TestResult(success=False, message="检测失败", detail=str(exc))
) as client:
try:
client.login()
groups = client.get_available_groups(u.groups_endpoint)
raw_rates = client.get_group_rates(u.rate_endpoint)
snapshot = build_snapshot(u.id, u.base_url, u.api_prefix, groups, raw_rates)
except Exception as exc:
u.consecutive_failures = (u.consecutive_failures or 0) + 1
u.last_error = str(exc)
u.last_checked_at = datetime.now(timezone.utc)
db.commit()
return TestResult(success=False, message="检测失败", detail=str(exc))
prev_row = (
db.query(UpstreamRateSnapshot)
+4 -2
View File
@@ -188,7 +188,8 @@ def test_website(wid: int, db: Session = Depends(get_db), _=Depends(get_current_
if not row:
raise HTTPException(404, "website not found")
try:
groups = _client(row).get_groups(row.groups_endpoint)
with _client(row) as c:
groups = c.get_groups(row.groups_endpoint)
row.last_status = "healthy"
row.last_error = None
row.last_checked_at = datetime.now(timezone.utc)
@@ -208,7 +209,8 @@ def list_website_groups(wid: int, db: Session = Depends(get_db), _=Depends(get_c
if not row:
raise HTTPException(404, "website not found")
try:
return _client(row).get_groups(row.groups_endpoint)
with _client(row) as c:
return c.get_groups(row.groups_endpoint)
except Exception as exc:
raise HTTPException(502, str(exc))
+33 -36
View File
@@ -37,7 +37,6 @@ def _check_upstream(upstream_id: int) -> None:
settings = get_settings()
# ── Phase 1: upstream check + DB write ──────────────────────────
db: Session = SessionLocal()
client = None
try:
upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first()
if not upstream or not upstream.enabled:
@@ -45,46 +44,45 @@ def _check_upstream(upstream_id: int) -> None:
return
auth_config = json.loads(upstream.auth_config_json or "{}")
client = UpstreamClient(
was_unhealthy = upstream.last_status == "unhealthy"
snapshot = None
changes = None
with UpstreamClient(
base_url=upstream.base_url,
api_prefix=upstream.api_prefix,
auth_type=upstream.auth_type,
auth_config=auth_config,
timeout=float(upstream.timeout_seconds),
)
) as client:
try:
client.login()
groups = client.get_available_groups(upstream.groups_endpoint)
raw_rates = client.get_group_rates(upstream.rate_endpoint)
snapshot = build_snapshot(
upstream.id, upstream.base_url, upstream.api_prefix, groups, raw_rates
)
except Exception as exc:
# failure path
upstream.consecutive_failures = (upstream.consecutive_failures or 0) + 1
upstream.last_error = str(exc)
upstream.last_checked_at = datetime.now(timezone.utc)
threshold = settings.unhealthy_threshold
became_unhealthy = (
upstream.consecutive_failures >= threshold
and upstream.last_status != "unhealthy"
)
if became_unhealthy:
upstream.last_status = "unhealthy"
db.commit()
logger.warning("upstream %s check failed: %s", upstream.name, exc)
# Phase 2: notify unhealthy in a fresh session
if became_unhealthy:
_notify_status(upstream.id, upstream.name, upstream.base_url,
"upstream_unhealthy", str(exc))
return
was_unhealthy = upstream.last_status == "unhealthy"
snapshot = None
changes = None
try:
client.login()
groups = client.get_available_groups(upstream.groups_endpoint)
raw_rates = client.get_group_rates(upstream.rate_endpoint)
snapshot = build_snapshot(
upstream.id, upstream.base_url, upstream.api_prefix, groups, raw_rates
)
except Exception as exc:
# failure path
upstream.consecutive_failures = (upstream.consecutive_failures or 0) + 1
upstream.last_error = str(exc)
upstream.last_checked_at = datetime.now(timezone.utc)
threshold = settings.unhealthy_threshold
became_unhealthy = (
upstream.consecutive_failures >= threshold
and upstream.last_status != "unhealthy"
)
if became_unhealthy:
upstream.last_status = "unhealthy"
db.commit()
logger.warning("upstream %s check failed: %s", upstream.name, exc)
# Phase 2: notify unhealthy in a fresh session
if became_unhealthy:
_notify_status(upstream.id, upstream.name, upstream.base_url,
"upstream_unhealthy", str(exc))
return
# success path
# success path (client auto-closed by `with`)
prev_snapshot_row = (
db.query(UpstreamRateSnapshot)
.filter(UpstreamRateSnapshot.upstream_id == upstream_id)
@@ -116,7 +114,6 @@ def _check_upstream(upstream_id: int) -> None:
)
finally:
client.close()
db.close()
# ── Phase 2: notifications (independent sessions) ──────────────
+7 -1
View File
@@ -83,7 +83,7 @@ def _rate_from_group(group: dict[str, Any]) -> str:
"effective_rate_multiplier", "effectiveRateMultiplier",
"rate_multiplier", "rateMultiplier",
):
r = _decimal_str(group.get(key))
r = decimal_string(group.get(key))
if r:
return r
return ""
@@ -214,6 +214,12 @@ class UpstreamClient:
def close(self) -> None:
self._client.close()
def __enter__(self) -> UpstreamClient:
return self
def __exit__(self, *args: Any) -> None:
self.close()
def _url(self, path: str) -> str:
prefix = f"/{self.api_prefix}" if self.api_prefix else ""
return f"{self.base_url}{prefix}/{path.lstrip('/')}"
+6
View File
@@ -105,6 +105,12 @@ class Sub2ApiWebsiteClient:
def close(self) -> None:
self._client.close()
def __enter__(self) -> Sub2ApiWebsiteClient:
return self
def __exit__(self, *args: Any) -> None:
self.close()
def _url(self, path: str) -> str:
prefix = f"/{self.api_prefix}" if self.api_prefix else ""
return f"{self.base_url}{prefix}/{path.lstrip('/')}"
+5 -5
View File
@@ -118,11 +118,11 @@ def sync_binding(db: Session, binding: WebsiteGroupBinding, write: bool = True)
old_rate = None
if write and website.enabled and website.auto_sync_enabled and binding.enabled:
try:
client = _client_for(website)
groups = client.get_groups(website.groups_endpoint)
target = next((item for item in groups if item.get("id") == binding.target_group_id), None)
old_rate = target.get("rate_multiplier") if target else None
client.update_group_rate(website.group_update_endpoint, binding.target_group_id, target_rate)
with _client_for(website) as client:
groups = client.get_groups(website.groups_endpoint)
target = next((item for item in groups if item.get("id") == binding.target_group_id), None)
old_rate = target.get("rate_multiplier") if target else None
client.update_group_rate(website.group_update_endpoint, binding.target_group_id, target_rate)
website.last_status = "healthy"
website.last_error = None
except Exception as exc:
+18 -18
View File
@@ -6,31 +6,31 @@ from app.services.upstream_client import UpstreamClient
logging.basicConfig(level=logging.DEBUG)
def main():
client = UpstreamClient(
with UpstreamClient(
base_url="http://170.106.100.210:55555",
api_prefix="",
auth_type="bearer",
auth_config={"token": ""}, # We don't have token, but /api/group/ in some new-api may be open, or fail with 401
timeout=10.0,
)
try:
groups = client.get_available_groups("/api/group/")
print("Groups:", groups)
except Exception as e:
print("Groups Error:", e)
) as client:
try:
groups = client.get_available_groups("/api/group/")
print("Groups:", groups)
except Exception as e:
print("Groups Error:", e)
try:
rates = client.get_group_rates("/api/option/?key=GroupRatio")
print("Rates:", rates)
except Exception as e:
print("Rates Error:", e)
try:
rates = client.get_group_rates("/api/option/?key=GroupRatio")
print("Rates:", rates)
except Exception as e:
print("Rates Error:", e)
try:
from app.services.upstream_client import _extract_rates_map, _unwrap_list
print("Unwrapped Groups:", _unwrap_list(groups))
print("Extracted Rates:", _extract_rates_map(rates))
except Exception as e:
pass
try:
from app.services.upstream_client import _extract_rates_map, _unwrap_list
print("Unwrapped Groups:", _unwrap_list(groups))
print("Extracted Rates:", _extract_rates_map(rates))
except Exception as e:
pass
if __name__ == "__main__":
main()