Skip to content

Commit

Permalink
Fixing cancelled async futures (#2666)
Browse files Browse the repository at this point in the history
Co-authored-by: James R T <jamestiotio@gmail.com>
Co-authored-by: dvora-h <dvora.heller@redis.com>
  • Loading branch information
3 people committed Mar 29, 2023
1 parent 7b48b1b commit 6cd5173
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 75 deletions.
94 changes: 66 additions & 28 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,24 +475,32 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
):
raise error

# COMMAND EXECUTION AND PROTOCOL PARSING
async def execute_command(self, *args, **options):
"""Execute a command and return a parsed response"""
await self.initialize()
pool = self.connection_pool
command_name = args[0]
conn = self.connection or await pool.get_connection(command_name, **options)

async def _try_send_command_parse_response(self, conn, *args, **options):
try:
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
conn, args[0], *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
except asyncio.CancelledError:
await conn.disconnect(nowait=True)
raise
finally:
if not self.connection:
await pool.release(conn)
await self.connection_pool.release(conn)

# COMMAND EXECUTION AND PROTOCOL PARSING
async def execute_command(self, *args, **options):
"""Execute a command and return a parsed response"""
await self.initialize()
pool = self.connection_pool
command_name = args[0]
conn = self.connection or await pool.get_connection(command_name, **options)

return await asyncio.shield(
self._try_send_command_parse_response(conn, *args, **options)
)

async def parse_response(
self, connection: Connection, command_name: Union[str, bytes], **options
Expand Down Expand Up @@ -726,10 +734,18 @@ async def _disconnect_raise_connect(self, conn, error):
is not a TimeoutError. Otherwise, try to reconnect
"""
await conn.disconnect()

if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
raise error
await conn.connect()

async def _try_execute(self, conn, command, *arg, **kwargs):
try:
return await command(*arg, **kwargs)
except asyncio.CancelledError:
await conn.disconnect()
raise

async def _execute(self, conn, command, *args, **kwargs):
"""
Connect manually upon disconnection. If the Redis server is down,
Expand All @@ -738,9 +754,11 @@ async def _execute(self, conn, command, *args, **kwargs):
called by the # connection to resubscribe us to any channels and
patterns we were previously listening to
"""
return await conn.retry.call_with_retry(
lambda: command(*args, **kwargs),
lambda error: self._disconnect_raise_connect(conn, error),
return await asyncio.shield(
conn.retry.call_with_retry(
lambda: self._try_execute(conn, command, *args, **kwargs),
lambda error: self._disconnect_raise_connect(conn, error),
)
)

async def parse_response(self, block: bool = True, timeout: float = 0):
Expand Down Expand Up @@ -1140,6 +1158,18 @@ async def _disconnect_reset_raise(self, conn, error):
await self.reset()
raise

async def _try_send_command_parse_response(self, conn, *args, **options):
try:
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, args[0], *args, **options
),
lambda error: self._disconnect_reset_raise(conn, error),
)
except asyncio.CancelledError:
await conn.disconnect()
raise

async def immediate_execute_command(self, *args, **options):
"""
Execute a command immediately, but don't auto-retry on a
Expand All @@ -1155,13 +1185,13 @@ async def immediate_execute_command(self, *args, **options):
command_name, self.shard_hint
)
self.connection = conn

return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_reset_raise(conn, error),
)
try:
return await asyncio.shield(
self._try_send_command_parse_response(conn, *args, **options)
)
except asyncio.CancelledError:
await conn.disconnect()
raise

def pipeline_execute_command(self, *args, **options):
"""
Expand Down Expand Up @@ -1328,6 +1358,19 @@ async def _disconnect_raise_reset(self, conn: Connection, error: Exception):
await self.reset()
raise

async def _try_execute(self, conn, execute, stack, raise_on_error):
try:
return await conn.retry.call_with_retry(
lambda: execute(conn, stack, raise_on_error),
lambda error: self._disconnect_raise_reset(conn, error),
)
except asyncio.CancelledError:
# not supposed to be possible, yet here we are
await conn.disconnect(nowait=True)
raise
finally:
await self.reset()

async def execute(self, raise_on_error: bool = True):
"""Execute all the commands in the current pipeline"""
stack = self.command_stack
Expand All @@ -1350,15 +1393,10 @@ async def execute(self, raise_on_error: bool = True):

try:
return await asyncio.shield(
conn.retry.call_with_retry(
lambda: execute(conn, stack, raise_on_error),
lambda error: self._disconnect_raise_reset(conn, error),
)
self._try_execute(conn, execute, stack, raise_on_error)
)
except asyncio.CancelledError:
# not supposed to be possible, yet here we are
await conn.disconnect(nowait=True)
raise
except RuntimeError:
await self.reset()
finally:
await self.reset()

Expand Down
21 changes: 14 additions & 7 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,19 @@ async def _parse_and_release(self, connection, *args, **kwargs):
finally:
self._free.append(connection)

async def _try_parse_response(self, cmd, connection, ret):
try:
cmd.result = await asyncio.shield(
self.parse_response(connection, cmd.args[0], **cmd.kwargs)
)
except asyncio.CancelledError:
await connection.disconnect(nowait=True)
raise
except Exception as e:
cmd.result = e
ret = True
return ret

async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
# Acquire connection
connection = self.acquire_connection()
Expand All @@ -905,13 +918,7 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
# Read responses
ret = False
for cmd in commands:
try:
cmd.result = await self.parse_response(
connection, cmd.args[0], **cmd.kwargs
)
except Exception as e:
cmd.result = e
ret = True
ret = await asyncio.shield(self._try_parse_response(cmd, connection, ret))

# Release connection
self._free.append(connection)
Expand Down
17 changes: 0 additions & 17 deletions tests/test_asyncio/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,23 +333,6 @@ async def test_execute_command_node_flag_random(self, r: RedisCluster) -> None:
called_count += 1
assert called_count == 1

async def test_asynckills(self, r) -> None:

await r.set("foo", "foo")
await r.set("bar", "bar")

t = asyncio.create_task(r.get("foo"))
await asyncio.sleep(1)
t.cancel()
try:
await t
except asyncio.CancelledError:
pytest.fail("connection is left open with unread response")

assert await r.get("bar") == b"bar"
assert await r.ping()
assert await r.get("foo") == b"foo"

async def test_execute_command_default_node(self, r: RedisCluster) -> None:
"""
Test command execution without node flag is being executed on the
Expand Down
23 changes: 0 additions & 23 deletions tests/test_asyncio/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,6 @@ async def test_invalid_response(create_redis):
assert str(cm.value) == f"Protocol Error: {raw!r}"


