Files
SmartUp/backend/app/database.py
T
liumangmang 6044b00685 feat: 上游 Key 唯一化、分组导入跳过、账号导入平台识别&远端校验&base_url 注入
- 上游 Key 命名改为 {prefix}-{upstream.id}-{safe_group_name}-{group_id}
- 唯一约束 (upstream_id, group_id, managed_prefix) 加 managed_prefix 列
- 上游检测成功时同步 Key 状态,远端已删/分组已删自动清理
- 重复分组导入跳过,目标网站已存在同名分组返回 exists
- 账号导入平台自动识别(auto/manual 模式)
- 全选可导入 Key 按钮 + 目标分组自动匹配
- 导入幂等:已导入过的 Key 校验远端账号,不存在则重建
- 新增同步接口 POST /sync-imported-upstream-keys
- account_exists() 通过拉取账号列表判断,避免 404 误判
- credentials.base_url 注入来源上游地址,避免 401
- 前端导入弹窗自动同步+刷新按钮+并发/优先级设置
- 新增 12 个测试覆盖同步、幂等、远端删除、校验失败路径
2026-05-21 01:16:39 +08:00

147 lines
7.1 KiB
Python

from sqlalchemy import create_engine, inspect, text
from sqlalchemy.orm import sessionmaker, DeclarativeBase
from app.config import get_settings
settings = get_settings()
engine = create_engine(
settings.database_url,
connect_args={"check_same_thread": False},
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
class Base(DeclarativeBase):
pass
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
def init_db():
"""Create all tables."""
# import models so SQLAlchemy registers them
from app.models import admin_user, upstream, snapshot, webhook_config, notification_log, custom_page, website, revoked_token, upstream_key # noqa: F401
Base.metadata.create_all(bind=engine)
_migrate_custom_pages()
_migrate_upstreams()
_migrate_upstream_generated_keys()
def _migrate_custom_pages():
"""Apply small SQLite-safe migrations for deployments without Alembic."""
inspector = inspect(engine)
if "custom_pages" not in inspector.get_table_names():
return
columns = {col["name"] for col in inspector.get_columns("custom_pages")}
with engine.begin() as conn:
if "access_mode" not in columns:
conn.execute(text("ALTER TABLE custom_pages ADD COLUMN access_mode VARCHAR(32) NOT NULL DEFAULT 'direct'"))
conn.execute(text("UPDATE custom_pages SET access_mode = CASE WHEN use_proxy = 1 THEN 'proxy' ELSE 'direct' END"))
if "login_username" not in columns:
conn.execute(text("ALTER TABLE custom_pages ADD COLUMN login_username VARCHAR(255)"))
if "login_password" not in columns:
conn.execute(text("ALTER TABLE custom_pages ADD COLUMN login_password TEXT"))
if "login_username_selector" not in columns:
conn.execute(text("ALTER TABLE custom_pages ADD COLUMN login_username_selector VARCHAR(512)"))
if "login_password_selector" not in columns:
conn.execute(text("ALTER TABLE custom_pages ADD COLUMN login_password_selector VARCHAR(512)"))
if "login_submit_selector" not in columns:
conn.execute(text("ALTER TABLE custom_pages ADD COLUMN login_submit_selector VARCHAR(512)"))
if "login_autofill_enabled" not in columns:
conn.execute(text("ALTER TABLE custom_pages ADD COLUMN login_autofill_enabled BOOLEAN NOT NULL DEFAULT 0"))
if "login_autofill_backfilled_at" not in columns:
conn.execute(text("ALTER TABLE custom_pages ADD COLUMN login_autofill_backfilled_at DATETIME"))
conn.execute(
text(
"UPDATE custom_pages "
"SET login_autofill_enabled = 1, login_autofill_backfilled_at = CURRENT_TIMESTAMP "
"WHERE login_autofill_enabled = 0 "
"AND NULLIF(TRIM(login_username), '') IS NOT NULL "
"AND NULLIF(TRIM(login_password), '') IS NOT NULL"
)
)
if "linked_upstream_id" not in columns:
conn.execute(text("ALTER TABLE custom_pages ADD COLUMN linked_upstream_id INTEGER"))
def _migrate_upstreams():
"""Apply SQLite-safe migrations to the upstreams table."""
inspector = inspect(engine)
if "upstreams" not in inspector.get_table_names():
return
columns = {col["name"] for col in inspector.get_columns("upstreams")}
with engine.begin() as conn:
if "balance" not in columns:
conn.execute(text("ALTER TABLE upstreams ADD COLUMN balance FLOAT"))
if "balance_updated_at" not in columns:
conn.execute(text("ALTER TABLE upstreams ADD COLUMN balance_updated_at DATETIME"))
if "balance_endpoint" not in columns:
conn.execute(text("ALTER TABLE upstreams ADD COLUMN balance_endpoint VARCHAR(256) NOT NULL DEFAULT ''"))
if "balance_response_path" not in columns:
conn.execute(text("ALTER TABLE upstreams ADD COLUMN balance_response_path VARCHAR(256) NOT NULL DEFAULT ''"))
if "balance_divisor" not in columns:
conn.execute(text("ALTER TABLE upstreams ADD COLUMN balance_divisor FLOAT NOT NULL DEFAULT 1.0"))
def _migrate_upstream_generated_keys():
"""Apply SQLite-safe migrations to the generated upstream keys table."""
inspector = inspect(engine)
if "upstream_generated_keys" not in inspector.get_table_names():
return
columns = {col["name"] for col in inspector.get_columns("upstream_generated_keys")}
with engine.begin() as conn:
if "imported_website_id" not in columns:
conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN imported_website_id INTEGER"))
if "imported_account_id" not in columns:
conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN imported_account_id VARCHAR(255)"))
if "imported_at" not in columns:
conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN imported_at DATETIME"))
if "updated_at" not in columns:
conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN updated_at DATETIME"))
conn.execute(text("UPDATE upstream_generated_keys SET updated_at = created_at WHERE updated_at IS NULL"))
if "managed_prefix" not in columns:
conn.execute(text("ALTER TABLE upstream_generated_keys ADD COLUMN managed_prefix VARCHAR(64)"))
# ——— 历史数据迁移:回填 managed_prefix + 清理重复 ———
with engine.begin() as conn:
# 1. 回填:key_name 以 SmartUp- 开头的旧记录设置 managed_prefix = 'SmartUp'
conn.execute(text(
"UPDATE upstream_generated_keys SET managed_prefix = 'SmartUp' "
"WHERE managed_prefix IS NULL AND key_name LIKE 'SmartUp-%'"
))
# 2. 清理:同一 (upstream_id, group_id, managed_prefix) 只保留最新一条
# SQLite 不支持子查询直接 DELETE,用两步
to_delete = conn.execute(text("""
SELECT id FROM upstream_generated_keys
WHERE managed_prefix IS NOT NULL
AND id NOT IN (
SELECT MAX(id) FROM upstream_generated_keys
WHERE managed_prefix IS NOT NULL
GROUP BY upstream_id, group_id, managed_prefix
)
""")).fetchall()
for (row_id,) in to_delete:
conn.execute(text("DELETE FROM upstream_generated_keys WHERE id = :id"), {"id": row_id})
# ——— 创建唯一索引 ———
try:
with engine.begin() as conn:
conn.execute(
text("CREATE UNIQUE INDEX IF NOT EXISTS uq_upstream_group_key "
"ON upstream_generated_keys(upstream_id, group_id, key_name)")
)
conn.execute(
text("CREATE UNIQUE INDEX IF NOT EXISTS uq_upstream_group_managed "
"ON upstream_generated_keys(upstream_id, group_id, managed_prefix) "
"WHERE managed_prefix IS NOT NULL")
)
except Exception:
logger = __import__("logging").getLogger(__name__)
logger.warning("could not create unique indexes on upstream_generated_keys (non-fatal)")