Skip to content

Commit

Permalink
fix: use disconnect to reauthorize, add logically disconnected state …
Browse files Browse the repository at this point in the history
…to ReconnectStage
  • Loading branch information
BertKleewein committed Jun 3, 2020
1 parent 10cd6fa commit 123c045
Show file tree
Hide file tree
Showing 6 changed files with 460 additions and 641 deletions.
35 changes: 3 additions & 32 deletions azure-iot-device/azure/iot/device/common/mqtt_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def on_disconnect(client, userdata, rc):
if not this:
# Paho will sometimes call this after we've been garbage collected, If so, we have to
# stop the loop to make sure the Paho thread shuts down.
logger.info("disconnected after garbage collection. stopping loop")
logger.info(
"on_disconnect called with this=None. Transport must have been garbage collected. stopping loop"
)
client.loop_stop()
else:
if this.on_mqtt_disconnected_handler:
Expand Down Expand Up @@ -428,37 +430,6 @@ def connect(self, password=None):
raise _create_error_from_rc_code(rc)
self._mqtt_client.loop_start()

def reauthorize_connection(self, password=None):
"""
Reauthorize with the MQTT broker, using username set at instantiation.
Connect should have previously been called in order to use this function.
The password is not required if the transport was instantiated with an x509 certificate.
:param str password: The password for reauthorizing with the MQTT broker (Optional).
:raises: ConnectionFailedError if connection could not be established.
:raises: ConnectionDroppedError if connection is dropped during execution.
:raises: UnauthorizedError if there is an error authenticating.
:raises: ProtocolClientError if there is some other client error.
"""
logger.info("reauthorizing MQTT client")
self._mqtt_client.username_pw_set(username=self._username, password=password)
try:
rc = self._mqtt_client.reconnect()
except Exception as e:
logger.info("reconnect raised {}".format(e))
self._cleanup_transport_on_error()
raise exceptions.ConnectionDroppedError(
message="Unexpected Paho failure during reconnect", cause=e
)
logger.debug("_mqtt_client.reconnect returned rc={}".format(rc))
if rc:
# This could result in ConnectionFailedError, ConnectionDroppedError, UnauthorizedError
# or ProtocolClientError
raise _create_error_from_rc_code(rc)

