Skip to content

Commit

Permalink
handle typing when already responded
Browse files Browse the repository at this point in the history
  • Loading branch information
lmaotrigine committed Oct 13, 2024
1 parent 263916a commit 2ba62c5
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
1 change: 1 addition & 0 deletions cogs/mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -4179,6 +4179,7 @@ async def start_lockdown(
records = []
success, failures = [], []
reason = f'Lockdown request by {ctx.author} (ID: {ctx.author.id})'
await ctx.defer()
async with ctx.typing():
for channel in channels:
overwrite = channel.overwrites_for(default_role)
Expand Down
41 changes: 33 additions & 8 deletions utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from types import TracebackType

from aiohttp import ClientSession
from discord.context_managers import Typing
from discord.ext.commands.context import DeferTyping

from bot import Ayaka

Expand All @@ -28,6 +30,16 @@

T = TypeVar('T')

# async with typing() when already responded should early return


class NoOpTyping:
async def __aenter__(self) -> None:
...

async def __aexit__(self, *_: Any) -> None:
...


# For typing purposes, `Context.db` returns a Protocol type
# that allows us to properly type the return values via narrowing
Expand All @@ -36,25 +48,33 @@


class ConnectionContextManager(Protocol):
async def __aenter__(self) -> asyncpg.Connection: ...
async def __aenter__(self) -> asyncpg.Connection:
...

async def __aexit__(
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
) -> None: ...
) -> None:
...


class DatabaseProtocol(Protocol):
async def execute(self, query: str, *args: Any, timeout: float | None = None) -> str: ...
async def execute(self, query: str, *args: Any, timeout: float | None = None) -> str:
...

async def fetch(self, query: str, *args: Any, timeout: float | None = None) -> list[asyncpg.Record]: ...
async def fetch(self, query: str, *args: Any, timeout: float | None = None) -> list[asyncpg.Record]:
...

async def fetchrow(self, query: str, *args: Any, timeout: float | None = None) -> asyncpg.Record | None: ...
async def fetchrow(self, query: str, *args: Any, timeout: float | None = None) -> asyncpg.Record | None:
...

async def fetchval(self, query: str, *args: Any, timeout: float | None = None) -> Any | None: ...
async def fetchval(self, query: str, *args: Any, timeout: float | None = None) -> Any | None:
...

def acquire(self, *, timeout: float | None = None) -> ConnectionContextManager: ...
def acquire(self, *, timeout: float | None = None) -> ConnectionContextManager:
...

def release(self, connection: asyncpg.Connection) -> None: ...
def release(self, connection: asyncpg.Connection) -> None:
...


class Context(commands.Context['Ayaka']):
Expand Down Expand Up @@ -162,6 +182,11 @@ async def safe_send(self, content: str, *, escape_mentions: bool = True, **kwarg
else:
return await self.send(content)

def typing(self, *, ephemeral: bool = False) -> Typing | DeferTyping | NoOpTyping:
if self.interaction and not self.interaction.response.is_done():
return super().typing(ephemeral=ephemeral)
return NoOpTyping()


class GuildContext(Context):
author: discord.Member
Expand Down

0 comments on commit 2ba62c5

Please sign in to comment.