Skip to content

Commit

Permalink
Merge pull request #231 from dhalbert/partial-send
Browse files Browse the repository at this point in the history
handle partial socket send()'s
  • Loading branch information
dhalbert authored Jan 2, 2025
2 parents 76f8c28 + 9be1a4c commit 75f3845
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 24 deletions.
51 changes: 33 additions & 18 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,21 @@ def connect(
raise MMQTTException(exc_msg) from last_exception
raise MMQTTException(exc_msg)

def _send_bytes(
self,
buffer: Union[bytes, bytearray, memoryview],
):
bytes_sent: int = 0
bytes_to_send = len(buffer)
view = memoryview(buffer)
while bytes_sent < bytes_to_send:
try:
bytes_sent += self._sock.send(view[bytes_sent:])
except OSError as exc:
if exc.errno == EAGAIN:
continue
raise

def _connect( # noqa: PLR0912, PLR0915, Too many branches, Too many statements
self,
clean_session: bool = True,
Expand Down Expand Up @@ -529,8 +544,8 @@ def _connect( # noqa: PLR0912, PLR0915, Too many branches, Too many statements
self.logger.debug("Sending CONNECT to broker...")
self.logger.debug(f"Fixed Header: {fixed_header}")
self.logger.debug(f"Variable Header: {var_header}")
self._sock.send(fixed_header)
self._sock.send(var_header)
self._send_bytes(fixed_header)
self._send_bytes(var_header)
# [MQTT-3.1.3-4]
self._send_str(self.client_id)
if self._lw_topic:
Expand Down Expand Up @@ -591,7 +606,7 @@ def disconnect(self) -> None:
self._connected()
self.logger.debug("Sending DISCONNECT packet to broker")
try:
self._sock.send(MQTT_DISCONNECT)
self._send_bytes(MQTT_DISCONNECT)
except (MemoryError, OSError, RuntimeError) as e:
self.logger.warning(f"Unable to send DISCONNECT packet: {e}")
self._close_socket()
Expand All @@ -608,7 +623,7 @@ def ping(self) -> list[int]:
"""
self._connected()
self.logger.debug("Sending PINGREQ")
self._sock.send(MQTT_PINGREQ)
self._send_bytes(MQTT_PINGREQ)
ping_timeout = self.keep_alive
stamp = ticks_ms()

Expand Down Expand Up @@ -683,9 +698,9 @@ def publish( # noqa: PLR0912, Too many branches
qos,
retain,
)
self._sock.send(pub_hdr_fixed)
self._sock.send(pub_hdr_var)
self._sock.send(msg)
self._send_bytes(pub_hdr_fixed)
self._send_bytes(pub_hdr_var)
self._send_bytes(msg)
self._last_msg_sent_timestamp = ticks_ms()
if qos == 0 and self.on_publish is not None:
self.on_publish(self, self.user_data, topic, self._pid)
Expand Down Expand Up @@ -749,12 +764,12 @@ def subscribe( # noqa: PLR0912, PLR0915, Too many branches, Too many statements
packet_length += sum(len(topic.encode("utf-8")) for topic, qos in topics)
self._encode_remaining_length(fixed_header, remaining_length=packet_length)
self.logger.debug(f"Fixed Header: {fixed_header}")
self._sock.send(fixed_header)
self._send_bytes(fixed_header)
self._pid = self._pid + 1 if self._pid < 0xFFFF else 1
packet_id_bytes = self._pid.to_bytes(2, "big")
var_header = packet_id_bytes
self.logger.debug(f"Variable Header: {var_header}")
self._sock.send(var_header)
self._send_bytes(var_header)
# attaching topic and QOS level to the packet
payload = b""
for t, q in topics:
Expand All @@ -764,7 +779,7 @@ def subscribe( # noqa: PLR0912, PLR0915, Too many branches, Too many statements
for t, q in topics:
self.logger.debug(f"SUBSCRIBING to topic {t} with QoS {q}")
self.logger.debug(f"payload: {payload}")
self._sock.send(payload)
self._send_bytes(payload)
stamp = ticks_ms()
self._last_msg_sent_timestamp = stamp
while True:
Expand Down Expand Up @@ -829,19 +844,19 @@ def unsubscribe( # noqa: PLR0912, Too many branches
packet_length += sum(len(topic.encode("utf-8")) for topic in topics)
self._encode_remaining_length(fixed_header, remaining_length=packet_length)
self.logger.debug(f"Fixed Header: {fixed_header}")
self._sock.send(fixed_header)
self._send_bytes(fixed_header)
self._pid = self._pid + 1 if self._pid < 0xFFFF else 1
packet_id_bytes = self._pid.to_bytes(2, "big")
var_header = packet_id_bytes
self.logger.debug(f"Variable Header: {var_header}")
self._sock.send(var_header)
self._send_bytes(var_header)
payload = b""
for t in topics:
topic_size = len(t.encode("utf-8")).to_bytes(2, "big")
payload += topic_size + t.encode()
for t in topics:
self.logger.debug(f"UNSUBSCRIBING from topic {t}")
self._sock.send(payload)
self._send_bytes(payload)
self._last_msg_sent_timestamp = ticks_ms()
self.logger.debug("Waiting for UNSUBACK...")
while True:
Expand Down Expand Up @@ -1028,7 +1043,7 @@ def _wait_for_msg( # noqa: PLR0912, Too many branches
if res[0] & 0x06 == 0x02:
pkt = bytearray(b"\x40\x02\0\0")
struct.pack_into("!H", pkt, 2, pid)
self._sock.send(pkt)
self._send_bytes(pkt)
elif res[0] & 6 == 4:
assert 0

Expand Down Expand Up @@ -1109,11 +1124,11 @@ def _send_str(self, string: str) -> None:
"""
if isinstance(string, str):
self._sock.send(struct.pack("!H", len(string.encode("utf-8"))))
self._sock.send(str.encode(string, "utf-8"))
self._send_bytes(struct.pack("!H", len(string.encode("utf-8"))))
self._send_bytes(str.encode(string, "utf-8"))
else:
self._sock.send(struct.pack("!H", len(string)))
self._sock.send(string)
self._send_bytes(struct.pack("!H", len(string)))
self._send_bytes(string)

@staticmethod
def _valid_topic(topic: str) -> None:
Expand Down
10 changes: 4 additions & 6 deletions tests/test_recv_timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from unittest import TestCase, main
from unittest.mock import Mock

from mocket import Mocket

import adafruit_minimqtt.adafruit_minimqtt as MQTT


Expand All @@ -34,7 +36,7 @@ def test_recv_timeout_vs_keepalive(self) -> None:
)

# Create a mock socket that will accept anything and return nothing.
socket_mock = Mock()
socket_mock = Mocket(b"")
socket_mock.recv_into = Mock(side_effect=side_effect)
mqtt_client._sock = socket_mock

Expand All @@ -43,12 +45,8 @@ def test_recv_timeout_vs_keepalive(self) -> None:
with self.assertRaises(MQTT.MMQTTException):
mqtt_client.ping()

# Verify the mock interactions.
socket_mock.send.assert_called_once()
socket_mock.recv_into.assert_called()

now = time.monotonic()
assert recv_timeout <= (now - start) <= (keep_alive + 0.1)
assert recv_timeout <= (now - start) <= (keep_alive + 0.2)


if __name__ == "__main__":
Expand Down

0 comments on commit 75f3845

Please sign in to comment.