diff --git a/app/utils/security.py b/app/utils/security.py index fb206d8..1671827 100644 --- a/app/utils/security.py +++ b/app/utils/security.py @@ -13,6 +13,7 @@ from app.models.base import ReservedWords, User from settings import SECRET_KEY redis_client = redis.Redis(host="localhost", port=6379, decode_responses=True) +ALGORITHM = "HS256" async def validate_username(username: str): @@ -67,48 +68,62 @@ def hash_password(raw_password: str) -> str: return bcrypt.hashpw(raw_password.encode("utf-8"), salt).decode("utf-8") -@asynccontextmanager -async def redis_pool(): - client = redis.Redis(host="localhost", port=6379, decode_responses=True) - yield client +# @asynccontextmanager +# async def redis_pool(): +# client = redis.Redis(host="localhost", port=6379, decode_responses=True) +# yield client + + +async def _extract_bearer_token(request: Request) -> str: + """ + 小工具:提取 Bearer Token(兼容大小写/多空格) + :return: token + """ + auth = request.headers.get("Authorization") + if not auth: + raise HTTPException(status_code=401, detail="未登录") + # 兼容 "bearer" / "Bearer" 等写法,并裁剪多余空格 + parts = auth.strip().split() + if len(parts) != 2 or parts[0].lower() != "bearer": + raise HTTPException(status_code=401, detail="无效的授权头") + return parts[1] # token 内容 + + +async def _decode_and_load_user(token: str) -> Tuple[User, Dict]: + # 黑名单校验(登出或主动失效) + if await redis_client.get(f"blacklist:{token}") == "true": + raise HTTPException(status_code=401, detail="token 已失效") + + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + except ExpiredSignatureError: + raise HTTPException(status_code=401, detail="登陆信息已过期") + except JWTError: + raise HTTPException(status_code=401, detail="无效的令牌") + + user_id = payload.get("user_id") + if not user_id: + raise HTTPException(status_code=401, detail="无效 token 载荷") + user = await User.get_or_none(id=user_id) + if not user: + raise HTTPException(status_code=401, detail="用户不存在") + return user, payload # 从请求头中获取当前用户信息 async def get_current_user(request: Request) -> Tuple[User, Dict]: - # 从 headers 中获取 Authorization 字段 - token = request.headers.get("Authorization") + token = await _extract_bearer_token(request) + return await _decode_and_load_user(token) - # 检查 token 是否存在且格式正确(Bearer 开头) - if not token or not token.startswith("Bearer "): - raise HTTPException(status_code=401, detail="未登录") - - raw_token = token[7:] - - # 黑名单校验 - if await redis_client.get(f"blacklist:{raw_token}") == "true": - raise HTTPException(status_code=401, detail="token 已失效") - - try: - # 去掉 "Bearer " 前缀后解析 JWT - payload = jwt.decode(token[7:], SECRET_KEY, algorithms=["HS256"]) # 自动校验exp - user_id = payload.get("user_id") - except ExpiredSignatureError: - # token 信息中的 exp 已经过期 - raise HTTPException(status_code=401, detail="登陆信息已过期") - except JWTError: - # JWT 格式错误或校验失败 - raise HTTPException(status_code=401, detail="无效的令牌") - - # 从数据库查找对应用户 - user = await User.get_or_none(id=user_id) - if not user: - raise HTTPException(status_code=401, detail="用户不存在") +async def is_admin_user(user_payload: Tuple[User, Dict] = Depends(get_current_user)) -> Tuple[User, Dict]: + user, payload=user_payload + if not getattr(user, "is_admin", False): + raise HTTPException(status_code=403, detail="Access denied") return user, payload oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/logout") -ALGORITHM = "HS256" async def get_current_user_with_OAuth(token: Annotated[str, Depends(oauth2_scheme)]): @@ -132,4 +147,3 @@ async def get_current_user_with_OAuth(token: Annotated[str, Depends(oauth2_schem raise HTTPException(status_code=401, detail="token 已过期") except JWTError: raise HTTPException(status_code=401, detail="") -