parent
0d46aaa3f0
commit
7ecfa03c7f
|
|
@ -13,6 +13,7 @@ from app.models.base import ReservedWords, User
|
||||||
from settings import SECRET_KEY
|
from settings import SECRET_KEY
|
||||||
|
|
||||||
redis_client = redis.Redis(host="localhost", port=6379, decode_responses=True)
|
redis_client = redis.Redis(host="localhost", port=6379, decode_responses=True)
|
||||||
|
ALGORITHM = "HS256"
|
||||||
|
|
||||||
|
|
||||||
async def validate_username(username: str):
|
async def validate_username(username: str):
|
||||||
|
|
@ -67,48 +68,62 @@ def hash_password(raw_password: str) -> str:
|
||||||
return bcrypt.hashpw(raw_password.encode("utf-8"), salt).decode("utf-8")
|
return bcrypt.hashpw(raw_password.encode("utf-8"), salt).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
# @asynccontextmanager
|
||||||
async def redis_pool():
|
# async def redis_pool():
|
||||||
client = redis.Redis(host="localhost", port=6379, decode_responses=True)
|
# client = redis.Redis(host="localhost", port=6379, decode_responses=True)
|
||||||
yield client
|
# yield client
|
||||||
|
|
||||||
|
|
||||||
|
async def _extract_bearer_token(request: Request) -> str:
|
||||||
|
"""
|
||||||
|
小工具:提取 Bearer Token(兼容大小写/多空格)
|
||||||
|
:return: token
|
||||||
|
"""
|
||||||
|
auth = request.headers.get("Authorization")
|
||||||
|
if not auth:
|
||||||
|
raise HTTPException(status_code=401, detail="未登录")
|
||||||
|
# 兼容 "bearer" / "Bearer" 等写法,并裁剪多余空格
|
||||||
|
parts = auth.strip().split()
|
||||||
|
if len(parts) != 2 or parts[0].lower() != "bearer":
|
||||||
|
raise HTTPException(status_code=401, detail="无效的授权头")
|
||||||
|
return parts[1] # token 内容
|
||||||
|
|
||||||
|
|
||||||
|
async def _decode_and_load_user(token: str) -> Tuple[User, Dict]:
|
||||||
|
# 黑名单校验(登出或主动失效)
|
||||||
|
if await redis_client.get(f"blacklist:{token}") == "true":
|
||||||
|
raise HTTPException(status_code=401, detail="token 已失效")
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
except ExpiredSignatureError:
|
||||||
|
raise HTTPException(status_code=401, detail="登陆信息已过期")
|
||||||
|
except JWTError:
|
||||||
|
raise HTTPException(status_code=401, detail="无效的令牌")
|
||||||
|
|
||||||
|
user_id = payload.get("user_id")
|
||||||
|
if not user_id:
|
||||||
|
raise HTTPException(status_code=401, detail="无效 token 载荷")
|
||||||
|
user = await User.get_or_none(id=user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=401, detail="用户不存在")
|
||||||
|
return user, payload
|
||||||
|
|
||||||
|
|
||||||
# 从请求头中获取当前用户信息
|
# 从请求头中获取当前用户信息
|
||||||
async def get_current_user(request: Request) -> Tuple[User, Dict]:
|
async def get_current_user(request: Request) -> Tuple[User, Dict]:
|
||||||
# 从 headers 中获取 Authorization 字段
|
token = await _extract_bearer_token(request)
|
||||||
token = request.headers.get("Authorization")
|
return await _decode_and_load_user(token)
|
||||||
|
|
||||||
# 检查 token 是否存在且格式正确(Bearer 开头)
|
|
||||||
if not token or not token.startswith("Bearer "):
|
|
||||||
raise HTTPException(status_code=401, detail="未登录")
|
|
||||||
|
|
||||||
raw_token = token[7:]
|
|
||||||
|
|
||||||
# 黑名单校验
|
|
||||||
if await redis_client.get(f"blacklist:{raw_token}") == "true":
|
|
||||||
raise HTTPException(status_code=401, detail="token 已失效")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 去掉 "Bearer " 前缀后解析 JWT
|
|
||||||
payload = jwt.decode(token[7:], SECRET_KEY, algorithms=["HS256"]) # 自动校验exp
|
|
||||||
user_id = payload.get("user_id")
|
|
||||||
except ExpiredSignatureError:
|
|
||||||
# token 信息中的 exp 已经过期
|
|
||||||
raise HTTPException(status_code=401, detail="登陆信息已过期")
|
|
||||||
except JWTError:
|
|
||||||
# JWT 格式错误或校验失败
|
|
||||||
raise HTTPException(status_code=401, detail="无效的令牌")
|
|
||||||
|
|
||||||
# 从数据库查找对应用户
|
|
||||||
user = await User.get_or_none(id=user_id)
|
|
||||||
if not user:
|
|
||||||
raise HTTPException(status_code=401, detail="用户不存在")
|
|
||||||
|
|
||||||
|
async def is_admin_user(user_payload: Tuple[User, Dict] = Depends(get_current_user)) -> Tuple[User, Dict]:
|
||||||
|
user, payload=user_payload
|
||||||
|
if not getattr(user, "is_admin", False):
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
return user, payload
|
return user, payload
|
||||||
|
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/logout")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/logout")
|
||||||
ALGORITHM = "HS256"
|
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user_with_OAuth(token: Annotated[str, Depends(oauth2_scheme)]):
|
async def get_current_user_with_OAuth(token: Annotated[str, Depends(oauth2_scheme)]):
|
||||||
|
|
@ -132,4 +147,3 @@ async def get_current_user_with_OAuth(token: Annotated[str, Depends(oauth2_schem
|
||||||
raise HTTPException(status_code=401, detail="token 已过期")
|
raise HTTPException(status_code=401, detail="token 已过期")
|
||||||
except JWTError:
|
except JWTError:
|
||||||
raise HTTPException(status_code=401, detail="")
|
raise HTTPException(status_code=401, detail="")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue