Skip to content

Commit

Permalink
Fix race in stunnel port selection
Browse files Browse the repository at this point in the history
  • Loading branch information
tsmetana committed May 31, 2022
1 parent b1b91c5 commit 69d41d6
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 20 deletions.
33 changes: 21 additions & 12 deletions src/mount_efs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,7 +929,7 @@ def get_tls_port_range(config):
return lower_bound, upper_bound


def choose_tls_port(config, options):
def choose_tls_port(state_file_dir, fs_id, mountpoint, config, options):
if "tlsport" in options:
ports_to_try = [int(options["tlsport"])]
else:
Expand All @@ -944,13 +944,13 @@ def choose_tls_port(config, options):
assert len(tls_ports) == len(ports_to_try)

if "netns" not in options:
tls_port = find_tls_port_in_range(ports_to_try)
sock = find_tls_port_in_range(state_file_dir, fs_id, mountpoint, ports_to_try)
else:
with NetNS(nspath=options["netns"]):
tls_port = find_tls_port_in_range(ports_to_try)
sock = find_tls_port_in_range(state_file_dir, fs_id, mountpoint, ports_to_try)

if tls_port:
return tls_port
if sock:
return sock

if "tlsport" in options:
fatal_error(
Expand All @@ -964,14 +964,18 @@ def choose_tls_port(config, options):
)


def find_tls_port_in_range(ports_to_try):
def find_tls_port_in_range(state_file_dir, fs_id, mountpoint, ports_to_try):
sock = socket.socket()
for tls_port in ports_to_try:
mount_filename = get_mount_specific_filename(fs_id, mountpoint, tls_port)
config_file = get_stunnel_config_filename(state_file_dir, mount_filename)
if os.access(config_file, os.R_OK):
logging.info("confifguration for port %s already exists, trying another port", tls_port)
continue
try:
logging.info("binding %s", tls_port)
sock.bind(("localhost", tls_port))
sock.close()
return tls_port
return sock
except socket.error as e:
logging.info(e)
continue
Expand Down Expand Up @@ -1219,9 +1223,7 @@ def write_stunnel_config_file(
)
logging.debug("Writing stunnel configuration:\n%s", stunnel_config)

stunnel_config_file = os.path.join(
state_file_dir, "stunnel-config.%s" % mount_filename
)
stunnel_config_file = get_stunnel_config_filename(state_file_dir, mount_filename)

with open(stunnel_config_file, "w") as f:
f.write(stunnel_config)
Expand Down Expand Up @@ -1419,6 +1421,10 @@ def create_required_directory(config, directory):
raise


def get_stunnel_config_filename(state_file_dir, mount_filename):
return os.path.join(state_file_dir, "stunnel-config.%s" % mount_filename)


