Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing cancelled async futures #2666

Merged
merged 18 commits into from
Mar 29, 2023
2 changes: 2 additions & 0 deletions .github/workflows/integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ jobs:
timeout-minutes: 30
strategy:
max-parallel: 15
fail-fast: false
chayim marked this conversation as resolved.
Show resolved Hide resolved
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7', 'pypy-3.8', 'pypy-3.9']
test-type: ['standalone', 'cluster']
Expand Down Expand Up @@ -108,6 +109,7 @@ jobs:
name: Install package from commit hash
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7', 'pypy-3.8', 'pypy-3.9']
steps:
Expand Down
104 changes: 69 additions & 35 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,28 +500,37 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
):
raise error

# COMMAND EXECUTION AND PROTOCOL PARSING
async def execute_command(self, *args, **options):
"""Execute a command and return a parsed response"""
await self.initialize()
pool = self.connection_pool
command_name = args[0]
conn = self.connection or await pool.get_connection(command_name, **options)

if self.single_connection_client:
await self._single_conn_lock.acquire()
async def _try_send_command_parse_response(self, conn, *args, **options):
try:
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
conn, args[0], *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
except asyncio.CancelledError:
await conn.disconnect(nowait=True)
raise
finally:
if self.single_connection_client:
self._single_conn_lock.release()
if not self.connection:
await pool.release(conn)
await self.connection_pool.release(conn)

# COMMAND EXECUTION AND PROTOCOL PARSING
async def execute_command(self, *args, **options):
"""Execute a command and return a parsed response"""
await self.initialize()
pool = self.connection_pool
command_name = args[0]
conn = self.connection or await pool.get_connection(command_name, **options)

if self.single_connection_client:
await self._single_conn_lock.acquire()

return await asyncio.shield(
self._try_send_command_parse_response(conn, *args, **options)
)

async def parse_response(
self, connection: Connection, command_name: Union[str, bytes], **options
Expand Down Expand Up @@ -760,10 +769,18 @@ async def _disconnect_raise_connect(self, conn, error):
is not a TimeoutError. Otherwise, try to reconnect
"""
await conn.disconnect()

if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
raise error
await conn.connect()

async def _try_execute(self, conn, command, *arg, **kwargs):
try:
return await command(*arg, **kwargs)
except asyncio.CancelledError:
await conn.disconnect()
raise

async def _execute(self, conn, command, *args, **kwargs):
"""
Connect manually upon disconnection. If the Redis server is down,
Expand All @@ -772,9 +789,11 @@ async def _execute(self, conn, command, *args, **kwargs):
called by the # connection to resubscribe us to any channels and
patterns we were previously listening to
"""
return await conn.retry.call_with_retry(
lambda: command(*args, **kwargs),
lambda error: self._disconnect_raise_connect(conn, error),
return await asyncio.shield(
conn.retry.call_with_retry(
lambda: self._try_execute(conn, command, *args, **kwargs),
lambda error: self._disconnect_raise_connect(conn, error),
)
)

async def parse_response(self, block: bool = True, timeout: float = 0):
Expand Down Expand Up @@ -1176,6 +1195,18 @@ async def _disconnect_reset_raise(self, conn, error):
await self.reset()
raise

async def _try_send_command_parse_response(self, conn, *args, **options):
try:
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, args[0], *args, **options
),
lambda error: self._disconnect_reset_raise(conn, error),
)
except asyncio.CancelledError:
await conn.disconnect()
raise

async def immediate_execute_command(self, *args, **options):
"""
Execute a command immediately, but don't auto-retry on a
Expand All @@ -1191,13 +1222,13 @@ async def immediate_execute_command(self, *args, **options):
command_name, self.shard_hint
)
self.connection = conn

return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_reset_raise(conn, error),
)
try:
return await asyncio.shield(
self._try_send_command_parse_response(conn, *args, **options)
)
except asyncio.CancelledError:
chayim marked this conversation as resolved.
Show resolved Hide resolved
await conn.disconnect()
raise

def pipeline_execute_command(self, *args, **options):
"""
Expand Down Expand Up @@ -1364,6 +1395,19 @@ async def _disconnect_raise_reset(self, conn: Connection, error: Exception):
await self.reset()
raise

async def _try_execute(self, conn, execute, stack, raise_on_error):
try:
return await conn.retry.call_with_retry(
lambda: execute(conn, stack, raise_on_error),
lambda error: self._disconnect_raise_reset(conn, error),
)
except asyncio.CancelledError:
# not supposed to be possible, yet here we are
await conn.disconnect(nowait=True)
raise
finally:
await self.reset()

async def execute(self, raise_on_error: bool = True):
"""Execute all the commands in the current pipeline"""
stack = self.command_stack
Expand All @@ -1384,19 +1428,9 @@ async def execute(self, raise_on_error: bool = True):
self.connection = conn
conn = cast(Connection, conn)

try:
return await asyncio.shield(
conn.retry.call_with_retry(
lambda: execute(conn, stack, raise_on_error),
lambda error: self._disconnect_raise_reset(conn, error),
)
)
except asyncio.CancelledError:
# not supposed to be possible, yet here we are
await conn.disconnect(nowait=True)
raise
finally:
await self.reset()
return await asyncio.shield(
self._try_execute(conn, execute, stack, raise_on_error)
)

async def discard(self):
"""Flushes all previously queued commands
Expand Down
21 changes: 14 additions & 7 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,19 @@ async def _parse_and_release(self, connection, *args, **kwargs):
finally:
self._free.append(connection)

async def _try_parse_response(self, cmd, connection, ret):
try:
cmd.result = await asyncio.shield(
self.parse_response(connection, cmd.args[0], **cmd.kwargs)
)
except asyncio.CancelledError:
await connection.disconnect(nowait=True)
raise
except Exception as e:
cmd.result = e
ret = True
return ret

