70 lines
2.0 KiB
Python
70 lines
2.0 KiB
Python
import os
|
||
import uuid
|
||
from datetime import datetime, timezone, timedelta
|
||
from typing import Tuple
|
||
|
||
from dotenv import load_dotenv
|
||
from fastapi import HTTPException
|
||
from jose import jwt, ExpiredSignatureError, JWTError
|
||
from redis.asyncio import Redis
|
||
|
||
load_dotenv()
|
||
|
||
RESET_SECRET_KEY = os.getenv("RESET_SECRET_KEY")
|
||
ALGORITHM = 'HS256'
|
||
|
||
|
||
class ResetTokenError(HTTPException):
|
||
def __init__(self, message: str):
|
||
super().__init__(status_code=400, detail=message)
|
||
|
||
|
||
def create_reset_token(user_id: int, expire_seconds: int = 300) -> Tuple[str, str]:
|
||
"""生成 reset_token (JWT) 和 jti"""
|
||
jti = uuid.uuid4().hex
|
||
payload = {
|
||
'sub': str(user_id),
|
||
'purpose': 'reset_pw',
|
||
'exp': datetime.now(timezone.utc) + timedelta(hours=2),
|
||
'jti': jti,
|
||
}
|
||
|
||
token = jwt.encode(payload, RESET_SECRET_KEY, algorithm=ALGORITHM)
|
||
return token, jti
|
||
|
||
|
||
async def save_reset_jti(redis: Redis, user_id: int, jti: str, expire_seconds: int = 300):
|
||
"""把 jti 存到 Redis,设置过期时间"""
|
||
await redis.setex(f"reset:{user_id}", expire_seconds, jti)
|
||
|
||
|
||
async def verify_and_consume_reset_token(redis: Redis, token: str) -> int | None:
|
||
"""
|
||
验证 reset_token:
|
||
- 校验签名、过期时间、用途
|
||
- 校验 Redis 里 jti 是否匹配
|
||
- 如果通过,删除 Redis 记录,确保一次性
|
||
- 返回 user_id,否则 None
|
||
"""
|
||
try:
|
||
# 1. 解码并验证签名
|
||
payload = jwt.decode(token, RESET_SECRET_KEY, algorithms=[ALGORITHM], options={"verify_exp": False})
|
||
|
||
# 2. 校验用途
|
||
if payload.get("purpose") != "reset_pw":
|
||
return None
|
||
|
||
user_id = int(payload.get("sub"))
|
||
jti = payload.get("jti")
|
||
|
||
stored = await redis.getdel(f"reset:{user_id}")
|
||
if stored is None or stored != jti:
|
||
raise ResetTokenError("Token 非法或已过期")
|
||
|
||
return user_id
|
||
|
||
except ExpiredSignatureError as e:
|
||
raise ExpiredSignatureError(e)
|
||
except JWTError as e:
|
||
raise JWTError(e)
|