feat: IP & device rate limits

This commit is contained in:
2026-03-14 01:07:26 +03:00
parent 9b87eb74d7
commit ad711d1daf
7 changed files with 212 additions and 7 deletions
+86 -6
View File
@@ -1,6 +1,8 @@
import logging
import os
import secrets
from hashlib import sha256
import ipaddress
from dataclasses import dataclass
from datetime import timedelta
@@ -33,6 +35,18 @@ class OtpCooldownError(RuntimeError):
self.retry_after_seconds = retry_after_seconds
class OtpIpRateLimitError(RuntimeError):
def __init__(self, retry_after_seconds: int):
super().__init__(_("Too many OTP requests from this IP. Try again later."))
self.retry_after_seconds = retry_after_seconds
class OtpDeviceRateLimitError(RuntimeError):
def __init__(self, retry_after_seconds: int):
super().__init__(_("Too many OTP requests from this device. Try again later."))
self.retry_after_seconds = retry_after_seconds
class BaseOtpProvider:
uses_provider_otp = False
@@ -166,7 +180,73 @@ def generate_code(length: int = 6) -> str:
return "".join(secrets.choice(digits) for _ in range(length))
def create_and_send_otp(phone_number: str, channel: str, purpose: str = OtpPurpose.AUTH) -> OtpSendResult:
def _compute_retry_after_seconds(oldest_recent, now, window_minutes: int, floor_seconds: int = 1) -> int:
if oldest_recent:
retry_after = int((oldest_recent.created_at + timedelta(minutes=window_minutes) - now).total_seconds())
else:
retry_after = floor_seconds
return max(retry_after, floor_seconds)
def normalize_request_ip(raw_ip: str | None) -> str | None:
if not raw_ip:
return None
candidate = raw_ip.split(",")[0].strip()
try:
return str(ipaddress.ip_address(candidate))
except ValueError:
return None
def build_device_signal(device_id: str | None, user_agent: str | None, accept_language: str | None) -> str:
# Prefer explicit device id; fallback to passive headers for coarse signal.
source = (device_id or "").strip()
if not source:
source = f"{(user_agent or '').strip()}|{(accept_language or '').strip()}"
if not source.strip("|"):
return ""
return f"sha256:{sha256(source.encode('utf-8')).hexdigest()[:24]}"
def enforce_phone_auth_request_limits(request_ip: str | None, device_signal: str | None) -> None:
now = timezone.now()
window_minutes = getattr(settings, "PHONE_AUTH_RISK_WINDOW_MINUTES", 15)
window_start = now - timedelta(minutes=window_minutes)
ip_limit = getattr(settings, "PHONE_AUTH_IP_MAX_PER_WINDOW", 20)
device_limit = getattr(settings, "PHONE_AUTH_DEVICE_MAX_PER_WINDOW", 20)
if request_ip and ip_limit > 0:
ip_recent = PhoneOTP.objects.filter(
purpose=OtpPurpose.AUTH,
request_ip=request_ip,
created_at__gte=window_start,
)
if ip_recent.count() >= ip_limit:
oldest_recent = ip_recent.order_by("created_at").first()
raise OtpIpRateLimitError(
retry_after_seconds=_compute_retry_after_seconds(oldest_recent, now, window_minutes)
)
if device_signal and device_limit > 0:
device_recent = PhoneOTP.objects.filter(
purpose=OtpPurpose.AUTH,
device_signal=device_signal,
created_at__gte=window_start,
)
if device_recent.count() >= device_limit:
oldest_recent = device_recent.order_by("created_at").first()
raise OtpDeviceRateLimitError(
retry_after_seconds=_compute_retry_after_seconds(oldest_recent, now, window_minutes)
)
def create_and_send_otp(
phone_number: str,
channel: str,
purpose: str = OtpPurpose.AUTH,
request_ip: str | None = None,
device_signal: str = "",
) -> OtpSendResult:
provider = get_provider()
now = timezone.now()
window_minutes = getattr(settings, "OTP_WINDOW_MINUTES", 15)
@@ -177,11 +257,9 @@ def create_and_send_otp(phone_number: str, channel: str, purpose: str = OtpPurpo
recent_qs = PhoneOTP.objects.filter(phone_number=phone_number, created_at__gte=window_start)
if recent_qs.count() >= max_per_window:
oldest_recent = recent_qs.order_by("created_at").first()
if oldest_recent:
retry_after = int((oldest_recent.created_at + timedelta(minutes=window_minutes) - now).total_seconds())
else:
retry_after = cooldown_seconds
raise OtpRateLimitError(retry_after_seconds=max(retry_after, cooldown_seconds))
raise OtpRateLimitError(
retry_after_seconds=_compute_retry_after_seconds(oldest_recent, now, window_minutes, cooldown_seconds)
)
latest = (
PhoneOTP.objects.filter(phone_number=phone_number, channel=channel)
@@ -210,6 +288,8 @@ def create_and_send_otp(phone_number: str, channel: str, purpose: str = OtpPurpo
provider=settings.OTP_PROVIDER,
code_hash=code_hash,
expires_at=PhoneOTP.expiry_at(),
request_ip=request_ip,
device_signal=device_signal,
)
if provider.uses_provider_otp: