Skip to content

Commit

Permalink
Fix race in stunnel port selection
Browse files Browse the repository at this point in the history
  • Loading branch information
tsmetana authored and Cappuccinuo committed Dec 2, 2022
1 parent e629560 commit 478f009
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 87 deletions.
171 changes: 92 additions & 79 deletions src/mount_efs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ def get_tls_port_range(config):
return lower_bound, upper_bound


def choose_tls_port(config, options):
def choose_tls_port_and_bind_sock(state_file_dir, fs_id, mountpoint, config, options):
if "tlsport" in options:
ports_to_try = [int(options["tlsport"])]
else:
Expand All @@ -954,13 +954,14 @@ def choose_tls_port(config, options):
assert len(tls_ports) == len(ports_to_try)

if "netns" not in options:
tls_port = find_tls_port_in_range(ports_to_try)
tls_port_sock = find_tls_port_in_range(state_file_dir, fs_id, mountpoint, ports_to_try)
else:
with NetNS(nspath=options["netns"]):
tls_port = find_tls_port_in_range(ports_to_try)
tls_port_sock = find_tls_port_in_range(state_file_dir, fs_id, mountpoint, ports_to_try)

if tls_port:
return tls_port
if tls_port_sock:
tls_port = tls_port_sock.getsockname()[1]
return tls_port_sock, tls_port

if "tlsport" in options:
fatal_error(
Expand All @@ -974,14 +975,18 @@ def choose_tls_port(config, options):
)


def find_tls_port_in_range(ports_to_try):
def find_tls_port_in_range(state_file_dir, fs_id, mountpoint, ports_to_try):
sock = socket.socket()
for tls_port in ports_to_try:
mount_filename = get_mount_specific_filename(fs_id, mountpoint, tls_port)
config_file = get_stunnel_config_filename(state_file_dir, mount_filename)
if os.access(config_file, os.R_OK):
logging.info("confifguration for port %s already exists, trying another port", tls_port)
continue
try:
logging.info("binding %s", tls_port)
sock.bind(("localhost", tls_port))
sock.close()
return tls_port
return sock
except socket.error as e:
logging.info(e)
continue
Expand Down Expand Up @@ -1262,9 +1267,7 @@ def write_stunnel_config_file(
)
logging.debug("Writing stunnel configuration:\n%s", stunnel_config)

stunnel_config_file = os.path.join(
state_file_dir, "stunnel-config.%s" % mount_filename
)
stunnel_config_file = get_stunnel_config_filename(state_file_dir, mount_filename)

with open(stunnel_config_file, "w") as f:
f.write(stunnel_config)
Expand Down Expand Up @@ -1464,6 +1467,10 @@ def create_required_directory(config, directory):
raise


def get_stunnel_config_filename(state_file_dir, mount_filename):
return os.path.join(state_file_dir, "stunnel-config.%s" % mount_filename)


