From 2934473770398c4043d5e9b2159f620308745d43 Mon Sep 17 00:00:00 2001 From: SmartUp Developer Date: Sun, 17 May 2026 11:29:51 +0800 Subject: [PATCH] fix: remove stale _decimal_str ref, add context manager to HTTP clients MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - UpstreamClient & Sub2ApiWebsiteClient: add __enter__/__exit__ - Convert all call sites to `with Client(...) as c:` pattern - Remove unused `upstream_name`/`upstream_base_url` locals in scheduler - Fix stale _decimal_str→decimal_string in _rate_from_group --- backend/app/routers/upstreams.py | 62 +++++++++++----------- backend/app/routers/websites.py | 6 ++- backend/app/services/scheduler.py | 69 ++++++++++++------------- backend/app/services/upstream_client.py | 8 ++- backend/app/services/website_client.py | 6 +++ backend/app/services/website_sync.py | 10 ++-- backend/test_upstream.py | 36 ++++++------- 7 files changed, 104 insertions(+), 93 deletions(-) diff --git a/backend/app/routers/upstreams.py b/backend/app/routers/upstreams.py index 114e106..48e2a6c 100644 --- a/backend/app/routers/upstreams.py +++ b/backend/app/routers/upstreams.py @@ -146,29 +146,29 @@ def test_upstream(uid: int, db: Session = Depends(get_db), _=Depends(get_current if not u: raise HTTPException(404, "upstream not found") auth_config = json.loads(u.auth_config_json or "{}") - client = UpstreamClient( + with UpstreamClient( base_url=u.base_url, api_prefix=u.api_prefix, auth_type=u.auth_type, auth_config=auth_config, timeout=float(u.timeout_seconds), - ) - try: - client.login() - groups = client.get_available_groups(u.groups_endpoint) - u.last_status = "healthy" - u.last_error = None - u.last_checked_at = datetime.now(timezone.utc) - u.consecutive_failures = 0 - db.commit() - return TestResult(success=True, message=f"连接成功,获取到 {len(groups)} 个分组") - except Exception as exc: - u.last_status = "unhealthy" - u.last_error = str(exc) - u.last_checked_at = datetime.now(timezone.utc) - u.consecutive_failures = (u.consecutive_failures or 0) + 1 - db.commit() - return TestResult(success=False, message="连接失败", detail=str(exc)) + ) as client: + try: + client.login() + groups = client.get_available_groups(u.groups_endpoint) + u.last_status = "healthy" + u.last_error = None + u.last_checked_at = datetime.now(timezone.utc) + u.consecutive_failures = 0 + db.commit() + return TestResult(success=True, message=f"连接成功,获取到 {len(groups)} 个分组") + except Exception as exc: + u.last_status = "unhealthy" + u.last_error = str(exc) + u.last_checked_at = datetime.now(timezone.utc) + u.consecutive_failures = (u.consecutive_failures or 0) + 1 + db.commit() + return TestResult(success=False, message="连接失败", detail=str(exc)) @router.post("/{uid}/check-now", response_model=TestResult) @@ -177,24 +177,24 @@ def check_now(uid: int, db: Session = Depends(get_db), _=Depends(get_current_use if not u: raise HTTPException(404, "upstream not found") auth_config = json.loads(u.auth_config_json or "{}") - client = UpstreamClient( + with UpstreamClient( base_url=u.base_url, api_prefix=u.api_prefix, auth_type=u.auth_type, auth_config=auth_config, timeout=float(u.timeout_seconds), - ) - try: - client.login() - groups = client.get_available_groups(u.groups_endpoint) - raw_rates = client.get_group_rates(u.rate_endpoint) - snapshot = build_snapshot(u.id, u.base_url, u.api_prefix, groups, raw_rates) - except Exception as exc: - u.consecutive_failures = (u.consecutive_failures or 0) + 1 - u.last_error = str(exc) - u.last_checked_at = datetime.now(timezone.utc) - db.commit() - return TestResult(success=False, message="检测失败", detail=str(exc)) + ) as client: + try: + client.login() + groups = client.get_available_groups(u.groups_endpoint) + raw_rates = client.get_group_rates(u.rate_endpoint) + snapshot = build_snapshot(u.id, u.base_url, u.api_prefix, groups, raw_rates) + except Exception as exc: + u.consecutive_failures = (u.consecutive_failures or 0) + 1 + u.last_error = str(exc) + u.last_checked_at = datetime.now(timezone.utc) + db.commit() + return TestResult(success=False, message="检测失败", detail=str(exc)) prev_row = ( db.query(UpstreamRateSnapshot) diff --git a/backend/app/routers/websites.py b/backend/app/routers/websites.py index f3bade1..c9c1774 100644 --- a/backend/app/routers/websites.py +++ b/backend/app/routers/websites.py @@ -188,7 +188,8 @@ def test_website(wid: int, db: Session = Depends(get_db), _=Depends(get_current_ if not row: raise HTTPException(404, "website not found") try: - groups = _client(row).get_groups(row.groups_endpoint) + with _client(row) as c: + groups = c.get_groups(row.groups_endpoint) row.last_status = "healthy" row.last_error = None row.last_checked_at = datetime.now(timezone.utc) @@ -208,7 +209,8 @@ def list_website_groups(wid: int, db: Session = Depends(get_db), _=Depends(get_c if not row: raise HTTPException(404, "website not found") try: - return _client(row).get_groups(row.groups_endpoint) + with _client(row) as c: + return c.get_groups(row.groups_endpoint) except Exception as exc: raise HTTPException(502, str(exc)) diff --git a/backend/app/services/scheduler.py b/backend/app/services/scheduler.py index bc19dda..f9b43d1 100644 --- a/backend/app/services/scheduler.py +++ b/backend/app/services/scheduler.py @@ -37,7 +37,6 @@ def _check_upstream(upstream_id: int) -> None: settings = get_settings() # ── Phase 1: upstream check + DB write ────────────────────────── db: Session = SessionLocal() - client = None try: upstream = db.query(Upstream).filter(Upstream.id == upstream_id).first() if not upstream or not upstream.enabled: @@ -45,46 +44,45 @@ def _check_upstream(upstream_id: int) -> None: return auth_config = json.loads(upstream.auth_config_json or "{}") - client = UpstreamClient( + was_unhealthy = upstream.last_status == "unhealthy" + snapshot = None + changes = None + + with UpstreamClient( base_url=upstream.base_url, api_prefix=upstream.api_prefix, auth_type=upstream.auth_type, auth_config=auth_config, timeout=float(upstream.timeout_seconds), - ) + ) as client: + try: + client.login() + groups = client.get_available_groups(upstream.groups_endpoint) + raw_rates = client.get_group_rates(upstream.rate_endpoint) + snapshot = build_snapshot( + upstream.id, upstream.base_url, upstream.api_prefix, groups, raw_rates + ) + except Exception as exc: + # failure path + upstream.consecutive_failures = (upstream.consecutive_failures or 0) + 1 + upstream.last_error = str(exc) + upstream.last_checked_at = datetime.now(timezone.utc) + threshold = settings.unhealthy_threshold + became_unhealthy = ( + upstream.consecutive_failures >= threshold + and upstream.last_status != "unhealthy" + ) + if became_unhealthy: + upstream.last_status = "unhealthy" + db.commit() + logger.warning("upstream %s check failed: %s", upstream.name, exc) + # Phase 2: notify unhealthy in a fresh session + if became_unhealthy: + _notify_status(upstream.id, upstream.name, upstream.base_url, + "upstream_unhealthy", str(exc)) + return - was_unhealthy = upstream.last_status == "unhealthy" - snapshot = None - changes = None - - try: - client.login() - groups = client.get_available_groups(upstream.groups_endpoint) - raw_rates = client.get_group_rates(upstream.rate_endpoint) - snapshot = build_snapshot( - upstream.id, upstream.base_url, upstream.api_prefix, groups, raw_rates - ) - except Exception as exc: - # failure path - upstream.consecutive_failures = (upstream.consecutive_failures or 0) + 1 - upstream.last_error = str(exc) - upstream.last_checked_at = datetime.now(timezone.utc) - threshold = settings.unhealthy_threshold - became_unhealthy = ( - upstream.consecutive_failures >= threshold - and upstream.last_status != "unhealthy" - ) - if became_unhealthy: - upstream.last_status = "unhealthy" - db.commit() - logger.warning("upstream %s check failed: %s", upstream.name, exc) - # Phase 2: notify unhealthy in a fresh session - if became_unhealthy: - _notify_status(upstream.id, upstream.name, upstream.base_url, - "upstream_unhealthy", str(exc)) - return - - # success path + # success path (client auto-closed by `with`) prev_snapshot_row = ( db.query(UpstreamRateSnapshot) .filter(UpstreamRateSnapshot.upstream_id == upstream_id) @@ -116,7 +114,6 @@ def _check_upstream(upstream_id: int) -> None: ) finally: - client.close() db.close() # ── Phase 2: notifications (independent sessions) ────────────── diff --git a/backend/app/services/upstream_client.py b/backend/app/services/upstream_client.py index 518edef..0a885ab 100644 --- a/backend/app/services/upstream_client.py +++ b/backend/app/services/upstream_client.py @@ -83,7 +83,7 @@ def _rate_from_group(group: dict[str, Any]) -> str: "effective_rate_multiplier", "effectiveRateMultiplier", "rate_multiplier", "rateMultiplier", ): - r = _decimal_str(group.get(key)) + r = decimal_string(group.get(key)) if r: return r return "" @@ -214,6 +214,12 @@ class UpstreamClient: def close(self) -> None: self._client.close() + def __enter__(self) -> UpstreamClient: + return self + + def __exit__(self, *args: Any) -> None: + self.close() + def _url(self, path: str) -> str: prefix = f"/{self.api_prefix}" if self.api_prefix else "" return f"{self.base_url}{prefix}/{path.lstrip('/')}" diff --git a/backend/app/services/website_client.py b/backend/app/services/website_client.py index 8e28441..f84a701 100644 --- a/backend/app/services/website_client.py +++ b/backend/app/services/website_client.py @@ -105,6 +105,12 @@ class Sub2ApiWebsiteClient: def close(self) -> None: self._client.close() + def __enter__(self) -> Sub2ApiWebsiteClient: + return self + + def __exit__(self, *args: Any) -> None: + self.close() + def _url(self, path: str) -> str: prefix = f"/{self.api_prefix}" if self.api_prefix else "" return f"{self.base_url}{prefix}/{path.lstrip('/')}" diff --git a/backend/app/services/website_sync.py b/backend/app/services/website_sync.py index 3dd7031..dab1152 100644 --- a/backend/app/services/website_sync.py +++ b/backend/app/services/website_sync.py @@ -118,11 +118,11 @@ def sync_binding(db: Session, binding: WebsiteGroupBinding, write: bool = True) old_rate = None if write and website.enabled and website.auto_sync_enabled and binding.enabled: try: - client = _client_for(website) - groups = client.get_groups(website.groups_endpoint) - target = next((item for item in groups if item.get("id") == binding.target_group_id), None) - old_rate = target.get("rate_multiplier") if target else None - client.update_group_rate(website.group_update_endpoint, binding.target_group_id, target_rate) + with _client_for(website) as client: + groups = client.get_groups(website.groups_endpoint) + target = next((item for item in groups if item.get("id") == binding.target_group_id), None) + old_rate = target.get("rate_multiplier") if target else None + client.update_group_rate(website.group_update_endpoint, binding.target_group_id, target_rate) website.last_status = "healthy" website.last_error = None except Exception as exc: diff --git a/backend/test_upstream.py b/backend/test_upstream.py index f166ea3..46ec497 100644 --- a/backend/test_upstream.py +++ b/backend/test_upstream.py @@ -6,31 +6,31 @@ from app.services.upstream_client import UpstreamClient logging.basicConfig(level=logging.DEBUG) def main(): - client = UpstreamClient( + with UpstreamClient( base_url="http://170.106.100.210:55555", api_prefix="", auth_type="bearer", auth_config={"token": ""}, # We don't have token, but /api/group/ in some new-api may be open, or fail with 401 timeout=10.0, - ) - try: - groups = client.get_available_groups("/api/group/") - print("Groups:", groups) - except Exception as e: - print("Groups Error:", e) + ) as client: + try: + groups = client.get_available_groups("/api/group/") + print("Groups:", groups) + except Exception as e: + print("Groups Error:", e) - try: - rates = client.get_group_rates("/api/option/?key=GroupRatio") - print("Rates:", rates) - except Exception as e: - print("Rates Error:", e) + try: + rates = client.get_group_rates("/api/option/?key=GroupRatio") + print("Rates:", rates) + except Exception as e: + print("Rates Error:", e) - try: - from app.services.upstream_client import _extract_rates_map, _unwrap_list - print("Unwrapped Groups:", _unwrap_list(groups)) - print("Extracted Rates:", _extract_rates_map(rates)) - except Exception as e: - pass + try: + from app.services.upstream_client import _extract_rates_map, _unwrap_list + print("Unwrapped Groups:", _unwrap_list(groups)) + print("Extracted Rates:", _extract_rates_map(rates)) + except Exception as e: + pass if __name__ == "__main__": main()