@contextmanager
def bootstrap_tls(
config,
Expand All @@ -1430,7 +1436,8 @@ def bootstrap_tls(
state_file_dir=STATE_FILE_DIR,
fallback_ip_address=None,
):
tls_port = choose_tls_port(config, options)
sock = choose_tls_port(state_file_dir, fs_id, mountpoint, config, options)
tls_port = sock.getsockname()[1]
# override the tlsport option so that we can later override the port the NFS client uses to connect to stunnel.
# if the user has specified tlsport=X at the command line this will just re-set tlsport to X.
options["tlsport"] = tls_port
Expand Down Expand Up @@ -1506,6 +1513,8 @@ def bootstrap_tls(
cert_details=cert_details,
fallback_ip_address=fallback_ip_address,
)
# close the socket now, so the stunnel process can bind to the port
sock.close()
tunnel_args = [_stunnel_bin(), stunnel_config_file]
if "netns" in options:
tunnel_args = ["nsenter", "--net=" + options["netns"]] + tunnel_args
Expand Down
3 changes: 3 additions & 0 deletions test/mount_efs_test/test_bootstrap_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ def test_bootstrap_tls_non_default_port(mocker, tmpdir):
popen_mock, write_config_mock = setup_mocks(mocker)
mocker.patch("os.rename")
state_file_dir = str(tmpdir)
fake_sock = MagicMock()
fake_sock.getsockname.return_value = ("localhost", 1000)
mocker.patch("socket.socket", return_value=fake_sock)

tls_port = 1000
mocker.patch("mount_efs._stunnel_bin", return_value="/usr/bin/stunnel")
Expand Down
48 changes: 40 additions & 8 deletions test/mount_efs_test/test_choose_tls_port.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# the License.

import socket
import random
from unittest.mock import MagicMock

import pytest
Expand All @@ -20,6 +21,9 @@

DEFAULT_TLS_PORT_RANGE_LOW = 20049
DEFAULT_TLS_PORT_RANGE_HIGH = 20449
FS_ID = "fs-deadbeef"
MOUNT_POINT = "/mnt"
STATE_FILE_DIR = "/tmp"


def _get_config():
Expand All @@ -42,22 +46,29 @@ def _get_config():


def test_choose_tls_port_first_try(mocker):
mocker.patch("socket.socket", return_value=MagicMock())
fake_sock = MagicMock()
tls_port = random.randrange(DEFAULT_TLS_PORT_RANGE_LOW, DEFAULT_TLS_PORT_RANGE_HIGH)
fake_sock.getsockname.return_value = ("localhost", tls_port)
mocker.patch("socket.socket", return_value=fake_sock)
options = {}

tls_port = mount_efs.choose_tls_port(_get_config(), options)
sock = mount_efs.choose_tls_port(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options)
tls_port = sock.getsockname()[1]

assert DEFAULT_TLS_PORT_RANGE_LOW <= tls_port <= DEFAULT_TLS_PORT_RANGE_HIGH


def test_choose_tls_port_second_try(mocker):
bad_sock = MagicMock()
bad_sock.bind.side_effect = [socket.error, None]
tls_port = random.randrange(DEFAULT_TLS_PORT_RANGE_LOW, DEFAULT_TLS_PORT_RANGE_HIGH)
bad_sock.getsockname.return_value = ("localhost", tls_port)
options = {}

mocker.patch("socket.socket", return_value=bad_sock)

tls_port = mount_efs.choose_tls_port(_get_config(), options)
sock = mount_efs.choose_tls_port(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options)
tls_port = sock.getsockname()[1]

assert DEFAULT_TLS_PORT_RANGE_LOW <= tls_port <= DEFAULT_TLS_PORT_RANGE_HIGH
assert 2 == bad_sock.bind.call_count
Expand All @@ -71,7 +82,7 @@ def test_choose_tls_port_never_succeeds(mocker, capsys):
mocker.patch("socket.socket", return_value=bad_sock)

with pytest.raises(SystemExit) as ex:
mount_efs.choose_tls_port(_get_config(), options)
mount_efs.choose_tls_port(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options)

assert 0 != ex.value.code

Expand All @@ -85,10 +96,13 @@ def test_choose_tls_port_never_succeeds(mocker, capsys):


def test_choose_tls_port_option_specified(mocker):
mocker.patch("socket.socket", return_value=MagicMock())
fake_sock = MagicMock()
fake_sock.getsockname.return_value = ("localhost", 1000)
mocker.patch("socket.socket", return_value=fake_sock)
options = {"tlsport": 1000}

tls_port = mount_efs.choose_tls_port(_get_config(), options)
sock = mount_efs.choose_tls_port(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options)
tls_port = sock.getsockname()[1]

assert 1000 == tls_port

Expand All @@ -101,7 +115,7 @@ def test_choose_tls_port_option_specified_unavailable(mocker, capsys):
mocker.patch("socket.socket", return_value=bad_sock)

with pytest.raises(SystemExit) as ex:
mount_efs.choose_tls_port(_get_config(), options)
mount_efs.choose_tls_port(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options)

assert 0 != ex.value.code

Expand All @@ -117,7 +131,7 @@ def test_choose_tls_port_under_netns(mocker, capsys):
mocker.patch("socket.socket", return_value=MagicMock())
options = {"netns": "/proc/1000/ns/net"}

mount_efs.choose_tls_port(_get_config(), options)
mount_efs.choose_tls_port(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options)
utils.assert_called(setns_mock)


Expand All @@ -130,3 +144,21 @@ def test_verify_tls_port(mocker):
result = mount_efs.verify_tlsport_can_be_connected(1000)
assert result is True
assert 2 == sock.connect.call_count

def test_choose_tls_port_already_configured(mocker, capsys):
fake_sock = MagicMock()
tls_port = random.randrange(DEFAULT_TLS_PORT_RANGE_LOW, DEFAULT_TLS_PORT_RANGE_HIGH)
fake_sock.getsockname.return_value = ("localhost", tls_port)
mocker.patch("socket.socket", return_value=fake_sock)
access_mock = mocker.patch("os.access", return_value=True)
options = {}

with pytest.raises(SystemExit) as ex:
mount_efs.choose_tls_port(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options)

assert 0 != ex.value.code

out, err = capsys.readouterr()
assert "Failed to locate an available port" in err

utils.assert_called_n_times(access_mock, DEFAULT_TLS_PORT_RANGE_HIGH - DEFAULT_TLS_PORT_RANGE_LOW)

0 comments on commit 69d41d6

Please sign in to comment.