@contextmanager
def bootstrap_tls(
config,
Expand All @@ -1475,82 +1482,88 @@ def bootstrap_tls(
state_file_dir=STATE_FILE_DIR,
fallback_ip_address=None,
):
tls_port = choose_tls_port(config, options)
# override the tlsport option so that we can later override the port the NFS client uses to connect to stunnel.
# if the user has specified tlsport=X at the command line this will just re-set tlsport to X.
options["tlsport"] = tls_port

use_iam = "iam" in options
ap_id = options.get("accesspoint")
cert_details = {}
security_credentials = None
client_info = get_client_info(config)
region = get_target_region(config)

if use_iam:
aws_creds_uri = options.get("awscredsuri")
if aws_creds_uri:
kwargs = {"aws_creds_uri": aws_creds_uri}
else:
kwargs = {"awsprofile": get_aws_profile(options, use_iam)}
sock, tls_port = choose_tls_port_and_bind_sock(state_file_dir, fs_id, mountpoint, config, options)
try:
# override the tlsport option so that we can later override the port the NFS client uses to connect to stunnel.
# if the user has specified tlsport=X at the command line this will just re-set tlsport to X.
options["tlsport"] = tls_port

use_iam = "iam" in options
ap_id = options.get("accesspoint")
cert_details = {}
security_credentials = None
client_info = get_client_info(config)
region = get_target_region(config)

if use_iam:
aws_creds_uri = options.get("awscredsuri")
if aws_creds_uri:
kwargs = {"aws_creds_uri": aws_creds_uri}
else:
kwargs = {"awsprofile": get_aws_profile(options, use_iam)}

security_credentials, credentials_source = get_aws_security_credentials(
config, use_iam, region, **kwargs
)
security_credentials, credentials_source = get_aws_security_credentials(
config, use_iam, region, **kwargs
)

if credentials_source:
cert_details["awsCredentialsMethod"] = credentials_source
if credentials_source:
cert_details["awsCredentialsMethod"] = credentials_source

if ap_id:
cert_details["accessPoint"] = ap_id
if ap_id:
cert_details["accessPoint"] = ap_id

# additional symbol appended to avoid naming collisions
cert_details["mountStateDir"] = (
get_mount_specific_filename(fs_id, mountpoint, tls_port) + "+"
)
# common name for certificate signing request is max 64 characters
cert_details["commonName"] = socket.gethostname()[0:64]
region = get_target_region(config)
cert_details["region"] = region
cert_details["certificateCreationTime"] = create_certificate(
config,
cert_details["mountStateDir"],
cert_details["commonName"],
cert_details["region"],
fs_id,
security_credentials,
ap_id,
client_info,
base_path=state_file_dir,
)
cert_details["certificate"] = os.path.join(
state_file_dir, cert_details["mountStateDir"], "certificate.pem"
)
cert_details["privateKey"] = get_private_key_path()
cert_details["fsId"] = fs_id
# additional symbol appended to avoid naming collisions
cert_details["mountStateDir"] = (
get_mount_specific_filename(fs_id, mountpoint, tls_port) + "+"
)
# common name for certificate signing request is max 64 characters
cert_details["commonName"] = socket.gethostname()[0:64]
region = get_target_region(config)
cert_details["region"] = region
cert_details["certificateCreationTime"] = create_certificate(
config,
cert_details["mountStateDir"],
cert_details["commonName"],
cert_details["region"],
fs_id,
security_credentials,
ap_id,
client_info,
base_path=state_file_dir,
)
cert_details["certificate"] = os.path.join(
state_file_dir, cert_details["mountStateDir"], "certificate.pem"
)
cert_details["privateKey"] = get_private_key_path()
cert_details["fsId"] = fs_id

start_watchdog(init_system)
start_watchdog(init_system)

if not os.path.exists(state_file_dir):
create_required_directory(config, state_file_dir)
if not os.path.exists(state_file_dir):
create_required_directory(config, state_file_dir)

verify_level = int(options.get("verify", DEFAULT_STUNNEL_VERIFY_LEVEL))
ocsp_enabled = is_ocsp_enabled(config, options)
verify_level = int(options.get("verify", DEFAULT_STUNNEL_VERIFY_LEVEL))
ocsp_enabled = is_ocsp_enabled(config, options)

stunnel_config_file = write_stunnel_config_file(
config,
state_file_dir,
fs_id,
mountpoint,
tls_port,
dns_name,
verify_level,
ocsp_enabled,
options,
region,
cert_details=cert_details,
fallback_ip_address=fallback_ip_address,
)
stunnel_config_file = write_stunnel_config_file(
config,
state_file_dir,
fs_id,
mountpoint,
tls_port,
dns_name,
verify_level,
ocsp_enabled,
options,
region,
cert_details=cert_details,
fallback_ip_address=fallback_ip_address,
)
except Exception as e:
logging.error("Error while creating the configuration file: %s" % e)
finally:
# close the socket now, so the stunnel process can bind to the port
sock.close()
tunnel_args = [_stunnel_bin(), stunnel_config_file]
if "netns" in options:
tunnel_args = ["nsenter", "--net=" + options["netns"]] + tunnel_args
Expand Down
3 changes: 3 additions & 0 deletions test/mount_efs_test/test_bootstrap_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ def test_bootstrap_tls_non_default_port(mocker, tmpdir):
popen_mock, write_config_mock = setup_mocks(mocker)
mocker.patch("os.rename")
state_file_dir = str(tmpdir)
fake_sock = MagicMock()
fake_sock.getsockname.return_value = ("localhost", 1000)
mocker.patch("socket.socket", return_value=fake_sock)

tls_port = 1000
mocker.patch("mount_efs._stunnel_bin", return_value="/usr/bin/stunnel")
Expand Down
45 changes: 37 additions & 8 deletions test/mount_efs_test/test_choose_tls_port.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# the License.

import socket
import random
from unittest.mock import MagicMock

import pytest
Expand All @@ -20,6 +21,9 @@

DEFAULT_TLS_PORT_RANGE_LOW = 20049
DEFAULT_TLS_PORT_RANGE_HIGH = 20449
FS_ID = "fs-deadbeef"
MOUNT_POINT = "/mnt"
STATE_FILE_DIR = "/tmp"


def _get_config():
Expand All @@ -42,22 +46,27 @@ def _get_config():


def test_choose_tls_port_first_try(mocker):
mocker.patch("socket.socket", return_value=MagicMock())
fake_sock = MagicMock()
tls_port = random.randrange(DEFAULT_TLS_PORT_RANGE_LOW, DEFAULT_TLS_PORT_RANGE_HIGH)
fake_sock.getsockname.return_value = ("localhost", tls_port)
mocker.patch("socket.socket", return_value=fake_sock)
options = {}

tls_port = mount_efs.choose_tls_port(_get_config(), options)
sock, tls_port = mount_efs.choose_tls_port_and_bind_sock(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options)

assert DEFAULT_TLS_PORT_RANGE_LOW <= tls_port <= DEFAULT_TLS_PORT_RANGE_HIGH


def test_choose_tls_port_second_try(mocker):
bad_sock = MagicMock()
bad_sock.bind.side_effect = [socket.error, None]
tls_port = random.randrange(DEFAULT_TLS_PORT_RANGE_LOW, DEFAULT_TLS_PORT_RANGE_HIGH)
bad_sock.getsockname.return_value = ("localhost", tls_port)
options = {}

mocker.patch("socket.socket", return_value=bad_sock)

tls_port = mount_efs.choose_tls_port(_get_config(), options)
sock, tls_port = mount_efs.choose_tls_port_and_bind_sock(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options)

assert DEFAULT_TLS_PORT_RANGE_LOW <= tls_port <= DEFAULT_TLS_PORT_RANGE_HIGH
assert 2 == bad_sock.bind.call_count
Expand All @@ -71,7 +80,7 @@ def test_choose_tls_port_never_succeeds(mocker, capsys):
mocker.patch("socket.socket", return_value=bad_sock)

with pytest.raises(SystemExit) as ex:
mount_efs.choose_tls_port(_get_config(), options)
mount_efs.choose_tls_port_and_bind_sock(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options)

assert 0 != ex.value.code

Expand All @@ -85,10 +94,12 @@ def test_choose_tls_port_never_succeeds(mocker, capsys):


def test_choose_tls_port_option_specified(mocker):
mocker.patch("socket.socket", return_value=MagicMock())
fake_sock = MagicMock()
fake_sock.getsockname.return_value = ("localhost", 1000)
mocker.patch("socket.socket", return_value=fake_sock)
options = {"tlsport": 1000}

tls_port = mount_efs.choose_tls_port(_get_config(), options)
sock, tls_port = mount_efs.choose_tls_port_and_bind_sock(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options)

assert 1000 == tls_port

Expand All @@ -101,7 +112,7 @@ def test_choose_tls_port_option_specified_unavailable(mocker, capsys):
mocker.patch("socket.socket", return_value=bad_sock)

with pytest.raises(SystemExit) as ex:
mount_efs.choose_tls_port(_get_config(), options)
mount_efs.choose_tls_port_and_bind_sock(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options)

assert 0 != ex.value.code

Expand All @@ -117,7 +128,7 @@ def test_choose_tls_port_under_netns(mocker, capsys):
mocker.patch("socket.socket", return_value=MagicMock())
options = {"netns": "/proc/1000/ns/net"}

mount_efs.choose_tls_port(_get_config(), options)
mount_efs.choose_tls_port_and_bind_sock(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options)
utils.assert_called(setns_mock)


Expand All @@ -130,3 +141,21 @@ def test_verify_tls_port(mocker):
result = mount_efs.verify_tlsport_can_be_connected(1000)
assert result is True
assert 2 == sock.connect.call_count

def test_choose_tls_port_already_configured(mocker, capsys):
fake_sock = MagicMock()
tls_port = random.randrange(DEFAULT_TLS_PORT_RANGE_LOW, DEFAULT_TLS_PORT_RANGE_HIGH)
fake_sock.getsockname.return_value = ("localhost", tls_port)
mocker.patch("socket.socket", return_value=fake_sock)
access_mock = mocker.patch("os.access", return_value=True)
options = {}

with pytest.raises(SystemExit) as ex:
mount_efs.choose_tls_port_and_bind_sock(STATE_FILE_DIR, FS_ID, MOUNT_POINT, _get_config(), options)

assert 0 != ex.value.code

out, err = capsys.readouterr()
assert "Failed to locate an available port" in err

utils.assert_called_n_times(access_mock, DEFAULT_TLS_PORT_RANGE_HIGH - DEFAULT_TLS_PORT_RANGE_LOW)

0 comments on commit 478f009

Please sign in to comment.