更新get_current_user()函数校验内容

新增管理员校验函数is_admin_user()
This commit is contained in:
Miyamizu-MitsuhaSang 2025-08-16 17:55:04 +08:00
parent 0d46aaa3f0
commit 7ecfa03c7f
1 changed files with 47 additions and 33 deletions

View File

@ -13,6 +13,7 @@ from app.models.base import ReservedWords, User
from settings import SECRET_KEY
redis_client = redis.Redis(host="localhost", port=6379, decode_responses=True)
ALGORITHM = "HS256"
async def validate_username(username: str):
@ -67,48 +68,62 @@ def hash_password(raw_password: str) -> str:
return bcrypt.hashpw(raw_password.encode("utf-8"), salt).decode("utf-8")
@asynccontextmanager
async def redis_pool():
client = redis.Redis(host="localhost", port=6379, decode_responses=True)
yield client
# @asynccontextmanager
# async def redis_pool():
# client = redis.Redis(host="localhost", port=6379, decode_responses=True)
# yield client
async def _extract_bearer_token(request: Request) -> str:
"""
小工具提取 Bearer Token兼容大小写/多空格
:return: token
"""
auth = request.headers.get("Authorization")
if not auth:
raise HTTPException(status_code=401, detail="未登录")
# 兼容 "bearer" / "Bearer" 等写法,并裁剪多余空格
parts = auth.strip().split()
if len(parts) != 2 or parts[0].lower() != "bearer":
raise HTTPException(status_code=401, detail="无效的授权头")
return parts[1] # token 内容
async def _decode_and_load_user(token: str) -> Tuple[User, Dict]:
# 黑名单校验(登出或主动失效)
if await redis_client.get(f"blacklist:{token}") == "true":
raise HTTPException(status_code=401, detail="token 已失效")
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
except ExpiredSignatureError:
raise HTTPException(status_code=401, detail="登陆信息已过期")
except JWTError:
raise HTTPException(status_code=401, detail="无效的令牌")
user_id = payload.get("user_id")
if not user_id:
raise HTTPException(status_code=401, detail="无效 token 载荷")
user = await User.get_or_none(id=user_id)
if not user:
raise HTTPException(status_code=401, detail="用户不存在")
return user, payload
# 从请求头中获取当前用户信息
async def get_current_user(request: Request) -> Tuple[User, Dict]:
# 从 headers 中获取 Authorization 字段
token = request.headers.get("Authorization")
token = await _extract_bearer_token(request)
return await _decode_and_load_user(token)
# 检查 token 是否存在且格式正确Bearer 开头)
if not token or not token.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未登录")
raw_token = token[7:]
# 黑名单校验
if await redis_client.get(f"blacklist:{raw_token}") == "true":
raise HTTPException(status_code=401, detail="token 已失效")
try:
# 去掉 "Bearer " 前缀后解析 JWT
payload = jwt.decode(token[7:], SECRET_KEY, algorithms=["HS256"]) # 自动校验exp
user_id = payload.get("user_id")
except ExpiredSignatureError:
# token 信息中的 exp 已经过期
raise HTTPException(status_code=401, detail="登陆信息已过期")
except JWTError:
# JWT 格式错误或校验失败
raise HTTPException(status_code=401, detail="无效的令牌")
# 从数据库查找对应用户
user = await User.get_or_none(id=user_id)
if not user:
raise HTTPException(status_code=401, detail="用户不存在")
async def is_admin_user(user_payload: Tuple[User, Dict] = Depends(get_current_user)) -> Tuple[User, Dict]:
user, payload=user_payload
if not getattr(user, "is_admin", False):
raise HTTPException(status_code=403, detail="Access denied")
return user, payload
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/logout")
ALGORITHM = "HS256"
async def get_current_user_with_OAuth(token: Annotated[str, Depends(oauth2_scheme)]):
@ -132,4 +147,3 @@ async def get_current_user_with_OAuth(token: Annotated[str, Depends(oauth2_schem
raise HTTPException(status_code=401, detail="token 已过期")
except JWTError:
raise HTTPException(status_code=401, detail="")