parent
0d46aaa3f0
commit
7ecfa03c7f
|
|
@ -13,6 +13,7 @@ from app.models.base import ReservedWords, User
|
|||
from settings import SECRET_KEY
|
||||
|
||||
redis_client = redis.Redis(host="localhost", port=6379, decode_responses=True)
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
|
||||
async def validate_username(username: str):
|
||||
|
|
@ -67,48 +68,62 @@ def hash_password(raw_password: str) -> str:
|
|||
return bcrypt.hashpw(raw_password.encode("utf-8"), salt).decode("utf-8")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def redis_pool():
|
||||
client = redis.Redis(host="localhost", port=6379, decode_responses=True)
|
||||
yield client
|
||||
# @asynccontextmanager
|
||||
# async def redis_pool():
|
||||
# client = redis.Redis(host="localhost", port=6379, decode_responses=True)
|
||||
# yield client
|
||||
|
||||
|
||||
async def _extract_bearer_token(request: Request) -> str:
|
||||
"""
|
||||
小工具:提取 Bearer Token(兼容大小写/多空格)
|
||||
:return: token
|
||||
"""
|
||||
auth = request.headers.get("Authorization")
|
||||
if not auth:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
# 兼容 "bearer" / "Bearer" 等写法,并裁剪多余空格
|
||||
parts = auth.strip().split()
|
||||
if len(parts) != 2 or parts[0].lower() != "bearer":
|
||||
raise HTTPException(status_code=401, detail="无效的授权头")
|
||||
return parts[1] # token 内容
|
||||
|
||||
|
||||
async def _decode_and_load_user(token: str) -> Tuple[User, Dict]:
|
||||
# 黑名单校验(登出或主动失效)
|
||||
if await redis_client.get(f"blacklist:{token}") == "true":
|
||||
raise HTTPException(status_code=401, detail="token 已失效")
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
except ExpiredSignatureError:
|
||||
raise HTTPException(status_code=401, detail="登陆信息已过期")
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=401, detail="无效的令牌")
|
||||
|
||||
user_id = payload.get("user_id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="无效 token 载荷")
|
||||
user = await User.get_or_none(id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="用户不存在")
|
||||
return user, payload
|
||||
|
||||
|
||||
# 从请求头中获取当前用户信息
|
||||
async def get_current_user(request: Request) -> Tuple[User, Dict]:
|
||||
# 从 headers 中获取 Authorization 字段
|
||||
token = request.headers.get("Authorization")
|
||||
token = await _extract_bearer_token(request)
|
||||
return await _decode_and_load_user(token)
|
||||
|
||||
# 检查 token 是否存在且格式正确(Bearer 开头)
|
||||
if not token or not token.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
raw_token = token[7:]
|
||||
|
||||
# 黑名单校验
|
||||
if await redis_client.get(f"blacklist:{raw_token}") == "true":
|
||||
raise HTTPException(status_code=401, detail="token 已失效")
|
||||
|
||||
try:
|
||||
# 去掉 "Bearer " 前缀后解析 JWT
|
||||
payload = jwt.decode(token[7:], SECRET_KEY, algorithms=["HS256"]) # 自动校验exp
|
||||
user_id = payload.get("user_id")
|
||||
except ExpiredSignatureError:
|
||||
# token 信息中的 exp 已经过期
|
||||
raise HTTPException(status_code=401, detail="登陆信息已过期")
|
||||
except JWTError:
|
||||
# JWT 格式错误或校验失败
|
||||
raise HTTPException(status_code=401, detail="无效的令牌")
|
||||
|
||||
# 从数据库查找对应用户
|
||||
user = await User.get_or_none(id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="用户不存在")
|
||||
|
||||
async def is_admin_user(user_payload: Tuple[User, Dict] = Depends(get_current_user)) -> Tuple[User, Dict]:
|
||||
user, payload=user_payload
|
||||
if not getattr(user, "is_admin", False):
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
return user, payload
|
||||
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/logout")
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
|
||||
async def get_current_user_with_OAuth(token: Annotated[str, Depends(oauth2_scheme)]):
|
||||
|
|
@ -132,4 +147,3 @@ async def get_current_user_with_OAuth(token: Annotated[str, Depends(oauth2_schem
|
|||
raise HTTPException(status_code=401, detail="token 已过期")
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=401, detail="")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue