diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 5c0b546bf5..f5d8b0150e 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -475,24 +475,32 @@ async def _disconnect_raise(self, conn: Connection, error: Exception): ): raise error - # COMMAND EXECUTION AND PROTOCOL PARSING - async def execute_command(self, *args, **options): - """Execute a command and return a parsed response""" - await self.initialize() - pool = self.connection_pool - command_name = args[0] - conn = self.connection or await pool.get_connection(command_name, **options) - + async def _try_send_command_parse_response(self, conn, *args, **options): try: return await conn.retry.call_with_retry( lambda: self._send_command_parse_response( - conn, command_name, *args, **options + conn, args[0], *args, **options ), lambda error: self._disconnect_raise(conn, error), ) + except asyncio.CancelledError: + await conn.disconnect(nowait=True) + raise finally: if not self.connection: - await pool.release(conn) + await self.connection_pool.release(conn) + + # COMMAND EXECUTION AND PROTOCOL PARSING + async def execute_command(self, *args, **options): + """Execute a command and return a parsed response""" + await self.initialize() + pool = self.connection_pool + command_name = args[0] + conn = self.connection or await pool.get_connection(command_name, **options) + + return await asyncio.shield( + self._try_send_command_parse_response(conn, *args, **options) + ) async def parse_response( self, connection: Connection, command_name: Union[str, bytes], **options @@ -726,10 +734,18 @@ async def _disconnect_raise_connect(self, conn, error): is not a TimeoutError. Otherwise, try to reconnect """ await conn.disconnect() + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): raise error await conn.connect() + async def _try_execute(self, conn, command, *arg, **kwargs): + try: + return await command(*arg, **kwargs) + except asyncio.CancelledError: + await conn.disconnect() + raise + async def _execute(self, conn, command, *args, **kwargs): """ Connect manually upon disconnection. If the Redis server is down, @@ -738,9 +754,11 @@ async def _execute(self, conn, command, *args, **kwargs): called by the # connection to resubscribe us to any channels and patterns we were previously listening to """ - return await conn.retry.call_with_retry( - lambda: command(*args, **kwargs), - lambda error: self._disconnect_raise_connect(conn, error), + return await asyncio.shield( + conn.retry.call_with_retry( + lambda: self._try_execute(conn, command, *args, **kwargs), + lambda error: self._disconnect_raise_connect(conn, error), + ) ) async def parse_response(self, block: bool = True, timeout: float = 0): @@ -1140,6 +1158,18 @@ async def _disconnect_reset_raise(self, conn, error): await self.reset() raise + async def _try_send_command_parse_response(self, conn, *args, **options): + try: + return await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, args[0], *args, **options + ), + lambda error: self._disconnect_reset_raise(conn, error), + ) + except asyncio.CancelledError: + await conn.disconnect() + raise + async def immediate_execute_command(self, *args, **options): """ Execute a command immediately, but don't auto-retry on a @@ -1155,13 +1185,13 @@ async def immediate_execute_command(self, *args, **options): command_name, self.shard_hint ) self.connection = conn - - return await conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), - lambda error: self._disconnect_reset_raise(conn, error), - ) + try: + return await asyncio.shield( + self._try_send_command_parse_response(conn, *args, **options) + ) + except asyncio.CancelledError: + await conn.disconnect() + raise def pipeline_execute_command(self, *args, **options): """ @@ -1328,6 +1358,19 @@ async def _disconnect_raise_reset(self, conn: Connection, error: Exception): await self.reset() raise + async def _try_execute(self, conn, execute, stack, raise_on_error): + try: + return await conn.retry.call_with_retry( + lambda: execute(conn, stack, raise_on_error), + lambda error: self._disconnect_raise_reset(conn, error), + ) + except asyncio.CancelledError: + # not supposed to be possible, yet here we are + await conn.disconnect(nowait=True) + raise + finally: + await self.reset() + async def execute(self, raise_on_error: bool = True): """Execute all the commands in the current pipeline""" stack = self.command_stack @@ -1350,15 +1393,10 @@ async def execute(self, raise_on_error: bool = True): try: return await asyncio.shield( - conn.retry.call_with_retry( - lambda: execute(conn, stack, raise_on_error), - lambda error: self._disconnect_raise_reset(conn, error), - ) + self._try_execute(conn, execute, stack, raise_on_error) ) - except asyncio.CancelledError: - # not supposed to be possible, yet here we are - await conn.disconnect(nowait=True) - raise + except RuntimeError: + await self.reset() finally: await self.reset() diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 8dfb1cbdb8..50f1b6bc07 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -893,6 +893,19 @@ async def _parse_and_release(self, connection, *args, **kwargs): finally: self._free.append(connection) + async def _try_parse_response(self, cmd, connection, ret): + try: + cmd.result = await asyncio.shield( + self.parse_response(connection, cmd.args[0], **cmd.kwargs) + ) + except asyncio.CancelledError: + await connection.disconnect(nowait=True) + raise + except Exception as e: + cmd.result = e + ret = True + return ret + async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: # Acquire connection connection = self.acquire_connection() @@ -905,13 +918,7 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: # Read responses ret = False for cmd in commands: - try: - cmd.result = await self.parse_response( - connection, cmd.args[0], **cmd.kwargs - ) - except Exception as e: - cmd.result = e - ret = True + ret = await asyncio.shield(self._try_parse_response(cmd, connection, ret)) # Release connection self._free.append(connection) diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 2e44cdde3e..d6e01f79dc 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -333,23 +333,6 @@ async def test_execute_command_node_flag_random(self, r: RedisCluster) -> None: called_count += 1 assert called_count == 1 - async def test_asynckills(self, r) -> None: - - await r.set("foo", "foo") - await r.set("bar", "bar") - - t = asyncio.create_task(r.get("foo")) - await asyncio.sleep(1) - t.cancel() - try: - await t - except asyncio.CancelledError: - pytest.fail("connection is left open with unread response") - - assert await r.get("bar") == b"bar" - assert await r.ping() - assert await r.get("foo") == b"foo" - async def test_execute_command_default_node(self, r: RedisCluster) -> None: """ Test command execution without node flag is being executed on the diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index c414ee05cc..f6259adbd2 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -28,29 +28,6 @@ async def test_invalid_response(create_redis): assert str(cm.value) == f"Protocol Error: {raw!r}" -@pytest.mark.onlynoncluster -async def test_asynckills(): - from redis.asyncio.client import Redis - - for b in [True, False]: - r = Redis(single_connection_client=b) - - await r.set("foo", "foo") - await r.set("bar", "bar") - - t = asyncio.create_task(r.get("foo")) - await asyncio.sleep(1) - t.cancel() - try: - await t - except asyncio.CancelledError: - pytest.fail("connection left open with unread response") - - assert await r.get("bar") == b"bar" - assert await r.ping() - assert await r.get("foo") == b"foo" - - @skip_if_server_version_lt("4.0.0") @pytest.mark.redismod @pytest.mark.onlynoncluster diff --git a/tests/test_asyncio/test_cwe_404.py b/tests/test_asyncio/test_cwe_404.py new file mode 100644 index 0000000000..668344042d --- /dev/null +++ b/tests/test_asyncio/test_cwe_404.py @@ -0,0 +1,146 @@ +import asyncio +import sys + +import pytest + +from redis.asyncio import Redis +from redis.asyncio.cluster import RedisCluster + + +async def pipe( + reader: asyncio.StreamReader, writer: asyncio.StreamWriter, delay: float, name="" +): + while True: + data = await reader.read(1000) + if not data: + break + await asyncio.sleep(delay) + writer.write(data) + await writer.drain() + + +class DelayProxy: + def __init__(self, addr, redis_addr, delay: float): + self.addr = addr + self.redis_addr = redis_addr + self.delay = delay + + async def start(self): + self.server = await asyncio.start_server(self.handle, *self.addr) + self.ROUTINE = asyncio.create_task(self.server.serve_forever()) + + async def handle(self, reader, writer): + # establish connection to redis + redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr) + pipe1 = asyncio.create_task(pipe(reader, redis_writer, self.delay, "to redis:")) + pipe2 = asyncio.create_task( + pipe(redis_reader, writer, self.delay, "from redis:") + ) + await asyncio.gather(pipe1, pipe2) + + async def stop(self): + # clean up enough so that we can reuse the looper + self.ROUTINE.cancel() + loop = self.server.get_loop() + await loop.shutdown_asyncgens() + + +@pytest.mark.onlynoncluster +@pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2]) +async def test_standalone(delay): + + # create a tcp socket proxy that relays data to Redis and back, + # inserting 0.1 seconds of delay + dp = DelayProxy( + addr=("localhost", 5380), redis_addr=("localhost", 6379), delay=delay * 2 + ) + await dp.start() + + for b in [True, False]: + # note that we connect to proxy, rather than to Redis directly + async with Redis(host="localhost", port=5380, single_connection_client=b) as r: + + await r.set("foo", "foo") + await r.set("bar", "bar") + + t = asyncio.create_task(r.get("foo")) + await asyncio.sleep(delay) + t.cancel() + try: + await t + sys.stderr.write("try again, we did not cancel the task in time\n") + except asyncio.CancelledError: + sys.stderr.write( + "canceled task, connection is left open with unread response\n" + ) + + assert await r.get("bar") == b"bar" + assert await r.ping() + assert await r.get("foo") == b"foo" + + await dp.stop() + + +@pytest.mark.onlynoncluster +@pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2]) +async def test_standalone_pipeline(delay): + dp = DelayProxy( + addr=("localhost", 5380), redis_addr=("localhost", 6379), delay=delay * 2 + ) + await dp.start() + async with Redis(host="localhost", port=5380) as r: + await r.set("foo", "foo") + await r.set("bar", "bar") + + pipe = r.pipeline() + + pipe2 = r.pipeline() + pipe2.get("bar") + pipe2.ping() + pipe2.get("foo") + + t = asyncio.create_task(pipe.get("foo").execute()) + await asyncio.sleep(delay) + t.cancel() + + pipe.get("bar") + pipe.ping() + pipe.get("foo") + pipe.reset() + + assert await pipe.execute() is None + + # validating that the pipeline can be used as it could previously + pipe.get("bar") + pipe.ping() + pipe.get("foo") + assert await pipe.execute() == [b"bar", True, b"foo"] + assert await pipe2.execute() == [b"bar", True, b"foo"] + + await dp.stop() + + +@pytest.mark.onlycluster +async def test_cluster(request): + + dp = DelayProxy(addr=("localhost", 5381), redis_addr=("localhost", 6372), delay=0.1) + await dp.start() + + r = RedisCluster.from_url("redis://localhost:5381") + await r.initialize() + await r.set("foo", "foo") + await r.set("bar", "bar") + + t = asyncio.create_task(r.get("foo")) + await asyncio.sleep(0.050) + t.cancel() + try: + await t + except asyncio.CancelledError: + pytest.fail("connection is left open with unread response") + + assert await r.get("bar") == b"bar" + assert await r.ping() + assert await r.get("foo") == b"foo" + + await dp.stop()