Skip to content

Commit

Permalink
Store rate limiter-related metadata in the database for more resilien…
Browse files Browse the repository at this point in the history
…ce (#629)

* Store rate limiter-related metadata in the database for more resilience
- This helps maintain state even between server restarts
- Allows you to scale up workers on your service without having to implement sticky routing
* Make the usage exceeded message less abrasive
* Fix rate limiter for specific conversation commands and improve the copy
  • Loading branch information
sabaimran authored Jan 29, 2024
1 parent 71cbe51 commit 4fb8d5c
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 31 deletions.
7 changes: 7 additions & 0 deletions src/khoj/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
aget_or_create_user_by_phone_number,
aget_user_by_phone_number,
aget_user_subscription_state,
delete_user_requests,
get_all_users,
get_or_create_search_models,
)
Expand Down Expand Up @@ -328,3 +329,9 @@ def upload_telemetry():
logger.error(f"📡 Error uploading telemetry: {e}", exc_info=True)
else:
state.telemetry = []


@schedule.repeat(schedule.every(31).minutes)
def delete_old_user_requests():
num_deleted = delete_user_requests()
logger.info(f"🔥 Deleted {num_deleted} day-old user requests")
5 changes: 5 additions & 0 deletions src/khoj/database/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
Subscription,
TextToImageModelConfig,
UserConversationConfig,
UserRequests,
UserSearchModelConfig,
)
from khoj.search_filter.date_filter import DateFilter
Expand Down Expand Up @@ -284,6 +285,10 @@ def get_user_notion_config(user: KhojUser):
return config


def delete_user_requests(window: timedelta = timedelta(days=1)):
return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete()


async def set_text_content_config(user: KhojUser, object: Type[models.Model], updated_config):
deduped_files = list(set(updated_config.input_files)) if updated_config.input_files else None
deduped_filters = list(set(updated_config.input_filter)) if updated_config.input_filter else None
Expand Down
27 changes: 27 additions & 0 deletions src/khoj/database/migrations/0029_userrequests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Generated by Django 4.2.7 on 2024-01-29 08:55

import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("database", "0028_khojuser_verified_phone_number"),
]

operations = [
migrations.CreateModel(
name="UserRequests",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("slug", models.CharField(max_length=200)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
]
5 changes: 5 additions & 0 deletions src/khoj/database/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,8 @@ class Meta:
indexes = [
models.Index(fields=["date"]),
]


class UserRequests(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
slug = models.CharField(max_length=200)
22 changes: 16 additions & 6 deletions src/khoj/routers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@
# Initialize Router
api = APIRouter()
logger = logging.getLogger(__name__)
conversation_command_rate_limiter = ConversationCommandRateLimiter(trial_rate_limit=5, subscribed_rate_limit=100)
conversation_command_rate_limiter = ConversationCommandRateLimiter(
trial_rate_limit=2, subscribed_rate_limit=100, slug="command"
)


@api.get("/search", response_model=List[SearchResponse])
Expand Down Expand Up @@ -301,8 +303,12 @@ async def transcribe(
request: Request,
common: CommonQueryParams,
file: UploadFile = File(...),
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=1, subscribed_requests=10, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=1, subscribed_requests=10, window=60, slug="transcribe_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24, slug="transcribe_day")
),
):
user: KhojUser = request.user.object
audio_filename = f"{user.uuid}-{str(uuid.uuid4())}.webm"
Expand Down Expand Up @@ -361,16 +367,20 @@ async def chat(
n: Optional[int] = 5,
d: Optional[float] = 0.18,
stream: Optional[bool] = False,
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24)),
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
),
) -> Response:
user: KhojUser = request.user.object
q = unquote(q)

await is_ready_to_chat(user)
conversation_command = get_conversation_command(query=q, any_references=True)

conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command)
await conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command)

q = q.replace(f"/{conversation_command.value}", "").strip()

Expand Down
56 changes: 31 additions & 25 deletions src/khoj/routers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from datetime import datetime, timedelta, timezone
from functools import partial
from time import time
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
Expand All @@ -19,6 +19,7 @@
KhojUser,
Subscription,
TextToImageModelConfig,
UserRequests,
)
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.chat_model import (
Expand Down Expand Up @@ -336,11 +337,11 @@ async def text_to_image(message: str, conversation_log: dict) -> Tuple[Optional[


class ApiUserRateLimiter:
def __init__(self, requests: int, subscribed_requests: int, window: int):
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
self.requests = requests
self.subscribed_requests = subscribed_requests
self.window = window
self.cache: dict[str, list[float]] = defaultdict(list)
self.slug = slug

def __call__(self, request: Request):
# Rate limiting is disabled if user unauthenticated.
Expand All @@ -350,31 +351,32 @@ def __call__(self, request: Request):

user: KhojUser = request.user.object
subscribed = has_required_scope(request, ["premium"])
user_requests = self.cache[user.uuid]

# Remove requests outside of the time window
cutoff = time() - self.window
while user_requests and user_requests[0] < cutoff:
user_requests.pop(0)
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=self.window)
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()

# Check if the user has exceeded the rate limit
if subscribed and len(user_requests) >= self.subscribed_requests:
raise HTTPException(status_code=429, detail="Too Many Requests")
if not subscribed and len(user_requests) >= self.requests:
raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.")
if subscribed and count_requests >= self.subscribed_requests:
raise HTTPException(status_code=429, detail="Slow down! Too Many Requests")
if not subscribed and count_requests >= self.requests:
raise HTTPException(
status_code=429,
detail="We're glad you're enjoying Khoj! You've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your rate limit via [your settings](https://app.khoj.dev/config).",
)

# Add the current request to the cache
user_requests.append(time())
UserRequests.objects.create(user=user, slug=self.slug)


class ConversationCommandRateLimiter:
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int):
self.cache: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
self.slug = slug
self.trial_rate_limit = trial_rate_limit
self.subscribed_rate_limit = subscribed_rate_limit
self.restricted_commands = [ConversationCommand.Online, ConversationCommand.Image]

def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand):
async def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand):
if state.billing_enabled is False:
return

Expand All @@ -385,19 +387,23 @@ def update_and_check_if_valid(self, request: Request, conversation_command: Conv
return

user: KhojUser = request.user.object
user_cache = self.cache[user.uuid]
subscribed = has_required_scope(request, ["premium"])
user_cache[conversation_command].append(time())

# Remove requests outside of the 24-hr time window
cutoff = time() - 60 * 60 * 24
while user_cache[conversation_command] and user_cache[conversation_command][0] < cutoff:
user_cache[conversation_command].pop(0)

if subscribed and len(user_cache[conversation_command]) > self.subscribed_rate_limit:
raise HTTPException(status_code=429, detail="Too Many Requests")
if not subscribed and len(user_cache[conversation_command]) > self.trial_rate_limit:
raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.")
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=60 * 60 * 24)
command_slug = f"{self.slug}_{conversation_command.value}"
count_requests = await UserRequests.objects.filter(
user=user, created_at__gte=cutoff, slug=command_slug
).acount()

if subscribed and count_requests >= self.subscribed_rate_limit:
raise HTTPException(status_code=429, detail="Slow down! Too Many Requests")
if not subscribed and count_requests >= self.trial_rate_limit:
raise HTTPException(
status_code=429,
detail=f"We're glad you're enjoying Khoj! You've exceeded your `/{conversation_command.value}` command usage limit for today. You can increase your rate limit via [your settings](https://app.khoj.dev/config).",
)
await UserRequests.objects.acreate(user=user, slug=command_slug)
return


Expand Down

0 comments on commit 4fb8d5c

Please sign in to comment.