Skip to content

Commit

Permalink
Fixing cancelled async futures (#2666)
Browse files Browse the repository at this point in the history
Co-authored-by: James R T <jamestiotio@gmail.com>
Co-authored-by: dvora-h <dvora.heller@redis.com>
  • Loading branch information
3 people committed Mar 29, 2023
1 parent b3c89ac commit af2ca45
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 71 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ jobs:
timeout-minutes: 30
strategy:
max-parallel: 15
fail-fast: false
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7', 'pypy-3.8']
test-type: ['standalone', 'cluster']
Expand Down Expand Up @@ -108,6 +109,7 @@ jobs:
name: Install package from commit hash
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7']
steps:
Expand Down
96 changes: 68 additions & 28 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,24 +493,34 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
):
raise error

# COMMAND EXECUTION AND PROTOCOL PARSING
async def execute_command(self, *args, **options):
"""Execute a command and return a parsed response"""
await self.initialize()
pool = self.connection_pool
command_name = args[0]
conn = self.connection or await pool.get_connection(command_name, **options)

async def _try_send_command_parse_response(self, conn, *args, **options):
try:
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
conn, args[0], *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
except asyncio.CancelledError:
await conn.disconnect(nowait=True)
raise
finally:
if self.single_connection_client:
self._single_conn_lock.release()
if not self.connection:
await pool.release(conn)
await self.connection_pool.release(conn)

# COMMAND EXECUTION AND PROTOCOL PARSING
async def execute_command(self, *args, **options):
"""Execute a command and return a parsed response"""
await self.initialize()
pool = self.connection_pool
command_name = args[0]
conn = self.connection or await pool.get_connection(command_name, **options)

return await asyncio.shield(
self._try_send_command_parse_response(conn, *args, **options)
)

async def parse_response(
self, connection: Connection, command_name: Union[str, bytes], **options
Expand Down Expand Up @@ -749,10 +759,18 @@ async def _disconnect_raise_connect(self, conn, error):
is not a TimeoutError. Otherwise, try to reconnect
"""
await conn.disconnect()

if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
raise error
await conn.connect()

async def _try_execute(self, conn, command, *arg, **kwargs):
try:
return await command(*arg, **kwargs)
except asyncio.CancelledError:
await conn.disconnect()
raise

async def _execute(self, conn, command, *args, **kwargs):
"""
Connect manually upon disconnection. If the Redis server is down,
Expand All @@ -761,9 +779,11 @@ async def _execute(self, conn, command, *args, **kwargs):
called by the # connection to resubscribe us to any channels and
patterns we were previously listening to
"""
return await conn.retry.call_with_retry(
lambda: command(*args, **kwargs),
lambda error: self._disconnect_raise_connect(conn, error),
return await asyncio.shield(
conn.retry.call_with_retry(
lambda: self._try_execute(conn, command, *args, **kwargs),
lambda error: self._disconnect_raise_connect(conn, error),
)
)

async def parse_response(self, block: bool = True, timeout: float = 0):
Expand Down Expand Up @@ -1165,6 +1185,18 @@ async def _disconnect_reset_raise(self, conn, error):
await self.reset()
raise

async def _try_send_command_parse_response(self, conn, *args, **options):
try:
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, args[0], *args, **options
),
lambda error: self._disconnect_reset_raise(conn, error),
)
except asyncio.CancelledError:
await conn.disconnect()
raise

async def immediate_execute_command(self, *args, **options):
"""
Execute a command immediately, but don't auto-retry on a
Expand All @@ -1180,13 +1212,13 @@ async def immediate_execute_command(self, *args, **options):
command_name, self.shard_hint
)
self.connection = conn

return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_reset_raise(conn, error),
)
try:
return await asyncio.shield(
self._try_send_command_parse_response(conn, *args, **options)
)
except asyncio.CancelledError:
await conn.disconnect()
raise