@pytest.mark.onlynoncluster
async def test_asynckills():
from redis.asyncio.client import Redis

for b in [True, False]:
r = Redis(single_connection_client=b)

await r.set("foo", "foo")
await r.set("bar", "bar")

t = asyncio.create_task(r.get("foo"))
await asyncio.sleep(1)
t.cancel()
try:
await t
except asyncio.CancelledError:
pytest.fail("connection left open with unread response")

assert await r.get("bar") == b"bar"
assert await r.ping()
assert await r.get("foo") == b"foo"


@skip_if_server_version_lt("4.0.0")
@pytest.mark.redismod
@pytest.mark.onlynoncluster
Expand Down
146 changes: 146 additions & 0 deletions tests/test_asyncio/test_cwe_404.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import asyncio
import sys

import pytest

from redis.asyncio import Redis
from redis.asyncio.cluster import RedisCluster


async def pipe(
reader: asyncio.StreamReader, writer: asyncio.StreamWriter, delay: float, name=""
):
while True:
data = await reader.read(1000)
if not data:
break
await asyncio.sleep(delay)
writer.write(data)
await writer.drain()


class DelayProxy:
def __init__(self, addr, redis_addr, delay: float):
self.addr = addr
self.redis_addr = redis_addr
self.delay = delay

async def start(self):
self.server = await asyncio.start_server(self.handle, *self.addr)
self.ROUTINE = asyncio.create_task(self.server.serve_forever())

async def handle(self, reader, writer):
# establish connection to redis
redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr)
pipe1 = asyncio.create_task(pipe(reader, redis_writer, self.delay, "to redis:"))
pipe2 = asyncio.create_task(
pipe(redis_reader, writer, self.delay, "from redis:")
)
await asyncio.gather(pipe1, pipe2)

async def stop(self):
# clean up enough so that we can reuse the looper
self.ROUTINE.cancel()
loop = self.server.get_loop()
await loop.shutdown_asyncgens()


@pytest.mark.onlynoncluster
@pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2])
async def test_standalone(delay):

# create a tcp socket proxy that relays data to Redis and back,
# inserting 0.1 seconds of delay
dp = DelayProxy(
addr=("localhost", 5380), redis_addr=("localhost", 6379), delay=delay * 2
)
await dp.start()

for b in [True, False]:
# note that we connect to proxy, rather than to Redis directly
async with Redis(host="localhost", port=5380, single_connection_client=b) as r:

await r.set("foo", "foo")
await r.set("bar", "bar")

t = asyncio.create_task(r.get("foo"))
await asyncio.sleep(delay)
t.cancel()
try:
await t
sys.stderr.write("try again, we did not cancel the task in time\n")
except asyncio.CancelledError:
sys.stderr.write(
"canceled task, connection is left open with unread response\n"
)

assert await r.get("bar") == b"bar"
assert await r.ping()
assert await r.get("foo") == b"foo"

await dp.stop()


@pytest.mark.onlynoncluster
@pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2])
async def test_standalone_pipeline(delay):
dp = DelayProxy(
addr=("localhost", 5380), redis_addr=("localhost", 6379), delay=delay * 2
)
await dp.start()
async with Redis(host="localhost", port=5380) as r:
await r.set("foo", "foo")
await r.set("bar", "bar")

pipe = r.pipeline()

pipe2 = r.pipeline()
pipe2.get("bar")
pipe2.ping()
pipe2.get("foo")

t = asyncio.create_task(pipe.get("foo").execute())
await asyncio.sleep(delay)
t.cancel()

pipe.get("bar")
pipe.ping()
pipe.get("foo")
pipe.reset()

assert await pipe.execute() is None

# validating that the pipeline can be used as it could previously
pipe.get("bar")
pipe.ping()
pipe.get("foo")
assert await pipe.execute() == [b"bar", True, b"foo"]
assert await pipe2.execute() == [b"bar", True, b"foo"]

await dp.stop()


@pytest.mark.onlycluster
async def test_cluster(request):

dp = DelayProxy(addr=("localhost", 5381), redis_addr=("localhost", 6372), delay=0.1)
await dp.start()

r = RedisCluster.from_url("redis://localhost:5381")
await r.initialize()
await r.set("foo", "foo")
await r.set("bar", "bar")

t = asyncio.create_task(r.get("foo"))
await asyncio.sleep(0.050)
t.cancel()
try:
await t
except asyncio.CancelledError:
pytest.fail("connection is left open with unread response")

assert await r.get("bar") == b"bar"
assert await r.ping()
assert await r.get("foo") == b"foo"

await dp.stop()

0 comments on commit 6cd5173

Please sign in to comment.