diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 5b534006a8..2a649ff100 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -51,6 +51,7 @@ jobs: timeout-minutes: 30 strategy: max-parallel: 15 + fail-fast: false matrix: python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7', 'pypy-3.8'] test-type: ['standalone', 'cluster'] @@ -108,6 +109,7 @@ jobs: name: Install package from commit hash runs-on: ubuntu-latest strategy: + fail-fast: false matrix: python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7'] steps: diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 0bd21d3fab..3fd0636830 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -493,24 +493,34 @@ async def _disconnect_raise(self, conn: Connection, error: Exception): ): raise error - # COMMAND EXECUTION AND PROTOCOL PARSING - async def execute_command(self, *args, **options): - """Execute a command and return a parsed response""" - await self.initialize() - pool = self.connection_pool - command_name = args[0] - conn = self.connection or await pool.get_connection(command_name, **options) - + async def _try_send_command_parse_response(self, conn, *args, **options): try: return await conn.retry.call_with_retry( lambda: self._send_command_parse_response( - conn, command_name, *args, **options + conn, args[0], *args, **options ), lambda error: self._disconnect_raise(conn, error), ) + except asyncio.CancelledError: + await conn.disconnect(nowait=True) + raise finally: + if self.single_connection_client: + self._single_conn_lock.release() if not self.connection: - await pool.release(conn) + await self.connection_pool.release(conn) + + # COMMAND EXECUTION AND PROTOCOL PARSING + async def execute_command(self, *args, **options): + """Execute a command and return a parsed response""" + await self.initialize() + pool = self.connection_pool + command_name = args[0] + conn = self.connection or await pool.get_connection(command_name, **options) + + return await asyncio.shield( + self._try_send_command_parse_response(conn, *args, **options) + ) async def parse_response( self, connection: Connection, command_name: Union[str, bytes], **options @@ -749,10 +759,18 @@ async def _disconnect_raise_connect(self, conn, error): is not a TimeoutError. Otherwise, try to reconnect """ await conn.disconnect() + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): raise error await conn.connect() + async def _try_execute(self, conn, command, *arg, **kwargs): + try: + return await command(*arg, **kwargs) + except asyncio.CancelledError: + await conn.disconnect() + raise + async def _execute(self, conn, command, *args, **kwargs): """ Connect manually upon disconnection. If the Redis server is down, @@ -761,9 +779,11 @@ async def _execute(self, conn, command, *args, **kwargs): called by the # connection to resubscribe us to any channels and patterns we were previously listening to """ - return await conn.retry.call_with_retry( - lambda: command(*args, **kwargs), - lambda error: self._disconnect_raise_connect(conn, error), + return await asyncio.shield( + conn.retry.call_with_retry( + lambda: self._try_execute(conn, command, *args, **kwargs), + lambda error: self._disconnect_raise_connect(conn, error), + ) ) async def parse_response(self, block: bool = True, timeout: float = 0): @@ -1165,6 +1185,18 @@ async def _disconnect_reset_raise(self, conn, error): await self.reset() raise + async def _try_send_command_parse_response(self, conn, *args, **options): + try: + return await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, args[0], *args, **options + ), + lambda error: self._disconnect_reset_raise(conn, error), + ) + except asyncio.CancelledError: + await conn.disconnect() + raise + async def immediate_execute_command(self, *args, **options): """ Execute a command immediately, but don't auto-retry on a @@ -1180,13 +1212,13 @@ async def immediate_execute_command(self, *args, **options): command_name, self.shard_hint ) self.connection = conn - - return await conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), - lambda error: self._disconnect_reset_raise(conn, error), - ) + try: + return await asyncio.shield( + self._try_send_command_parse_response(conn, *args, **options) + ) + except asyncio.CancelledError: + await conn.disconnect() + raise def pipeline_execute_command(self, *args, **options): """ @@ -1353,6 +1385,19 @@ async def _disconnect_raise_reset(self, conn: Connection, error: Exception): await self.reset() raise + async def _try_execute(self, conn, execute, stack, raise_on_error): + try: + return await conn.retry.call_with_retry( + lambda: execute(conn, stack, raise_on_error), + lambda error: self._disconnect_raise_reset(conn, error), + ) + except asyncio.CancelledError: + # not supposed to be possible, yet here we are + await conn.disconnect(nowait=True) + raise + finally: + await self.reset() + async def execute(self, raise_on_error: bool = True): """Execute all the commands in the current pipeline""" stack = self.command_stack @@ -1375,15 +1420,10 @@ async def execute(self, raise_on_error: bool = True): try: return await asyncio.shield( - conn.retry.call_with_retry( - lambda: execute(conn, stack, raise_on_error), - lambda error: self._disconnect_raise_reset(conn, error), - ) + self._try_execute(conn, execute, stack, raise_on_error) ) - except asyncio.CancelledError: - # not supposed to be possible, yet here we are - await conn.disconnect(nowait=True) - raise + except RuntimeError: + await self.reset() finally: await self.reset() diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 569a0765f8..a4a9561cf1 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -1016,6 +1016,19 @@ async def _parse_and_release(self, connection, *args, **kwargs): finally: self._free.append(connection) + async def _try_parse_response(self, cmd, connection, ret): + try: + cmd.result = await asyncio.shield( + self.parse_response(connection, cmd.args[0], **cmd.kwargs) + ) + except asyncio.CancelledError: + await connection.disconnect(nowait=True) + raise + except Exception as e: + cmd.result = e + ret = True + return ret + async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: # Acquire connection connection = self.acquire_connection() @@ -1028,13 +1041,7 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: # Read responses ret = False for cmd in commands: - try: - cmd.result = await self.parse_response( - connection, cmd.args[0], **cmd.kwargs - ) - except Exception as e: - cmd.result = e - ret = True + ret = await asyncio.shield(self._try_parse_response(cmd, connection, ret)) # Release connection self._free.append(connection) diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 0857c056c2..13e5e26ae3 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -340,23 +340,6 @@ async def test_from_url(self, request: FixtureRequest) -> None: rc = RedisCluster.from_url("rediss://localhost:16379") assert rc.connection_kwargs["connection_class"] is SSLConnection - async def test_asynckills(self, r) -> None: - - await r.set("foo", "foo") - await r.set("bar", "bar") - - t = asyncio.create_task(r.get("foo")) - await asyncio.sleep(1) - t.cancel() - try: - await t - except asyncio.CancelledError: - pytest.fail("connection is left open with unread response") - - assert await r.get("bar") == b"bar" - assert await r.ping() - assert await r.get("foo") == b"foo" - async def test_max_connections( self, create_redis: Callable[..., RedisCluster] ) -> None: diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index ccd6cdb67a..8e4fdac309 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -6,6 +6,7 @@ import pytest import redis +from redis.asyncio import Redis from redis.asyncio.connection import ( BaseParser, Connection, @@ -42,25 +43,47 @@ async def test_invalid_response(create_redis): @pytest.mark.onlynoncluster -async def test_asynckills(create_redis): - - for b in [True, False]: - r = await create_redis(single_connection_client=b) - - await r.set("foo", "foo") - await r.set("bar", "bar") - - t = asyncio.create_task(r.get("foo")) - await asyncio.sleep(1) - t.cancel() - try: - await t - except asyncio.CancelledError: - pytest.fail("connection left open with unread response") - - assert await r.get("bar") == b"bar" - assert await r.ping() - assert await r.get("foo") == b"foo" +async def test_single_connection(): + """Test that concurrent requests on a single client are synchronised.""" + r = Redis(single_connection_client=True) + + init_call_count = 0 + command_call_count = 0 + in_use = False + + class Retry_: + async def call_with_retry(self, _, __): + # If we remove the single-client lock, this error gets raised as two + # coroutines will be vying for the `in_use` flag due to the two + # asymmetric sleep calls + nonlocal command_call_count + nonlocal in_use + if in_use is True: + raise ValueError("Commands should be executed one at a time.") + in_use = True + await asyncio.sleep(0.01) + command_call_count += 1 + await asyncio.sleep(0.03) + in_use = False + return "foo" + + mock_conn = mock.MagicMock() + mock_conn.retry = Retry_() + + async def get_conn(_): + # Validate only one client is created in single-client mode when + # concurrent requests are made + nonlocal init_call_count + await asyncio.sleep(0.01) + init_call_count += 1 + return mock_conn + + with mock.patch.object(r.connection_pool, "get_connection", get_conn): + with mock.patch.object(r.connection_pool, "release"): + await asyncio.gather(r.set("a", "b"), r.set("c", "d")) + + assert init_call_count == 1 + assert command_call_count == 2 @skip_if_server_version_lt("4.0.0") diff --git a/tests/test_asyncio/test_cwe_404.py b/tests/test_asyncio/test_cwe_404.py new file mode 100644 index 0000000000..668344042d --- /dev/null +++ b/tests/test_asyncio/test_cwe_404.py @@ -0,0 +1,146 @@ +import asyncio +import sys + +import pytest + +from redis.asyncio import Redis +from redis.asyncio.cluster import RedisCluster + + +async def pipe( + reader: asyncio.StreamReader, writer: asyncio.StreamWriter, delay: float, name="" +): + while True: + data = await reader.read(1000) + if not data: + break + await asyncio.sleep(delay) + writer.write(data) + await writer.drain() + + +class DelayProxy: + def __init__(self, addr, redis_addr, delay: float): + self.addr = addr + self.redis_addr = redis_addr + self.delay = delay + + async def start(self): + self.server = await asyncio.start_server(self.handle, *self.addr) + self.ROUTINE = asyncio.create_task(self.server.serve_forever()) + + async def handle(self, reader, writer): + # establish connection to redis + redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr) + pipe1 = asyncio.create_task(pipe(reader, redis_writer, self.delay, "to redis:")) + pipe2 = asyncio.create_task( + pipe(redis_reader, writer, self.delay, "from redis:") + ) + await asyncio.gather(pipe1, pipe2) + + async def stop(self): + # clean up enough so that we can reuse the looper + self.ROUTINE.cancel() + loop = self.server.get_loop() + await loop.shutdown_asyncgens() + + +@pytest.mark.onlynoncluster +@pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2]) +async def test_standalone(delay): + + # create a tcp socket proxy that relays data to Redis and back, + # inserting 0.1 seconds of delay + dp = DelayProxy( + addr=("localhost", 5380), redis_addr=("localhost", 6379), delay=delay * 2 + ) + await dp.start() + + for b in [True, False]: + # note that we connect to proxy, rather than to Redis directly + async with Redis(host="localhost", port=5380, single_connection_client=b) as r: + + await r.set("foo", "foo") + await r.set("bar", "bar") + + t = asyncio.create_task(r.get("foo")) + await asyncio.sleep(delay) + t.cancel() + try: + await t + sys.stderr.write("try again, we did not cancel the task in time\n") + except asyncio.CancelledError: + sys.stderr.write( + "canceled task, connection is left open with unread response\n" + ) + + assert await r.get("bar") == b"bar" + assert await r.ping() + assert await r.get("foo") == b"foo" + + await dp.stop() + + +@pytest.mark.onlynoncluster +@pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2]) +async def test_standalone_pipeline(delay): + dp = DelayProxy( + addr=("localhost", 5380), redis_addr=("localhost", 6379), delay=delay * 2 + ) + await dp.start() + async with Redis(host="localhost", port=5380) as r: + await r.set("foo", "foo") + await r.set("bar", "bar") + + pipe = r.pipeline() + + pipe2 = r.pipeline() + pipe2.get("bar") + pipe2.ping() + pipe2.get("foo") + + t = asyncio.create_task(pipe.get("foo").execute()) + await asyncio.sleep(delay) + t.cancel() + + pipe.get("bar") + pipe.ping() + pipe.get("foo") + pipe.reset() + + assert await pipe.execute() is None + + # validating that the pipeline can be used as it could previously + pipe.get("bar") + pipe.ping() + pipe.get("foo") + assert await pipe.execute() == [b"bar", True, b"foo"] + assert await pipe2.execute() == [b"bar", True, b"foo"] + + await dp.stop() + + +@pytest.mark.onlycluster +async def test_cluster(request): + + dp = DelayProxy(addr=("localhost", 5381), redis_addr=("localhost", 6372), delay=0.1) + await dp.start() + + r = RedisCluster.from_url("redis://localhost:5381") + await r.initialize() + await r.set("foo", "foo") + await r.set("bar", "bar") + + t = asyncio.create_task(r.get("foo")) + await asyncio.sleep(0.050) + t.cancel() + try: + await t + except asyncio.CancelledError: + pytest.fail("connection is left open with unread response") + + assert await r.get("bar") == b"bar" + assert await r.ping() + assert await r.get("foo") == b"foo" + + await dp.stop() diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index c2a9130e83..243d410fac 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -967,6 +967,9 @@ async def get_msg_or_timeout(timeout=0.1): # the timeout on the read should not cause disconnect assert pubsub.connection.is_connected + @pytest.mark.skipif( + sys.version_info < (3, 8), reason="requires python 3.8 or higher" + ) async def test_base_exception(self, r: redis.Redis): """ Manually trigger a BaseException inside the parser's .read_response method