def pipeline_execute_command(self, *args, **options):
"""
Expand Down Expand Up @@ -1353,6 +1385,19 @@ async def _disconnect_raise_reset(self, conn: Connection, error: Exception):
await self.reset()
raise

async def _try_execute(self, conn, execute, stack, raise_on_error):
try:
return await conn.retry.call_with_retry(
lambda: execute(conn, stack, raise_on_error),
lambda error: self._disconnect_raise_reset(conn, error),
)
except asyncio.CancelledError:
# not supposed to be possible, yet here we are
await conn.disconnect(nowait=True)
raise
finally:
await self.reset()

async def execute(self, raise_on_error: bool = True):
"""Execute all the commands in the current pipeline"""
stack = self.command_stack
Expand All @@ -1375,15 +1420,10 @@ async def execute(self, raise_on_error: bool = True):

try:
return await asyncio.shield(
conn.retry.call_with_retry(
lambda: execute(conn, stack, raise_on_error),
lambda error: self._disconnect_raise_reset(conn, error),
)
self._try_execute(conn, execute, stack, raise_on_error)
)
except asyncio.CancelledError:
# not supposed to be possible, yet here we are
await conn.disconnect(nowait=True)
raise
except RuntimeError:
await self.reset()
finally:
await self.reset()

Expand Down
21 changes: 14 additions & 7 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,19 @@ async def _parse_and_release(self, connection, *args, **kwargs):
finally:
self._free.append(connection)

async def _try_parse_response(self, cmd, connection, ret):
try:
cmd.result = await asyncio.shield(
self.parse_response(connection, cmd.args[0], **cmd.kwargs)
)
except asyncio.CancelledError:
await connection.disconnect(nowait=True)
raise
except Exception as e:
cmd.result = e
ret = True
return ret

async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
# Acquire connection
connection = self.acquire_connection()
Expand All @@ -1028,13 +1041,7 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
# Read responses
ret = False
for cmd in commands:
try:
cmd.result = await self.parse_response(
connection, cmd.args[0], **cmd.kwargs
)
except Exception as e:
cmd.result = e
ret = True
ret = await asyncio.shield(self._try_parse_response(cmd, connection, ret))

# Release connection
self._free.append(connection)
Expand Down
17 changes: 0 additions & 17 deletions tests/test_asyncio/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,23 +340,6 @@ async def test_from_url(self, request: FixtureRequest) -> None:
rc = RedisCluster.from_url("rediss://localhost:16379")
assert rc.connection_kwargs["connection_class"] is SSLConnection

async def test_asynckills(self, r) -> None:

await r.set("foo", "foo")
await r.set("bar", "bar")

t = asyncio.create_task(r.get("foo"))
await asyncio.sleep(1)
t.cancel()
try:
await t
except asyncio.CancelledError:
pytest.fail("connection is left open with unread response")

assert await r.get("bar") == b"bar"
assert await r.ping()
assert await r.get("foo") == b"foo"

async def test_max_connections(
self, create_redis: Callable[..., RedisCluster]
) -> None:
Expand Down
61 changes: 42 additions & 19 deletions tests/test_asyncio/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

import redis
from redis.asyncio import Redis
from redis.asyncio.connection import (
BaseParser,
Connection,
Expand Down Expand Up @@ -42,25 +43,47 @@ async def test_invalid_response(create_redis):


@pytest.mark.onlynoncluster
async def test_asynckills(create_redis):

for b in [True, False]:
r = await create_redis(single_connection_client=b)

await r.set("foo", "foo")
await r.set("bar", "bar")

t = asyncio.create_task(r.get("foo"))
await asyncio.sleep(1)
t.cancel()
try:
await t
except asyncio.CancelledError:
pytest.fail("connection left open with unread response")

assert await r.get("bar") == b"bar"
assert await r.ping()
assert await r.get("foo") == b"foo"
async def test_single_connection():
"""Test that concurrent requests on a single client are synchronised."""
r = Redis(single_connection_client=True)

init_call_count = 0
command_call_count = 0
in_use = False

class Retry_:
async def call_with_retry(self, _, __):
# If we remove the single-client lock, this error gets raised as two
# coroutines will be vying for the `in_use` flag due to the two
# asymmetric sleep calls
nonlocal command_call_count
nonlocal in_use
if in_use is True:
raise ValueError("Commands should be executed one at a time.")
in_use = True
await asyncio.sleep(0.01)
command_call_count += 1
await asyncio.sleep(0.03)
in_use = False
return "foo"

mock_conn = mock.MagicMock()
mock_conn.retry = Retry_()

async def get_conn(_):
# Validate only one client is created in single-client mode when
# concurrent requests are made
nonlocal init_call_count
await asyncio.sleep(0.01)
init_call_count += 1
return mock_conn

with mock.patch.object(r.connection_pool, "get_connection", get_conn):
with mock.patch.object(r.connection_pool, "release"):
await asyncio.gather(r.set("a", "b"), r.set("c", "d"))

assert init_call_count == 1
assert command_call_count == 2


@skip_if_server_version_lt("4.0.0")
Expand Down
Loading

0 comments on commit af2ca45

Please sign in to comment.