async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
# Acquire connection
connection = self.acquire_connection()
Expand All @@ -1028,13 +1041,7 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
# Read responses
ret = False
for cmd in commands:
try:
cmd.result = await self.parse_response(
connection, cmd.args[0], **cmd.kwargs
)
except Exception as e:
cmd.result = e
ret = True
ret = await asyncio.shield(self._try_parse_response(cmd, connection, ret))

# Release connection
self._free.append(connection)
Expand Down
17 changes: 0 additions & 17 deletions tests/test_asyncio/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,23 +340,6 @@ async def test_from_url(self, request: FixtureRequest) -> None:
rc = RedisCluster.from_url("rediss://localhost:16379")
assert rc.connection_kwargs["connection_class"] is SSLConnection

async def test_asynckills(self, r) -> None:

await r.set("foo", "foo")
await r.set("bar", "bar")

t = asyncio.create_task(r.get("foo"))
await asyncio.sleep(1)
t.cancel()
try:
await t
except asyncio.CancelledError:
pytest.fail("connection is left open with unread response")

assert await r.get("bar") == b"bar"
assert await r.ping()
assert await r.get("foo") == b"foo"

async def test_max_connections(
self, create_redis: Callable[..., RedisCluster]
) -> None:
Expand Down
21 changes: 0 additions & 21 deletions tests/test_asyncio/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,27 +44,6 @@ async def test_invalid_response(create_redis):
await r.connection.disconnect()


async def test_asynckills():

for b in [True, False]:
r = Redis(single_connection_client=b)

await r.set("foo", "foo")
await r.set("bar", "bar")

t = asyncio.create_task(r.get("foo"))
await asyncio.sleep(1)
t.cancel()
try:
await t
except asyncio.CancelledError:
pytest.fail("connection left open with unread response")

assert await r.get("bar") == b"bar"
assert await r.ping()
assert await r.get("foo") == b"foo"


@pytest.mark.onlynoncluster
async def test_single_connection():
"""Test that concurrent requests on a single client are synchronised."""
Expand Down
129 changes: 129 additions & 0 deletions tests/test_asyncio/test_cwe_404.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import asyncio
import sys

import pytest

from redis.asyncio import Redis
from redis.asyncio.cluster import RedisCluster


async def pipe(
reader: asyncio.StreamReader, writer: asyncio.StreamWriter, delay: float, name=""
):
while True:
data = await reader.read(1000)
if not data:
break
await asyncio.sleep(delay)
writer.write(data)
await writer.drain()


class DelayProxy:
def __init__(self, addr, redis_addr, delay: float):
self.addr = addr
self.redis_addr = redis_addr
self.delay = delay

async def start(self):
self.server = await asyncio.start_server(self.handle, *self.addr)
self.ROUTINE = asyncio.create_task(self.server.serve_forever())

async def handle(self, reader, writer):
# establish connection to redis
redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr)
pipe1 = asyncio.create_task(pipe(reader, redis_writer, self.delay, "to redis:"))
pipe2 = asyncio.create_task(
pipe(redis_reader, writer, self.delay, "from redis:")
)
await asyncio.gather(pipe1, pipe2)

async def stop(self):
# clean up enough so that we can reuse the looper
self.ROUTINE.cancel()
loop = self.server.get_loop()
await loop.shutdown_asyncgens()


async def test_standalone():

# create a tcp socket proxy that relays data to Redis and back,
# inserting 0.1 seconds of delay
dp = DelayProxy(addr=("localhost", 5380), redis_addr=("localhost", 6379), delay=0.1)
await dp.start()

for b in [True, False]:
# note that we connect to proxy, rather than to Redis directly
async with Redis(host="localhost", port=5380, single_connection_client=b) as r:

await r.set("foo", "foo")
await r.set("bar", "bar")

t = asyncio.create_task(r.get("foo"))
await asyncio.sleep(0.050)
t.cancel()
try:
await t
sys.stderr.write("try again, we did not cancel the task in time\n")
except asyncio.CancelledError:
sys.stderr.write(
"canceled task, connection is left open with unread response\n"
)

assert await r.get("bar") == b"bar"
assert await r.ping()
assert await r.get("foo") == b"foo"

await dp.stop()


async def test_standalone_pipeline():
dp = DelayProxy(addr=("localhost", 5380), redis_addr=("localhost", 6379), delay=0.1)
await dp.start()
async with Redis(host="localhost", port=5380) as r:
await r.set("foo", "foo")
await r.set("bar", "bar")

pipe = r.pipeline()

pipe2 = r.pipeline()
pipe2.get("bar")
pipe2.ping()
pipe2.get("foo")

t = asyncio.create_task(pipe.get("foo").execute())
await asyncio.sleep(0.050)
t.cancel()

pipe.get("bar")
pipe.ping()
pipe.get("foo")
assert await pipe.execute() == [b"foo", b"bar", True, b"foo"]
assert await pipe2.execute() == [b"bar", True, b"foo"]

await dp.stop()


async def test_cluster(request):

dp = DelayProxy(addr=("localhost", 5381), redis_addr=("localhost", 6372), delay=0.1)
await dp.start()

r = RedisCluster.from_url("redis://localhost:5381")
await r.initialize()
await r.set("foo", "foo")
await r.set("bar", "bar")

t = asyncio.create_task(r.get("foo"))
await asyncio.sleep(0.050)
t.cancel()
try:
await t
except asyncio.CancelledError:
pytest.fail("connection is left open with unread response")

assert await r.get("bar") == b"bar"
assert await r.ping()
assert await r.get("foo") == b"foo"

await dp.stop()