def disconnect(self):
"""
Disconnect from the MQTT broker.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -897,28 +897,33 @@ class ReconnectState(object):
"""
Class which holds reconenct states as class variables. Created to make code that reads like an enum without using an enum.
NEVER_CONNECTED: Ttransport has never been conencted. This state is necessary because some errors might be fatal or transient,
depending on wether the transport has been connceted. For example, a failed conenction is a transient error if we've connected
before, but it's fatal if we've never conencted.
WAITING_TO_RECONNECT: This stage is in a waiting period before reconnecting. This state implies
that the user wants the pipeline to be connected. ie. After a successful connection, the
state will change to LOGICALLY_CONNECED
WAITING_TO_RECONNECT: This stage is in a waiting period before reconnecting.
LOGICALLY_CONNECTED: The client wants the pipeline to be connected. This state is independent
of the actual connection state since the pipeline could be logically connected but
physically disconnected.
CONNECTED_OR_DISCONNECTED: The transport is either connected or disconencted. This stage doesn't really care which one, so
it doesn't keep track.
LOGICALLY_DISCONNECTED: The client does not want the pipeline to be connected or the pipeline had
a fatal error and was forced to disconnect. If the state is LOGICALLY_DISCONNECTED, then the pipeline
should be physically disconnected since there is no reason to leave the pipeline connected in this state.
"""

NEVER_CONNECTED = "NEVER_CONNECTED"
WAITING_TO_RECONNECT = "WAITING_TO_RECONNECT"
CONNECTED_OR_DISCONNECTED = "CONNECTED_OR_DISCONNECTED"
LOGICALLY_CONNECTED = "LOGICALLY_CONNECTED"
LOGICALLY_DISCONNECTED = "LOGICALLY_DISCONNECTED"


class ReconnectStage(PipelineStage):
def __init__(self):
super(ReconnectStage, self).__init__()
self.reconnect_timer = None
self.state = ReconnectState.NEVER_CONNECTED
self.state = ReconnectState.LOGICALLY_DISCONNECTED
self.never_connected = True
# connect delay is hardcoded for now. Later, this comes from a retry policy
self.reconnect_delay = 10
self.delayed_reconnect_delay = 10
self.immediate_reconnect_delay = 0.01
self.waiting_connect_ops = []

@pipeline_thread.runs_on_pipeline_thread
Expand All @@ -937,6 +942,14 @@ def _run_op(self, op):
self.name, op.name, self.state
)
)
self.state = ReconnectState.LOGICALLY_CONNECTED
# We don't send this op down. Instead, we send a new connect op down. This way,
# we can distinguish between connect ops that we're handling (they go into the
# queue) and connect ops that we are sending down.
#
# Once we finally connect, we only have to complete the ops in the queue and we
# never have to worry about completing the op that we sent down. The code is much
# cleaner this way, especially when you take retries into account, trust me.
self.waiting_connect_ops.append(op)
self._send_new_connect_op_down()

Expand All @@ -947,17 +960,18 @@ def _run_op(self, op):
self.name, op.name, self.state
)
)
self.state = ReconnectState.LOGICALLY_DISCONNECTED
self._clear_reconnect_timer()
self._complete_waiting_connect_ops(
pipeline_exceptions.OperationCancelled("Explicit disconnect invoked")
)
self.state = ReconnectState.CONNECTED_OR_DISCONNECTED
op.complete()

else:
logger.info(
"{}({}): State is {}. Sending op down.".format(self.name, op.name, self.state)
)
self.state = ReconnectState.LOGICALLY_DISCONNECTED
self.send_op_down(op)

else:
Expand All @@ -966,18 +980,27 @@ def _run_op(self, op):
@pipeline_thread.runs_on_pipeline_thread
def _handle_pipeline_event(self, event):
if isinstance(event, pipeline_events_base.DisconnectedEvent):
if self.pipeline_root.connected:
logger.info(
"{}({}): State is {}. Triggering reconnect timer".format(
self.name, event.name, self.state
)
logger.debug(
"{}({}): State is {} Connected is {}.".format(
self.name, event.name, self.state, self.pipeline_root.connected
)
)

if self.pipeline_root.connected and self.state == ReconnectState.LOGICALLY_CONNECTED:
# when we get disconnected, we try to immediatly reconnect. If that fails,
# then we wait a while before retrying.
#
# For now, we use immediate_reconnect_delay to mean "close enough to immediate"
# because it saves us from refactoring this code when we know that another change
# is coming. Specifically, we know that this stage is broken in cases where the
# user calls disconnect(). In that case, this stage doesn't currently know that
# it shouldn't reconnect automatically.
self.state = ReconnectState.WAITING_TO_RECONNECT
self._start_reconnect_timer()
self._start_reconnect_timer(self.immediate_reconnect_delay)

else:
logger.info(
"{}({}): State is {}. Doing nothing".format(self.name, event.name, self.state)
)
# do nothing
pass

self.send_event_up(event)

Expand All @@ -992,60 +1015,52 @@ def _send_new_connect_op_down(self):
def on_connect_complete(op, error):
this = self_weakref()
if this:
logger.debug(
"{}({}): on_connect_complete error={} state={} never_conencted={} connected={} ".format(
this.name,
op.name,
error,
this.state,
this.never_connected,
this.pipeline_root.connected,
)
)
if error:
if this.state == ReconnectState.NEVER_CONNECTED:
logger.info(
"{}({}): error on first connection. Not triggering reconnection".format(
this.name, op.name
)
)
if this.never_connected:
# any error on a first connection is fatal
this.state = ReconnectState.LOGICALLY_DISCONNECTED
this._clear_reconnect_timer()
this._complete_waiting_connect_ops(error)
elif type(error) in transient_connect_errors:
logger.info(
"{}({}): State is {}. Connect failed with transient error. Triggering reconnect timer".format(
self.name, op.name, self.state
)
)
# transient errors cause a reconnect attempt
self.state = ReconnectState.WAITING_TO_RECONNECT
self._start_reconnect_timer()

elif this.state == ReconnectState.WAITING_TO_RECONNECT:
logger.info(
"{}({}): non-tranient error. Failing all waiting ops.n".format(
this.name, op.name
)
)
self.state = ReconnectState.CONNECTED_OR_DISCONNECTED
self._clear_reconnect_timer()
this._complete_waiting_connect_ops(error)

self._start_reconnect_timer(self.delayed_reconnect_delay)
else:
logger.info(
"{}({}): State is {}. Connection failed. Not triggering reconnection".format(
this.name, op.name, this.state
)
)
# all others are fatal
this.state = ReconnectState.LOGICALLY_DISCONNECTED
this._clear_reconnect_timer()
this._complete_waiting_connect_ops(error)
else:
logger.info(
"{}({}): State is {}. Connection succeeded".format(
this.name, op.name, this.state
)
)
self.state = ReconnectState.CONNECTED_OR_DISCONNECTED
self._clear_reconnect_timer()
self._complete_waiting_connect_ops()
# successfully connected
this.never_connected = False
this.state = ReconnectState.LOGICALLY_CONNECTED
this._clear_reconnect_timer()
this._complete_waiting_connect_ops()

logger.info("{}: sending new connect op down".format(self.name))
logger.debug("{}: sending new connect op down".format(self.name))
op = pipeline_ops_base.ConnectOperation(callback=on_connect_complete)
self.send_op_down(op)

@pipeline_thread.runs_on_pipeline_thread
def _start_reconnect_timer(self):
def _start_reconnect_timer(self, delay):
"""
Set a timer to reconnect after some period of time
"""
logger.info("{}: State is {}. Starting reconnect timer".format(self.name, self.state))
logger.debug(
"{}: State is {}. Connected={} Starting reconnect timer".format(
self.name, self.state, self.pipeline_root.connected
)
)

self._clear_reconnect_timer()

Expand All @@ -1054,23 +1069,22 @@ def _start_reconnect_timer(self):
@pipeline_thread.invoke_on_pipeline_thread_nowait
def on_reconnect_timer_expired():
this = self_weakref()
this.reconnect_timer = None
if this.state == ReconnectState.WAITING_TO_RECONNECT:
logger.info(
"{}: State is {}. Reconnect timer expired. Sending connect op down".format(
this.name, this.state
)
logger.debug(
"{}: Reconnect timer expired. State is {} Connected is {}.".format(
self.name, self.state, self.pipeline_root.connected
)
this.state = ReconnectState.CONNECTED_OR_DISCONNECTED
)

this.reconnect_timer = None
if (
this.state == ReconnectState.WAITING_TO_RECONNECT
and not self.pipeline_root.connected
):
# if we're waiting to reconnect and not connected, we try again
this.state = ReconnectState.LOGICALLY_CONNECTED
this._send_new_connect_op_down()
else:
logger.info(
"{}: State is {}. Reconnect timer expired. Doing nothing".format(
this.name, this.state
)
)

self.reconnect_timer = threading.Timer(self.reconnect_delay, on_reconnect_timer_expired)
self.reconnect_timer = threading.Timer(delay, on_reconnect_timer_expired)
self.reconnect_timer.start()

@pipeline_thread.runs_on_pipeline_thread
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

logger = logging.getLogger(__name__)

# Maximum amount of time we wait for ConnectOperation or ReauthorizeConnectionOperation to complete
# Maximum amount of time we wait for ConnectOperation to complete
WATCHDOG_INTERVAL = 10


Expand Down Expand Up @@ -187,34 +187,10 @@ def _run_op(self, op):
self._pending_connection_op = None
op.complete(error=e)

elif isinstance(op, pipeline_ops_base.ReauthorizeConnectionOperation):
logger.info("{}({}): reauthorizing".format(self.name, op.name))

# We set _active_connect_op here because reauthorizing the connection is the same as a connect for "active operation" tracking purposes.
self._cancel_pending_connection_op()
self._pending_connection_op = op
self._start_connection_watchdog(op)
try:
self.transport.reauthorize_connection(
password=str(self.pipeline_root.pipeline_configuration.sastoken)
)
except Exception as e:
logger.error("transport.reauthorize_connection raised error")
logger.error(traceback.format_exc())
self._cancel_connection_watchdog(op)
self._pending_connection_op = None
# Send up a DisconenctedEvent. If we ran a ReauthorizeConnectionOperatoin,
# some code must think we're still connected. If we got an exception here,
# we're not conencted, and we need to notify upper layers. (Paho should do this,
# but it only causes a DisconnectedEvent on manual disconnect or if a PINGRESP
# failed, and it's possible to hit this code without either of those things
# happening.
if isinstance(e, transport_exceptions.ConnectionDroppedError):
self.send_event_up(pipeline_events_base.DisconnectedEvent())
op.complete(error=e)

elif isinstance(op, pipeline_ops_base.DisconnectOperation):
logger.info("{}({}): disconnecting".format(self.name, op.name))
elif isinstance(op, pipeline_ops_base.DisconnectOperation) or isinstance(
op, pipeline_ops_base.ReauthorizeConnectionOperation
):
logger.info("{}({}): disconnecting or reauthorizing".format(self.name, op.name))

self._cancel_pending_connection_op()
self._pending_connection_op = op
Expand Down Expand Up @@ -300,11 +276,7 @@ def _on_mqtt_connected(self):
# we do anything else (in case upper stages have any "are we connected" logic.
self.send_event_up(pipeline_events_base.ConnectedEvent())

if isinstance(
self._pending_connection_op, pipeline_ops_base.ConnectOperation
) or isinstance(
self._pending_connection_op, pipeline_ops_base.ReauthorizeConnectionOperation
):
if isinstance(self._pending_connection_op, pipeline_ops_base.ConnectOperation):
logger.debug("completing connect op")
op = self._pending_connection_op
self._cancel_connection_watchdog(op)
Expand All @@ -326,11 +298,7 @@ def _on_mqtt_connection_failure(self, cause):

logger.info("{}: _on_mqtt_connection_failure called: {}".format(self.name, cause))

if isinstance(
self._pending_connection_op, pipeline_ops_base.ConnectOperation
) or isinstance(
self._pending_connection_op, pipeline_ops_base.ReauthorizeConnectionOperation
):
if isinstance(self._pending_connection_op, pipeline_ops_base.ConnectOperation):
logger.debug("{}: failing connect op".format(self.name))
op = self._pending_connection_op
self._cancel_connection_watchdog(op)
Expand Down Expand Up @@ -369,7 +337,9 @@ def _on_mqtt_disconnected(self, cause=None):
self._cancel_connection_watchdog(op)
self._pending_connection_op = None

if isinstance(op, pipeline_ops_base.DisconnectOperation):
if isinstance(op, pipeline_ops_base.DisconnectOperation) or isinstance(
op, pipeline_ops_base.ReauthorizeConnectionOperation
):
# Swallow any errors if we intended to disconnect - even if something went wrong, we
# got to the state we wanted to be in!
if cause:
Expand Down
Loading

0 comments on commit 123c045

Please sign in to comment.