diff --git a/.circleci/config.yml b/.circleci/config.yml index c59a7922..1d7fda93 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -37,10 +37,20 @@ commands: name: Repo update command: | apt-get update + - run: + name: Install curl + command: | + apt-get -y install curl + - run: + name: Install latest Rust + command: | + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y + . "$HOME/.cargo/env" - run: name: Install dependencies command: | - apt-get -y install binutils git + DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends tzdata + apt-get -y install binutils git rustc cargo pkg-config libssl-dev - run: name: Add local build repo as safe git directory command: | @@ -49,6 +59,9 @@ commands: - run: name: Build DEB command: | + . "$HOME/.cargo/env" + rustc --version + cargo --version ./build-deb.sh - run: name: Install package @@ -64,7 +77,7 @@ commands: - run: name: Install dependencies command: | - yum -y install rpm-build make systemd + yum -y install rpm-build make systemd rust cargo openssl-devel - run: name: Build RPM command: | @@ -81,6 +94,40 @@ commands: name: Check changelog command: | rpm -q --changelog amazon-efs-utils + build-rpm-rustup: + steps: + - run: + name: Install dependencies + command: | + yum install -y curl + - run: + name: Install latest Rust + command: | + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y + - checkout + - run: + name: Install dependencies + command: | + yum -y install rpm-build make systemd rust cargo openssl-devel + - run: + name: Build RPM + command: | + . "$HOME/.cargo/env" + rustc --version + make rpm + - run: + name: Install package + command: | + yum -y install build/amazon-efs-utils*rpm + - run: + name: Check installed successfully + command: | + mount.efs --version + - run: + name: Check changelog + command: | + rpm -q --changelog amazon-efs-utils + build-suse-rpm: steps: - checkout @@ -88,14 +135,24 @@ commands: name: Refresh source command: | zypper refresh + - run: + name: Install curl + command: | + zypper install -y curl + - run: + name: Install latest Rust + command: | + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y - run: name: Install dependencies command: | zypper install -y --force-resolution rpm-build - zypper install -y make systemd + zypper install -y make systemd rust cargo openssl-devel - run: name: Build RPM command: | + . "$HOME/.cargo/env" + rustc --version make rpm - run: name: Install package @@ -116,14 +173,6 @@ commands: command: | sed -i 's/mirrorlist/#mirrorlist/g' /etc/yum.repos.d/CentOS-* sed -i 's|#baseurl=http://mirror.centos.org|baseurl=http://vault.centos.org|g' /etc/yum.repos.d/CentOS-* - build-debian-eol-repo: - steps: - - run: - name: change repo url to archive.debian.org and remove updates repo for EOL versions - command: | - sed -i 's/deb.debian.org/archive.debian.org/g' /etc/apt/sources.list - sed -i 's/security.debian.org/archive.debian.org/g' /etc/apt/sources.list - sed -i '/stretch-updates/d' /etc/apt/sources.list jobs: test: parameters: @@ -152,7 +201,7 @@ jobs: image: << parameters.image >> steps: - build-rpm - build-suse-rpm-package: + build-rpm-package-rustup: parameters: image: type: string @@ -160,8 +209,8 @@ jobs: name: linux image: << parameters.image >> steps: - - build-suse-rpm - build-centos-rpm-package: + - build-rpm-rustup + build-suse-rpm-package: parameters: image: type: string @@ -169,9 +218,8 @@ jobs: name: linux image: << parameters.image >> steps: - - build-centos-repo - - build-rpm - build-debian-eol-rpm-package: + - build-suse-rpm + build-centos-rpm-package: parameters: image: type: string @@ -179,8 +227,8 @@ jobs: name: linux image: << parameters.image >> steps: - - build-debian-eol-repo - - build-deb + - build-centos-repo + - build-rpm-rustup workflows: workflow: jobs: @@ -217,21 +265,12 @@ workflows: - build-deb-package: name: ubuntu22 image: ubuntu:22.04 - - build-debian-eol-rpm-package: - name: debian9 - image: debian:stretch - - build-deb-package: - name: debian10 - image: debian:buster - build-deb-package: name: debian11 image: debian:bullseye - build-centos-rpm-package: name: centos-latest image: centos:latest - - build-rpm-package: - name: centos7 - image: centos:centos7 - build-centos-rpm-package: name: centos8 image: centos:centos8 @@ -244,31 +283,25 @@ workflows: - build-rpm-package: name: amazon-linux-2 image: amazonlinux:2 - - build-rpm-package: - name: amazon-linux - image: amazonlinux:1 - build-rpm-package: name: fedora-latest image: fedora:latest - - build-rpm-package: - name: fedora28 - image: fedora:28 - - build-rpm-package: + - build-rpm-package-rustup: name: fedora29 image: fedora:29 - - build-rpm-package: + - build-rpm-package-rustup: name: fedora30 image: fedora:30 - - build-rpm-package: + - build-rpm-package-rustup: name: fedora31 image: fedora:31 - - build-rpm-package: + - build-rpm-package-rustup: name: fedora32 image: fedora:32 - - build-rpm-package: + - build-rpm-package-rustup: name: fedora33 image: fedora:33 - - build-rpm-package: + - build-rpm-package-rustup: name: fedora34 image: fedora:34 - build-rpm-package: diff --git a/Makefile b/Makefile index fd48bfc1..3cc4e47f 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,7 @@ PACKAGE_NAME = amazon-efs-utils SOURCE_TARBALL = $(PACKAGE_NAME).tar.gz SPECFILE = $(PACKAGE_NAME).spec BUILD_DIR = build/rpmbuild +PROXY_VERSION = 2.0.0 export PYTHONPATH := $(shell pwd)/src .PHONY: clean @@ -31,6 +32,7 @@ tarball: clean mkdir -p $(PACKAGE_NAME)/src cp -rp src/mount_efs $(PACKAGE_NAME)/src cp -rp src/watchdog $(PACKAGE_NAME)/src + cp -rp src/proxy $(PACKAGE_NAME)/src mkdir -p ${PACKAGE_NAME}/man cp -rp man/mount.efs.8 ${PACKAGE_NAME}/man @@ -45,7 +47,8 @@ rpm-only: mkdir -p $(BUILD_DIR)/{SPECS,COORD_SOURCES,DATA_SOURCES,BUILD,RPMS,SOURCES,SRPMS} cp $(SPECFILE) $(BUILD_DIR)/SPECS cp $(SOURCE_TARBALL) $(BUILD_DIR)/SOURCES - rpmbuild -ba --define "_topdir `pwd`/$(BUILD_DIR)" $(BUILD_DIR)/SPECS/$(SPECFILE) + cp config.toml $(BUILD_DIR)/SOURCES + rpmbuild -ba --define "_topdir `pwd`/$(BUILD_DIR)" --define "include_vendor_tarball false" $(BUILD_DIR)/SPECS/$(SPECFILE) cp $(BUILD_DIR)/RPMS/*/*rpm build .PHONY: rpm diff --git a/README.md b/README.md index 88691852..dac59fae 100644 --- a/README.md +++ b/README.md @@ -8,21 +8,17 @@ The `efs-utils` package has been verified against the following Linux distributi | Distribution | Package Type | `init` System | |----------------------| ----- | --------- | -| Amazon Linux 2017.09 | `rpm` | `upstart` | | Amazon Linux 2 | `rpm` | `systemd` | | Amazon Linux 2023 | `rpm` | `systemd` | -| CentOS 7 | `rpm` | `systemd` | | CentOS 8 | `rpm` | `systemd` | | RHEL 7 | `rpm` | `systemd` | | RHEL 8 | `rpm` | `systemd` | | RHEL 9 | `rpm` | `systemd` | -| Fedora 28 | `rpm` | `systemd` | | Fedora 29 | `rpm` | `systemd` | | Fedora 30 | `rpm` | `systemd` | | Fedora 31 | `rpm` | `systemd` | | Fedora 32 | `rpm` | `systemd` | -| Debian 9 | `deb` | `systemd` | -| Debian 10 | `deb` | `systemd` | +| Debian 11 | `deb` | `systemd` | | Ubuntu 16.04 | `deb` | `systemd` | | Ubuntu 18.04 | `deb` | `systemd` | | Ubuntu 20.04 | `deb` | `systemd` | @@ -55,6 +51,7 @@ The `efs-utils` package has been verified against the following MacOS distributi - [MacOS](#macos) - [amazon-efs-mount-watchdog](#amazon-efs-mount-watchdog) - [Troubleshooting](#troubleshooting) + - [Upgrading to efs-utils v2.0.0](#upgrading-from-efs-utils-v1-to-v2) - [Upgrading stunnel for RHEL/CentOS](#upgrading-stunnel-for-rhelcentos) - [Upgrading stunnel for SLES12](#upgrading-stunnel-for-sles12) - [Upgrading stunnel for MacOS](#upgrading-stunnel-for-macos) @@ -81,9 +78,11 @@ The `efs-utils` package has been verified against the following MacOS distributi ## Prerequisites * `nfs-utils` (RHEL/CentOS/Amazon Linux/Fedora) or `nfs-common` (Debian/Ubuntu) -* OpenSSL 1.0.2+ +* OpenSSL-devel 1.0.2+ * Python 3.4+ * `stunnel` 4.56+ +- `rust` 1.68+ +- `cargo` ## Optional @@ -93,7 +92,7 @@ The `efs-utils` package has been verified against the following MacOS distributi ### On Amazon Linux distributions -For those using Amazon Linux or Amazon Linux 2, the easiest way to install `efs-utils` is from Amazon's repositories: +For those using Amazon Linux, the easiest way to install `efs-utils` is from Amazon's repositories: ```bash $ sudo yum -y install amazon-efs-utils @@ -121,7 +120,7 @@ Other distributions require building the package from source and installing it. If the distribution is not OpenSUSE or SLES ```bash -$ sudo yum -y install git rpm-build make +$ sudo yum -y install git rpm-build make rust cargo openssl-devel $ git clone /~https://github.com/aws/efs-utils $ cd efs-utils $ make rpm @@ -132,7 +131,7 @@ Otherwise ```bash $ sudo zypper refresh -$ sudo zypper install -y git rpm-build make +$ sudo zypper install -y git rpm-build make rust cargo openssl-devel $ git clone /~https://github.com/aws/efs-utils $ cd efs-utils $ make rpm @@ -152,13 +151,20 @@ sudo zypper refresh ```bash $ sudo apt-get update -$ sudo apt-get -y install git binutils +$ sudo apt-get -y install git binutils rustc cargo pkg-config libssl-dev $ git clone /~https://github.com/aws/efs-utils $ cd efs-utils $ ./build-deb.sh $ sudo apt-get -y install ./build/amazon-efs-utils*deb ``` +If your Debian distribution doesn't provide a rust or cargo package, or your distribution provides versions +that are older than 1.68, then you can install rust and cargo through rustup: +```bash +$ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh +. "$HOME/.cargo/env" +``` + ### On MacOS Big Sur, macOS Monterey, macOS Sonoma and macOS Ventura distribution For EC2 Mac instances running macOS Big Sur, macOS Monterey, macOS Sonoma and macOS Ventura, you can install amazon-efs-utils from the @@ -194,8 +200,10 @@ $ make test ## Usage ### mount.efs +`efs-utils` includes a mount helper utility, `mount.efs`, that simplifies and improves the performance of EFS file system mounts. -`efs-utils` includes a mount helper utility to simplify mounting and using EFS file systems. +`mount.efs` launches a proxy process that forwards NFS traffic from the kernel's NFS client to EFS. +This proxy is responsible for TLS encryption, and for providing improved throughput performance. To mount with the recommended default options, simply run: @@ -318,6 +326,16 @@ You can also enable stunnel debug logs with Make sure to perform the failed mount again after running the prior commands before pulling the logs. +## Upgrading from efs-utils v1 to v2 +Efs-utils v2.0.0 replaces stunnel, which provides TLS encryptions for mounts, with efs-proxy, a component built in-house at AWS. +Efs-proxy lays the foundation for upcoming feature launches at EFS. + +To utilize the improved performance benefits of efs-proxy, you must re-mount any existing mounts. + +Efs-proxy is not compatible with OCSP or Mac clients. In these cases, efs-utils will automatically revert back to using stunnel. + +If you are building efs-utils v2.0.0 from source, then you need Rust and Cargo >= 1.68. + ## Upgrading stunnel for RHEL/CentOS By default, when using the EFS mount helper with TLS, it enforces certificate hostname checking. The EFS mount helper uses the `stunnel` program for its TLS functionality. Please note that some versions of Linux do not include a version of `stunnel` that supports TLS features by default. When using such a Linux version, mounting an EFS file system using TLS will fail. diff --git a/amazon-efs-utils.spec b/amazon-efs-utils.spec index ff6eabdd..658ad89b 100644 --- a/amazon-efs-utils.spec +++ b/amazon-efs-utils.spec @@ -6,6 +6,8 @@ # the License. # +%bcond_without check + %if 0%{?amzn1} %global python_requires python36 %else @@ -34,8 +36,13 @@ %global efs_bindir /sbin %endif +%global proxy_name efs-proxy +%global proxy_version 2.0.0 + +%{?!include_vendor_tarball:%define include_vendor_tarball true} + Name : amazon-efs-utils -Version : 1.36.0 +Version : 2.0.0 Release : 1%{platform} Summary : This package provides utilities for simplifying the use of EFS file systems @@ -43,8 +50,7 @@ Group : Amazon/Tools License : MIT URL : https://aws.amazon.com/efs - -BuildArch : noarch +BuildArchitectures: x86_64 aarch64 Requires : nfs-utils %if 0%{?amzn2} @@ -67,13 +73,32 @@ Requires(preun) : /sbin/service /sbin/chkconfig Requires(postun) : /sbin/service %endif -Source : %{name}.tar.gz +BuildRequires : cargo rust +BuildRequires: openssl-devel + +Source0 : %{name}.tar.gz +%if "%{include_vendor_tarball}" == "true" +Source1 : %{proxy_name}-%{proxy_version}-vendor.tar.xz +Source2 : config.toml +%endif %description This package provides utilities for simplifying the use of EFS file systems +%global debug_package %{nil} + %prep %setup -n %{name} +mkdir -p %{_builddir}/%{name}/src/proxy/.cargo +%if "%{include_vendor_tarball}" == "true" +cp %{SOURCE2} %{_builddir}/%{name}/src/proxy/.cargo/ +tar xf %{SOURCE1} +mv vendor %{_builddir}/%{name}/src/proxy/ +%endif + +%build +cd %{_builddir}/%{name}/src/proxy +cargo build --release --manifest-path %{_builddir}/%{name}/src/proxy/Cargo.toml %install mkdir -p %{buildroot}%{_sysconfdir}/amazon/efs @@ -95,6 +120,7 @@ install -p -m 444 %{_builddir}/%{name}/dist/efs-utils.crt %{buildroot}%{_sysconf install -p -m 755 %{_builddir}/%{name}/src/mount_efs/__init__.py %{buildroot}%{efs_bindir}/mount.efs install -p -m 755 %{_builddir}/%{name}/src/watchdog/__init__.py %{buildroot}%{_bindir}/amazon-efs-mount-watchdog install -p -m 644 %{_builddir}/%{name}/man/mount.efs.8 %{buildroot}%{_mandir}/man8 +install -p -m 755 %{_builddir}/%{name}/src/proxy/target/release/efs-proxy %{buildroot}%{efs_bindir}/efs-proxy %files %defattr(-,root,root,-) @@ -105,6 +131,7 @@ install -p -m 644 %{_builddir}/%{name}/man/mount.efs.8 %{buildroot}%{_mandir}/ma %endif %{_sysconfdir}/amazon/efs/efs-utils.crt %{efs_bindir}/mount.efs +%{efs_bindir}/efs-proxy %{_bindir}/amazon-efs-mount-watchdog /var/log/amazon %{_mandir}/man8/mount.efs.8.gz @@ -138,6 +165,9 @@ fi %clean %changelog +* Mon Apr 08 2024 Ryan Stankiewicz - 2.0.0 +- Replace stunnel, which provides TLS encryptions for mounts, with efs-proxy, a component built in-house at AWS. Efs-proxy lays the foundation for upcoming feature launches at EFS. + * Mon Mar 18 2024 Sean Zatz - 1.36.0 - Support new mount option: crossaccount, conduct cross account mounts via ip address. Use client AZ-ID to choose mount target. diff --git a/build-deb.sh b/build-deb.sh index f2e499d6..d97ab37a 100755 --- a/build-deb.sh +++ b/build-deb.sh @@ -11,7 +11,7 @@ set -ex BASE_DIR=$(pwd) BUILD_ROOT=${BASE_DIR}/build/debbuild -VERSION=1.36.0 +VERSION=2.0.0 RELEASE=1 DEB_SYSTEM_RELEASE_PATH=/etc/os-release @@ -28,12 +28,18 @@ mkdir -p ${BUILD_ROOT}/usr/bin mkdir -p ${BUILD_ROOT}/var/log/amazon/efs mkdir -p ${BUILD_ROOT}/usr/share/man/man8 +echo 'Building efs-proxy' +cd src/proxy +cargo build --release --manifest-path ${BASE_DIR}/src/proxy/Cargo.toml +cd ${BASE_DIR} + echo 'Copying application files' install -p -m 644 dist/amazon-efs-mount-watchdog.conf ${BUILD_ROOT}/etc/init install -p -m 644 dist/amazon-efs-mount-watchdog.service ${BUILD_ROOT}/etc/systemd/system install -p -m 444 dist/efs-utils.crt ${BUILD_ROOT}/etc/amazon/efs install -p -m 644 dist/efs-utils.conf ${BUILD_ROOT}/etc/amazon/efs install -p -m 755 src/mount_efs/__init__.py ${BUILD_ROOT}/sbin/mount.efs +install -p -m 755 src/proxy/target/release/efs-proxy ${BUILD_ROOT}/usr/bin/efs-proxy install -p -m 755 src/watchdog/__init__.py ${BUILD_ROOT}/usr/bin/amazon-efs-mount-watchdog echo 'Copying install scripts' diff --git a/config.ini b/config.ini index 49ba7066..6944864d 100644 --- a/config.ini +++ b/config.ini @@ -7,5 +7,5 @@ # [global] -version=1.36.0 +version=2.0.0 release=1 diff --git a/config.toml b/config.toml new file mode 100644 index 00000000..ba8862ec --- /dev/null +++ b/config.toml @@ -0,0 +1,12 @@ +[source] + +# Under the `source` table are a number of other tables whose keys are a +# name for the relevant source. For example this section defines a new +# source, called `my-vendor-source`, which comes from a directory +# located at `vendor` relative to the directory containing this `.cargo/config.toml` +# file +[source.my-vendor-source] +directory = "vendor" + +[source.crates-io] +replace-with = "my-vendor-source" \ No newline at end of file diff --git a/dist/amazon-efs-utils.control b/dist/amazon-efs-utils.control index d40419ca..d7d71c79 100644 --- a/dist/amazon-efs-utils.control +++ b/dist/amazon-efs-utils.control @@ -1,6 +1,6 @@ Package: amazon-efs-utils Architecture: all -Version: 1.36.0 +Version: 2.0.0 Section: utils Depends: python3, nfs-common, stunnel4 (>= 4.56), openssl (>= 1.0.2), util-linux Priority: optional diff --git a/dist/efs-utils.crt b/dist/efs-utils.crt index fdcc5585..11caee70 100644 --- a/dist/efs-utils.crt +++ b/dist/efs-utils.crt @@ -1,12 +1,3 @@ -# -# Copyright 2017-2018 Amazon.com, Inc. and its affiliates. All Rights Reserved. -# -# Licensed under the MIT License. See the LICENSE accompanying this file -# for the specific language governing permissions and limitations under -# the License. -# - -# Amazon Root CA 1 -----BEGIN CERTIFICATE----- MIIDQTCCAimgAwIBAgITBmyfz5m/jAo54vB4ikPmljZbyjANBgkqhkiG9w0BAQsF ADA5MQswCQYDVQQGEwJVUzEPMA0GA1UEChMGQW1hem9uMRkwFwYDVQQDExBBbWF6 @@ -28,7 +19,6 @@ o/ufQJVtMVT8QtPHRh8jrdkPSHCa2XV4cdFyQzR1bldZwgJcJmApzyMZFo6IQ6XU rqXRfboQnoZsG4q5WTP468SQvvG5 -----END CERTIFICATE----- -# Amazon Root CA 2 -----BEGIN CERTIFICATE----- MIIFQTCCAymgAwIBAgITBmyf0pY1hp8KD+WGePhbJruKNzANBgkqhkiG9w0BAQwF ADA5MQswCQYDVQQGEwJVUzEPMA0GA1UEChMGQW1hem9uMRkwFwYDVQQDExBBbWF6 @@ -61,7 +51,6 @@ n749sSmvZ6ES8lgQGVMDMBu4Gon2nL2XA46jCfMdiyHxtN/kHNGfZQIG6lzWE7OE 4PsJYGw= -----END CERTIFICATE----- -# Amazon Root CA 3 -----BEGIN CERTIFICATE----- MIIBtjCCAVugAwIBAgITBmyf1XSXNmY/Owua2eiedgPySjAKBggqhkjOPQQDAjA5 MQswCQYDVQQGEwJVUzEPMA0GA1UEChMGQW1hem9uMRkwFwYDVQQDExBBbWF6b24g @@ -75,7 +64,6 @@ BqWTrBqYaGFy+uGh0PsceGCmQ5nFuMQCIQCcAu/xlJyzlvnrxir4tiz+OpAUFteM YyRIHN8wfdVoOw== -----END CERTIFICATE----- -# Amazon Root CA 4 -----BEGIN CERTIFICATE----- MIIB8jCCAXigAwIBAgITBmyf18G7EEwpQ+Vxe3ssyBrBDjAKBggqhkjOPQQDAzA5 MQswCQYDVQQGEwJVUzEPMA0GA1UEChMGQW1hem9uMRkwFwYDVQQDExBBbWF6b24g diff --git a/man/mount.efs.8 b/man/mount.efs.8 index 3a18f30a..f962fd97 100644 --- a/man/mount.efs.8 +++ b/man/mount.efs.8 @@ -7,11 +7,14 @@ .SH "DESCRIPTION" .sp \fBmount\&.efs\fR is part of the \fBamazon\-efs\-utils\fR \ -package, which simplifies using EFS file systems\&. +package. It improves mount performance and simplifies using EFS file systems\&. .sp \fBmount\&.efs\fR is meant to be used through the \ \fBmount\fR(8) command for mounting EFS file systems\&. .sp +\fBmount\&.efs\fR launches a proxy process that forwards NFS traffic from the kernel's NFS client to EFS. \ +This proxy is responsible for TLS encryption, and for providing improved throughput performance. +.sp \fIfs-id-or-dns-name\fR has to be of one of the following \ two forms: .P @@ -77,8 +80,9 @@ this option is by default passed and the EFS file system is mounted over TLS\&. Mounts the EFS file system without TLS, applies for Mac distributions only\&. .TP \fBtlsport=\fR\fIn\fR -Configure the TLS relay to listen on the specified port\&. By default, the \ -tlsport is choosing randomly from port range defined in the config file located \ +Configures the proxy process to listen for connections from the NFS client on the specified port\&. This is applicable to both non-tls and tls mounts. + By default, the \ +tlsport is chosen randomly from port range defined in the config file located \ at \fI/etc/amazon/efs/efs\-utils\&.conf\&\fR. .TP \fBverify=\fR\fIn\fR @@ -88,7 +92,9 @@ more information, see \fBstunnel(8)\fR\&. \fBocsp / noocsp\fR Selects whether to perform OCSP validation on TLS certificates\&, \ overriding /etc/amazon/efs/efs-utils.conf. By default OCSP is disabled. \ -For more information, see \fBstunnel(8)\fR\&. +For more information, see \fBstunnel(8)\fR\&. \ +The ocsp mount option is incompatible with the efs-proxy process, and will revert efs-utils \ +to the legacy "stunnel" mode, which does not support improved per-client throughput performance. .TP \fBiam\fR Use the system's IAM identity to authenticate with EFS. The mount helper will try \ @@ -132,6 +138,12 @@ Use the port 2049 to bypass portmapper daemon on EC2 Mac instances running macOS .TP \fBmounttargetip\fR Mount the EFS file system to the specified mount target ip address\&. +.TP +\fBstunnel\fR +Forward NFS traffic from the local NFS client to EFS using stunnel instead of efs-proxy. +This will enable compatibility with the ocsp mount option, but will not +deliver the increased throughput performance provided by efs-proxy. \ +This option is enabled by default for Mac clients. .if n \{\ .RE .\} diff --git a/src/mount_efs/__init__.py b/src/mount_efs/__init__.py index b6a91462..c8b7566c 100755 --- a/src/mount_efs/__init__.py +++ b/src/mount_efs/__init__.py @@ -85,7 +85,7 @@ BOTOCORE_PRESENT = False -VERSION = "1.36.0" +VERSION = "2.0.0" SERVICE = "elasticfilesystem" AMAZON_LINUX_2_RELEASE_ID = "Amazon Linux release 2 (Karoo)" @@ -222,9 +222,13 @@ DEFAULT_STUNNEL_VERIFY_LEVEL = 2 DEFAULT_STUNNEL_CAFILE = "/etc/amazon/efs/efs-utils.crt" +LEGACY_STUNNEL_MOUNT_OPTION = "stunnel" + NOT_BEFORE_MINS = 15 NOT_AFTER_HOURS = 3 +EFS_PROXY_TLS_OPTION = "--tls" + EFS_ONLY_OPTIONS = [ "accesspoint", "awscredsuri", @@ -244,6 +248,7 @@ "jwtpath", "fsap", "crossaccount", + LEGACY_STUNNEL_MOUNT_OPTION, ] UNSUPPORTED_OPTIONS = ["capath"] @@ -1039,6 +1044,11 @@ def get_resp_obj(request_resp, url, unsuccessful_resp): def parse_options(options): + """ + Parses a comma delineated string of key=value options (e.g. 'opt1,opt2=val'). + Returns a dictionary of key,value pairs, where value = None if + it was not provided. + """ opts = {} for o in options.split(","): if "=" in o: @@ -1172,7 +1182,8 @@ def serialize_stunnel_config(config, header=None): return lines -def add_stunnel_ca_options(efs_config, config, options, region): +# These options are used by both stunnel and efs-proxy for TLS mounts +def add_tunnel_ca_options(efs_config, config, options, region): if "cafile" in options: stunnel_cafile = options["cafile"] else: @@ -1257,6 +1268,11 @@ def _stunnel_bin(): return find_command_path("stunnel", installation_message) +def _efs_proxy_bin(): + error_message = "The efs-proxy binary is packaged with efs-utils. It was deleted or not installed correctly." + return find_command_path("efs-proxy", error_message) + + def find_command_path(command, install_method): # If not running on macOS, use linux paths if not check_if_platform_is_mac(): @@ -1314,6 +1330,7 @@ def write_stunnel_config_file( log_dir=LOG_DIR, cert_details=None, fallback_ip_address=None, + efs_proxy_enabled=True, ): """ Serializes stunnel configuration to a file. Unfortunately this does not conform to Python's config file format, so we have to @@ -1326,12 +1343,13 @@ def write_stunnel_config_file( system_release_version = get_system_release_version() global_config = dict(STUNNEL_GLOBAL_CONFIG) - if is_stunnel_option_supported( + if not efs_proxy_enabled and is_stunnel_option_supported( stunnel_options, b"foreground", b"quiet", emit_warning_log=False ): # Do not log to stderr of subprocess in addition to the destinations specified with syslog and output. # Only support in stunnel version 5.25+. global_config["foreground"] = "quiet" + if any( release in system_release_version for release in SKIP_NO_SO_BINDTODEVICE_RELEASES @@ -1350,12 +1368,17 @@ def write_stunnel_config_file( CONFIG_SECTION, "stunnel_logs_file" ).replace("{fs_id}", fs_id) else: + proxy_log_file = ( + "%s.efs-proxy.log" if efs_proxy_enabled else "%s.stunnel.log" + ) global_config["output"] = os.path.join( - log_dir, "%s.stunnel.log" % mount_filename + log_dir, proxy_log_file % mount_filename ) + global_config["pid"] = os.path.join( state_file_dir, mount_filename + "+", "stunnel.pid" ) + if get_fips_config(config): global_config["fips"] = "yes" @@ -1367,9 +1390,11 @@ def write_stunnel_config_file( else: efs_config["connect"] = efs_config["connect"] % dns_name - efs_config["verify"] = verify_level - if verify_level > 0: - add_stunnel_ca_options(efs_config, config, options, region) + # Verify level is only valid for tls mounts + if (verify_level is not None) and tls_enabled(options): + efs_config["verify"] = verify_level + if verify_level > 0: + add_tunnel_ca_options(efs_config, config, options, region) if cert_details: efs_config["cert"] = cert_details["certificate"] @@ -1381,27 +1406,30 @@ def write_stunnel_config_file( % (CONFIG_FILE, "https://docs.aws.amazon.com/console/efs/troubleshooting-tls") ) - if get_boolean_config_item_value( - config, CONFIG_SECTION, "stunnel_check_cert_hostname", default_value=True - ): - if is_stunnel_option_supported(stunnel_options, b"checkHost"): - # Stunnel checkHost option checks if the specified DNS host name or wildcard matches any of the provider in peer - # certificate's CN fields, after introducing the AZ field in dns name, the host name in the stunnel config file - # is not valid, remove the az info there - efs_config["checkHost"] = dns_name[dns_name.index(fs_id) :] - else: - fatal_error(tls_controls_message % "stunnel_check_cert_hostname") + if tls_enabled(options): + # These config options are not applicable to non-tls mounts with efs-proxy + if get_boolean_config_item_value( + config, CONFIG_SECTION, "stunnel_check_cert_hostname", default_value=True + ): + if (not efs_proxy_enabled) and ( + not is_stunnel_option_supported(stunnel_options, b"checkHost") + ): + fatal_error(tls_controls_message % "stunnel_check_cert_hostname") + else: + efs_config["checkHost"] = dns_name[dns_name.index(fs_id) :] - # Only use the config setting if the override is not set - if ocsp_enabled: - if is_stunnel_option_supported(stunnel_options, b"OCSPaia"): - efs_config["OCSPaia"] = "yes" - else: - fatal_error(tls_controls_message % "stunnel_check_cert_validity") + # Only use the config setting if the override is not set + if not efs_proxy_enabled and ocsp_enabled: + if is_stunnel_option_supported(stunnel_options, b"OCSPaia"): + efs_config["OCSPaia"] = "yes" + else: + fatal_error(tls_controls_message % "stunnel_check_cert_validity") # If the stunnel libwrap option is supported, we disable the usage of /etc/hosts.allow and /etc/hosts.deny by # setting the option to no - if is_stunnel_option_supported(stunnel_options, b"libwrap"): + if not efs_proxy_enabled and is_stunnel_option_supported( + stunnel_options, b"libwrap" + ): efs_config["libwrap"] = "no" stunnel_config = "\n".join( @@ -1420,7 +1448,7 @@ def write_stunnel_config_file( return stunnel_config_file -def write_tls_tunnel_state_file( +def write_tunnel_state_file( fs_id, mountpoint, tls_port, @@ -1433,6 +1461,8 @@ def write_tls_tunnel_state_file( """ Return the name of the temporary file containing TLS tunnel state, prefixed with a '~'. This file needs to be renamed to a non-temporary version following a successful mount. + + The "tunnel" here refers to efs-proxy, or stunnel. """ state_file = "~" + get_mount_specific_filename(fs_id, mountpoint, tls_port) @@ -1453,19 +1483,19 @@ def write_tls_tunnel_state_file( return state_file -def rewrite_tls_tunnel_state_file(state, state_file_dir, state_file): +def rewrite_tunnel_state_file(state, state_file_dir, state_file): with open(os.path.join(state_file_dir, state_file), "w") as f: json.dump(state, f) return state_file -def update_tls_tunnel_temp_state_file_with_tunnel_pid( +def update_tunnel_temp_state_file_with_tunnel_pid( temp_tls_state_file, state_file_dir, stunnel_pid ): with open(os.path.join(state_file_dir, temp_tls_state_file), "r") as f: state = json.load(f) state["pid"] = stunnel_pid - temp_tls_state_file = rewrite_tls_tunnel_state_file( + temp_tls_state_file = rewrite_tunnel_state_file( state, state_file_dir, temp_tls_state_file ) return temp_tls_state_file @@ -1476,9 +1506,9 @@ def test_tunnel_process(tunnel_proc, fs_id): if tunnel_proc.returncode is not None: _, err = tunnel_proc.communicate() fatal_error( - "Failed to initialize TLS tunnel for %s, please check mount.log for the failure reason." + "Failed to initialize tunnel for %s, please check mount.log for the failure reason." % fs_id, - 'Failed to start TLS tunnel (errno=%d), stderr="%s". If the stderr is lacking enough details, please ' + 'Failed to start tunnel (errno=%d), stderr="%s". If the stderr is lacking enough details, please ' "enable stunnel debug log in efs-utils config file and retry the mount to capture more info." % (tunnel_proc.returncode, err.strip()), ) @@ -1642,8 +1672,12 @@ def get_tls_port_from_sock(tls_port_sock): return tls_port_sock.getsockname()[1] +def tls_enabled(options): + return "tls" in options + + @contextmanager -def bootstrap_tls( +def bootstrap_proxy( config, init_system, dns_name, @@ -1652,85 +1686,115 @@ def bootstrap_tls( options, state_file_dir=STATE_FILE_DIR, fallback_ip_address=None, + efs_proxy_enabled=True, ): - tls_port_sock = choose_tls_port_and_get_bind_sock(config, options, state_file_dir) - tls_port = get_tls_port_from_sock(tls_port_sock) + """ + Generates a TLS private key and client-side certificate, a stunnel configuration file, and a state file + that is used to pass information to the Watchdog process. + + This function will spin up a stunnel or efs-proxy process, and pass it the stunnel configuration file. + The client-side certificate generated by this function contains IAM information that can be used by the EFS backend to enforce + file system policies. + + The state file passes information about the mount and the associated proxy process (whether that's stunnel or efs-proxy) to + the Watchdog daemon service. This allows Watchdog to monitor the proxy process's health. + + This function will yield a handle on the proxy process, whether it's efs-proxy or stunnel. + """ + + proxy_listen_sock = choose_tls_port_and_get_bind_sock( + config, options, state_file_dir + ) + proxy_listen_port = get_tls_port_from_sock(proxy_listen_sock) try: # override the tlsport option so that we can later override the port the NFS client uses to connect to stunnel. # if the user has specified tlsport=X at the command line this will just re-set tlsport to X. - options["tlsport"] = tls_port + options["tlsport"] = proxy_listen_port use_iam = "iam" in options ap_id = options.get("accesspoint") - cert_details = {} + cert_details = None security_credentials = None client_info = get_client_info(config) region = get_target_region(config) - if use_iam: - aws_creds_uri = options.get("awscredsuri") - role_arn = options.get("rolearn") - jwt_path = options.get("jwtpath") - if aws_creds_uri: - kwargs = {"aws_creds_uri": aws_creds_uri} - elif role_arn and jwt_path: - kwargs = {"role_arn": role_arn, "jwt_path": jwt_path} - else: - kwargs = {"awsprofile": get_aws_profile(options, use_iam)} - - security_credentials, credentials_source = get_aws_security_credentials( - config, use_iam, region, **kwargs - ) + if tls_enabled(options): + cert_details = {} + # IAM can only be used for tls mounts + if use_iam: + aws_creds_uri = options.get("awscredsuri") + role_arn = options.get("rolearn") + jwt_path = options.get("jwtpath") + if aws_creds_uri: + kwargs = {"aws_creds_uri": aws_creds_uri} + elif role_arn and jwt_path: + kwargs = {"role_arn": role_arn, "jwt_path": jwt_path} + else: + kwargs = {"awsprofile": get_aws_profile(options, use_iam)} - if credentials_source: - cert_details["awsCredentialsMethod"] = credentials_source - logging.debug( - "AWS credentials source used for IAM authentication: ", - credentials_source, + security_credentials, credentials_source = get_aws_security_credentials( + config, use_iam, region, **kwargs ) - if ap_id: - cert_details["accessPoint"] = ap_id + if credentials_source: + cert_details["awsCredentialsMethod"] = credentials_source + logging.debug( + "AWS credentials source used for IAM authentication: ", + credentials_source, + ) - # additional symbol appended to avoid naming collisions - cert_details["mountStateDir"] = ( - get_mount_specific_filename(fs_id, mountpoint, tls_port) + "+" - ) - # common name for certificate signing request is max 64 characters - cert_details["commonName"] = socket.gethostname()[0:64] - cert_details["region"] = region - cert_details["certificateCreationTime"] = create_certificate( - config, - cert_details["mountStateDir"], - cert_details["commonName"], - cert_details["region"], - fs_id, - security_credentials, - ap_id, - client_info, - base_path=state_file_dir, - ) - cert_details["certificate"] = os.path.join( - state_file_dir, cert_details["mountStateDir"], "certificate.pem" - ) - cert_details["privateKey"] = get_private_key_path() - cert_details["fsId"] = fs_id + # Access points must be mounted over TLS + if ap_id: + cert_details["accessPoint"] = ap_id + + # additional symbol appended to avoid naming collisions + cert_details["mountStateDir"] = ( + get_mount_specific_filename(fs_id, mountpoint, proxy_listen_port) + "+" + ) + # common name for certificate signing request is max 64 characters + cert_details["commonName"] = socket.gethostname()[0:64] + cert_details["region"] = region + cert_details["certificateCreationTime"] = create_certificate( + config, + cert_details["mountStateDir"], + cert_details["commonName"], + cert_details["region"], + fs_id, + security_credentials, + ap_id, + client_info, + base_path=state_file_dir, + ) + cert_details["certificate"] = os.path.join( + state_file_dir, cert_details["mountStateDir"], "certificate.pem" + ) + cert_details["privateKey"] = get_private_key_path() + cert_details["fsId"] = fs_id if not os.path.exists(state_file_dir): create_required_directory(config, state_file_dir) start_watchdog(init_system) - verify_level = int(options.get("verify", DEFAULT_STUNNEL_VERIFY_LEVEL)) + verify_level = ( + int(options.get("verify", DEFAULT_STUNNEL_VERIFY_LEVEL)) + if tls_enabled(options) + else None + ) + ocsp_enabled = is_ocsp_enabled(config, options) + if ocsp_enabled: + assert ( + not efs_proxy_enabled + ), "OCSP is not supported by efs-proxy, and efs-utils failed to revert to stunnel-mode." stunnel_config_file = write_stunnel_config_file( config, state_file_dir, fs_id, mountpoint, - tls_port, + proxy_listen_port, dns_name, verify_level, ocsp_enabled, @@ -1738,16 +1802,31 @@ def bootstrap_tls( region, cert_details=cert_details, fallback_ip_address=fallback_ip_address, - ) - tunnel_args = [_stunnel_bin(), stunnel_config_file] + efs_proxy_enabled=efs_proxy_enabled, + ) + if efs_proxy_enabled: + if "tls" in options: + tunnel_args = [ + _efs_proxy_bin(), + stunnel_config_file, + EFS_PROXY_TLS_OPTION, + ] + else: + tunnel_args = [ + _efs_proxy_bin(), + stunnel_config_file, + ] + else: + tunnel_args = [_stunnel_bin(), stunnel_config_file] + if "netns" in options: tunnel_args = ["nsenter", "--net=" + options["netns"]] + tunnel_args # This temp state file is acting like a tlsport lock file, which is why pid =- 1 - temp_tls_state_file = write_tls_tunnel_state_file( + temp_tls_state_file = write_tunnel_state_file( fs_id, mountpoint, - tls_port, + proxy_listen_port, -1, tunnel_args, [stunnel_config_file], @@ -1755,13 +1834,21 @@ def bootstrap_tls( cert_details=cert_details, ) finally: - # Always close the socket we created when choosing TLS port only until now to - # 1. avoid concurrent TLS mount port collision 2. enable stunnel process to bind the port - logging.debug("Closing socket used to choose TLS port %s.", tls_port) - tls_port_sock.close() + # When choosing a TLS port for efs-proxy/stunnel to listen on, we open the port to ensure it is free. + # However, we must free it again so efs-proxy/stunnel can bind to it. We make sure to only free it after + # we write the temporary state file, which acts like a tlsport lock file. This ensures we don't encounter + # any race conditions when choosing tls ports during concurrent mounts. + logging.debug( + "Closing socket used to choose proxy listen port %s.", proxy_listen_port + ) + proxy_listen_sock.close() # launch the tunnel in a process group so if it has any child processes, they can be killed easily by the mount watchdog - logging.info('Starting TLS tunnel: "%s"', " ".join(tunnel_args)) + logging.info( + 'Starting %s: "%s"', + "efs-proxy" if efs_proxy_enabled else "stunnel", + " ".join(tunnel_args), + ) tunnel_proc = subprocess.Popen( tunnel_args, stdout=subprocess.DEVNULL, @@ -1769,9 +1856,13 @@ def bootstrap_tls( preexec_fn=os.setsid, close_fds=True, ) - logging.info("Started TLS tunnel, pid: %d", tunnel_proc.pid) + logging.info( + "Started %s, pid: %d", + "efs-proxy" if efs_proxy_enabled else "stunnel", + tunnel_proc.pid, + ) - update_tls_tunnel_temp_state_file_with_tunnel_pid( + update_tunnel_temp_state_file_with_tunnel_pid( temp_tls_state_file, state_file_dir, tunnel_proc.pid ) @@ -1784,6 +1875,8 @@ def bootstrap_tls( try: yield tunnel_proc finally: + # The caller of this function should use this function in the context of a `with` statement + # so that the state file is correctly renamed. os.rename( os.path.join(state_file_dir, temp_tls_state_file), os.path.join(state_file_dir, temp_tls_state_file[1:]), @@ -1813,7 +1906,17 @@ def check_if_nfsvers_is_compatible_with_macos(options): fatal_error("NFSv4.1 is not supported on MacOS, please switch to NFSv4.0") -def get_nfs_mount_options(options): +# Use stunnel instead of efs-proxy for tls mounts, +# and attach non-tls mounts directly to the mount target. +def legacy_stunnel_mode_enabled(options, config): + return ( + LEGACY_STUNNEL_MOUNT_OPTION in options + or check_if_platform_is_mac() + or is_ocsp_enabled(config, options) + ) + + +def get_nfs_mount_options(options, config): # If you change these options, update the man page as well at man/mount.efs.8 if "nfsvers" not in options and "vers" not in options: options["nfsvers"] = "4.1" if not check_if_platform_is_mac() else "4.0" @@ -1838,7 +1941,11 @@ def get_nfs_mount_options(options): if check_if_platform_is_mac(): options["mountport"] = "2049" - if "tls" in options: + if legacy_stunnel_mode_enabled(options, config): + # Non-tls mounts in stunnel mode should not re-map the port + if "tls" in options: + options["port"] = options["tlsport"] + else: options["port"] = options["tlsport"] def to_nfs_option(k, v): @@ -1854,12 +1961,15 @@ def to_nfs_option(k, v): def mount_nfs(config, dns_name, path, mountpoint, options, fallback_ip_address=None): - if "tls" in options: - mount_path = "127.0.0.1:%s" % path - elif fallback_ip_address: - mount_path = "%s:%s" % (fallback_ip_address, path) + if legacy_stunnel_mode_enabled(options, config): + if "tls" in options: + mount_path = "127.0.0.1:%s" % path + elif fallback_ip_address: + mount_path = "%s:%s" % (fallback_ip_address, path) + else: + mount_path = "%s:%s" % (dns_name, path) else: - mount_path = "%s:%s" % (dns_name, path) + mount_path = "127.0.0.1:%s" % path if not check_if_platform_is_mac(): command = [ @@ -1867,13 +1977,13 @@ def mount_nfs(config, dns_name, path, mountpoint, options, fallback_ip_address=N mount_path, mountpoint, "-o", - get_nfs_mount_options(options), + get_nfs_mount_options(options, config), ] else: command = [ "/sbin/mount_nfs", "-o", - get_nfs_mount_options(options), + get_nfs_mount_options(options, config), mount_path, mountpoint, ] @@ -2448,7 +2558,8 @@ def read_config(config_file=CONFIG_FILE): return p -def bootstrap_logging(config, log_dir=LOG_DIR): +# Retrieve and parse the logging level from the config file. +def get_log_level_from_config(config): raw_level = config.get(CONFIG_SECTION, "logging_level") levels = { "debug": logging.DEBUG, @@ -2465,6 +2576,26 @@ def bootstrap_logging(config, log_dir=LOG_DIR): level_error = True level = logging.INFO + return (level, raw_level, level_error) + + +# Convert the log level provided in the config into a log level string +# that is understandable by efs-proxy +def get_efs_proxy_log_level(config): + level, raw_level, level_error = get_log_level_from_config(config) + if level_error: + return "info" + + # Efs-proxy does not have a CRITICAL log level + if level == logging.CRITICAL: + return "error" + + return raw_level.lower() + + +def bootstrap_logging(config, log_dir=LOG_DIR): + level, raw_level, level_error = get_log_level_from_config(config) + max_bytes = config.getint(CONFIG_SECTION, "logging_max_bytes") file_count = config.getint(CONFIG_SECTION, "logging_file_count") @@ -3023,7 +3154,7 @@ def is_nfs_mount(mountpoint): return False -def mount_tls( +def mount_with_proxy( config, init_system, dns_name, @@ -3033,6 +3164,11 @@ def mount_tls( options, fallback_ip_address=None, ): + """ + This function is responsible for launching a efs-proxy process and attaching a NFS mount to that process + over the loopback interface. Efs-proxy is responsible for forwarding NFS operations to EFS. + When the legacy 'stunnel' mount option is used, this function will launch a stunnel process instead of efs-proxy. + """ if os.path.ismount(mountpoint) and is_nfs_mount(mountpoint): sys.stdout.write( "%s is already mounted, please run 'mount' command to verify\n" % mountpoint @@ -3040,7 +3176,10 @@ def mount_tls( logging.warning("%s is already mounted, mount aborted" % mountpoint) return - with bootstrap_tls( + efs_proxy_enabled = not legacy_stunnel_mode_enabled(options, config) + logging.debug("mount_with_proxy: efs_proxy_enabled = %s", efs_proxy_enabled) + + with bootstrap_proxy( config, init_system, dns_name, @@ -3048,6 +3187,7 @@ def mount_tls( mountpoint, options, fallback_ip_address=fallback_ip_address, + efs_proxy_enabled=efs_proxy_enabled, ) as tunnel_proc: mount_completed = threading.Event() t = threading.Thread( @@ -3907,22 +4047,22 @@ def main(): if check_if_platform_is_mac() and "notls" not in options: options["tls"] = None - if "tls" in options: - mount_tls( + if "tls" not in options and legacy_stunnel_mode_enabled(options, config): + mount_nfs( config, - init_system, dns_name, path, - fs_id, mountpoint, options, fallback_ip_address=fallback_ip_address, ) else: - mount_nfs( + mount_with_proxy( config, + init_system, dns_name, path, + fs_id, mountpoint, options, fallback_ip_address=fallback_ip_address, diff --git a/src/proxy/Cargo.toml b/src/proxy/Cargo.toml new file mode 100644 index 00000000..4bbf3042 --- /dev/null +++ b/src/proxy/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "efs-proxy" +edition = "2021" +build = "build.rs" +# The version of efs-proxy is tied to efs-utils. +version = "2.0.0" +publish = false + +[dependencies] +anyhow = "1.0.72" +async-trait = "0.1" +bytes = { version = "1.4.0" } +chrono = "0.4" +clap = { version = "=4.0.0", features = ["derive"] } +fern = "0.6" +futures = "0.3" +log = "0.4" +log4rs = { version = "0", features = ["rolling_file_appender", "compound_policy", "size_trigger", "fixed_window_roller"]} +nix = { version = "0.26.2", features = ["signal"]} +onc-rpc = "0.2.3" +rand = "0.8.5" +s2n-tls = "0.0" +s2n-tls-tokio = "0.0" +s2n-tls-sys = "0.0" +serde = {version="1.0.175",features=["derive"]} +serde_ini = "0.2.0" +thiserror = "1.0.44" +tokio = { version = "1.29.0", features = ["full"] } +tokio-util = "0.7.8" +uuid = { version = "1.4.1", features = ["v4", "fast-rng", "macro-diagnostics"]} +xdr-codec = "0.4.4" + +[dev-dependencies] +test-case = "*" +tokio = { version = "1.29.0", features = ["test-util"] } + +[build-dependencies] +xdrgen = "0.4.4" \ No newline at end of file diff --git a/src/proxy/build.rs b/src/proxy/build.rs new file mode 100644 index 00000000..71e8d0da --- /dev/null +++ b/src/proxy/build.rs @@ -0,0 +1,5 @@ +extern crate xdrgen; + +fn main() { + xdrgen::compile("src/efs_prot.x").expect("xdrgen efs_prot.x failed"); +} diff --git a/src/proxy/src/config_parser.rs b/src/proxy/src/config_parser.rs new file mode 100644 index 00000000..0e26757b --- /dev/null +++ b/src/proxy/src/config_parser.rs @@ -0,0 +1,211 @@ +use log::LevelFilter; +use serde::{Deserialize, Serialize}; +use std::{error::Error, path::Path, str::FromStr}; + +const DEFAULT_LOG_LEVEL: LevelFilter = LevelFilter::Warn; + +fn default_log_level() -> String { + DEFAULT_LOG_LEVEL.to_string() +} + +fn deserialize_bool<'de, D>(deserializer: D) -> Result +where + D: serde::Deserializer<'de>, +{ + let s = String::deserialize(deserializer)?; + match s.to_lowercase().as_str() { + "true" | "yes" | "1" => Ok(true), + "false" | "no" | "0" => Ok(false), + _ => Err(serde::de::Error::custom(format!("Invalid value: {}", s))), + } +} + +#[derive(Default, Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] +pub struct ProxyConfig { + #[serde(alias = "fips", deserialize_with = "deserialize_bool")] + pub fips: bool, + + /// Logging level. Values should correspond to the log::LevelFilter enum. + #[serde(alias = "debug", default = "default_log_level")] + pub debug: String, + + /// Output path for log files. Logging is disabled if this value is not provided. + #[serde(alias = "output")] + pub output: Option, + + /// The proxy process is responsible for writing it's PID into this file so that the Watchdog + /// process can monitor it + #[serde(alias = "pid")] + pub pid_file_path: String, + + /// This nested structure is required for backwards compatibility + #[serde(alias = "efs")] + pub nested_config: EfsConfig, +} + +impl FromStr for ProxyConfig { + type Err = serde_ini::de::Error; + + fn from_str(s: &str) -> Result { + serde_ini::from_str(s) + } +} + +impl ProxyConfig { + pub fn from_path(config_path: &Path) -> Result> { + let config_string = std::fs::read_to_string(config_path)?; + let config = ProxyConfig::from_str(&config_string)?; + Ok(config) + } +} + +#[derive(Default, Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] +pub struct EfsConfig { + /// The mount target address - DNS name or IP address + #[serde(alias = "connect")] + pub mount_target_addr: String, + + /// Listen for and accept connections on the specified host:port + #[serde(alias = "accept")] + pub listen_addr: String, + + /// File path of the file that contains the client-side certificate and public key + #[serde(alias = "cert", default)] + pub client_cert_pem_file: String, + + /// File path of the file that contains the client private key + #[serde(alias = "key", default)] + pub client_private_key_pem_file: String, + + /// The hostname that is expected to be on the TLS certificate that the remote server presents + #[serde(alias = "checkHost", default)] + pub expected_server_hostname_tls: String, + + /// File path of the certificate authority file. + /// This is used to verify the EFS server-side TLS certificate. + #[serde(alias = "CAfile", default)] + pub ca_file: String, +} + +#[cfg(test)] +pub mod tests { + use super::*; + use std::{path::Path, string::String}; + + pub static TEST_CONFIG_PATH: &str = "tests/certs/test_config.ini"; + + pub fn get_test_config() -> ProxyConfig { + ProxyConfig::from_path(&Path::new(TEST_CONFIG_PATH)).expect("Could not parse test config.") + } + + #[test] + fn test_read_config_from_file() { + assert!(ProxyConfig::from_path(&Path::new(TEST_CONFIG_PATH)).is_ok()); + } + + #[test] + fn test_parse_config() { + let config_string = r#"fips = yes +foreground = quiet +socket = l:SO_REUSEADDR=yes +socket = a:SO_BINDTODEVICE=lo +debug = debug +output = /var/log/amazon/efs/fs-12341234.home.ec2-user.efs.21036.efs-proxy.log +pid = /var/run/efs/fs-12341234.home.ec2-user.efs.21036+/stunnel.pid +port = 8081 +initial_partition_ip = 127.0.0.1:2049 + +[efs] +accept = 127.0.0.1:21036 +connect = fs-12341234.efs.us-east-1.amazonaws.com:2049 +sslVersion = TLSv1.2 +renegotiation = no +TIMEOUTbusy = 20 +TIMEOUTclose = 0 +TIMEOUTidle = 70 +delay = yes +verify = 2 +CAfile = /etc/amazon/efs/efs-utils.crt +cert = /var/run/efs/fs-12341234.home.ec2-user.efs.21036+/certificate.pem +key = /etc/amazon/efs/privateKey.pem +checkHost = fs-12341234.efs.us-east-1.amazonaws.com +"#; + + let result_config = ProxyConfig::from_str(&config_string).unwrap(); + let expected_proxy_config = ProxyConfig { + fips: true, + pid_file_path: String::from( + "/var/run/efs/fs-12341234.home.ec2-user.efs.21036+/stunnel.pid", + ), + debug: LevelFilter::Debug.to_string().to_ascii_lowercase(), + output: Some(String::from( + "/var/log/amazon/efs/fs-12341234.home.ec2-user.efs.21036.efs-proxy.log", + )), + nested_config: EfsConfig { + listen_addr: String::from("127.0.0.1:21036"), + mount_target_addr: String::from("fs-12341234.efs.us-east-1.amazonaws.com:2049"), + ca_file: String::from("/etc/amazon/efs/efs-utils.crt"), + client_cert_pem_file: String::from( + "/var/run/efs/fs-12341234.home.ec2-user.efs.21036+/certificate.pem", + ), + client_private_key_pem_file: String::from("/etc/amazon/efs/privateKey.pem"), + expected_server_hostname_tls: String::from( + "fs-12341234.efs.us-east-1.amazonaws.com", + ), + }, + }; + + assert_eq!(result_config, expected_proxy_config); + } + + #[test] + fn test_parse_config_fips_disabled() { + let config_string = r#"fips = no +foreground = quiet +socket = l:SO_REUSEADDR=yes +socket = a:SO_BINDTODEVICE=lo +pid = /var/run/efs/fs-12341234.home.ec2-user.efs.21036+/stunnel.pid +port = 8081 +initial_partition_ip = 127.0.0.1:2049 + +[efs] +accept = 127.0.0.1:21036 +connect = fs-12341234.efs.us-east-1.amazonaws.com:2049 +sslVersion = TLSv1.2 +renegotiation = no +TIMEOUTbusy = 20 +TIMEOUTclose = 0 +TIMEOUTidle = 70 +delay = yes +verify = 2 +CAfile = /etc/amazon/efs/efs-utils.crt +cert = /var/run/efs/fs-12341234.home.ec2-user.efs.21036+/certificate.pem +key = /etc/amazon/efs/privateKey.pem +checkHost = fs-12341234.efs.us-east-1.amazonaws.com +"#; + + let result_config = ProxyConfig::from_str(&config_string).unwrap(); + let expected_proxy_config = ProxyConfig { + fips: false, + pid_file_path: String::from( + "/var/run/efs/fs-12341234.home.ec2-user.efs.21036+/stunnel.pid", + ), + debug: DEFAULT_LOG_LEVEL.to_string(), + output: None, + nested_config: EfsConfig { + listen_addr: String::from("127.0.0.1:21036"), + mount_target_addr: String::from("fs-12341234.efs.us-east-1.amazonaws.com:2049"), + ca_file: String::from("/etc/amazon/efs/efs-utils.crt"), + client_cert_pem_file: String::from( + "/var/run/efs/fs-12341234.home.ec2-user.efs.21036+/certificate.pem", + ), + client_private_key_pem_file: String::from("/etc/amazon/efs/privateKey.pem"), + expected_server_hostname_tls: String::from( + "fs-12341234.efs.us-east-1.amazonaws.com", + ), + }, + }; + + assert_eq!(result_config, expected_proxy_config); + } +} diff --git a/src/proxy/src/connections.rs b/src/proxy/src/connections.rs new file mode 100644 index 00000000..41b25a4f --- /dev/null +++ b/src/proxy/src/connections.rs @@ -0,0 +1,710 @@ +use crate::efs_prot::{BindClientResponse, BindResponse, ScaleUpConfig}; +use crate::efs_rpc::{self, PartitionId}; +use crate::error::{ConnectError, RpcError}; +use crate::proxy_identifier::ProxyIdentifier; +use crate::{ + controller::Event, shutdown::ShutdownHandle, tls::establish_tls_stream, tls::TlsConfig, +}; +use async_trait::async_trait; +use futures::future; +use log::{debug, info, warn}; +use s2n_tls_tokio::TlsStream; +use std::sync::Arc; +use std::{collections::HashMap, time::Duration}; +use tokio::task::JoinHandle; +use tokio::time::timeout; +use tokio::{ + io::AsyncWriteExt, + io::{AsyncRead, AsyncWrite}, + net::TcpStream, + sync::mpsc, +}; + +const CONCURRENT_ATTEMPT_COUNT: u32 = 3; + +pub const MAX_ATTEMPT_COUNT: u32 = 120; +const SINGLE_CONNECTION_TIMEOUT_SEC: u64 = 15; +pub const MULTIPLEX_CONNECTION_TIMEOUT_SEC: u64 = 15; + +pub trait ProxyStream: AsyncRead + AsyncWrite + Unpin + Send + 'static {} +impl ProxyStream for T {} + +#[async_trait] +pub trait PartitionFinder { + async fn establish_connection( + &self, + proxy_id: ProxyIdentifier, + ) -> Result<(S, Option, Option), ConnectError>; + + async fn spawn_establish_connection_task( + &self, + proxy_id: ProxyIdentifier, + ) -> JoinHandle), ConnectError>>; + + // Establish multiple connections to an EFS "Partition" to enable higher IO throughput. A + // `target` partition should be provided if the proxy owns an existing connection to EFS. When + // provided, the search will prefer to find a connection that maps to this `target` partition. + // This `target` does not represent a hard requirement, as connections mapping to a different + // partition can still be returned. + // + async fn inner_establish_multiplex_connection( + &self, + proxy_id: ProxyIdentifier, + target: Option, + shutdown_handle: ShutdownHandle, + ) -> Result<(PartitionId, Vec, ScaleUpConfig), (ConnectError, Option)> { + let mut connect_futures = Vec::with_capacity(CONCURRENT_ATTEMPT_COUNT as usize); + for _ in 0..CONCURRENT_ATTEMPT_COUNT { + connect_futures.push(self.spawn_establish_connection_task(proxy_id).await); + } + + let mut connected_partitions: HashMap> = HashMap::new(); + + let mut failure_count = 0; + let mut attempt_count = CONCURRENT_ATTEMPT_COUNT; + + let overall_timeout = + tokio::time::sleep(Duration::from_secs(MULTIPLEX_CONNECTION_TIMEOUT_SEC)); + tokio::pin!(overall_timeout); + + loop { + tokio::select! { + (join_result, index, _) = future::select_all(connect_futures.iter_mut()) => { + let Ok(connection_result) = join_result else { + warn!("JoinError encountered during connection search."); + tokio::spawn(shutdown_connections(connected_partitions)); + return Err((ConnectError::MultiplexFailure, None)); + }; + + let (stream, bind_result) = match connection_result { + Ok(r) => r, + Err(ConnectError::IoError(e)) => { + debug!("Retryable ConnectError encountered during connection search. Error: {:?}", e); + failure_count += 1; + self.retry_multiplex_connection_attempt(proxy_id, &mut attempt_count, index, &mut connect_futures).await?; + continue; + }, + Err(e) => { + warn!("Non-retryable ConnectError encountered during connection search. Error: {}", e); + tokio::spawn(shutdown_connections(connected_partitions)); + return Err((ConnectError::MultiplexFailure, None)) + } + }; + + let response = match bind_result { + Ok(r) => r, + Err(RpcError::IoError(e)) => { + debug!("Retryable RpcError encountered during connection search. Error: {:?}", e); + failure_count += 1; + self.retry_multiplex_connection_attempt(proxy_id, &mut attempt_count, index, &mut connect_futures).await?; + continue; + }, + Err(e) => { + warn!("Non-retryable RpcError encountered during connection search. Error: {}", e); + tokio::spawn(shutdown_connections(connected_partitions)); + return Err((ConnectError::MultiplexFailure, None)) + } + }; + + let bind_response = response.bind_response; + let new_scale_up_config = response.scale_up_config; + debug!("Received {}", get_bind_response_string(&bind_response)); + match bind_response { + BindResponse::READY(id) => { + let partition_id = PartitionId { id: id.0 }; + + if Some(partition_id) == target { + debug!("Connection to target partition found. Attempt Count: {}, Failure Count: {}", attempt_count, failure_count); + } else { + debug!("Connection to non-target partition found. Attempt Count: {}, Failure Count: {}", attempt_count, failure_count); + } + + if let Some(mut streams) = connected_partitions.remove(&partition_id) { + streams.push(stream); + + let target_connection_count = if Some(partition_id) == target { + (new_scale_up_config.max_multiplexed_connections - 1) as usize + } else { + new_scale_up_config.max_multiplexed_connections as usize + }; + + if streams.len() >= target_connection_count { + tokio::spawn(shutdown_connections(connected_partitions)); + return Ok((partition_id, streams, new_scale_up_config)); + } else { + connected_partitions.insert(partition_id, streams); + } + } else { + connected_partitions.insert(partition_id, vec!(stream)); + } + }, + BindResponse::RETRY(_) | BindResponse::PREFERRED(_) => (), + BindResponse::RETRY_LATER(_) | BindResponse::ERROR(_) | BindResponse::default => { + tokio::spawn(shutdown_connections(connected_partitions)); + return Err((ConnectError::MultiplexFailure, Some(new_scale_up_config))) + }, + }; + + debug!("Continuing partition search. Attempt Count: {}, Failure Count: {}, Partitions Found: {}", attempt_count, failure_count, connected_partitions.len()); + self.retry_multiplex_connection_attempt(proxy_id, &mut attempt_count, index, &mut connect_futures).await?; + }, + _ = &mut overall_timeout => { + tokio::spawn(shutdown_connections(connected_partitions)); + return Err((ConnectError::Timeout, None)); + }, + _ = shutdown_handle.cancellation_token.cancelled() => { + tokio::spawn(shutdown_connections(connected_partitions)); + return Err((ConnectError::Cancelled, None)); + } + } + } + } + + async fn retry_multiplex_connection_attempt( + &self, + proxy_id: ProxyIdentifier, + attempt_count: &mut u32, + last_failed_index: usize, + connect_futures: &mut Vec< + JoinHandle), ConnectError>>, + >, + ) -> Result<(), (ConnectError, Option)> { + if *attempt_count > MAX_ATTEMPT_COUNT { + return Err((ConnectError::MaxAttemptsExceeded, None)); + } else { + connect_futures.swap_remove(last_failed_index); + connect_futures.push(self.spawn_establish_connection_task(proxy_id).await); + *attempt_count += 1; + Ok(()) + } + } + + // Increase the number of connections to the EFS Service. + async fn scale_up_connection( + &self, + proxy_id: ProxyIdentifier, + partition_id: Option, + notification_queue: mpsc::Sender>, + shutdown_handle: ShutdownHandle, + ) { + let result = match self + .inner_establish_multiplex_connection(proxy_id, partition_id, shutdown_handle) + .await + { + Ok((id, proxy_streams, scale_up_config)) => { + notification_queue + .send(Event::ConnectionSuccess( + Some(id), + proxy_streams, + scale_up_config, + )) + .await + } + Err(e) => { + info!("Attempt to scale up failed: {}", e.0); + notification_queue.send(Event::ConnectionFail(e.1)).await + } + }; + result.unwrap_or_else(|_| warn!("Unable to notify event queue of established connections")); + } +} + +pub fn configure_stream(tcp_stream: TcpStream) -> TcpStream { + match tcp_stream.set_nodelay(true) { + Ok(_) => {} + Err(e) => warn!("Error setting TCP_NODELAY: {}", e), + } + tcp_stream +} + +// Allow for graceful closure of Tls connections +async fn shutdown_connections(connections: HashMap>) { + for streams in connections.into_values() { + for mut stream in streams.into_iter() { + tokio::spawn(async move { + if let Err(e) = stream.shutdown().await { + debug!("Failed to gracefully shutdown connection: {}", e); + } + }); + } + } +} + +// BindResponse in generated by xdrgen and does not implement the Debug or Display traits +pub fn get_bind_response_string(bind_response: &BindResponse) -> String { + match bind_response { + BindResponse::PREFERRED(_partition_id) => String::from("BindResponse::PREFERRED"), + BindResponse::READY(_partition_id) => String::from("BindResponse::READY"), + BindResponse::RETRY(m) => { + if m.is_empty() { + String::from("BindResponse::RETRY") + } else { + format!("BindResponse::RETRY. message: {m}") + } + } + BindResponse::RETRY_LATER(m) => { + if m.is_empty() { + String::from("BindResponse::RETRY_LATER") + } else { + format!("BindResponse::RETRY_LATER. message: {m}") + } + } + BindResponse::ERROR(m) => { + if m.is_empty() { + String::from("BindResponse::ERROR") + } else { + format!("BindResponse::ERROR. message: {m}") + } + } + BindResponse::default => String::from("BindResponse::default"), + } +} + +#[derive(Clone)] +pub struct PlainTextPartitionFinder { + pub mount_target_addr: String, +} + +impl PlainTextPartitionFinder { + async fn establish_plain_text_connection( + mount_target_addr: String, + proxy_id: ProxyIdentifier, + ) -> Result<(TcpStream, Result), ConnectError> { + timeout(Duration::from_secs(SINGLE_CONNECTION_TIMEOUT_SEC), async { + let mut tcp_stream = TcpStream::connect(mount_target_addr).await?; + let response = efs_rpc::bind_client_to_partition(proxy_id, &mut tcp_stream).await; + Ok((configure_stream(tcp_stream), response)) + }) + .await + .map_err(|_| ConnectError::Timeout)? + } +} + +#[async_trait] +impl PartitionFinder for PlainTextPartitionFinder { + async fn establish_connection( + &self, + proxy_id: ProxyIdentifier, + ) -> Result<(TcpStream, Option, Option), ConnectError> { + let (s, bind_result) = + Self::establish_plain_text_connection(self.mount_target_addr.clone(), proxy_id).await?; + match bind_result { + Ok(response) => { + debug!( + "EFS RPC call succeeded while establishing initial connection. Response: {}", + get_bind_response_string(&response.bind_response) + ); + let partition_id = match &response.bind_response { + BindResponse::READY(id) => Some(PartitionId { id: id.0 }), + _ => None, + }; + Ok((s, partition_id, Some(response.scale_up_config))) + } + Err(e) => { + warn!("EFS RPC call errored while establishing initial connection. Error {e}",); + let tcp_stream = TcpStream::connect(self.mount_target_addr.clone()).await?; + return Ok((configure_stream(tcp_stream), None, None)); + } + } + } + + async fn spawn_establish_connection_task( + &self, + proxy_id: ProxyIdentifier, + ) -> JoinHandle), ConnectError>> { + let addr = self.mount_target_addr.clone(); + tokio::spawn(Self::establish_plain_text_connection(addr, proxy_id)) + } +} + +pub struct TlsPartitionFinder { + tls_config: Arc>, +} + +impl TlsPartitionFinder { + pub fn new(tls_config: Arc>) -> Self { + TlsPartitionFinder { tls_config } + } + + async fn establish_tls_connection( + tls_config: TlsConfig, + proxy_id: ProxyIdentifier, + ) -> Result<(TlsStream, Result), ConnectError> { + timeout(Duration::from_secs(SINGLE_CONNECTION_TIMEOUT_SEC), async { + let mut tls_stream = establish_tls_stream(tls_config).await?; + let response = efs_rpc::bind_client_to_partition(proxy_id, &mut tls_stream).await; + Ok((tls_stream, response)) + }) + .await + .map_err(|_| ConnectError::Timeout)? + } +} + +#[async_trait] +impl PartitionFinder> for TlsPartitionFinder { + async fn establish_connection( + &self, + proxy_id: ProxyIdentifier, + ) -> Result< + ( + TlsStream, + Option, + Option, + ), + ConnectError, + > { + let tls_config_copy = self.tls_config.lock().await.clone(); + let (s, bind_result) = Self::establish_tls_connection(tls_config_copy, proxy_id).await?; + let (bind_response, scale_up_config) = match bind_result { + Ok(response) => { + warn!( + "EFS RPC call succeeded while establishing initial connection. Response: {}", + get_bind_response_string(&response.bind_response) + ); + (response.bind_response, Some(response.scale_up_config)) + } + Err(e) => { + warn!("EFS RPC call errored while establishing initial connection. Error {e}",); + let tls_stream = establish_tls_stream(self.tls_config.lock().await.clone()).await?; + return Ok((tls_stream, None, None)); + } + }; + + match bind_response { + BindResponse::READY(id) => Ok((s, Some(PartitionId { id: id.0 }), scale_up_config)), + _ => Ok((s, None, scale_up_config)), + } + } + + async fn spawn_establish_connection_task( + &self, + proxy_id: ProxyIdentifier, + ) -> JoinHandle< + Result<(TlsStream, Result), ConnectError>, + > { + let tls_config_copy = self.tls_config.lock().await.clone(); + tokio::spawn(Self::establish_tls_connection(tls_config_copy, proxy_id)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config_parser::tests::get_test_config; + use crate::connections::PartitionFinder; + use crate::controller::tests::{find_available_port, ServiceAction, TestService}; + use crate::controller::DEFAULT_SCALE_UP_CONFIG; + use crate::ProxyConfig; + use nix::sys::signal::kill; + use nix::sys::signal::Signal; + use std::path::Path; + use std::str::FromStr; + use tokio::signal; + use tokio::sync::Mutex; + use tokio_util::sync::CancellationToken; + use uuid::Uuid; + + const PROXY_ID: ProxyIdentifier = ProxyIdentifier { + uuid: Uuid::from_u128(1 as u128), + incarnation: 0, + }; + + struct MultiplexTest { + service: TestService, + partition_finder: TlsPartitionFinder, + initial_partition_id: PartitionId, + } + + impl MultiplexTest { + async fn new() -> Self { + let service = TestService::new(true).await; + MultiplexTest::new_with_service(service).await + } + + async fn new_with_service(service: TestService) -> Self { + let mut tls_config = TlsConfig::new_from_config(&get_test_config()) + .await + .expect("Failed to acquire TlsConfig."); + tls_config.remote_addr = format!("127.0.0.1:{}", service.listen_port); + + let partition_finder = TlsPartitionFinder::new(Arc::new(Mutex::new(tls_config))); + + let (_s, id, _) = partition_finder + .establish_connection(PROXY_ID.clone()) + .await + .expect("Failed to connect to server"); + + let Some(initial_partition_id) = id else { + panic!("Partition Id not found for initial connection.") + }; + + MultiplexTest { + service, + partition_finder: partition_finder, + initial_partition_id, + } + } + } + + #[tokio::test] + async fn test_establish_multiplex_same_partition_found() { + let test = MultiplexTest::new().await; + + let (shutdown_handle, _waiter) = ShutdownHandle::new(CancellationToken::new()); + + let (new_connnection_id, connections, _) = test + .partition_finder + .inner_establish_multiplex_connection( + PROXY_ID.clone(), + Some(test.initial_partition_id.clone()), + shutdown_handle, + ) + .await + .expect("Could not establish a multiplex connection"); + + assert_eq!(test.initial_partition_id, new_connnection_id); + assert_eq!( + DEFAULT_SCALE_UP_CONFIG.max_multiplexed_connections - 1, + connections.len() as i32 + ); + + test.service.shutdown().await; + } + + #[tokio::test] + async fn test_establish_multiplex_new_partition_found() { + let test = MultiplexTest::new().await; + + let (shutdown_handle, _waiter) = ShutdownHandle::new(CancellationToken::new()); + + test.service + .post_action(ServiceAction::StopPartitionAcceptor( + test.initial_partition_id.clone(), + )) + .await; + + let (new_connnection_id, connections, _) = test + .partition_finder + .inner_establish_multiplex_connection( + PROXY_ID.clone(), + Some(test.initial_partition_id.clone()), + shutdown_handle, + ) + .await + .expect("Could not establish a multiplex connection"); + + assert_eq!( + DEFAULT_SCALE_UP_CONFIG.max_multiplexed_connections, + connections.len() as i32 + ); + assert_ne!(test.initial_partition_id, new_connnection_id); + + test.service.shutdown().await; + } + + #[tokio::test] + async fn test_establish_multiplex_no_target() { + let test = MultiplexTest::new().await; + + let (shutdown_handle, _waiter) = ShutdownHandle::new(CancellationToken::new()); + + let (new_connnection_id, connections, _) = test + .partition_finder + .inner_establish_multiplex_connection(PROXY_ID.clone(), None, shutdown_handle) + .await + .expect("Could not establish a multiplex connection"); + + assert_eq!( + DEFAULT_SCALE_UP_CONFIG.max_multiplexed_connections, + connections.len() as i32 + ); + assert_ne!(test.initial_partition_id, new_connnection_id); + + test.service.shutdown().await; + } + + #[tokio::test] + async fn test_establish_connection_timeout() { + let (_listener, port) = find_available_port().await; + + let error = tokio::spawn(async move { + let partition_finder = PlainTextPartitionFinder { + mount_target_addr: format!("127.0.0.1:{}", port.clone()), + }; + partition_finder + .establish_connection(PROXY_ID.clone()) + .await + }) + .await + .expect("join err"); + + assert!(matches!(error, Err(ConnectError::Timeout))); + } + + #[tokio::test] + async fn test_establish_multiplex_timeout() { + let (_listener, port) = find_available_port().await; + + let error = tokio::spawn(async move { + let (shutdown_handle, _waiter) = ShutdownHandle::new(CancellationToken::new()); + + let partition_finder = PlainTextPartitionFinder { + mount_target_addr: format!("127.0.0.1:{}", port.clone()), + }; + partition_finder + .inner_establish_multiplex_connection(PROXY_ID.clone(), None, shutdown_handle) + .await + }) + .await + .expect("join err"); + + assert!(matches!(error, Err((ConnectError::Timeout, None)))); + } + + #[tokio::test] + async fn test_establish_multiplex_shutdown() { + let (_listener, port) = find_available_port().await; + + let (shutdown_handle, _waiter) = ShutdownHandle::new(CancellationToken::new()); + + let shutdown_handle_clone = shutdown_handle.clone(); + let task = tokio::spawn(async move { + let partition_finder = PlainTextPartitionFinder { + mount_target_addr: format!("127.0.0.1:{}", port.clone()), + }; + partition_finder + .inner_establish_multiplex_connection(PROXY_ID.clone(), None, shutdown_handle_clone) + .await + }); + + shutdown_handle.exit(None).await; + let error = task.await.expect("Unexpected join error"); + + assert!(matches!(error, Err((ConnectError::Cancelled, None)))); + } + + #[tokio::test] + async fn test_scale_up_max_attempts() { + // Create a service in which the all calls of bind_client_to_partition will return a + // different value. Our "TestService" returns these PartitionIds in a round robin fashion, + // and this service will have more PartitionId than MAX_ATTEMPT_COUNT + let service = + TestService::new_with_partition_count((MAX_ATTEMPT_COUNT + 2) as usize, true).await; + + let test = MultiplexTest::new_with_service(service).await; + + let (shutdown_handle, _waiter) = ShutdownHandle::new(CancellationToken::new()); + + let error = test + .partition_finder + .inner_establish_multiplex_connection( + PROXY_ID.clone(), + Some(test.initial_partition_id.clone()), + shutdown_handle.clone(), + ) + .await; + + assert!(matches!( + error, + Err((ConnectError::MaxAttemptsExceeded, None)) + )); + } + + enum BrokenPartitionFinderType { + _ConnectIoError, + _RpcIoError, + RpcNonIoError, + } + + struct BrokenPartitionFinder { + finder_type: BrokenPartitionFinderType, + } + + impl BrokenPartitionFinder { + fn new(finder_type: BrokenPartitionFinderType) -> Self { + Self { finder_type } + } + } + + #[async_trait] + impl PartitionFinder for BrokenPartitionFinder { + async fn establish_connection( + &self, + _proxy_id: ProxyIdentifier, + ) -> Result<(TcpStream, Option, Option), ConnectError> { + unimplemented!() + } + + async fn spawn_establish_connection_task( + &self, + _proxy_id: ProxyIdentifier, + ) -> JoinHandle), ConnectError>> + { + let (_listener, port) = find_available_port().await; + let tcp_stream = TcpStream::connect(("127.0.0.1", port)) + .await + .expect("Could not establish TCP stream."); + let error = match self.finder_type { + BrokenPartitionFinderType::_ConnectIoError => Err(ConnectError::IoError( + tokio::io::ErrorKind::BrokenPipe.into(), + )), + BrokenPartitionFinderType::_RpcIoError => Ok(( + tcp_stream, + Err(RpcError::IoError(tokio::io::ErrorKind::BrokenPipe.into())), + )), + BrokenPartitionFinderType::RpcNonIoError => { + Ok((tcp_stream, Err(RpcError::GarbageArgs))) + } + }; + tokio::spawn(async { error }) + } + } + + #[tokio::test] + async fn test_scale_up_rpc_error() { + let partition_finder = BrokenPartitionFinder::new(BrokenPartitionFinderType::RpcNonIoError); + + let (shutdown_handle, _waiter) = ShutdownHandle::new(CancellationToken::new()); + let error = partition_finder + .inner_establish_multiplex_connection(PROXY_ID.clone(), None, shutdown_handle.clone()) + .await; + + assert!(matches!(error, Err((ConnectError::MultiplexFailure, None)))); + } + + #[tokio::test] + async fn test_reload_certificate() { + let (tx, rx) = tokio::sync::oneshot::channel(); + let mut sigs_hangup_listener = + signal::unix::signal(signal::unix::SignalKind::hangup()).unwrap(); + let config_file_path = Path::new("tests/certs/test_config.ini"); + let config_contents = std::fs::read_to_string(&config_file_path).unwrap(); + let proxy_config = ProxyConfig::from_str(&config_contents).unwrap(); + let mut tls_config = TlsConfig::new_from_config(&proxy_config).await.unwrap(); + tls_config.client_cert = vec![1, 2]; + let old_cert = tls_config.client_cert.clone(); + let tls_config_ptr = Arc::new(Mutex::new(tls_config)); + let cloned_tls_config_ptr = Arc::clone(&tls_config_ptr); + tokio::spawn(async move { + loop { + // Check if the SIGHUP signal is received + if (sigs_hangup_listener.recv().await).is_some() { + //Reloading the TLS configuration + let mut locked_config = cloned_tls_config_ptr.lock().await; + *locked_config = crate::get_tls_config(&proxy_config).await.unwrap(); + tx.send(()).unwrap(); + break; + } + } + }); + let tls_partition_finder = TlsPartitionFinder { + tls_config: tls_config_ptr.clone(), + }; + let _ = kill(nix::unistd::Pid::this(), Signal::SIGHUP); + rx.await.unwrap(); + assert_ne!( + old_cert, + tls_partition_finder.tls_config.lock().await.client_cert + ); + } +} diff --git a/src/proxy/src/controller.rs b/src/proxy/src/controller.rs new file mode 100644 index 00000000..46406e5e --- /dev/null +++ b/src/proxy/src/controller.rs @@ -0,0 +1,1614 @@ +use crate::connections::configure_stream; +use crate::efs_prot::ScaleUpConfig; +use crate::efs_rpc::PartitionId; +use crate::shutdown::ShutdownReason; +use crate::status_reporter::{self, StatusReporter}; +use crate::{ + connections::{PartitionFinder, ProxyStream}, + proxy::{PerformanceStats, Proxy}, + proxy_identifier::ProxyIdentifier, + shutdown::ShutdownHandle, +}; +use log::{debug, error, info, warn}; +use std::{sync::Arc, time::Duration}; +use tokio::{net::TcpListener, sync::mpsc, time::Instant}; +use tokio_util::sync::CancellationToken; + +pub const DEFAULT_SCALE_UP_BACKOFF: Duration = Duration::from_secs(300); + +pub const DEFAULT_SCALE_UP_CONFIG: ScaleUpConfig = ScaleUpConfig { + max_multiplexed_connections: 5, + scale_up_bytes_per_sec_threshold: 300 * 1024 * 1024, + scale_up_threshold_breached_duration_sec: 1, +}; + +#[derive(Debug)] +pub enum Event { + ProxyUpdate(PerformanceStats), + ConnectionSuccess(Option, Vec, ScaleUpConfig), + ConnectionFail(Option), +} + +enum EventResult { + Restart((Option, Vec, Option)), + Ok, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum ConnectionSearchState { + SearchingAdditional(Option), + Stop(Instant), + Idle, +} + +struct IncarnationState { + pub proxy_id: ProxyIdentifier, + pub last_proxy_update: Option<(Instant, PerformanceStats)>, + pub partition_id: Option, + connection_state: ConnectionSearchState, + pub num_connections: u16, + events_tx: mpsc::Sender>, +} + +impl IncarnationState { + fn new( + proxy_id: ProxyIdentifier, + partition_id: Option, + events_tx: mpsc::Sender>, + num_connections: u16, + ) -> Self { + Self { + proxy_id, + last_proxy_update: None, + partition_id, + connection_state: ConnectionSearchState::Idle, + num_connections, + events_tx, + } + } +} + +pub struct Controller { + listener: TcpListener, + partition_finder: Arc + Sync + Send>, + proxy_id: ProxyIdentifier, + scale_up_attempt_count: u64, + restart_count: u64, + scale_up_config: ScaleUpConfig, + status_reporter: StatusReporter, +} + +impl Controller { + pub async fn new( + listen_addr: &str, + partition_finder: Arc + Sync + Send + 'static>, + status_reporter: StatusReporter, + ) -> Self { + let Ok(listener) = TcpListener::bind(listen_addr).await else { + panic!("Failed to bind {}", listen_addr); + }; + + Self { + listener, + partition_finder, + proxy_id: ProxyIdentifier::new(), + scale_up_attempt_count: 0, + restart_count: 0, + scale_up_config: DEFAULT_SCALE_UP_CONFIG, + status_reporter, + } + } + + pub async fn run(mut self, token: CancellationToken) -> Option { + let mut ready_connections = None; + loop { + info!("Starting new incarnation of proxy"); + let nfs_client = match self.listener.accept().await { + Ok((client, socket_addr)) => { + self.proxy_id.increment(); + info!( + "Accepted new connection {:?}, {:?} ", + socket_addr, self.proxy_id + ); + configure_stream(client) + } + Err(e) => { + error!("Failed to establish connection to NFS client. {e}"); + continue; + } + }; + + let peek_result = nfs_client.peek(&mut [0; 1]).await; + if let Ok(0) = peek_result { + // efs-utils performs a test in which it checks if a connection to the proxy port + // can be established. This connection is never used and is immediately closed. + // When this behavior is detected, this loops should be restarted so that another + // connection to the port can be established + debug!("Connection to nfs client was closed before any data was sent to the proxy. This is expected. Restarting controller"); + continue; + } else if let Err(e) = peek_result { + error!("Failed to check if data was sent by the NFS client. {}", e); + return Some(ShutdownReason::UnexpectedError); + } + + let (events_tx, mut events_rx) = mpsc::channel(1024); + let (shutdown, mut waiter) = ShutdownHandle::new(token.child_token()); + + let (partition_id, partition_servers, scale_up_config) = match ready_connections { + Some(connections) => { + ready_connections = None; + connections + } + None => { + match self + .partition_finder + .establish_connection(self.proxy_id) + .await + { + Ok((s, partition_id, scale_up_config)) => { + (partition_id, vec![s], scale_up_config) + } + Err(e) => { + warn!("Failed to establish an initial connection to EFS. Error: {e}",); + continue; + } + } + } + }; + + match partition_id { + Some(id) => debug!("Established initial connection with PartitionId: {id:?}"), + None => debug!("Established initial connection without a PartitionId"), + } + + self.scale_up_config = scale_up_config.unwrap_or(self.scale_up_config); + debug!("ScaleUpConfig: {:#?}", self.scale_up_config); + + let mut state = IncarnationState::new( + self.proxy_id, + partition_id, + events_tx.clone(), + partition_servers.len() as u16, + ); + + let mut proxy = Proxy::new(nfs_client, partition_servers, events_tx, shutdown.clone()); + + loop { + let mut err = Ok(()); + tokio::select! { + _ = self.status_reporter.await_report_request() => { + let report = status_reporter::Report { + proxy_id: state.proxy_id, + partition_id: state.partition_id, + connection_state: state.connection_state.clone(), + num_connections: state.num_connections as usize, + last_proxy_update: state.last_proxy_update, + scale_up_attempt_count: self.scale_up_attempt_count, + restart_count: self.restart_count + }; + self.status_reporter.publish_status(report).await; + } + event = events_rx.recv() => { + if let Some(next_event) = event { + match self.handle_event(next_event, &mut proxy, &mut state, shutdown.clone()).await { + Ok(EventResult::Restart(connections)) => { + debug!("Restarting proxy to use multiple connections"); + ready_connections = Some(connections); + shutdown.exit(Some(ShutdownReason::NeedsRestart)).await; + break; + }, + Ok(EventResult::Ok) => continue, + Err(e) => err = Err(e), + }; + + } else { + err = Err("All senders have closed"); + } + } + _ = shutdown.cancellation_token.cancelled() => { + debug!("Controller exiting due to child exit"); + break; + } + _ = self.listener.accept() => { + warn!("Unexpected connection, ignoring") + } + } + if err.is_err() { + info!("Starting proxy restart due to {}", err.unwrap_err()); + break; + } + } + + if let Some(count) = self.restart_count.checked_add(1) { + self.restart_count = count; + } + + // Ensure that connection(s) to EFS is closed. If we can't successfully stop the proxy, + // then exit from this process and allow the watchdog to restart the efs-proxy program. + // + if let Err(e) = proxy.shutdown().await { + error!("Proxy shutdown failed. {}", e); + return Some(ShutdownReason::UnexpectedError); + }; + + let shutdown_reason = waiter.recv().await; + match shutdown_reason { + Some(ShutdownReason::NeedsRestart) => { + debug!("Proxy restarting with ShutdownReason::NeedsRestart") + } + Some(ShutdownReason::Unmount) => { + debug!("Proxy restarting with ShutdownReason::Unmount") + } + reason => return reason, + } + } + } + + fn should_scale_up(&self, state: &mut IncarnationState, stats: PerformanceStats) -> bool { + if let ConnectionSearchState::Stop(last_failure) = state.connection_state { + if Instant::now().duration_since(last_failure) > DEFAULT_SCALE_UP_BACKOFF { + state.connection_state = ConnectionSearchState::Idle; + } + } + + state.num_connections == 1 + && state.connection_state == ConnectionSearchState::Idle + && stats.get_total_throughput() + >= self.scale_up_config.scale_up_bytes_per_sec_threshold as u64 + } + + async fn handle_event( + &mut self, + event: Event, + proxy: &mut Proxy, + state: &mut IncarnationState, + shutdown_handle: ShutdownHandle, + ) -> Result, &str> { + match event { + Event::ProxyUpdate(stats) => { + info!("Proxy performance: {:?}", stats); + + if self.should_scale_up(state, stats) { + info!("Searching for a new connection"); + if let Some(count) = self.scale_up_attempt_count.checked_add(1) { + self.scale_up_attempt_count = count; + } + + state.connection_state = + ConnectionSearchState::SearchingAdditional(state.partition_id); + self.partition_finder + .scale_up_connection( + state.proxy_id, + state.partition_id, + state.events_tx.clone(), + shutdown_handle, + ) + .await; + } + } + Event::ConnectionSuccess(id, streams, scale_up_config) => { + info!("Established new TCP connection to {:?}", id); + if state.partition_id == id { + assert_eq!( + (self.scale_up_config.max_multiplexed_connections - 1) as usize, + streams.len() + ); + for stream in streams { + proxy.add_connection(stream).await; + } + } else { + assert_eq!( + self.scale_up_config.max_multiplexed_connections as usize, + streams.len() + ); + assert!(id.is_some()); + assert_ne!(state.partition_id, id); + + return Ok(EventResult::Restart((id, streams, Some(scale_up_config)))); + } + state.num_connections = self.scale_up_config.max_multiplexed_connections as u16; + state.connection_state = ConnectionSearchState::Idle; + self.scale_up_config = scale_up_config; + } + Event::ConnectionFail(scale_up_config) => { + state.connection_state = ConnectionSearchState::Stop(Instant::now()); + self.scale_up_config = scale_up_config.unwrap_or(self.scale_up_config); + info!("Connection failed"); + } + } + debug!("ScaleUpConfig: {:#?}", self.scale_up_config); + Ok(EventResult::Ok) + } +} + +#[cfg(test)] +pub mod tests { + use crate::config_parser::tests::get_test_config; + use crate::connections::PlainTextPartitionFinder; + use crate::connections::ProxyStream; + use crate::connections::MULTIPLEX_CONNECTION_TIMEOUT_SEC; + use crate::controller::ConnectionSearchState; + use crate::controller::DEFAULT_SCALE_UP_BACKOFF; + use crate::efs_prot; + use crate::efs_prot::BindResponse; + use crate::efs_prot::ScaleUpConfig; + use crate::efs_rpc; + use crate::efs_rpc::PartitionId; + use crate::proxy; + use crate::proxy_identifier::ProxyIdentifier; + use crate::proxy_identifier::INITIAL_INCARNATION; + use crate::rpc; + use crate::rpc::RPC_HEADER_SIZE; + use crate::shutdown::ShutdownReason; + use crate::status_reporter; + use crate::status_reporter::Report; + use crate::status_reporter::StatusRequester; + use crate::tls::tests::get_server_config; + use crate::tls::TlsConfig; + use crate::{connections::TlsPartitionFinder, controller::Controller}; + + use bytes::BytesMut; + use log::debug; + use onc_rpc::RpcMessage; + use rand::Rng; + use std::collections::HashMap; + use std::collections::HashSet; + use std::io::ErrorKind; + use std::sync::atomic::AtomicU32; + use std::time::Duration; + use std::{self, io::Error, sync::Arc}; + use test_case::test_case; + use tokio::time::error::Elapsed; + use tokio::time::timeout; + use tokio::{ + io::AsyncWriteExt, + net::{TcpListener, TcpStream}, + sync::oneshot, + sync::Mutex, + task::JoinHandle, + }; + use tokio_util::sync::CancellationToken; + + use super::DEFAULT_SCALE_UP_CONFIG; + + #[derive(Copy, Clone, Debug, PartialEq)] + pub enum ServiceAction { + // Server will reject the next incoming TCP connection. Further attempts will succeed. + // + RejectNextNewConnectionRequest, + + // The server will close the next connection that receives a request from the proxy. + // + CloseOnNextRequest, + + // The server will close a random connection without waiting for any incoming request. + // + CloseRandomConnection, + + // This service will restart accepting connections to the given PartitionId + // + _RestartPartitionAcceptor(PartitionId), + + // This service will not accept connections to the given PartitionId + // + StopPartitionAcceptor(PartitionId), + + // This service will close the connection if a bind_client_to_partition request is received + // + CloseOnNextBindClientToPartitionRequest, + + // The service will send BindResponse::RETRY_LATER on subsequent bind_client_to_partition requests + // + DisableScaleUp, + + // The service will allow re-enabling scale up after the DisableScaleUp action is posted. + // + EnableScaleUp, + + // The service will respond with BindResponse::RETRY on the next n bind_client_to_partition requests + SendRetries(u32), + } + + const PARTITION_COUNT: usize = 3; + + pub struct TestService { + pub listen_port: u16, + posted_action: Arc>>, + shutdown_tx: oneshot::Sender<()>, + join_handle: JoinHandle<()>, + pub partition_ids: Vec, + pub stopped_partitions: Arc>>, + pub request_counter: Arc>>>>, + } + + impl TestService { + const ALWAYS_SCALE_UP_THRESHOLD_BYTES_PER_SEC: i32 = 0; + const NEVER_SCALE_UP_THRESHOLD_BYTES_PER_SEC: i32 = i32::MAX; + + pub async fn new(tls: bool) -> Self { + TestService::new_with_partition_count(PARTITION_COUNT, tls).await + } + + pub async fn new_with_partition_count(count: usize, tls: bool) -> Self { + TestService::new_with_partition_count_and_scale_up_config( + count, + super::DEFAULT_SCALE_UP_CONFIG, + tls, + ) + .await + } + + pub async fn new_with_throughput_scale_up_threshold(threshold: i32, tls: bool) -> Self { + let mut config = super::DEFAULT_SCALE_UP_CONFIG.clone(); + config.scale_up_bytes_per_sec_threshold = threshold; + TestService::new_with_partition_count_and_scale_up_config(PARTITION_COUNT, config, tls) + .await + } + + pub async fn new_with_partition_count_and_scale_up_threshold( + count: usize, + threshold: i32, + tls: bool, + ) -> Self { + let mut config = super::DEFAULT_SCALE_UP_CONFIG.clone(); + config.scale_up_bytes_per_sec_threshold = threshold; + TestService::new_with_partition_count_and_scale_up_config(count, config, tls).await + } + + pub async fn new_with_partition_count_and_scale_up_config( + count: usize, + scale_up_config: ScaleUpConfig, + tls: bool, + ) -> Self { + let (tcp_listener, listen_port) = find_available_port().await; + + let partition_ids = (0..count) + .map(|_| PartitionId { + id: efs_rpc::tests::generate_partition_id().0, + }) + .collect::>(); + + let stopped_partitions = Arc::new(Mutex::new(HashSet::new())); + + let mut counter = HashMap::new(); + for id in partition_ids.iter() { + counter.insert(id.clone(), Vec::new()); + } + let request_counter = Arc::new(Mutex::new(counter)); + + let posted_action = Arc::new(Mutex::new(Option::None)); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let service_handle = TestService::run( + tcp_listener, + scale_up_config, + partition_ids.clone(), + stopped_partitions.clone(), + request_counter.clone(), + posted_action.clone(), + tls, + shutdown_rx, + ); + + TestService { + listen_port, + posted_action, + shutdown_tx, + join_handle: service_handle, + partition_ids, + stopped_partitions, + request_counter, + } + } + + pub async fn post_action(&self, new_action: ServiceAction) { + match new_action { + ServiceAction::_RestartPartitionAcceptor(id) => { + let mut stopped = self.stopped_partitions.lock().await; + assert!(stopped.remove(&id), "Partition is not stopped"); + return; + } + ServiceAction::StopPartitionAcceptor(id) => { + let mut stopped = self.stopped_partitions.lock().await; + stopped.insert(id); + return; + } + ServiceAction::EnableScaleUp => { + TestService::check_and_consume_action( + &self.posted_action, + ServiceAction::DisableScaleUp, + ) + .await; + return; + } + _ => (), + }; + + let mut consumable_action = self.posted_action.lock().await; + if consumable_action.is_some() { + panic!("Previous action was not consumed"); + } + *consumable_action = Some(new_action); + } + + fn run( + listener: TcpListener, + scale_up_config: ScaleUpConfig, + partition_ids: Vec, + stopped_partitions: Arc>>, + request_counter: Arc>>>>, + posted_action: Arc>>, + tls: bool, + mut shutdown_rx: oneshot::Receiver<()>, + ) -> JoinHandle<()> { + tokio::spawn(async move { + let mut partition_idx = 0; + loop { + tokio::select! { + socket = listener.accept() => { + let Ok((tcp_stream, _socket_addr)) = socket else { + panic!("Failed to establish connection to client"); + }; + + if tls { + let tls_acceptor = s2n_tls_tokio::TlsAcceptor::new(get_server_config().await.expect("Could not get config")); + let tls_stream = match tls_acceptor.accept(tcp_stream).await { + Ok(conn) => conn, + Err(e) => { + panic!("Failed to establish TLS connection: {}", e); + } + }; + Self::inner_run(tls_stream, scale_up_config, &mut partition_idx, &partition_ids, stopped_partitions.clone(), request_counter.clone(), posted_action.clone()).await; + } else { + Self::inner_run(tcp_stream, scale_up_config, &mut partition_idx, &partition_ids, stopped_partitions.clone(), request_counter.clone(), posted_action.clone()).await; + } + }, + _ = &mut shutdown_rx => { + break; + } + }; + } + }) + } + + async fn inner_run( + stream: S, + scale_up_config: ScaleUpConfig, + partition_idx: &mut usize, + partition_ids: &Vec, + stopped_partitions: Arc>>, + request_counter: Arc>>>>, + posted_action: Arc>>, + ) { + if TestService::check_and_consume_action( + &posted_action, + ServiceAction::RejectNextNewConnectionRequest, + ) + .await + || TestService::check_and_consume_action( + &posted_action, + ServiceAction::CloseRandomConnection, + ) + .await + { + debug!("RejectNextNewConnectionRequest processed"); + drop(stream); + } else { + let stopped = stopped_partitions.lock().await; + let mut next_id = None; + for i in 0..partition_ids.len() { + *partition_idx = (*partition_idx + i + 1) % partition_ids.len(); + if !stopped.contains(&partition_ids[*partition_idx]) { + next_id = Some(partition_ids[*partition_idx].clone()); + break; + } + } + let Some(id) = next_id else { + panic!("No available PartitionIds") + }; + + let request_count = Arc::new(AtomicU32::new(0)); + request_counter + .lock() + .await + .get_mut(&id) + .expect("Counter for partition not found") + .push(request_count.clone()); + + tokio::spawn(TestService::new_connection( + stream, + scale_up_config, + posted_action.clone(), + id, + request_count.clone(), + )); + } + } + + async fn check_and_consume_action( + posted_action: &Arc>>, + to_check: ServiceAction, + ) -> bool { + let mut action = posted_action.lock().await; + if *action == Some(to_check) { + *action = Option::None; + true + } else { + false + } + } + + async fn check_action( + posted_action: &Arc>>, + to_check: ServiceAction, + ) -> bool { + let action = posted_action.lock().await; + *action == Some(to_check) + } + + async fn new_connection( + mut stream: S, + scale_up_config: ScaleUpConfig, + posted_action: Arc>>, + partition_id: PartitionId, + request_count: Arc, + ) { + loop { + let Ok(message) = rpc::read_rpc_bytes(&mut stream).await else { + break; + }; + + request_count.fetch_add(1, std::sync::atomic::Ordering::AcqRel); + + if TestService::check_and_consume_action( + &posted_action, + ServiceAction::CloseOnNextRequest, + ) + .await + { + debug!("CloseOnNextRequest processed"); + break; + } + + let response = match TestService::parse_bind_client_to_partition_request(&message) { + Ok(rpc_message) => { + if TestService::check_and_consume_action( + &posted_action, + ServiceAction::CloseOnNextBindClientToPartitionRequest, + ) + .await + { + debug!("CloseOnNextBindClientToPartitionRequest processed"); + break; + } + + let mut bind_response = + BindResponse::READY(efs_prot::PartitionId(partition_id.id)); + + if TestService::check_action(&posted_action, ServiceAction::DisableScaleUp) + .await + { + bind_response = BindResponse::RETRY_LATER( + "Returning BindResponse::RETRY_LATER".into(), + ); + } + + let mut action = posted_action.lock().await; + if let Some(ServiceAction::SendRetries(count)) = *action { + bind_response = + BindResponse::RETRY("Returning BindResponse::RETRY".into()); + if count > 1 { + *action = Some(ServiceAction::SendRetries(count - 1)); + } else { + *action = None; + } + } + + efs_rpc::tests::create_bind_client_to_partition_response( + rpc_message.xid(), + bind_response, + scale_up_config, + ) + .expect("Could not create response") + } + Err(_) => { + // If the test server doesn't parse a `bind_client_to_partition` request, + // then echo request back to the client + message + } + }; + + stream + .write_all(&response) + .await + .expect("Could not write to stream"); + } + } + + fn parse_bind_client_to_partition_request( + request: &Vec, + ) -> Result, Box> { + let rpc_message = onc_rpc::RpcMessage::try_from(request.as_slice())?; + efs_rpc::tests::parse_bind_client_to_partition_request(&rpc_message)?; + Ok(rpc_message) + } + + pub async fn shutdown(self) { + drop(self.shutdown_tx); + self.join_handle.await.unwrap(); + } + } + + struct TestClient { + stream: TcpStream, + next_xid: u32, + } + + impl TestClient { + async fn new(proxy_port: u16) -> Self { + let stream = TcpStream::connect(("127.0.0.1", proxy_port)).await.unwrap(); + Self { + stream, + next_xid: 0, + } + } + + async fn send_message_with_size(&mut self, size: usize) -> Result<(), Error> { + self.next_xid += 1; + let (request, expected_data) = rpc::test::generate_msg_fragments(size, 1); + self.stream.write_all(&request).await?; + + let response = rpc::read_rpc_bytes(&mut self.stream).await?; + + let payload_result = + rpc::RpcBatch::parse_batch(&mut BytesMut::from(response.as_slice())) + .expect("No message found") + .expect("failed to parse"); + + let rpc = payload_result.rpcs.get(0).expect("No RPCs found"); + assert_eq!(expected_data, rpc.to_vec()[RPC_HEADER_SIZE..]); + Ok(()) + } + + async fn send_partial_message_with_size(&mut self, size: usize) -> Result<(), Error> { + self.next_xid += 1; + let (_, m1) = rpc::test::generate_msg_fragments(size, 1); + let mut rng = rand::thread_rng(); + self.stream + .write_all(&m1[0..rng.gen_range(1..size - 1)]) + .await?; + Ok(()) + } + } + + pub struct ProxyUnderTest { + listen_port: u16, + handle: JoinHandle>, + status_requester: StatusRequester, + scale_up_config: ScaleUpConfig, + } + + impl ProxyUnderTest { + pub async fn new(tls: bool, server_port: u16) -> Self { + let scale_up_config = DEFAULT_SCALE_UP_CONFIG; + let (tcp_listener, listen_port) = find_available_port().await; + + let (status_requester, status_reporter) = status_reporter::create_status_channel(); + + let handle = if tls { + let mut tls_config = TlsConfig::new_from_config(&get_test_config()) + .await + .expect("Failed to acquire TlsConfig."); + tls_config.remote_addr = format!("127.0.0.1:{}", server_port); + + let partition_finder = + Arc::new(TlsPartitionFinder::new(Arc::new(Mutex::new(tls_config)))); + + let controller = Controller { + listener: tcp_listener, + partition_finder, + proxy_id: ProxyIdentifier::new(), + scale_up_attempt_count: 0, + restart_count: 0, + scale_up_config: scale_up_config, + status_reporter, + }; + + let token = CancellationToken::new(); + tokio::spawn(controller.run(token)) + } else { + let partition_finder = Arc::new(PlainTextPartitionFinder { + mount_target_addr: format!("127.0.0.1:{}", server_port), + }); + + let controller = Controller { + listener: tcp_listener, + partition_finder, + proxy_id: ProxyIdentifier::new(), + scale_up_attempt_count: 0, + restart_count: 0, + scale_up_config: scale_up_config, + status_reporter, + }; + + let token = CancellationToken::new(); + tokio::spawn(controller.run(token)) + }; + + Self { + listen_port, + handle, + status_requester, + scale_up_config, + } + } + + pub async fn poll_scale_up(&mut self) -> Result<(), Elapsed> { + timeout(Duration::from_secs(5), async { + loop { + let num_connections = self.get_num_connections().await; + if num_connections == self.scale_up_config.max_multiplexed_connections as usize + { + break; + } else { + tokio::time::sleep(Duration::from_millis(500)).await; + } + } + }) + .await + } + + pub async fn get_report(&mut self) -> Report { + self.status_requester + ._request_status() + .await + .expect("Could not get report") + } + + pub async fn get_proxy_id(&mut self) -> ProxyIdentifier { + let report = self.get_report().await; + report.proxy_id + } + + async fn get_num_connections(&mut self) -> usize { + let report = self.get_report().await; + report.num_connections + } + } + + pub async fn find_available_port() -> (TcpListener, u16) { + for port in 10000..15000 { + match TcpListener::bind(("127.0.0.1", port)).await { + Ok(v) => { + return (v, port); + } + Err(_) => continue, + } + } + panic!("Failed to find port"); + } + + #[test_case(true; "tls enabled")] + #[test_case(false; "tls disabled")] + #[tokio::test] + async fn test_basic(tls_enabled: bool) { + let service = TestService::new(tls_enabled).await; + let mut proxy = ProxyUnderTest::new(tls_enabled, service.listen_port).await; + let mut client = TestClient::new(proxy.listen_port).await; + client.send_message_with_size(10).await.unwrap(); + client.send_message_with_size(1024).await.unwrap(); + + let report = proxy.get_report().await; + assert!(report.partition_id.is_some()); + + service.shutdown().await; + } + + #[test_case(true; "tls enabled")] + #[test_case(false; "tls disabled")] + #[tokio::test] + async fn test_success_after_connection_closed_on_bind_client_to_partition_request( + tls_enabled: bool, + ) { + let service = TestService::new(tls_enabled).await; + let mut proxy = ProxyUnderTest::new(tls_enabled, service.listen_port).await; + let mut client = TestClient::new(proxy.listen_port).await; + + service + .post_action(ServiceAction::CloseOnNextBindClientToPartitionRequest) + .await; + + client.send_message_with_size(10).await.unwrap(); + client.send_message_with_size(1024).await.unwrap(); + + let report = proxy.get_report().await; + assert!(report.partition_id.is_none()); + + service.shutdown().await; + } + + #[test_case(true; "tls enabled")] + #[test_case(false; "tls disabled")] + #[tokio::test] + async fn test_success_after_bind_client_to_partition_stop_response_on_initial_connection( + tls_enabled: bool, + ) { + let service = TestService::new(tls_enabled).await; + let mut proxy = ProxyUnderTest::new(tls_enabled, service.listen_port).await; + let mut client = TestClient::new(proxy.listen_port).await; + + service.post_action(ServiceAction::DisableScaleUp).await; + + client.send_message_with_size(10).await.unwrap(); + client.send_message_with_size(1024).await.unwrap(); + + let report = proxy.get_report().await; + assert!(report.partition_id.is_none()); + + service.shutdown().await; + } + + #[test_case(true; "tls enabled")] + #[test_case(false; "tls disabled")] + #[tokio::test] + async fn test_closed_connection(tls_enabled: bool) { + let service = TestService::new(tls_enabled).await; + let proxy = ProxyUnderTest::new(tls_enabled, service.listen_port).await; + let mut client = TestClient::new(proxy.listen_port).await; + client.send_message_with_size(10).await.unwrap(); + service.post_action(ServiceAction::CloseOnNextRequest).await; + let result = client.send_message_with_size(10).await; + assert!(result.is_err()); + } + + #[test_case(true; "tls enabled")] + #[test_case(false; "tls disabled")] + #[tokio::test] + async fn test_closed_connection_after_scale_up(tls_enabled: bool) { + // Use a single partition so that the same PartitionId is return on each + // bind_client_to_partition request. This prevents a controller "reset", which simplifies + // testing that the proxy will retry scale up after the backoff time as elapsed. + // + let scale_up_threshold = 10; + let service = TestService::new_with_partition_count_and_scale_up_threshold( + 1, + scale_up_threshold, + tls_enabled, + ) + .await; + + let mut proxy = ProxyUnderTest::new(tls_enabled, service.listen_port).await; + + let mut client = TestClient::new(proxy.listen_port).await; + client.send_message_with_size(100).await.unwrap(); + + // Expect that scale up does not occur + proxy.poll_scale_up().await.expect("Scale up did not occur"); + + // Close one proxy connection. The subsequent requests should fail. + service.post_action(ServiceAction::CloseOnNextRequest).await; + client.send_message_with_size(100).await.unwrap_err(); + + // Wait some time for proxy to reset + tokio::time::sleep(Duration::from_secs(5)).await; + + for _ in 0..5 { + client.send_message_with_size(100).await.unwrap_err(); + } + + // Reconnecting with the client should result in successful requests + let mut new_client = TestClient::new(proxy.listen_port).await; + new_client.send_message_with_size(5).await.unwrap(); + + let num_connections = proxy.get_report().await.num_connections; + assert_eq!(1, num_connections); + + service.shutdown().await; + } + + #[test_case(true; "tls enabled")] + #[test_case(false; "tls disabled")] + #[tokio::test] + async fn test_closed_connection_when_big_frame_sent(tls_enabled: bool) { + let service = TestService::new(tls_enabled).await; + let proxy = ProxyUnderTest::new(tls_enabled, service.listen_port).await; + let mut client = TestClient::new(proxy.listen_port).await; + let result = client.send_message_with_size(22222220).await; + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!( + error.kind() == ErrorKind::BrokenPipe || error.kind() == ErrorKind::ConnectionReset + ); + let reason_opt = proxy.handle.await.unwrap(); + assert_eq!(reason_opt, Some(ShutdownReason::FrameSizeExceeded)); + } + + #[test_case(true; "tls enabled")] + #[test_case(false; "tls disabled")] + #[tokio::test] + async fn test_message_too_small(tls_enabled: bool) { + let service = TestService::new(tls_enabled).await; + let proxy = ProxyUnderTest::new(tls_enabled, service.listen_port).await; + let mut client = TestClient::new(proxy.listen_port).await; + let _ = client.send_message_with_size(1).await; + let reason_opt = proxy.handle.await.unwrap(); + assert_eq!(reason_opt, Some(ShutdownReason::FrameSizeTooSmall)); + } + + #[test_case(true; "tls enabled")] + #[test_case(false; "tls disabled")] + #[tokio::test] + async fn test_client_disconnects(tls_enabled: bool) { + let service = TestService::new(tls_enabled).await; + let proxy = ProxyUnderTest::new(tls_enabled, service.listen_port).await; + let mut initial_client = TestClient::new(proxy.listen_port).await; + let _ = initial_client.send_partial_message_with_size(1000).await; + // Drop has been implemented to simulate client disconnection + drop(initial_client); + + // After initial_client is disconnects, the proxy should still accept new connection + let mut client = TestClient::new(proxy.listen_port).await; + assert!(matches!( + client.send_partial_message_with_size(1000).await, + Ok(()) + )); + } + + #[test_case(true; "tls enabled")] + #[test_case(false; "tls disabled")] + #[tokio::test] + async fn test_client_disconnects_without_send(tls_enabled: bool) { + let service = TestService::new(tls_enabled).await; + let proxy = ProxyUnderTest::new(tls_enabled, service.listen_port).await; + + // Drop this client to simulate a connection to the proxy port that immediately closes + let disconnecting_client = TestClient::new(proxy.listen_port).await; + drop(disconnecting_client); + + // After the connection to the disconnecting_client is dropped, the proxy should still accept new connection + let mut client = TestClient::new(proxy.listen_port).await; + assert!(matches!( + client.send_partial_message_with_size(1000).await, + Ok(()) + )); + } + + #[test_case(true; "tls enabled")] + #[test_case(false; "tls disabled")] + #[tokio::test] + async fn test_handle_server_disconnect(tls_enabled: bool) { + let service = TestService::new(tls_enabled).await; + let mut proxy = ProxyUnderTest::new(tls_enabled, service.listen_port).await; + + let mut client = TestClient::new(proxy.listen_port).await; + assert!(client.send_message_with_size(10).await.is_ok()); + + // Incarnation is incremented when connection with NFS client is established + assert_eq!( + INITIAL_INCARNATION + 1, + proxy.get_proxy_id().await.incarnation + ); + + service.post_action(ServiceAction::CloseOnNextRequest).await; + + assert!(client.send_message_with_size(10).await.is_err()); + + // Reconnect + client = TestClient::new(proxy.listen_port).await; + assert!(client.send_message_with_size(10).await.is_ok()); + + // Incarnation is incremented when connection with NFS client is reestablished + assert_eq!( + INITIAL_INCARNATION + 2, + proxy.get_proxy_id().await.incarnation + ); + + proxy.handle.abort(); + service.shutdown().await; + } + + #[test_case(true; "tls enabled")] + #[test_case(false; "tls disabled")] + #[tokio::test] + async fn test_scale_up_same_partition(tls_enabled: bool) { + let service = TestService::new_with_partition_count_and_scale_up_threshold( + 1, + TestService::ALWAYS_SCALE_UP_THRESHOLD_BYTES_PER_SEC, + tls_enabled, + ) + .await; + let mut proxy = ProxyUnderTest::new(tls_enabled, service.listen_port).await; + + // A request from the client will cause the proxy to establish an addition connection to the NFS server + let mut client = TestClient::new(proxy.listen_port).await; + client.send_message_with_size(10).await.unwrap(); + + proxy + .poll_scale_up() + .await + .expect("Timeout exceeded while awaiting scale up"); + + service.shutdown().await; + } + + #[test_case(true; "tls enabled")] + #[test_case(false; "tls disabled")] + #[tokio::test] + async fn test_scale_up_periodic_workload(tls_enabled: bool) { + // Requests of 15 bytes every 100 milliseconds should result in 300 bytes of traffic (150 + // bytes sent, 150 bytes received) every second. This exceeds the scale_up_threshold of 299 + // bytes/s. + let scale_up_threshold = 299; + let num_requests = 60; + let request_size = 30; + let request_interval_millis = 100; + + let service = TestService::new_with_partition_count_and_scale_up_threshold( + 1, + scale_up_threshold, + tls_enabled, + ) + .await; + + let mut proxy = ProxyUnderTest::new(tls_enabled, service.listen_port).await; + + let mut client = TestClient::new(proxy.listen_port).await; + for _ in 0..num_requests { + client.send_message_with_size(request_size).await.unwrap(); + tokio::time::sleep(Duration::from_millis(request_interval_millis)).await; + } + + proxy + .poll_scale_up() + .await + .expect("Timeout exceeded while awaiting scale up"); + + service.shutdown().await; + } + + #[test_case(true; "tls enabled")] + #[test_case(false; "tls disabled")] + #[tokio::test] + async fn test_no_scale_up_periodic_workload(tls_enabled: bool) { + // Requests of 10 bytes every 100 milliseconds should result in 200 bytes of traffic (100 + // bytes sent, 100 bytes received) every seconds. This does not exceeds the + // scale_up_threshold of 300 bytes/s. + // + let scale_up_threshold = 300; + let num_requests = 60; + let request_size = 10; + let request_interval_millis = 100; + + let service = TestService::new_with_partition_count_and_scale_up_threshold( + 1, + scale_up_threshold, + tls_enabled, + ) + .await; + let mut proxy = ProxyUnderTest::new(tls_enabled, service.listen_port).await; + + // Only requests proxied within the monitoring window will be considered when determining + // when to scale up. The following requests should not result in a scale up attempt. + // + let mut client = TestClient::new(proxy.listen_port).await; + for _ in 0..num_requests { + client.send_message_with_size(request_size).await.unwrap(); + tokio::time::sleep(Duration::from_millis(request_interval_millis)).await; + } + + proxy + .poll_scale_up() + .await + .expect_err("Unexpected Scale Up"); + + service.shutdown().await; + } + + #[test_case(true; "tls enabled")] + #[test_case(false; "tls disabled")] + #[tokio::test] + async fn test_scale_up_new_partition(tls_enabled: bool) { + let service = TestService::new_with_throughput_scale_up_threshold( + TestService::ALWAYS_SCALE_UP_THRESHOLD_BYTES_PER_SEC, + tls_enabled, + ) + .await; + let mut proxy = ProxyUnderTest::new(tls_enabled, service.listen_port).await; + + // A request from the client will cause the proxy to establish an addition connection to + // the NFS server + // + let mut client = TestClient::new(proxy.listen_port).await; + client.send_message_with_size(10).await.unwrap(); + + let report = proxy.get_report().await; + let initial_partition_id = report.partition_id.expect("No PartitionId"); + + service + .post_action(ServiceAction::StopPartitionAcceptor(initial_partition_id)) + .await; + + // After scale up, we need to wait for the controller to reset and to listen to a new + // connection from the client + // + tokio::time::sleep(Duration::from_secs(5)).await; + + let mut new_client = TestClient::new(proxy.listen_port).await; + new_client.send_message_with_size(10).await.unwrap(); + + proxy + .poll_scale_up() + .await + .expect("Timeout exceeded while awaiting scale up"); + + let connection_state = proxy.get_report().await.connection_state; + assert_eq!(ConnectionSearchState::Idle, connection_state); + + service.shutdown().await; + } + + #[test_case(true; "tls enabled")] + #[test_case(false; "tls disabled")] + #[tokio::test] + async fn test_successful_scale_up_with_retries(tls_enabled: bool) { + let scale_up_threshold = 10; + let service = TestService::new_with_partition_count_and_scale_up_threshold( + 1, + scale_up_threshold, + tls_enabled, + ) + .await; + let mut proxy = ProxyUnderTest::new(tls_enabled, service.listen_port).await; + + // A request from the client will cause the proxy to establish an addition connection to the NFS server + let mut client = TestClient::new(proxy.listen_port).await; + client.send_message_with_size(5).await.unwrap(); + + service + .post_action(ServiceAction::SendRetries(std::cmp::min( + 5, + crate::connections::MAX_ATTEMPT_COUNT - 5, + ))) + .await; + + client.send_message_with_size(100).await.unwrap(); + + proxy + .poll_scale_up() + .await + .expect("Timeout exceeded while awaiting scale up"); + + service.shutdown().await; + } + + #[test_case(true; "tls enabled")] + #[test_case(false; "tls disabled")] + #[tokio::test] + async fn test_no_scale_up_threshold_not_exceed(tls_enabled: bool) { + let service = TestService::new_with_throughput_scale_up_threshold( + TestService::NEVER_SCALE_UP_THRESHOLD_BYTES_PER_SEC, + tls_enabled, + ) + .await; + let mut proxy = ProxyUnderTest::new(tls_enabled, service.listen_port).await; + + // Requests from the client below the throughput threshold should not cause new connections + // to the NFS server to be established + let mut client = TestClient::new(proxy.listen_port).await; + client.send_message_with_size(10).await.unwrap(); + + proxy + .poll_scale_up() + .await + .expect_err("Unexpected scale up occured"); + + let connection_state = proxy.get_report().await.connection_state; + assert_eq!(ConnectionSearchState::Idle, connection_state); + + service.shutdown().await; + } + + #[test_case(true; "tls enabled")] + #[test_case(false; "tls disabled")] + #[tokio::test] + async fn test_no_scale_up_if_already_scaled_up(tls_enabled: bool) { + let scale_up_threshold = 10; + let service = TestService::new_with_partition_count_and_scale_up_threshold( + 5, + scale_up_threshold, + tls_enabled, + ) + .await; + + let mut proxy = ProxyUnderTest::new(tls_enabled, service.listen_port).await; + + // Requests from the client below the throughput threshold should not cause scale up + let mut client = TestClient::new(proxy.listen_port).await; + client + .send_message_with_size((scale_up_threshold - 1) as usize) + .await + .unwrap(); + + // Stop initial partition so that the proxy resets after scale up + let initial_report = proxy.get_report().await; + let initial_partition_id = initial_report.partition_id.expect("No PartitionId"); + assert_eq!(0, initial_report.scale_up_attempt_count); + assert_eq!(0, initial_report.restart_count); + + service + .post_action(ServiceAction::StopPartitionAcceptor(initial_partition_id)) + .await; + + // This requests should cause scale up to be attempted + client + .send_message_with_size((scale_up_threshold + 10) as usize) + .await + .unwrap(); + + tokio::time::sleep(Duration::from_secs(5)).await; + let mut client = TestClient::new(proxy.listen_port).await; + client + .send_message_with_size((scale_up_threshold - 1) as usize) + .await + .unwrap(); + + proxy + .poll_scale_up() + .await + .expect("Timeout exceeded while awaiting scale up"); + + let second_report = proxy.get_report().await; + assert_eq!(ConnectionSearchState::Idle, second_report.connection_state); + assert_eq!( + DEFAULT_SCALE_UP_CONFIG.max_multiplexed_connections as usize, + second_report.num_connections + ); + assert_eq!(1, second_report.scale_up_attempt_count); + assert_eq!(1, second_report.restart_count); + + // Additional requests from the client should not cause additional scale up attempts + for _ in 0..5 { + client + .send_message_with_size((scale_up_threshold + 10) as usize) + .await + .unwrap(); + tokio::time::sleep(Duration::from_secs(1)).await; + } + + let third_report = proxy.get_report().await; + assert_eq!(ConnectionSearchState::Idle, third_report.connection_state); + assert_eq!( + DEFAULT_SCALE_UP_CONFIG.max_multiplexed_connections as usize, + third_report.num_connections + ); + assert_eq!(1, third_report.scale_up_attempt_count); + assert_eq!(1, third_report.restart_count); + + service.shutdown().await; + } + + #[test_case(true; "tls enabled")] + #[test_case(false; "tls disabled")] + #[tokio::test] + async fn test_scale_up_failed_too_many_retries(tls_enabled: bool) { + // Use a single partition so that the same PartitionId is return on each + // bind_client_to_partition request. This prevents a controller "reset", which simplifies + // testing that the proxy will retry scale up after the backoff time as elapsed. + // + let scale_up_threshold = 10; + let service = TestService::new_with_partition_count_and_scale_up_threshold( + 1, + scale_up_threshold, + tls_enabled, + ) + .await; + + let mut proxy = ProxyUnderTest::new(tls_enabled, service.listen_port).await; + + let mut client = TestClient::new(proxy.listen_port).await; + + // Send an initial request in which the bind_client_to_partition request succeeds, and the + // main controller loop starts, but scale up is not requested + // + client + .send_message_with_size((scale_up_threshold - 1) as usize) + .await + .unwrap(); + + // Update the server to return BindResponse::RETRY until scale up attempt fails + service + .post_action(ServiceAction::SendRetries( + crate::connections::MAX_ATTEMPT_COUNT + 1, + )) + .await; + + // This request will cause the proxy to attempt scale up, in which bind_client_to_partition + // requests will fail + // + client.send_message_with_size(100).await.unwrap(); + + // Wait for scale up to fail + tokio::time::sleep(Duration::from_secs(5)).await; + + // Expect that scale up does not occur + proxy + .poll_scale_up() + .await + .expect_err("Unexpected scale up occured"); + + let report = proxy.get_report().await; + assert!(matches!( + report.connection_state, + ConnectionSearchState::Stop(_) + )); + + // Advance time and assert that scale up occurs after backoff duration elapsed + tokio::time::pause(); + tokio::time::advance( + DEFAULT_SCALE_UP_BACKOFF + Duration::from_secs(MULTIPLEX_CONNECTION_TIMEOUT_SEC), + ) + .await; + tokio::time::resume(); + + service.post_action(ServiceAction::EnableScaleUp).await; + client.send_message_with_size(100).await.unwrap(); + + proxy.poll_scale_up().await.expect("Scale up failed"); + + let connection_state = proxy.get_report().await.connection_state; + assert_eq!(ConnectionSearchState::Idle, connection_state); + + service.shutdown().await; + } + + #[test_case(true; "tls enabled")] + #[test_case(false; "tls disabled")] + #[tokio::test] + async fn test_scale_up_failed_retry_later(tls_enabled: bool) { + // Use a single partition so that the same PartitionId is return on each + // bind_client_to_partition request. This prevents a controller "reset", which simplifies + // testing that the proxy will retry scale up after the backoff time as elapsed. + // + let scale_up_threshold = 10; + let service = TestService::new_with_partition_count_and_scale_up_threshold( + 1, + scale_up_threshold, + tls_enabled, + ) + .await; + + let mut proxy = ProxyUnderTest::new(tls_enabled, service.listen_port).await; + + let mut client = TestClient::new(proxy.listen_port).await; + + // Send an initial request in which the bind_client_to_partition request succeeds, and the + // main controller loop starts, but scale up is not requested + // + client + .send_message_with_size((scale_up_threshold - 1) as usize) + .await + .unwrap(); + + // Update the server to return BindResponse::RETRY_LATER on the next bind_client_to_partition rpc + // request + // + service.post_action(ServiceAction::DisableScaleUp).await; + + // This request will cause the proxy to attempt scale up, in which bind_client_to_partition + // requests will fail + // + client + .send_message_with_size((scale_up_threshold) as usize) + .await + .unwrap(); + + // Expect that scale up does not occur + proxy + .poll_scale_up() + .await + .expect_err("Unexpected scale up occured"); + + let report = proxy.get_report().await; + assert!(matches!( + report.connection_state, + ConnectionSearchState::Stop(_) + )); + + // Advance time and assert that scale up occurs after backoff duration elapsed + tokio::time::pause(); + tokio::time::advance( + DEFAULT_SCALE_UP_BACKOFF + Duration::from_secs(MULTIPLEX_CONNECTION_TIMEOUT_SEC), + ) + .await; + tokio::time::resume(); + + service.post_action(ServiceAction::EnableScaleUp).await; + client + .send_message_with_size( + (scale_up_threshold * proxy::REPORT_INTERVAL_SECS as i32) as usize, + ) + .await + .unwrap(); + + proxy.poll_scale_up().await.expect("Scale up failed"); + + let connection_state = proxy.get_report().await.connection_state; + assert_eq!(ConnectionSearchState::Idle, connection_state); + + service.shutdown().await; + } + + #[test_case(true; "tls enabled")] + #[test_case(false; "tls disabled")] + #[tokio::test] + async fn test_scale_up_connection_usage(tls_enabled: bool) { + // Prevent controller reset after scale up by using existing partition + let service = TestService::new_with_partition_count_and_scale_up_threshold( + 1, + TestService::ALWAYS_SCALE_UP_THRESHOLD_BYTES_PER_SEC, + tls_enabled, + ) + .await; + + let mut proxy = ProxyUnderTest::new(tls_enabled, service.listen_port).await; + + let mut client = TestClient::new(proxy.listen_port).await; + client.send_message_with_size(10).await.unwrap(); + + proxy + .poll_scale_up() + .await + .expect("Timeout exceeded while awaiting scale up"); + + let request_to_send_per_connection = 10; + for _ in + 0..(request_to_send_per_connection * proxy.scale_up_config.max_multiplexed_connections) + { + client.send_message_with_size(10).await.unwrap(); + } + + // Check that requests are routed over multiple connections + let partition_id = proxy + .get_report() + .await + .partition_id + .expect("Missing PartitionId"); + + let request_counter = service.request_counter.lock().await; + let counts = request_counter + .get(&partition_id) + .expect("Missing request counts"); + + assert!(counts.len() >= proxy.scale_up_config.max_multiplexed_connections as usize); + for count in counts { + let operation_count = count.load(std::sync::atomic::Ordering::Acquire); + // Unused connections to a partition can be established during connection search. For + // this connections, the operation count will be 1 + // + assert!( + operation_count >= request_to_send_per_connection as u32 || operation_count == 1 + ); + } + + drop(request_counter); + service.shutdown().await; + } + + #[test_case(true; "tls enabled")] + #[tokio::test] + async fn test_efs_utils_port_test(tls_enabled: bool) { + let service = TestService::new(tls_enabled).await; + let mut proxy = ProxyUnderTest::new(tls_enabled, service.listen_port).await; + let mut port_health_check = TestClient::new(proxy.listen_port).await; + // Mimic efs-utils's port test which checks whether efs-proxy is alive. + let _ = port_health_check.stream.shutdown().await.unwrap(); + let mut client = TestClient::new(proxy.listen_port).await; + client.send_message_with_size(10).await.unwrap(); + client.send_message_with_size(1024).await.unwrap(); + + let report = proxy.get_report().await; + assert!(report.partition_id.is_some()); + + service.shutdown().await; + } +} diff --git a/src/proxy/src/efs_prot.x b/src/proxy/src/efs_prot.x new file mode 100644 index 00000000..d0faeb4f --- /dev/null +++ b/src/proxy/src/efs_prot.x @@ -0,0 +1,57 @@ +/* +* EFS program V1 +*/ + +const PROXY_ID_LENGTH = 16; +const PROXY_INCARNATION_LENGTH = 8; +const PARTITION_ID_LENGTH = 64; + +enum OperationType { + OP_BIND_CLIENT_TO_PARTITION = 1 +}; + +typedef opaque PartitionId[PARTITION_ID_LENGTH]; + +struct ProxyIdentifier { + opaque identifier; + opaque incarnation; +}; + +struct ScaleUpConfig { + int max_multiplexed_connections; + int scale_up_bytes_per_sec_threshold; + int scale_up_threshold_breached_duration_sec; +}; + +enum BindResponseType { + RETRY = 0, + RETRY_LATER = 1, + PREFERRED = 2, + READY = 3, + ERROR = 4 +}; + +union BindResponse switch (BindResponseType type) { + case PREFERRED: + case READY: + PartitionId partition_id; + case RETRY: + case RETRY_LATER: + String stop_msg; + case ERROR: + String error_msg; + default: + void; +}; + +struct BindClientResponse { + BindResponse bind_response; + ScaleUpConfig scale_up_config; +}; + +union OperationResponse switch (OperationType operation_type) { + case OP_BIND_CLIENT_TO_PARTITION: + BindClientResponse response; + default: + void; +}; diff --git a/src/proxy/src/efs_rpc.rs b/src/proxy/src/efs_rpc.rs new file mode 100644 index 00000000..5e464734 --- /dev/null +++ b/src/proxy/src/efs_rpc.rs @@ -0,0 +1,318 @@ +use std::io::Cursor; +use tokio::io::AsyncWriteExt; + +use crate::connections::ProxyStream; +use crate::efs_prot; +use crate::efs_prot::BindClientResponse; +use crate::efs_prot::OperationType; +use crate::error::RpcError; +use crate::proxy_identifier::ProxyIdentifier; +use crate::rpc; + +const PROGRAM_NUMBER: u32 = 100200; +const PROGRAM_VERSION: u32 = 1; + +#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)] +pub struct PartitionId { + pub id: [u8; 64], +} + +pub async fn bind_client_to_partition( + proxy_id: ProxyIdentifier, + stream: &mut dyn ProxyStream, +) -> Result { + let request = create_bind_client_to_partition_request(&proxy_id)?; + stream.write_all(&request).await?; + stream.flush().await?; + + let response_bytes = rpc::read_rpc_bytes(stream).await?; + let response = onc_rpc::RpcMessage::try_from(response_bytes.as_slice())?; + + parse_bind_client_to_partition_response(&response) +} + +pub fn create_bind_client_to_partition_request( + proxy_id: &ProxyIdentifier, +) -> Result, RpcError> { + let payload = efs_prot::ProxyIdentifier { + identifier: proxy_id.uuid.as_bytes().to_vec(), + incarnation: proxy_id.incarnation.to_be_bytes().to_vec(), + }; + let mut payload_buf = Vec::new(); + xdr_codec::pack(&payload, &mut payload_buf)?; + + let call_body = onc_rpc::CallBody::new( + PROGRAM_NUMBER, + PROGRAM_VERSION, + OperationType::OP_BIND_CLIENT_TO_PARTITION as u32, + onc_rpc::auth::AuthFlavor::AuthNone::>(None), + onc_rpc::auth::AuthFlavor::AuthNone::>(None), + payload_buf, + ); + + let xid = rand::random::(); + onc_rpc::RpcMessage::new(xid, onc_rpc::MessageType::Call(call_body)) + .serialise() + .map_err(|e| e.into()) +} + +pub fn parse_bind_client_to_partition_response( + response: &onc_rpc::RpcMessage<&[u8], &[u8]>, +) -> Result { + let Some(reply_body) = response.reply_body() else { + Err(RpcError::MalformedResponse)? + }; + + let accepted_status = match reply_body { + onc_rpc::ReplyBody::Accepted(reply) => reply.status(), + onc_rpc::ReplyBody::Denied(_m) => Err(RpcError::Denied)?, + }; + + let payload = match accepted_status { + onc_rpc::AcceptedStatus::Success(p) => p, + onc_rpc::AcceptedStatus::GarbageArgs => Err(RpcError::GarbageArgs)?, + onc_rpc::AcceptedStatus::ProgramUnavailable => Err(RpcError::ProgramUnavailable)?, + onc_rpc::AcceptedStatus::ProgramMismatch { low, high } => Err(RpcError::ProgramMismatch { + low: *low, + high: *high, + })?, + onc_rpc::AcceptedStatus::ProcedureUnavailable => Err(RpcError::ProcedureUnavailable)?, + onc_rpc::AcceptedStatus::SystemError => Err(RpcError::SystemError)?, + }; + + xdr_codec::unpack::<_, BindClientResponse>(&mut Cursor::new(payload)).map_err(|e| e.into()) +} + +#[cfg(test)] +pub mod tests { + use super::*; + use crate::controller::tests::TestService; + use crate::controller::DEFAULT_SCALE_UP_CONFIG; + use crate::efs_prot::BindResponse; + use crate::efs_prot::ScaleUpConfig; + use crate::tls::tests::get_client_config; + use onc_rpc::{AuthError, RejectedReply}; + use rand::RngCore; + use s2n_tls_tokio::TlsConnector; + use tokio::net::TcpStream; + + const XID: u32 = 1; + + pub fn parse_bind_client_to_partition_request( + request: &onc_rpc::RpcMessage<&[u8], &[u8]>, + ) -> Result { + let call_body = request.call_body().expect("not a call rpc"); + + if PROGRAM_NUMBER != call_body.program() || PROGRAM_VERSION != call_body.program_version() { + return Err(RpcError::GarbageArgs); + } + + let mut payload = Cursor::new(call_body.payload()); + let raw_proxy_id = xdr_codec::unpack::<_, efs_prot::ProxyIdentifier>(&mut payload)?; + + Ok(ProxyIdentifier { + uuid: uuid::Builder::from_bytes( + raw_proxy_id + .identifier + .try_into() + .expect("Failed not convert vec to sized array"), + ) + .into_uuid(), + incarnation: i64::from_be_bytes( + raw_proxy_id + .incarnation + .try_into() + .expect("Failed to convert vec to sized array"), + ), + }) + } + + pub fn create_bind_client_to_partition_response( + xid: u32, + bind_response: BindResponse, + scale_up_config: ScaleUpConfig, + ) -> Result, RpcError> { + let mut payload_buf = Vec::new(); + + let response = BindClientResponse { + bind_response: bind_response, + scale_up_config: scale_up_config, + }; + xdr_codec::pack(&response, &mut payload_buf)?; + + create_bind_client_to_partition_response_from_accepted_status( + xid, + onc_rpc::AcceptedStatus::Success(payload_buf), + ) + } + + pub fn create_bind_client_to_partition_response_from_accepted_status( + xid: u32, + accepted_status: onc_rpc::AcceptedStatus>, + ) -> Result, RpcError> { + let reply_body = onc_rpc::ReplyBody::Accepted(onc_rpc::AcceptedReply::new( + onc_rpc::auth::AuthFlavor::AuthNone::>(None), + accepted_status, + )); + + onc_rpc::RpcMessage::new(xid, onc_rpc::MessageType::Reply(reply_body)) + .serialise() + .map_err(|e| e.into()) + } + + fn generate_parse_bind_client_to_partition_response_result( + accepted_status: onc_rpc::AcceptedStatus>, + ) -> Result { + let response = + create_bind_client_to_partition_response_from_accepted_status(XID, accepted_status)?; + let deserialized = onc_rpc::RpcMessage::try_from(response.as_slice())?; + parse_bind_client_to_partition_response(&deserialized) + } + + pub fn generate_partition_id() -> efs_prot::PartitionId { + let mut bytes = [0u8; efs_prot::PARTITION_ID_LENGTH as usize]; + rand::thread_rng().fill_bytes(&mut bytes); + efs_prot::PartitionId(bytes) + } + + #[tokio::test] + async fn test_bind_client_to_partition() { + let server = TestService::new(true).await; + let tcp_stream = TcpStream::connect(("127.0.0.1", server.listen_port)) + .await + .expect("Could not connect to test server."); + + let connector = + TlsConnector::new(get_client_config().await.expect("Failed to read config")); + let mut tls_stream = connector + .connect("localhost", tcp_stream) + .await + .expect("Failed to establish TLS Connection"); + + let response = bind_client_to_partition(ProxyIdentifier::new(), &mut tls_stream) + .await + .expect("bind_client_to_partition request failed"); + + let partition_id = match response.bind_response { + BindResponse::READY(id) => PartitionId { id: id.0 }, + _ => panic!(), + }; + + assert_eq!( + server + .partition_ids + .get(1) + .expect("Service has no partition IDs"), + &partition_id + ); + server.shutdown().await; + } + + #[test] + fn test_request_serde() -> Result<(), RpcError> { + let proxy_id = ProxyIdentifier::new(); + let request = create_bind_client_to_partition_request(&proxy_id)?; + + let deserialized = onc_rpc::RpcMessage::try_from(request.as_slice())?; + let deserialized_proxy_id = parse_bind_client_to_partition_request(&deserialized)?; + + assert_eq!(proxy_id.uuid, deserialized_proxy_id.uuid); + assert_eq!(proxy_id.incarnation, deserialized_proxy_id.incarnation); + Ok(()) + } + + #[test] + fn test_response_serde() -> Result<(), RpcError> { + let partition_id = generate_partition_id(); + let partition_id_copy = efs_prot::PartitionId(partition_id.0.clone()); + + let response = create_bind_client_to_partition_response( + XID, + BindResponse::READY(partition_id_copy), + DEFAULT_SCALE_UP_CONFIG, + )?; + + let deserialized = onc_rpc::RpcMessage::try_from(response.as_slice())?; + let deserialized_response = parse_bind_client_to_partition_response(&deserialized)?; + + assert!( + matches!(deserialized_response.bind_response, BindResponse::READY(id) if id.0 == partition_id.0) + ); + Ok(()) + } + + #[test] + fn test_parse_bind_client_to_partition_response_missing_reply() -> Result<(), RpcError> { + // Create a call message, which will error when parsed as a response + let malformed_response = create_bind_client_to_partition_request(&ProxyIdentifier::new())?; + let deserialized = onc_rpc::RpcMessage::try_from(malformed_response.as_slice())?; + + let result = parse_bind_client_to_partition_response(&deserialized); + assert!(matches!(result, Err(RpcError::MalformedResponse))); + Ok(()) + } + + #[test] + fn test_parse_bind_client_to_partition_response_denied() -> Result<(), RpcError> { + let reply_body = + onc_rpc::ReplyBody::Denied(RejectedReply::AuthError(AuthError::BadCredentials)); + let rpc_message = onc_rpc::RpcMessage::new(XID, onc_rpc::MessageType::Reply(reply_body)); + + let result = parse_bind_client_to_partition_response(&rpc_message); + assert!(matches!(result, Err(RpcError::Denied))); + Ok(()) + } + + #[test] + fn test_parse_bind_client_to_partition_response_garbage_args() -> Result<(), RpcError> { + let parse_result = generate_parse_bind_client_to_partition_response_result( + onc_rpc::AcceptedStatus::GarbageArgs, + ); + assert!(matches!(parse_result, Err(RpcError::GarbageArgs))); + Ok(()) + } + + #[test] + fn test_parse_bind_client_to_partition_response_program_unavailable() -> Result<(), RpcError> { + let parse_result = generate_parse_bind_client_to_partition_response_result( + onc_rpc::AcceptedStatus::ProcedureUnavailable, + ); + assert!(matches!(parse_result, Err(RpcError::ProcedureUnavailable))); + Ok(()) + } + + #[test] + fn test_parse_bind_client_to_partition_response_program_mismatch() -> Result<(), RpcError> { + let program_version_low = 10; + let program_version_high = 100; + let parse_result = generate_parse_bind_client_to_partition_response_result( + onc_rpc::AcceptedStatus::ProgramMismatch { + low: program_version_low, + high: program_version_high, + }, + ); + assert!(matches!( + parse_result, + Err(RpcError::ProgramMismatch { low: l, high: h }) if program_version_low == l && program_version_high == h)); + Ok(()) + } + + #[test] + fn test_parse_bind_client_to_partition_response_procedure_unavailable() -> Result<(), RpcError> + { + let parse_result = generate_parse_bind_client_to_partition_response_result( + onc_rpc::AcceptedStatus::ProcedureUnavailable, + ); + assert!(matches!(parse_result, Err(RpcError::ProcedureUnavailable))); + Ok(()) + } + + #[test] + fn test_parse_bind_client_to_partition_response_system_error() -> Result<(), RpcError> { + let parse_result = generate_parse_bind_client_to_partition_response_result( + onc_rpc::AcceptedStatus::SystemError, + ); + assert!(matches!(parse_result, Err(RpcError::SystemError))); + Ok(()) + } +} diff --git a/src/proxy/src/error.rs b/src/proxy/src/error.rs new file mode 100644 index 00000000..f6eee0d4 --- /dev/null +++ b/src/proxy/src/error.rs @@ -0,0 +1,41 @@ +use thiserror::Error as ThisError; + +#[derive(Debug, ThisError)] +pub enum ConnectError { + #[error("Connect attempt cancelled")] + Cancelled, + #[error("{0}")] + IoError(#[from] tokio::io::Error), + #[error("Connect attempt failed - Maximum attempt count exceeded")] + MaxAttemptsExceeded, + #[error("Attempt to acquire additional connections to EFS failed.")] + MultiplexFailure, + #[error(transparent)] + Tls(#[from] s2n_tls::error::Error), + #[error("Connect attempt failed - Timeout")] + Timeout, +} + +#[derive(Debug, ThisError)] +pub enum RpcError { + #[error("not a rpc response")] + MalformedResponse, + #[error("rpc reply_stat: MSG_DENIED")] + Denied, + #[error("rpc accept_stat: GARBAGE_ARGS")] + GarbageArgs, + #[error("rpc accept_stat: PROG_UNAVAIL")] + ProgramUnavailable, + #[error("rpc accept_stat: PROG_MISMATCH low: {} high: {}", .low, .high)] + ProgramMismatch { low: u32, high: u32 }, + #[error("rpc accept_stat: PROC_UNAVAIL")] + ProcedureUnavailable, + #[error("rpc accept_stat: SystemError")] + SystemError, + #[error(transparent)] + IoError(#[from] tokio::io::Error), + #[error(transparent)] + XdrCodecError(#[from] xdr_codec::Error), + #[error(transparent)] + OncRpc(#[from] onc_rpc::Error), +} diff --git a/src/proxy/src/lib.rs b/src/proxy/src/lib.rs new file mode 100644 index 00000000..008bcdf6 --- /dev/null +++ b/src/proxy/src/lib.rs @@ -0,0 +1,4 @@ +//! One-sentence summary of your crate. +//! +//! Followed by more detailed Markdown documentation of your crate. +#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)] diff --git a/src/proxy/src/logger.rs b/src/proxy/src/logger.rs new file mode 100644 index 00000000..1d153fe9 --- /dev/null +++ b/src/proxy/src/logger.rs @@ -0,0 +1,65 @@ +use log::LevelFilter; +use log4rs::{ + append::{ + console::{ConsoleAppender, Target}, + rolling_file::{ + policy::compound::{ + roll::fixed_window::FixedWindowRoller, trigger::size::SizeTrigger, CompoundPolicy, + }, + RollingFileAppender, + }, + }, + config::{Appender, Config, Root}, + encode::pattern::PatternEncoder, + filter::threshold::ThresholdFilter, +}; +use std::{path::Path, str::FromStr}; + +use crate::config_parser::ProxyConfig; + +const LOG_FILE_MAX_BYTES: u64 = 1048576; +const LOG_FILE_COUNT: u32 = 10; + +pub fn init(config: &ProxyConfig) { + let log_file_path_string = config + .output + .clone() + .expect("config value `output` is not set"); + let log_file_path = Path::new(&log_file_path_string); + let level_filter = + LevelFilter::from_str(&config.debug).expect("config value for `debug` is invalid"); + + let stderr = ConsoleAppender::builder().target(Target::Stderr).build(); + + let trigger = SizeTrigger::new(LOG_FILE_MAX_BYTES); + let mut pattern = log_file_path_string.clone(); + pattern.push_str(".{}"); + let roller = FixedWindowRoller::builder() + .build(&pattern, LOG_FILE_COUNT) + .expect("Unable to create roller"); + let policy = CompoundPolicy::new(Box::new(trigger), Box::new(roller)); + + let log_file = RollingFileAppender::builder() + .encoder(Box::new(PatternEncoder::new( + "{d(%Y-%m-%dT%H:%M:%S%.3fZ)(utc)} {l} {M} {m}{n}", + ))) + .build(log_file_path, Box::new(policy)) + .expect("Unable to create log file"); + + let config = Config::builder() + .appender(Appender::builder().build("logfile", Box::new(log_file))) + .appender( + Appender::builder() + .filter(Box::new(ThresholdFilter::new(LevelFilter::Error))) + .build("stderr", Box::new(stderr)), + ) + .build( + Root::builder() + .appender("logfile") + .appender("stderr") + .build(level_filter), + ) + .expect("Invalid logger config"); + + let _ = log4rs::init_config(config).expect("Unable to initialize logger"); +} diff --git a/src/proxy/src/main.rs b/src/proxy/src/main.rs new file mode 100644 index 00000000..acc82e15 --- /dev/null +++ b/src/proxy/src/main.rs @@ -0,0 +1,138 @@ +use crate::config_parser::ProxyConfig; +use crate::connections::{PlainTextPartitionFinder, TlsPartitionFinder}; +use crate::tls::TlsConfig; +use clap::Parser; +use controller::Controller; +use log::{debug, error, info}; +use std::path::Path; +use std::sync::Arc; +use tokio::signal; +use tokio::sync::Mutex; +use tokio_util::sync::CancellationToken; + +mod config_parser; +mod connections; +mod controller; +mod efs_rpc; +mod error; +mod logger; +mod proxy; +mod proxy_identifier; +mod rpc; +mod shutdown; +mod status_reporter; +mod tls; + +#[allow(clippy::all)] +#[allow(deprecated)] +#[allow(invalid_value)] +#[allow(non_camel_case_types)] +#[allow(unused_assignments)] +mod efs_prot { + include!(concat!(env!("OUT_DIR"), "/efs_prot_xdr.rs")); +} + +#[tokio::main] +async fn main() { + let args = Args::parse(); + + let proxy_config = match ProxyConfig::from_path(Path::new(&args.proxy_config_path)) { + Ok(config) => config, + Err(e) => panic!("Failed to read configuration. {}", e), + }; + + if let Some(_log_file_path) = &proxy_config.output { + logger::init(&proxy_config) + } + + info!("Running with configuration: {:?}", proxy_config); + + // This "status reporter" is currently only used in tests + let (_status_requester, status_reporter) = status_reporter::create_status_channel(); + + let sigterm_cancellation_token = CancellationToken::new(); + let mut sigterm_listener = match signal::unix::signal(signal::unix::SignalKind::terminate()) { + Ok(listener) => listener, + Err(e) => panic!("Failed to create SIGTERM listener. {}", e), + }; + + let controller_handle = if args.tls { + let tls_config = match get_tls_config(&proxy_config).await { + Ok(config) => Arc::new(Mutex::new(config)), + Err(e) => panic!("Failed to obtain TLS config:{}", e), + }; + + run_sighup_handler(proxy_config.clone(), tls_config.clone()); + + let controller = Controller::new( + &proxy_config.nested_config.listen_addr, + Arc::new(TlsPartitionFinder::new(tls_config)), + status_reporter, + ) + .await; + tokio::spawn(controller.run(sigterm_cancellation_token.clone())) + } else { + let controller = Controller::new( + &proxy_config.nested_config.listen_addr, + Arc::new(PlainTextPartitionFinder { + mount_target_addr: proxy_config.nested_config.mount_target_addr.clone(), + }), + status_reporter, + ) + .await; + tokio::spawn(controller.run(sigterm_cancellation_token.clone())) + }; + + tokio::select! { + shutdown_reason = controller_handle => error!("Shutting down. {:?}", shutdown_reason), + _ = sigterm_listener.recv() => { + info!("Received SIGTERM"); + sigterm_cancellation_token.cancel(); + }, + } +} + +async fn get_tls_config(proxy_config: &ProxyConfig) -> Result { + let tls_config = TlsConfig::new( + proxy_config.fips, + Path::new(&proxy_config.nested_config.ca_file), + Path::new(&proxy_config.nested_config.client_cert_pem_file), + Path::new(&proxy_config.nested_config.client_private_key_pem_file), + &proxy_config.nested_config.mount_target_addr, + &proxy_config.nested_config.expected_server_hostname_tls, + ) + .await; + let tls_config = tls_config?; + Ok(tls_config) +} + +fn run_sighup_handler(proxy_config: ProxyConfig, tls_config: Arc>) { + tokio::spawn(async move { + let mut sighup_listener = match signal::unix::signal(signal::unix::SignalKind::hangup()) { + Ok(listener) => listener, + Err(e) => panic!("Failed to create SIGHUP listener. {}", e), + }; + + loop { + sighup_listener + .recv() + .await + .expect("SIGHUP listener stream is closed"); + + debug!("Received SIGHUP"); + let mut locked_config = tls_config.lock().await; + match get_tls_config(&proxy_config).await { + Ok(config) => *locked_config = config, + Err(e) => panic!("Failed to acquire TLS config. {}", e), + } + } + }); +} + +#[derive(Parser, Debug, Clone)] +pub struct Args { + pub proxy_config_path: String, + + #[arg(long, default_value_t = false)] + pub tls: bool, +} diff --git a/src/proxy/src/proxy.rs b/src/proxy/src/proxy.rs new file mode 100644 index 00000000..d686e144 --- /dev/null +++ b/src/proxy/src/proxy.rs @@ -0,0 +1,525 @@ +use std::{ + error::Error, + marker::PhantomData, + sync::{atomic::AtomicU64, Arc}, + time::{Duration, Instant}, +}; + +use bytes::BytesMut; +use log::{debug, error, info, trace}; +use tokio::{ + io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, + net::{ + tcp::{OwnedReadHalf, OwnedWriteHalf}, + TcpStream, + }, + sync::{ + mpsc::{self}, + Mutex, + }, + task::JoinHandle, +}; +use tokio_util::sync::CancellationToken; + +use crate::rpc::{RpcFragmentParseError, RPC_MAX_SIZE}; +use crate::{ + connections::ProxyStream, + controller::Event, + rpc::RpcBatch, + shutdown::{ShutdownHandle, ShutdownReason}, +}; + +pub const REPORT_INTERVAL_SECS: u64 = 3; + +#[derive(Copy, Clone, Debug)] +pub struct PerformanceStats { + _num_connections: usize, + pub read_bytes: u64, + pub write_bytes: u64, + pub time_delta: Duration, +} + +impl PerformanceStats { + pub fn new( + num_connections: usize, + read_bytes: u64, + write_bytes: u64, + time_delta: Duration, + ) -> Self { + PerformanceStats { + _num_connections: num_connections, + read_bytes, + write_bytes, + time_delta, + } + } + + // Return total throughput in bytes per second + pub fn get_total_throughput(&self) -> u64 { + let time_delta_seconds = self.time_delta.as_secs(); + if time_delta_seconds == 0 { + 0 + } else { + let total_bytes = self.read_bytes + self.write_bytes; + total_bytes / time_delta_seconds + } + } +} +pub struct Proxy { + partition_to_nfs_cli_queue: mpsc::Sender, + partition_senders: Arc>>>, + shutdown: ShutdownHandle, + proxy_task_handle: JoinHandle<()>, + phantom: PhantomData, +} + +impl Proxy { + const SHUTDOWN_TIMEOUT: u64 = 15; + + pub fn new( + nfs_client: TcpStream, + partition_servers: Vec, + notification_queue: mpsc::Sender>, + shutdown: ShutdownHandle, + ) -> Self { + let (tx, rx) = mpsc::channel(64); + + let senders = partition_servers + .into_iter() + .map(|stream| Proxy::create_connection(stream, tx.clone(), shutdown.clone())) + .collect::>>(); + + let partition_senders = Arc::new(Mutex::new(senders)); + + let proxy = ProxyTask::new( + nfs_client, + notification_queue, + partition_senders.clone(), + rx, + shutdown.clone(), + ); + let proxy_task_handle = tokio::spawn(proxy.run()); + Self { + partition_to_nfs_cli_queue: tx, + partition_senders, + shutdown, + proxy_task_handle, + phantom: PhantomData, + } + } + + pub async fn add_connection(&self, stream: S) { + let conn = Proxy::create_connection( + stream, + self.partition_to_nfs_cli_queue.clone(), + self.shutdown.clone(), + ); + let mut f = self.partition_senders.lock().await; + f.push(conn); + } + + fn create_connection( + stream: S, + proxy: mpsc::Sender, + shutdown: ShutdownHandle, + ) -> mpsc::Sender { + let (tx, rx) = mpsc::channel(64); + tokio::spawn(ConnectionTask::new(stream, rx, proxy).run(shutdown)); + tx + } + + pub async fn shutdown(self) -> Result<(), Box> { + self.shutdown.cancellation_token.cancel(); + match tokio::time::timeout( + Duration::from_secs(Self::SHUTDOWN_TIMEOUT), + self.proxy_task_handle, + ) + .await? + { + Ok(()) => Ok(()), + Err(join_err) => Err(join_err.into()), + } + } +} + +const BUFFER_SIZE: usize = RPC_MAX_SIZE; + +struct ProxyTask { + nfs_client: TcpStream, + notification_queue: mpsc::Sender>, + partition_senders: Arc>>>, + response_queue: mpsc::Receiver, + shutdown: ShutdownHandle, +} + +enum ConnectionMessage { + Response(RpcBatch), +} + +impl ProxyTask { + pub fn new( + nfs_client: TcpStream, + notification_queue: mpsc::Sender>, + partition_senders: Arc>>>, + response_queue: mpsc::Receiver, + shutdown: ShutdownHandle, + ) -> Self { + Self { + nfs_client, + notification_queue, + partition_senders, + response_queue, + shutdown, + } + } + + async fn run(self) { + // Runs Proxy between NFS Client and the EFS Service. + // + // This function returns when it is cancelled by the `ShutdownHandle`, or if an error + // causes the `ProxyTask`'s `reader`, `writer`, or `reporter` task to return. In any of + // these cases, the `tokio::select!` block will cancel all of the tasks run by this object. + // + // An unused `mspc::Sender` is passed to each task spawned, so that we can await task + // shutdown with `mspc::Receiver::recv`. See https://tokio.rs/tokio/topics/shutdown. + + trace!("Starting proxy task"); + + let (shutdown_sender, mut shutdown_receiver) = mpsc::channel::(1); + + let write_byte_count = Arc::new(AtomicU64::new(0)); + let read_byte_count = Arc::new(AtomicU64::new(0)); + + let (read_half, write_half) = self.nfs_client.into_split(); + + let reader = Self::run_reader( + read_half, + read_byte_count.clone(), + self.partition_senders.clone(), + self.shutdown.clone(), + shutdown_sender.clone(), + ); + let shutdown = self.shutdown.clone(); + tokio::spawn(async move { + tokio::select! { + _ = reader => trace!("Proxy reader stopped"), + _ = shutdown.cancellation_token.cancelled() => trace!("Proxy reader stopped by ShutdownHandle"), + } + }); + + let writer = Self::run_writer( + write_half, + write_byte_count.clone(), + self.response_queue, + self.shutdown.clone(), + shutdown_sender.clone(), + ); + let shutdown = self.shutdown.clone(); + tokio::spawn(async move { + tokio::select! { + _ = writer => trace!("Proxy writer stopped"), + _ = shutdown.cancellation_token.cancelled() => trace!("Proxy writer stopped by ShutdownHandle"), + } + }); + + let reporter = Self::run_reporter( + read_byte_count, + write_byte_count, + self.partition_senders.clone(), + self.notification_queue.clone(), + shutdown_sender.clone(), + ); + let shutdown = self.shutdown.clone(); + tokio::spawn(async move { + tokio::select! { + _ = reporter => trace!("Proxy reporter stopped"), + _ = shutdown.cancellation_token.cancelled() => trace!("Proxy reporter stopped by ShutdownHandle"), + } + }); + + drop(shutdown_sender); + shutdown_receiver.recv().await; + } + + // NFS client to Proxy + async fn run_reader( + mut read_half: OwnedReadHalf, + read_count: Arc, + partition_senders: Arc>>>, + shutdown: ShutdownHandle, + _shutdown_sender: mpsc::Sender, + ) { + trace!("Starting proxy reader"); + let mut buffer = BytesMut::with_capacity(BUFFER_SIZE); + let reason; + let mut next_conn = 0; + + loop { + match read_half.read_buf(&mut buffer).await { + Ok(n_read) => { + if n_read == 0 { + reason = Some(ShutdownReason::Unmount); + break; + } else { + read_count.fetch_add(n_read as u64, std::sync::atomic::Ordering::AcqRel); + } + } + Err(e) => { + info!("Error reading from NFS client {:?}", e); + reason = Some(ShutdownReason::Unmount); + break; + } + } + + match RpcBatch::parse_batch(&mut buffer) { + Ok(Some(batch)) => { + let f = partition_senders.lock().await; + let r = f[next_conn].send(batch).await; + next_conn = (next_conn + 1) % f.len(); + if let Err(e) = r { + debug!("Error sending RPC batch to connection task {:?}", e); + reason = Some(ShutdownReason::UnexpectedError); + break; + }; + } + Err(RpcFragmentParseError::InvalidSizeTooSmall) => { + drop(read_half); + error!("NFS Client Error: invalid RPC size - size too small"); + reason = Some(ShutdownReason::FrameSizeTooSmall); + break; + } + Err(RpcFragmentParseError::SizeLimitExceeded) => { + drop(read_half); + error!("NFS Client Error: invalid RPC size - size limit exceeded"); + reason = Some(ShutdownReason::FrameSizeExceeded); + break; + } + Ok(None) | Err(RpcFragmentParseError::Incomplete) => (), + } + + if buffer.capacity() == 0 { + buffer.reserve(BUFFER_SIZE) + } + } + trace!("cli_to_server exiting!"); + shutdown.exit(reason).await; + } + + // Proxy to NFS Client + async fn run_writer( + mut write_half: OwnedWriteHalf, + write_count: Arc, + mut response_queue: mpsc::Receiver, + shutdown: ShutdownHandle, + _shutdown_sender: mpsc::Sender, + ) { + trace!("Starting proxy writer"); + + let mut reason = None; + loop { + match response_queue.recv().await { + Some(ConnectionMessage::Response(batch)) => { + let mut total_written = 0; + + for b in &batch.rpcs { + match write_half.write_all(b).await { + Ok(_) => total_written += b.len(), + Err(e) => { + debug!("Error writing to nfs_client. {:?}", e); + reason = Some(ShutdownReason::Unmount); + break; + } + }; + } + + write_count + .fetch_add(total_written as u64, std::sync::atomic::Ordering::AcqRel); + } + None => { + info!("Exiting server_to_cli"); + break; + } + } + } + shutdown.exit(reason).await; + } + + async fn run_reporter( + read_count: Arc, + write_count: Arc, + partition_senders: Arc>>>, + notification_queue: mpsc::Sender>, + _shutdown_sender: mpsc::Sender, + ) { + trace!("Starting reporter task"); + + let mut last = Instant::now(); + loop { + tokio::time::sleep(Duration::from_secs(REPORT_INTERVAL_SECS)).await; + + let num_connections; + { + let t = partition_senders.lock().await; + num_connections = t.len(); + drop(t); + } + + let now = Instant::now(); + let delta = now - last; + last = now; + let read = read_count.swap(0, std::sync::atomic::Ordering::AcqRel); + let write = write_count.swap(0, std::sync::atomic::Ordering::AcqRel); + let result = notification_queue + .send(Event::ProxyUpdate(PerformanceStats::new( + num_connections, + read, + write, + delta, + ))) + .await; + if result.is_err() { + break; + } + } + } +} + +struct ConnectionTask { + stream: S, + proxy_receiver: mpsc::Receiver, + proxy_sender: mpsc::Sender, +} + +impl ConnectionTask { + fn new( + stream: S, + proxy_receiver: mpsc::Receiver, + proxy_sender: mpsc::Sender, + ) -> Self { + Self { + stream, + proxy_receiver, + proxy_sender, + } + } + + async fn run(self, shutdown_handle: ShutdownHandle) { + let (r, w) = split(self.stream); + + let shutdown = shutdown_handle.clone(); + + // This CancellationToken facilitates graceful TLS connection closures by ensuring that + // that the ReadHalf is dropped only after the WriteHalf.shutdown() has returned + let connection_cancellation_token = CancellationToken::new(); + + let writer = Self::run_writer( + w, + self.proxy_receiver, + shutdown_handle.clone(), + connection_cancellation_token.clone(), + ); + tokio::spawn(async move { + tokio::select! { + _ = shutdown.cancellation_token.cancelled() => trace!("Cancelled"), + _ = writer => {}, + } + }); + + let reader = Self::run_reader(r, self.proxy_sender, shutdown_handle.clone()); + tokio::spawn(async move { + tokio::select! { + _ = connection_cancellation_token.cancelled() => trace!("Cancelled"), + _ = reader => {}, + } + }); + } + + // EFS to Proxy + async fn run_reader( + mut server_read_half: ReadHalf, + sender: mpsc::Sender, + shutdown: ShutdownHandle, + ) { + let reason; + let mut buffer = BytesMut::with_capacity(BUFFER_SIZE); + loop { + match server_read_half.read_buf(&mut buffer).await { + Ok(n_read) => { + if n_read == 0 { + reason = Option::Some(ShutdownReason::NeedsRestart); + break; + } + } + Err(e) => { + debug!("Error reading from server: {:?}", e); + reason = Option::Some(ShutdownReason::NeedsRestart); + break; + } + }; + + match RpcBatch::parse_batch(&mut buffer) { + Ok(Some(batch)) => { + if let Err(e) = sender.send(ConnectionMessage::Response(batch)).await { + debug!("Error sending result back: {:?}", e); + reason = Some(ShutdownReason::UnexpectedError); + break; + } + } + Err(RpcFragmentParseError::InvalidSizeTooSmall) => { + drop(server_read_half); + error!("Server Error: invalid RPC size - size too small"); + reason = Some(ShutdownReason::UnexpectedError); + break; + } + Err(RpcFragmentParseError::SizeLimitExceeded) => { + drop(server_read_half); + error!("Server Error: invalid RPC size - size limit exceeded"); + reason = Some(ShutdownReason::UnexpectedError); + break; + } + Ok(None) | Err(RpcFragmentParseError::Incomplete) => (), + } + + if buffer.capacity() == 0 { + buffer.reserve(BUFFER_SIZE) + } + } + shutdown.exit(reason).await; + } + + // Proxy to EFS + async fn run_writer( + mut server_write_half: WriteHalf, + mut receiver: mpsc::Receiver, + shutdown: ShutdownHandle, + connection_cancellation_token: CancellationToken, + ) { + let mut reason = Option::None; + loop { + let Some(batch) = receiver.recv().await else { + debug!("sender dropped"); + break; + }; + + for b in &batch.rpcs { + match server_write_half.write_all(b).await { + Ok(_) => (), + Err(e) => { + debug!("Error writing to server: {:?}", e); + reason = Option::Some(ShutdownReason::NeedsRestart); + break; + } + }; + } + } + + tokio::spawn(async move { + match server_write_half.shutdown().await { + Ok(_) => (), + Err(e) => debug!("Failed to gracefully shutdown connection: {}", e), + }; + connection_cancellation_token.cancel(); + }); + shutdown.exit(reason).await; + } +} diff --git a/src/proxy/src/proxy_identifier.rs b/src/proxy/src/proxy_identifier.rs new file mode 100644 index 00000000..e8e08e06 --- /dev/null +++ b/src/proxy/src/proxy_identifier.rs @@ -0,0 +1,54 @@ +use uuid::Uuid; + +pub const INITIAL_INCARNATION: i64 = 0; + +#[derive(Eq, PartialEq, Clone, Copy, Debug)] +pub struct ProxyIdentifier { + pub uuid: Uuid, + pub incarnation: i64, +} + +impl ProxyIdentifier { + pub fn new() -> Self { + ProxyIdentifier { + uuid: Uuid::new_v4(), + incarnation: INITIAL_INCARNATION, + } + } + + pub fn increment(&mut self) { + if self.incarnation == i64::MAX { + self.incarnation = 0; + return; + } + self.incarnation += 1; + } +} + +#[cfg(test)] +mod tests { + use super::ProxyIdentifier; + use super::INITIAL_INCARNATION; + + #[test] + fn test_increment() { + let mut proxy_id = ProxyIdentifier::new(); + let proxy_id_original = proxy_id.clone(); + for i in 0..5 { + assert_eq!(i, proxy_id.incarnation); + proxy_id.increment(); + } + assert_eq!(proxy_id_original.uuid, proxy_id.uuid); + assert_eq!(INITIAL_INCARNATION, proxy_id_original.incarnation); + } + + #[test] + fn test_wrap_around() { + let mut proxy_id = ProxyIdentifier::new(); + let proxy_id_original = proxy_id.clone(); + proxy_id.incarnation = i64::MAX; + proxy_id.increment(); + assert_eq!(proxy_id_original.uuid, proxy_id.uuid); + assert_eq!(INITIAL_INCARNATION, proxy_id.incarnation); + } +} diff --git a/src/proxy/src/rpc.rs b/src/proxy/src/rpc.rs new file mode 100644 index 00000000..f167839d --- /dev/null +++ b/src/proxy/src/rpc.rs @@ -0,0 +1,242 @@ +use std::io::Cursor; + +use bytes::{Buf, Bytes, BytesMut}; +use tokio::io::AsyncReadExt; + +use crate::connections::ProxyStream; + +// Each element is an RPC call. +pub struct RpcBatch { + pub rpcs: Vec, +} + +#[derive(Debug, PartialEq)] +pub enum RpcFragmentParseError { + InvalidSizeTooSmall, + SizeLimitExceeded, + Incomplete, +} + +pub const RPC_LAST_FRAG: u32 = 0x80000000; +pub const RPC_SIZE_MASK: u32 = 0x7FFFFFFF; +pub const RPC_HEADER_SIZE: usize = 4; + +/* The sunrpc server implementation in linux has a maximum payload of 1MB + 1 page + * (see include/linux/sunrpc/svc.h#RPCSVC_MAXPAYLOAD and sv_max_mesg). + */ +pub const RPC_MAX_SIZE: usize = 1024 * 1024 + 4 * 1024; +pub const RPC_MIN_SIZE: usize = 2; + +impl RpcBatch { + pub fn parse_batch(buffer: &mut BytesMut) -> Result, RpcFragmentParseError> { + let mut batch = RpcBatch { rpcs: Vec::new() }; + + loop { + match Self::check_rpc_message(Cursor::new(&buffer[..])) { + Ok(len) => { + let rpc_message = buffer.split_to(len); + batch.rpcs.push(rpc_message.freeze()); + } + Err(RpcFragmentParseError::Incomplete) => break, + Err(e) => return Err(e), + } + } + + if batch.rpcs.is_empty() { + Ok(None) + } else { + Ok(Some(batch)) + } + } + + pub fn check_rpc_message(mut src: Cursor<&[u8]>) -> Result { + loop { + if src.remaining() < RPC_HEADER_SIZE { + return Err(RpcFragmentParseError::Incomplete); + } + + let fragment_header = src.get_u32(); + let fragment_size = (fragment_header & RPC_SIZE_MASK) as usize; + let is_last_fragment = (fragment_header & RPC_LAST_FRAG) != 0; + + if fragment_size <= RPC_MIN_SIZE { + return Err(RpcFragmentParseError::InvalidSizeTooSmall); + } + + if fragment_size >= RPC_MAX_SIZE { + return Err(RpcFragmentParseError::SizeLimitExceeded); + } + + if src.remaining() < fragment_size { + return Err(RpcFragmentParseError::Incomplete); + } + + src.advance(fragment_size); + + if is_last_fragment { + return Ok(src.position() as usize); + } + } + } +} + +pub async fn read_rpc_bytes(stream: &mut dyn ProxyStream) -> Result, tokio::io::Error> { + let mut header = [0; RPC_HEADER_SIZE]; + stream.read_exact(&mut header).await?; + + // NOTE: onc-rpc crate does not support fragmentation out of the box. Add 4 to include the header. + let len = (RPC_SIZE_MASK & extract_u32_from_bytes(&header)) + RPC_HEADER_SIZE as u32; + + let mut payload = vec![0; len as usize]; + payload[0..RPC_HEADER_SIZE].clone_from_slice(&header); + + stream.read_exact(&mut payload[RPC_HEADER_SIZE..]).await?; + + Ok(payload) +} + +fn extract_u32_from_bytes(header: &[u8]) -> u32 { + u32::from_be_bytes([header[0], header[1], header[2], header[3]]) +} + +#[cfg(test)] +pub mod test { + use crate::rpc::RPC_MAX_SIZE; + + use super::{RpcBatch, RpcFragmentParseError, RPC_HEADER_SIZE, RPC_LAST_FRAG}; + use bytes::{BufMut, BytesMut}; + use rand::Rng; + + // Generates message fragments for tests + // + // This function generates a set of message fragments from random data. The fragments are constructed + // in a way that they can be later assembled into the full long message data + // function. + // + // # Arguments + // * `size` - The total size of the message. + // * `num_fragments` - The number of fragments to generate. + // + pub fn generate_msg_fragments(size: usize, num_fragments: usize) -> (bytes::BytesMut, Vec) { + let mut rng = rand::thread_rng(); + let data: Vec = (0..size).map(|_| rng.gen()).collect(); + + let fragment_data_size = data.len() / num_fragments; + + let mut data_buffer = bytes::BytesMut::new(); + for i in 0..num_fragments { + let start_idx = i * fragment_data_size; + let end_idx = std::cmp::min(size, start_idx + fragment_data_size); + let fragment_data = &data[start_idx..end_idx]; + + let mut header = (end_idx - start_idx) as u32; + if end_idx == size { + header |= 1 << 31; + } + + data_buffer.extend_from_slice(&header.to_be_bytes()); + data_buffer.extend_from_slice(fragment_data); + } + assert_eq!(data_buffer.len(), (num_fragments * 4) + data.len()); + + (data_buffer, data) + } + + #[test] + fn multiple_messages() { + let mut b = BytesMut::with_capacity(8); + b.put_u32(RPC_LAST_FRAG | 4); + b.put_u32(42); + b.put_u32(RPC_LAST_FRAG | 4); + + let batch = RpcBatch::parse_batch(&mut b); + let batch = batch.unwrap().unwrap(); + assert_eq!(batch.rpcs[0].len(), 8); + assert_eq!(batch.rpcs.len(), 1); + + b.put_u32(43); + let batch = RpcBatch::parse_batch(&mut b); + let batch = batch.unwrap().unwrap(); + assert_eq!(batch.rpcs[0].len(), 8); + assert_eq!(batch.rpcs.len(), 1); + + let batch = RpcBatch::parse_batch(&mut b); + assert!(matches!(batch, Ok(None))); + } + + #[test] + fn test_invalid_rpc_small_fragment() { + let num_fragments = 1; + let (mut input_buffer, _) = generate_msg_fragments(1, num_fragments); + let result = RpcBatch::parse_batch(&mut input_buffer); + assert!(matches!( + result, + Err(RpcFragmentParseError::InvalidSizeTooSmall) + )); + } + + #[test] + fn test_invalid_rpc_big_fragment() { + let num_fragments = 1; + let (mut input_buffer, _) = generate_msg_fragments(RPC_MAX_SIZE + 1, num_fragments); + let result = RpcBatch::parse_batch(&mut input_buffer); + assert!(matches!( + result, + Err(RpcFragmentParseError::SizeLimitExceeded) + )); + } + + #[test] + fn test_parse_batch_single_message() { + // Create an input buffer with multiple RPC fragments + let num_fragments = 3; + let message_size = 12; + let (mut input_buffer, _) = generate_msg_fragments(message_size, num_fragments); + let mut rpc_batch = RpcBatch::parse_batch(&mut input_buffer) + .expect("parse batch failed") + .expect("no rpc messages found"); + + assert_eq!(1, rpc_batch.rpcs.len()); + let rpc_message = rpc_batch.rpcs.pop().expect("No RPC messages"); + + let expected_message_size = num_fragments * RPC_HEADER_SIZE + message_size; + assert_eq!(expected_message_size, rpc_message.len()); + } + + #[test] + fn test_parse_batch_multiple_message() { + // Create an input buffer with multiple RPC messages + let num_fragments_1 = 3; + let message_size_1 = 12; + let (mut input_buffer, _) = generate_msg_fragments(message_size_1, num_fragments_1); + + let num_fragments_2 = 6; + let message_size_2 = 24; + let (input_buffer_2, _) = generate_msg_fragments(message_size_2, num_fragments_2); + + let num_fragments_3 = 1; + let message_size_3 = 50; + let (input_buffer_3, _) = generate_msg_fragments(message_size_3, num_fragments_3); + + input_buffer.extend_from_slice(&input_buffer_2); + input_buffer.extend_from_slice(&input_buffer_3); + + let mut rpc_batch = RpcBatch::parse_batch(&mut input_buffer) + .expect("parse batch failed") + .expect("no rpc messages found"); + + assert_eq!(3, rpc_batch.rpcs.len()); + + let rpc_message_3 = rpc_batch.rpcs.pop().expect("No RPC messages"); + let expected_message_size_3 = num_fragments_3 * RPC_HEADER_SIZE + message_size_3; + assert_eq!(expected_message_size_3, rpc_message_3.len()); + + let rpc_message_2 = rpc_batch.rpcs.pop().expect("No RPC messages"); + let expected_message_size_2 = num_fragments_2 * RPC_HEADER_SIZE + message_size_2; + assert_eq!(expected_message_size_2, rpc_message_2.len()); + + let rpc_message_1 = rpc_batch.rpcs.pop().expect("No RPC messages"); + let expected_message_size_1 = num_fragments_1 * RPC_HEADER_SIZE + message_size_1; + assert_eq!(expected_message_size_1, rpc_message_1.len()); + } +} diff --git a/src/proxy/src/shutdown.rs b/src/proxy/src/shutdown.rs new file mode 100644 index 00000000..5c8f148a --- /dev/null +++ b/src/proxy/src/shutdown.rs @@ -0,0 +1,85 @@ +use log::debug; +use tokio::sync::mpsc::{self, Receiver, Sender}; +use tokio_util::sync::CancellationToken; + +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum ShutdownReason { + NeedsRestart, + UnexpectedError, + Unmount, + FrameSizeExceeded, + FrameSizeTooSmall, +} + +#[derive(Clone)] +pub struct ShutdownHandle { + pub cancellation_token: CancellationToken, + notifier: Sender, +} + +impl ShutdownHandle { + pub fn new(cancellation_token: CancellationToken) -> (Self, Receiver) { + let (notifier, r) = mpsc::channel(1024); + let h = Self { + cancellation_token, + notifier, + }; + (h, r) + } + + pub async fn exit(self, reason: Option) { + debug!("Exiting: {:?}", reason); + self.cancellation_token.cancel(); + if let Some(reason) = reason { + let _ = self.notifier.send(reason).await; + } + } +} + +#[cfg(test)] +mod test { + use log::info; + use std::time::Duration; + + use tokio::sync::mpsc; + use tokio_util::sync::CancellationToken; + + use super::ShutdownHandle; + + #[tokio::test] + async fn test() { + let (t, mut r) = mpsc::channel(1); + let token = CancellationToken::new(); + + let s1 = ShutdownHandle { + cancellation_token: token.clone(), + notifier: t.clone(), + }; + let s2 = ShutdownHandle { + cancellation_token: token.clone(), + notifier: t.clone(), + }; + + tokio::spawn(run_task(s1, false)); + tokio::spawn(run_task(s2, true)); + drop(t); + + let _ = r.recv().await; + info!("Done"); + } + + async fn run_task(shutdown: ShutdownHandle, to_cancel: bool) { + let f = async { + if to_cancel { + shutdown.cancellation_token.clone().cancel() + } else { + tokio::time::sleep(Duration::from_secs(10)).await; + } + }; + tokio::select! { + _ = shutdown.cancellation_token.cancelled() => {}, + _ = f => {} + } + info!("Task exiting"); + } +} diff --git a/src/proxy/src/status_reporter.rs b/src/proxy/src/status_reporter.rs new file mode 100644 index 00000000..ac9f6a9c --- /dev/null +++ b/src/proxy/src/status_reporter.rs @@ -0,0 +1,110 @@ +use crate::controller::ConnectionSearchState; +use crate::efs_rpc::PartitionId; +use crate::{proxy::PerformanceStats, proxy_identifier::ProxyIdentifier}; +use anyhow::{Error, Result}; +use tokio::sync::mpsc::{self, Receiver, Sender}; +use tokio::time::Instant; + +pub struct Report { + pub proxy_id: ProxyIdentifier, + pub partition_id: Option, + pub connection_state: ConnectionSearchState, + pub num_connections: usize, + pub last_proxy_update: Option<(Instant, PerformanceStats)>, + pub scale_up_attempt_count: u64, + pub restart_count: u64, +} + +type Request = (); +type Response = Report; + +pub struct StatusReporter { + pub sender: Sender, + pub receiver: Receiver, +} + +impl StatusReporter { + pub async fn await_report_request(&mut self) -> Result<()> { + self.receiver + .recv() + .await + .ok_or_else(|| Error::msg("Request channel closed"))?; + Ok(()) + } + + // Note: This should only be called when a message is received by the receiver. + pub async fn publish_status(&mut self, report: Report) { + match self.sender.send(report).await { + Ok(_) => (), + Err(e) => panic!("StatusReporter could not send report {}", e), + } + } +} + +pub struct StatusRequester { + _sender: Sender, + _receiver: Receiver, +} + +impl StatusRequester { + pub async fn _request_status(&mut self) -> Result { + self._sender.send(()).await?; + self._receiver + .recv() + .await + .ok_or_else(|| Error::msg("Response channel closed")) + } +} + +pub fn create_status_channel() -> (StatusRequester, StatusReporter) { + let (call_sender, call_receiver) = mpsc::channel::(1); + let (reply_sender, reply_receiver) = mpsc::channel::(1); + + let status_requester = StatusRequester { + _sender: call_sender, + _receiver: reply_receiver, + }; + + let status_reporter = StatusReporter { + sender: reply_sender, + receiver: call_receiver, + }; + + (status_requester, status_reporter) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_basic() -> Result<()> { + let proxy_id = ProxyIdentifier::new(); + + let (mut status_requester, mut status_reporter) = create_status_channel(); + tokio::spawn(async move { + status_reporter + .await_report_request() + .await + .expect("Request channel closed"); + let report = Report { + proxy_id: proxy_id.clone(), + partition_id: None, + connection_state: ConnectionSearchState::Idle, + num_connections: 1, + last_proxy_update: None, + scale_up_attempt_count: 0, + restart_count: 0, + }; + status_reporter.publish_status(report).await + }); + + let r = status_requester._request_status().await?; + assert_eq!(proxy_id, r.proxy_id); + assert!(matches!(r.partition_id, None)); + assert_eq!(r.connection_state, ConnectionSearchState::Idle); + assert!(matches!(r.last_proxy_update, None)); + assert_eq!(1, r.num_connections); + Ok(()) + } +} diff --git a/src/proxy/src/tls.rs b/src/proxy/src/tls.rs new file mode 100644 index 00000000..6a6f062e --- /dev/null +++ b/src/proxy/src/tls.rs @@ -0,0 +1,230 @@ +use anyhow::{Context, Result}; +use log::*; +use nix::NixPath; +use s2n_tls::enums::ClientAuthType::Optional; +use s2n_tls::security::Policy; +use s2n_tls::{config::Config, security::DEFAULT_TLS13}; +use s2n_tls_tokio::TlsConnector; +use s2n_tls_tokio::TlsStream; +use std::path::Path; +use tokio::net::TcpStream; + +use crate::error::ConnectError; + +pub const FIPS_COMPLIANT_POLICY_VERSION: &str = "20230317"; +pub struct InsecureAcceptAllCertificatesHandler; +impl s2n_tls::callbacks::VerifyHostNameCallback for InsecureAcceptAllCertificatesHandler { + fn verify_host_name(&self, _host_name: &str) -> bool { + true + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct TlsConfig { + pub fips_enabled: bool, + + /// Contents of the certificate authority file. E.g. /etc/amazon/efs/efs-utils.crt + pub ca_file_contents: Vec, + + /// The client-side certificate and public key + pub client_cert: Vec, + + /// The client private key + pub client_private_key: Vec, + + /// The remote address to establish the TLS connection with + pub remote_addr: String, + + /// The hostname that is expected to be on the remote server's TLS certificate + pub server_domain: String, +} + +// s2n-tls errors if there are comments in the certificate files. This function removes comments if +// they are present. +async fn read_file_with_comments_removed(path: &Path) -> Result> { + let file = tokio::fs::File::open(path).await?; + let reader = tokio::io::BufReader::new(file); + let mut lines = tokio::io::AsyncBufReadExt::lines(reader); + + let mut output = Vec::new(); + while let Ok(Some(line)) = lines.next_line().await { + if !line.starts_with("# ") { + if !output.is_empty() { + output.push(b'\n'); + } + + output.extend_from_slice(line.as_bytes()); + } + } + Ok(output) +} + +impl TlsConfig { + /// Create an instance of TlsConfig. + /// + /// This will return an error if the files could not be read or the remote address could not be resolved. + /// + /// # Arguments + /// * `ca_file` - File path of the certificate authority file. E.g. /etc/amazon/efs/efs-utils.crt + /// * `client_cert_pem_file` - File path of the file that contains the client-side certificate and public key + /// * `client_private_key_pem_file` - File path of the file that contains the client private key + /// * `remote_addr` - The remote address to establish the TLS connection with + /// * `server_domain` - The hostname that is expected to be on the certificate that the remote server presents + /// + pub async fn new( + fips_enabled: bool, + ca_file: &Path, + client_cert_pem_file: &Path, + client_private_key_pem_file: &Path, + remote_addr: &str, + server_domain: &str, + ) -> Result { + let mut ca_file_contents: Vec = Vec::new(); + if !ca_file.is_empty() { + ca_file_contents = read_file_with_comments_removed(ca_file).await.context( + String::from("Error in TlsConfig::new. Unable to the CA File. Make sure it does not have any comments (lines that start with #)."))?; + } + let client_cert = read_file_with_comments_removed(client_cert_pem_file) + .await + .context(String::from( + "Error in TlsConfig::new. Unable to read the client certificate file.", + ))?; + let client_private_key = read_file_with_comments_removed(client_private_key_pem_file) + .await + .context(String::from( + "Error in TlsConfig::new. Unable to read private key file.", + ))?; + let server_domain = server_domain.to_string(); + let remote_addr = remote_addr.to_string(); + + Ok(TlsConfig { + fips_enabled, + ca_file_contents, + client_cert, + client_private_key, + remote_addr, + server_domain, + }) + } + + #[cfg(test)] + pub async fn new_from_config(config: &crate::ProxyConfig) -> Result { + let efs_config = &config.nested_config; + + let ca_file = Path::new(&efs_config.ca_file); + let ca_cert_pem = Path::new(&efs_config.client_cert_pem_file); + let private_key_pem = Path::new(&efs_config.client_private_key_pem_file); + if !ca_file.exists() || !ca_cert_pem.exists() || !private_key_pem.exists() { + let error_msg = "One or more required files for TLS config are missing"; + return Err(anyhow::Error::msg(error_msg)); + } + TlsConfig::new( + config.fips, + &ca_file, + &ca_cert_pem, + &private_key_pem, + efs_config.mount_target_addr.as_str(), + efs_config.expected_server_hostname_tls.as_str(), + ) + .await + } +} + +/// Establishes a TLS stream using the configuration and remote address specified in tls_config +pub async fn establish_tls_stream( + tls_config: TlsConfig, +) -> Result, ConnectError> { + let config = create_config_builder(&tls_config).build()?; + + let tls_connector = TlsConnector::new(config); + + let tcp_stream = TcpStream::connect(tls_config.remote_addr).await?; + + let tls_stream = tls_connector + .connect(&tls_config.server_domain, tcp_stream) + .await?; + + debug!("{:#?}", tls_stream); + Ok(tls_stream) +} + +fn create_config_builder(tls_config: &TlsConfig) -> s2n_tls::config::Builder { + let mut config = Config::builder(); + + let policy = if tls_config.fips_enabled { + Policy::from_version(FIPS_COMPLIANT_POLICY_VERSION).expect("Invalid policy") + } else { + DEFAULT_TLS13 + }; + config + .set_security_policy(&policy) + .expect("Error in create_tls_connector. Failed to set security policy."); + config + .set_client_auth_type(Optional) + .expect("Error in create_tls_connector. Failed to set client auth type."); + config + .load_pem(&tls_config.client_cert, &tls_config.client_private_key) + .expect( + "Error in create_tls_connector. Failed to load the client certificate and private key.", + ); + + // If the customer is using the verify=0 mount option, we want to disable cert verification. + if !tls_config.ca_file_contents.is_empty() { + config + .trust_pem(&tls_config.ca_file_contents) + .expect("Error in create_tls_connector. Failed to add the CA file to the trust store."); + } else { + unsafe { + config + .disable_x509_verification() + .expect("Error disabling x509 verification"); + }; + } + + // If stunnel_check_cert_hostname = false in efs-utils config, then we don't verify the hostname + if tls_config.server_domain.is_empty() { + config + .set_verify_host_callback(InsecureAcceptAllCertificatesHandler) + .expect("Unable to disable host name verification"); + } + + config +} + +#[cfg(test)] +pub mod tests { + + use crate::config_parser::tests::get_test_config; + + use super::*; + + pub async fn get_client_config() -> Result { + let tls_config = TlsConfig::new_from_config(&get_test_config()).await?; + let builder = create_config_builder(&tls_config); + + let config = builder.build()?; + Ok(config) + } + + pub async fn get_server_config() -> Result { + let tls_config = TlsConfig::new_from_config(&get_test_config()).await?; + let mut builder = create_config_builder(&tls_config); + + // Accept all client certificates + builder.set_verify_host_callback(InsecureAcceptAllCertificatesHandler {})?; + + let config = builder.build()?; + Ok(config) + } + + #[tokio::test] + async fn test_remove_comments() { + let comment_file = Path::new("tests/certs/cert_with_comments.pem"); + let decommented_output = read_file_with_comments_removed(comment_file).await; + + let expected = tokio::fs::read(&Path::new("tests/certs/cert.pem")) + .await + .expect("Could not read certificate file"); + assert_eq!(expected.len(), decommented_output.unwrap().len()); + } +} diff --git a/src/watchdog/__init__.py b/src/watchdog/__init__.py index 0e9c439c..41d49f1d 100755 --- a/src/watchdog/__init__.py +++ b/src/watchdog/__init__.py @@ -56,7 +56,7 @@ AMAZON_LINUX_2_RELEASE_ID, AMAZON_LINUX_2_PRETTY_NAME, ] -VERSION = "1.36.0" +VERSION = "2.0.0" SERVICE = "elasticfilesystem" CONFIG_FILE = "/etc/amazon/efs/efs-utils.conf" @@ -180,6 +180,9 @@ SYSTEM_RELEASE_PATH = "/etc/system-release" OS_RELEASE_PATH = "/etc/os-release" STUNNEL_INSTALLATION_MESSAGE = "Please install it following the instructions at: https://docs.aws.amazon.com/efs/latest/ug/using-amazon-efs-utils.html#upgrading-stunnel" +EFS_PROXY_INSTALLATION_MESSAGE = "Please install it by reinstalling amazon-efs-utils" + +EFS_PROXY_BIN = "efs-proxy" def fatal_error(user_message, log_message=None): @@ -798,9 +801,10 @@ def get_pid_in_state_dir(state_file, state_file_dir): def is_mount_stunnel_proc_running(state_pid, state_file, state_file_dir): """ - Check whether a given stunnel process id in state file is running for the mount. To avoid we incorrectly checking - processes running by other applications and send signal further, the stunnel process in state file is counted as - running iff: + Check whether the PID in the state file corresponds to a running efs-proxy/stunnel process. + Although this code was originally written to check if stunnel is running, we've modified + it to support the efs-proxy process as well. + The proxy or stunnel process is counted as running iff: 1. The pid in state file is not None. 2. The process running with the pid is a stunnel process. This is validated through process command name. 3. The process can be reached via os.kill(pid, 0). @@ -818,9 +822,11 @@ def is_mount_stunnel_proc_running(state_pid, state_file, state_file_dir): return False process_name = check_process_name(state_pid) - if not process_name or "stunnel" not in str(process_name): + if not process_name or ( + "efs-proxy" not in str(process_name) and "stunnel" not in str(process_name) + ): logging.debug( - "Process running on %s is not a stunnel process, full command: %s.", + "Process running on %s is not an efs-proxy or stunnel process, full command: %s.", state_pid, str(process_name) if process_name else "", ) @@ -828,7 +834,7 @@ def is_mount_stunnel_proc_running(state_pid, state_file, state_file_dir): if not is_pid_running(state_pid): logging.debug( - "Stunnel process with pid %s is not running anymore for %s.", + "Stunnel or efs-proxy process with pid %s is not running anymore for %s.", state_pid, state_file, ) @@ -942,11 +948,39 @@ def update_stunnel_command_for_ecs_amazon_linux_2( return command +def command_uses_efs_proxy(command): + """ + Accepts a list of strings which represents the command that was used + to start or efs-proxy. If the command contains efs-proxy, return True. + + Since we control the filepath in which the efs-proxy executable is stored, we + know that we will not run into situations where a directory on the filepath is named + efs-proxy but the executable command is something else, like stunnel. + """ + for i in range(len(command)): + if EFS_PROXY_BIN in command[i]: + return True + + return False + + def start_tls_tunnel(child_procs, state, state_file_dir, state_file): - # launch the tunnel in a process group so if it has any child processes, they can be killed easily + """ + Reads the command from the state file, and uses it to start a subprocess. + This is the command that efs-utils used to spin up the efs-proxy or stunnel process. + + We launch the stunnel and efs-proxy process in a process group so that child processes can be easily killed. + :param child_procs: list that contains efs-proxy / stunnel processes that the Watchdog instance has spawned + :param state: the state corresponding to a given mount - the proxy process associated with this mount will be started + :param state_file_dir: the directory where mount state files are stored + :param state_file: this function may rewrite the command used to start up the proxy or stunnel process, and thus needs a handle on the state file to update it. + :return: the pid of the proxy or stunnel process that was spawned + """ command = state["cmd"] logging.info('Starting TLS tunnel: "%s"', " ".join(command)) + efs_proxy_enabled = command_uses_efs_proxy(command) + command = update_stunnel_command_for_ecs_amazon_linux_2( command, state, state_file_dir, state_file ) @@ -960,44 +994,57 @@ def start_tls_tunnel(child_procs, state, state_file_dir, state_file): close_fds=True, ) except FileNotFoundError as e: - logging.warning("Watchdog failed to start stunnel due to %s", e) + if efs_proxy_enabled: + logging.warning("Watchdog failed to start efs-proxy due to %s", e) + else: + logging.warning("Watchdog failed to start stunnel due to %s", e) + + # /~https://github.com/kubernetes-sigs/aws-efs-csi-driver/issues/812 It is possible that the stunnel is not + # present anymore and replaced by stunnel5 on AL2, meanwhile watchdog is attempting to restart stunnel for + # mount using old efs-utils based on old state file generated during previous mount, which has stale command + # using stunnel bin. Update the state file if the stunnel does not exist anymore, and use stunnel5 on Al2. + # + if get_system_release_version() in AMAZON_LINUX_2_RELEASE_VERSIONS: + for i in range(len(command)): + if "stunnel" in command[i] and "stunnel-config" not in command[i]: + command[i] = find_command_path( + "stunnel5", STUNNEL_INSTALLATION_MESSAGE + ) + break + + tunnel = subprocess.Popen( + command, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + preexec_fn=os.setsid, + close_fds=True, + ) - # /~https://github.com/kubernetes-sigs/aws-efs-csi-driver/issues/812 It is possible that the stunnel is not - # present anymore and replaced by stunnel5 on AL2, meanwhile watchdog is attempting to restart stunnel for - # mount using old efs-utils based on old state file generated during previous mount, which has stale command - # using stunnel bin. Update the state file if the stunnel does not exist anymore, and use stunnel5 on Al2. - # - if get_system_release_version() in AMAZON_LINUX_2_RELEASE_VERSIONS: - for i in range(len(command)): - if "stunnel" in command[i] and "stunnel-config" not in command[i]: - command[i] = find_command_path( - "stunnel5", STUNNEL_INSTALLATION_MESSAGE - ) - break - - tunnel = subprocess.Popen( - command, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - preexec_fn=os.setsid, - close_fds=True, - ) + state["cmd"] = command + logging.info( + "Rewriting %s with new stunnel cmd: %s for Amazon Linux 2 platform.", + state_file, + " ".join(state["cmd"]), + ) + rewrite_state_file(state, state_file_dir, state_file) - state["cmd"] = command - logging.info( - "Rewriting %s with new stunnel cmd: %s for Amazon Linux 2 platform.", - state_file, - " ".join(state["cmd"]), - ) - rewrite_state_file(state, state_file_dir, state_file) + # We may have used either stunnel or efs-proxy as the TLS tunnel. + # We want to make it clear in the logs which was used. + tunnel_process_name = "stunnel" + if efs_proxy_enabled: + tunnel_process_name = "efs-proxy" if tunnel is None or not is_pid_running(tunnel.pid): fatal_error( - "Failed to initialize TLS tunnel for %s" % state_file, - "Failed to start TLS tunnel.", + "Failed to initialize %s for %s" % (tunnel_process_name, state_file), + "Failed to start %s." % tunnel_process_name, + ) + fatal_error( + "Failed to initialize %s for %s" % (tunnel_process_name, state_file), + "Failed to start %s." % tunnel_process_name, ) - logging.info("Started TLS tunnel, pid: %d", tunnel.pid) + logging.info("Started %s, pid: %d", tunnel_process_name, tunnel.pid) child_procs.append(tunnel) return tunnel.pid @@ -1148,18 +1195,6 @@ def check_efs_mounts( if is_mount_stunnel_proc_running( state.get("pid"), state_file, state_file_dir ): - # /~https://github.com/kubernetes-sigs/aws-efs-csi-driver/issues/616 We have seen EFS hanging issue caused - # by stuck stunnel (version: 4.56) process. Apart from checking whether stunnel is running or not, we - # need to check whether the stunnel connection established is healthy periodically. - # - # The way to check the stunnel health is by `df` the mountpoint, i.e. check the file system information, - # which will trigger a remote GETATTR on the root of the file system. Normally the command will finish - # in 10 milliseconds, thus if the command hang for certain period (defined as 30 sec as of now), the - # stunnel connection is likely to be unhealthy. Watchdog will kill the old stunnel process and restart - # a new one for the unhealthy mount. The health check will run every 5 min since mount. - # - # Both the command hang timeout and health check interval are configurable in efs-utils config file. - # check_stunnel_health( config, state, state_file_dir, state_file, child_procs, nfs_mounts ) @@ -1171,6 +1206,21 @@ def check_efs_mounts( def check_stunnel_health( config, state, state_file_dir, state_file, child_procs, nfs_mounts ): + """ + Check the health of efs-proxy, or stunnel (older versions of efs-utils), by executing `df` on the mountpoint. + + /~https://github.com/kubernetes-sigs/aws-efs-csi-driver/issues/616 We have seen EFS hanging issue caused + by stuck stunnel (version: 4.56) process. Apart from checking whether stunnel is running or not, we + need to check whether the stunnel connection established is healthy periodically. + + The way to check the stunnel health is by `df` the mountpoint, i.e. check the file system information, + which will trigger a remote GETATTR on the root of the file system. Normally the command will finish + in 10 milliseconds, thus if the command hang for certain period (defined as 30 sec as of now), the + stunnel connection is likely to be unhealthy. Watchdog will kill the old stunnel process and restart + a new one for the unhealthy mount. The health check will run every 5 min since mount. + + Both the command hang timeout and health check interval are configurable in efs-utils config file. + """ if not get_boolean_config_item_value( config, CONFIG_SECTION, "stunnel_health_check_enabled", default_value=True ): diff --git a/test/mount_efs_test/test_add_stunnel_ca_options.py b/test/mount_efs_test/test_add_stunnel_ca_options.py index 5ac7ccd2..f528c439 100644 --- a/test/mount_efs_test/test_add_stunnel_ca_options.py +++ b/test/mount_efs_test/test_add_stunnel_ca_options.py @@ -41,7 +41,7 @@ def test_use_existing_cafile(tmpdir): options = {"cafile": str(_create_temp_file(tmpdir))} efs_config = {} - mount_efs.add_stunnel_ca_options(efs_config, _get_config(), options, DEFAULT_REGION) + mount_efs.add_tunnel_ca_options(efs_config, _get_config(), options, DEFAULT_REGION) assert options["cafile"] == efs_config.get("CAfile") assert "CApath" not in efs_config @@ -52,7 +52,7 @@ def test_use_missing_cafile(capsys): efs_config = {} with pytest.raises(SystemExit) as ex: - mount_efs.add_stunnel_ca_options( + mount_efs.add_tunnel_ca_options( efs_config, _get_config(), options, DEFAULT_REGION ) @@ -68,7 +68,7 @@ def test_stunnel_cafile_configuration_in_option(mocker): mocker.patch("os.path.exists", return_value=True) - mount_efs.add_stunnel_ca_options(efs_config, _get_config(), options, DEFAULT_REGION) + mount_efs.add_tunnel_ca_options(efs_config, _get_config(), options, DEFAULT_REGION) assert CAFILE == efs_config.get("CAfile") @@ -82,7 +82,7 @@ def test_stunnel_cafile_configuration_in_config(mocker): mocker.patch("os.path.exists", return_value=True) - mount_efs.add_stunnel_ca_options(efs_config, config, options, DEFAULT_REGION) + mount_efs.add_tunnel_ca_options(efs_config, config, options, DEFAULT_REGION) assert CAFILE == efs_config.get("CAfile") @@ -93,7 +93,7 @@ def test_stunnel_cafile_not_configured(mocker): mocker.patch("os.path.exists", return_value=True) - mount_efs.add_stunnel_ca_options(efs_config, _get_config(), options, DEFAULT_REGION) + mount_efs.add_tunnel_ca_options(efs_config, _get_config(), options, DEFAULT_REGION) assert mount_efs.DEFAULT_STUNNEL_CAFILE == efs_config.get("CAfile") @@ -110,6 +110,6 @@ def test_stunnel_cafile_configured_in_mount_region_section(mocker): mocker.patch("os.path.exists", return_value=True) - mount_efs.add_stunnel_ca_options(efs_config, config, options, ISOLATED_REGION) + mount_efs.add_tunnel_ca_options(efs_config, config, options, ISOLATED_REGION) assert ISOLATED_REGION_STUNNEL_CAFILE == efs_config.get("CAfile") diff --git a/test/mount_efs_test/test_bootstrap_tls.py b/test/mount_efs_test/test_bootstrap_proxy.py similarity index 51% rename from test/mount_efs_test/test_bootstrap_tls.py rename to test/mount_efs_test/test_bootstrap_proxy.py index e3060766..d56f3e53 100644 --- a/test/mount_efs_test/test_bootstrap_tls.py +++ b/test/mount_efs_test/test_bootstrap_proxy.py @@ -40,15 +40,17 @@ def setup_mocks(mocker): return_value=(DNS_NAME, None), ) mocker.patch("mount_efs.get_target_region", return_value=REGION) - mocker.patch("mount_efs.write_tls_tunnel_state_file", return_value="~mocktempfile") + mocker.patch("mount_efs.write_tunnel_state_file", return_value="~mocktempfile") mocker.patch("mount_efs.create_certificate") mocker.patch("os.rename") mocker.patch("os.kill") mocker.patch( - "mount_efs.update_tls_tunnel_temp_state_file_with_tunnel_pid", + "mount_efs.update_tunnel_temp_state_file_with_tunnel_pid", return_value="~mocktempfile", ) + mocker.patch("mount_efs.get_efs_proxy_log_level", return_value="info") + process_mock = MagicMock() process_mock.communicate.return_value = ( "stdout", @@ -74,10 +76,10 @@ def setup_mocks_without_popen(mocker): "mount_efs.get_dns_name_and_fallback_mount_target_ip_address", return_value=(DNS_NAME, None), ) - mocker.patch("mount_efs.write_tls_tunnel_state_file", return_value="~mocktempfile") + mocker.patch("mount_efs.write_tunnel_state_file", return_value="~mocktempfile") mocker.patch("os.kill") mocker.patch( - "mount_efs.update_tls_tunnel_temp_state_file_with_tunnel_pid", + "mount_efs.update_tunnel_temp_state_file_with_tunnel_pid", return_value="~mocktempfile", ) @@ -87,12 +89,12 @@ def setup_mocks_without_popen(mocker): return write_config_mock -def test_bootstrap_tls_state_file_dir_exists(mocker, tmpdir): +def test_bootstrap_proxy_state_file_dir_exists(mocker, tmpdir): popen_mock, _ = setup_mocks(mocker) state_file_dir = str(tmpdir) - - mocker.patch("mount_efs._stunnel_bin", return_value="/usr/bin/stunnel") - with mount_efs.bootstrap_tls( + mocker.patch("mount_efs.is_ocsp_enabled", return_value=False) + mocker.patch("mount_efs._efs_proxy_bin", return_value="/usr/bin/efs-proxy") + with mount_efs.bootstrap_proxy( MOCK_CONFIG, INIT_SYSTEM, DNS_NAME, FS_ID, MOUNT_POINT, {}, state_file_dir ): pass @@ -100,11 +102,11 @@ def test_bootstrap_tls_state_file_dir_exists(mocker, tmpdir): args, _ = popen_mock.call_args args = args[0] - assert "/usr/bin/stunnel" in args + assert "/usr/bin/efs-proxy" in args assert EXPECTED_STUNNEL_CONFIG_FILE in args -def test_bootstrap_tls_state_file_nonexistent_dir(mocker, tmpdir): +def test_bootstrap_proxy_state_file_nonexistent_dir(mocker, tmpdir): popen_mock, _ = setup_mocks(mocker) state_file_dir = str(tmpdir.join(tempfile.mkdtemp()[1])) @@ -122,9 +124,10 @@ def config_get_side_effect(section, field): assert not os.path.exists(state_file_dir) - mocker.patch("mount_efs._stunnel_bin", return_value="/usr/bin/stunnel") + mocker.patch("mount_efs.is_ocsp_enabled", return_value=False) + mocker.patch("mount_efs._efs_proxy_bin", return_value="/usr/bin/efs-proxy") mocker.patch("mount_efs.find_existing_mount_using_tls_port", return_value=None) - with mount_efs.bootstrap_tls( + with mount_efs.bootstrap_proxy( MOCK_CONFIG, INIT_SYSTEM, DNS_NAME, FS_ID, MOUNT_POINT, {}, state_file_dir ): pass @@ -132,13 +135,13 @@ def config_get_side_effect(section, field): assert os.path.exists(state_file_dir) -def test_bootstrap_tls_cert_created(mocker, tmpdir): +def test_bootstrap_proxy_cert_created_tls_mount(mocker, tmpdir): setup_mocks_without_popen(mocker) mocker.patch("mount_efs.get_mount_specific_filename", return_value=DNS_NAME) mocker.patch("mount_efs.get_target_region", return_value=REGION) state_file_dir = str(tmpdir) tls_dict = mount_efs.tls_paths_dictionary(DNS_NAME + "+", state_file_dir) - + mocker.patch("mount_efs.is_ocsp_enabled", return_value=False) pk_path = os.path.join(str(tmpdir), "privateKey.pem") mocker.patch("mount_efs.get_private_key_path", return_value=pk_path) @@ -147,6 +150,8 @@ def config_get_side_effect(section, field): return "0755" elif section == mount_efs.CONFIG_SECTION and field == "dns_name_format": return "{fs_id}.efs.{region}.amazonaws.com" + elif section == mount_efs.CONFIG_SECTION and field == "logging_level": + return "info" elif section == mount_efs.CLIENT_INFO_SECTION and field == "source": return CLIENT_SOURCE else: @@ -154,15 +159,15 @@ def config_get_side_effect(section, field): MOCK_CONFIG.get.side_effect = config_get_side_effect - mocker.patch("mount_efs._stunnel_bin", return_value="/usr/bin/stunnel") + mocker.patch("mount_efs._efs_proxy_bin", return_value="/usr/bin/efs-proxy") try: - with mount_efs.bootstrap_tls( + with mount_efs.bootstrap_proxy( MOCK_CONFIG, INIT_SYSTEM, DNS_NAME, FS_ID, MOUNT_POINT, - {"accesspoint": AP_ID}, + {"accesspoint": AP_ID, "tls": None}, state_file_dir, ): pass @@ -175,7 +180,53 @@ def config_get_side_effect(section, field): assert os.path.exists(pk_path) -def test_bootstrap_tls_non_default_port(mocker, tmpdir): +def test_bootstrap_proxy_cert_not_created_non_tls_mount(mocker, tmpdir): + setup_mocks_without_popen(mocker) + mocker.patch("mount_efs.get_mount_specific_filename", return_value=DNS_NAME) + mocker.patch("mount_efs.get_target_region", return_value=REGION) + state_file_dir = str(tmpdir) + tls_dict = mount_efs.tls_paths_dictionary(DNS_NAME + "+", state_file_dir) + + pk_path = os.path.join(str(tmpdir), "privateKey.pem") + mocker.patch("mount_efs.get_private_key_path", return_value=pk_path) + + def config_get_side_effect(section, field): + if section == mount_efs.CONFIG_SECTION and field == "state_file_dir_mode": + return "0755" + elif section == mount_efs.CONFIG_SECTION and field == "dns_name_format": + return "{fs_id}.efs.{region}.amazonaws.com" + elif section == mount_efs.CONFIG_SECTION and field == "logging_level": + return "info" + elif section == mount_efs.CLIENT_INFO_SECTION and field == "source": + return CLIENT_SOURCE + else: + raise ValueError("Unexpected arguments") + + MOCK_CONFIG.get.side_effect = config_get_side_effect + + mocker.patch("mount_efs.is_ocsp_enabled", return_value=False) + mocker.patch("mount_efs._efs_proxy_bin", return_value="/usr/bin/efs-proxy") + try: + with mount_efs.bootstrap_proxy( + MOCK_CONFIG, + INIT_SYSTEM, + DNS_NAME, + FS_ID, + MOUNT_POINT, + {"accesspoint": AP_ID}, + state_file_dir, + ): + pass + except OSError as e: + assert "[Errno 2] No such file or directory" in str(e) + + assert not os.path.exists(os.path.join(tls_dict["mount_dir"], "certificate.pem")) + assert not os.path.exists(os.path.join(tls_dict["mount_dir"], "request.csr")) + assert not os.path.exists(os.path.join(tls_dict["mount_dir"], "config.conf")) + assert not os.path.exists(pk_path) + + +def test_bootstrap_proxy_non_default_port(mocker, tmpdir): popen_mock, write_config_mock = setup_mocks(mocker) mocker.patch("os.rename") state_file_dir = str(tmpdir) @@ -185,9 +236,9 @@ def test_bootstrap_tls_non_default_port(mocker, tmpdir): tls_port_sock_mock.getsockname.return_value = ("local_host", tls_port) tls_port_sock_mock.close.side_effect = None mocker.patch("socket.socket", return_value=tls_port_sock_mock) - - mocker.patch("mount_efs._stunnel_bin", return_value="/usr/bin/stunnel") - with mount_efs.bootstrap_tls( + mocker.patch("mount_efs.is_ocsp_enabled", return_value=False) + mocker.patch("mount_efs._efs_proxy_bin", return_value="/usr/bin/efs-proxy") + with mount_efs.bootstrap_proxy( MOCK_CONFIG, INIT_SYSTEM, DNS_NAME, @@ -202,29 +253,55 @@ def test_bootstrap_tls_non_default_port(mocker, tmpdir): popen_args = popen_args[0] write_config_args, _ = write_config_mock.call_args - assert "/usr/bin/stunnel" in popen_args + assert "/usr/bin/efs-proxy" in popen_args assert EXPECTED_STUNNEL_CONFIG_FILE in popen_args assert tls_port == write_config_args[4] # positional argument for tls_port - # Ensure tls port socket is closed in bootstrap_tls + # Ensure tls port socket is closed in bootstrap_proxy # The number is two here, the first one is the actual socket when choosing tls port, the second one is a socket to # verify tls port can be connected after establishing TLS stunnel. They share the same mock. assert 2 == tls_port_sock_mock.close.call_count -def test_bootstrap_tls_non_default_verify_level(mocker, tmpdir): +def test_bootstrap_proxy_non_tls_verify_ignored(mocker, tmpdir): popen_mock, write_config_mock = setup_mocks(mocker) state_file_dir = str(tmpdir) + mocker.patch("mount_efs.is_ocsp_enabled", return_value=False) + mocker.patch("mount_efs._efs_proxy_bin", return_value="/usr/bin/efs-proxy") + with mount_efs.bootstrap_proxy( + MOCK_CONFIG, + INIT_SYSTEM, + DNS_NAME, + FS_ID, + MOUNT_POINT, + {}, + state_file_dir, + ): + pass + popen_args, _ = popen_mock.call_args + popen_args = popen_args[0] + write_config_args, _ = write_config_mock.call_args + + assert "/usr/bin/efs-proxy" in popen_args + assert EXPECTED_STUNNEL_CONFIG_FILE in popen_args + assert None == write_config_args[6] # positional argument for verify_level + + +def test_bootstrap_proxy_non_default_verify_level_stunnel(mocker, tmpdir): + popen_mock, write_config_mock = setup_mocks(mocker) + state_file_dir = str(tmpdir) + mocker.patch("mount_efs.is_ocsp_enabled", return_value=False) verify = 0 mocker.patch("mount_efs._stunnel_bin", return_value="/usr/bin/stunnel") - with mount_efs.bootstrap_tls( + with mount_efs.bootstrap_proxy( MOCK_CONFIG, INIT_SYSTEM, DNS_NAME, FS_ID, MOUNT_POINT, - {"verify": verify}, + {"verify": verify, "tls": None}, state_file_dir, + efs_proxy_enabled=False, ): pass @@ -237,12 +314,11 @@ def test_bootstrap_tls_non_default_verify_level(mocker, tmpdir): assert 0 == write_config_args[6] # positional argument for verify_level -def test_bootstrap_tls_ocsp_option(mocker, tmpdir): +def test_bootstrap_proxy_ocsp_option(mocker, tmpdir): popen_mock, write_config_mock = setup_mocks(mocker) state_file_dir = str(tmpdir) - mocker.patch("mount_efs._stunnel_bin", return_value="/usr/bin/stunnel") - with mount_efs.bootstrap_tls( + with mount_efs.bootstrap_proxy( MOCK_CONFIG, INIT_SYSTEM, DNS_NAME, @@ -250,6 +326,7 @@ def test_bootstrap_tls_ocsp_option(mocker, tmpdir): MOUNT_POINT, {"ocsp": None}, state_file_dir, + efs_proxy_enabled=False, ): pass @@ -263,12 +340,11 @@ def test_bootstrap_tls_ocsp_option(mocker, tmpdir): assert write_config_args[7] is True -def test_bootstrap_tls_noocsp_option(mocker, tmpdir): +def test_bootstrap_proxy_noocsp_option(mocker, tmpdir): popen_mock, write_config_mock = setup_mocks(mocker) state_file_dir = str(tmpdir) - mocker.patch("mount_efs._stunnel_bin", return_value="/usr/bin/stunnel") - with mount_efs.bootstrap_tls( + with mount_efs.bootstrap_proxy( MOCK_CONFIG, INIT_SYSTEM, DNS_NAME, @@ -276,6 +352,7 @@ def test_bootstrap_tls_noocsp_option(mocker, tmpdir): MOUNT_POINT, {"noocsp": None}, state_file_dir, + efs_proxy_enabled=False, ): pass @@ -287,3 +364,114 @@ def test_bootstrap_tls_noocsp_option(mocker, tmpdir): assert EXPECTED_STUNNEL_CONFIG_FILE in popen_args # positional argument for ocsp_override assert write_config_args[7] is False + + +def test_bootstrap_proxy_efs_proxy_enabled_tls(mocker, tmpdir): + popen_mock, _ = setup_mocks(mocker) + mocker.patch("os.rename") + state_file_dir = str(tmpdir) + mocker.patch("mount_efs.is_ocsp_enabled", return_value=False) + mocker.patch("mount_efs._efs_proxy_bin", return_value="/usr/bin/efs-proxy") + with mount_efs.bootstrap_proxy( + MOCK_CONFIG, + INIT_SYSTEM, + DNS_NAME, + FS_ID, + MOUNT_POINT, + {"tls": None}, + state_file_dir, + efs_proxy_enabled=True, + ): + pass + + popen_args, _ = popen_mock.call_args + popen_args = popen_args[0] + + assert "/usr/bin/efs-proxy" in popen_args + assert "--tls" in popen_args + assert EXPECTED_STUNNEL_CONFIG_FILE in popen_args + + +def test_bootstrap_proxy_efs_proxy_enabled_non_tls(mocker, tmpdir): + popen_mock, _ = setup_mocks(mocker) + mocker.patch("os.rename") + state_file_dir = str(tmpdir) + mocker.patch("mount_efs.is_ocsp_enabled", return_value=False) + mocker.patch("mount_efs._efs_proxy_bin", return_value="/usr/bin/efs-proxy") + with mount_efs.bootstrap_proxy( + MOCK_CONFIG, + INIT_SYSTEM, + DNS_NAME, + FS_ID, + MOUNT_POINT, + {}, + state_file_dir, + efs_proxy_enabled=True, + ): + pass + + popen_args, _ = popen_mock.call_args + popen_args = popen_args[0] + + assert "/usr/bin/stunnel" not in popen_args + assert "--tls" not in popen_args + + assert "/usr/bin/efs-proxy" in popen_args + assert EXPECTED_STUNNEL_CONFIG_FILE in popen_args + + +def test_bootstrap_proxy_stunnel_enabled(mocker, tmpdir): + popen_mock, _ = setup_mocks(mocker) + mocker.patch("os.rename") + state_file_dir = str(tmpdir) + + mocker.patch("mount_efs._stunnel_bin", return_value="/usr/bin/stunnel") + with mount_efs.bootstrap_proxy( + MOCK_CONFIG, + INIT_SYSTEM, + DNS_NAME, + FS_ID, + MOUNT_POINT, + {}, + state_file_dir, + efs_proxy_enabled=False, + ): + pass + + popen_args, _ = popen_mock.call_args + popen_args = popen_args[0] + + assert "/usr/bin/efs-proxy" not in popen_args + assert "info" not in popen_args + + assert "/usr/bin/stunnel" in popen_args + assert EXPECTED_STUNNEL_CONFIG_FILE in popen_args + + +def test_bootstrap_proxy_netns_option(mocker, tmpdir): + popen_mock, write_config_mock = setup_mocks(mocker) + state_file_dir = str(tmpdir) + + netns = "/proc/1/net/ns" + mocker.patch("mount_efs._efs_proxy_bin", return_value="/usr/bin/efs-proxy") + mocker.patch("mount_efs.NetNS") + mocker.patch("mount_efs.is_ocsp_enabled", return_value=False) + with mount_efs.bootstrap_proxy( + MOCK_CONFIG, + INIT_SYSTEM, + DNS_NAME, + FS_ID, + MOUNT_POINT, + {"netns": netns}, + state_file_dir, + ): + pass + + popen_args, _ = popen_mock.call_args + popen_args = popen_args[0] + write_config_args, _ = write_config_mock.call_args + + assert "/usr/bin/efs-proxy" in popen_args + assert EXPECTED_STUNNEL_CONFIG_FILE in popen_args + assert "nsenter" in popen_args + assert "--net=" + netns in popen_args diff --git a/test/mount_efs_test/test_get_nfs_mount_options.py b/test/mount_efs_test/test_get_nfs_mount_options.py index 4675f3d6..cf7d1d5d 100644 --- a/test/mount_efs_test/test_get_nfs_mount_options.py +++ b/test/mount_efs_test/test_get_nfs_mount_options.py @@ -6,10 +6,49 @@ from unittest.mock import MagicMock +try: + import ConfigParser +except ImportError: + from configparser import ConfigParser + import pytest import mount_efs +DEFAULT_OPTIONS = {"tlsport": "3030"} + + +def _get_config(ocsp_enabled=False): + try: + config = ConfigParser.SafeConfigParser() + except AttributeError: + config = ConfigParser() + + mount_nfs_command_retry_count = 4 + mount_nfs_command_retry_timeout = 10 + mount_nfs_command_retry = "false" + config.add_section(mount_efs.CONFIG_SECTION) + config.set( + mount_efs.CONFIG_SECTION, "retry_nfs_mount_command", mount_nfs_command_retry + ) + config.set( + mount_efs.CONFIG_SECTION, + "retry_nfs_mount_command_count", + str(mount_nfs_command_retry_count), + ) + config.set( + mount_efs.CONFIG_SECTION, + "retry_nfs_mount_command_timeout_sec", + str(mount_nfs_command_retry_timeout), + ) + if ocsp_enabled: + config.set( + mount_efs.CONFIG_SECTION, + "stunnel_check_cert_validity", + "true", + ) + return config + def _mock_popen(mocker, returncode=0, stdout="stdout", stderr="stderr"): popen_mock = MagicMock() @@ -23,7 +62,7 @@ def _mock_popen(mocker, returncode=0, stdout="stdout", stderr="stderr"): def test_get_default_nfs_mount_options(): - nfs_opts = mount_efs.get_nfs_mount_options({}) + nfs_opts = mount_efs.get_nfs_mount_options(dict(DEFAULT_OPTIONS), _get_config()) assert "nfsvers=4.1" in nfs_opts assert "rsize=1048576" in nfs_opts @@ -31,17 +70,22 @@ def test_get_default_nfs_mount_options(): assert "hard" in nfs_opts assert "timeo=600" in nfs_opts assert "retrans=2" in nfs_opts + assert "port=3030" in nfs_opts def test_override_nfs_version(): - nfs_opts = mount_efs.get_nfs_mount_options({"nfsvers": 4.0}) + options = dict(DEFAULT_OPTIONS) + options["nfsvers"] = 4.0 + nfs_opts = mount_efs.get_nfs_mount_options(options, _get_config()) assert "nfsvers=4.0" in nfs_opts assert "nfsvers=4.1" not in nfs_opts def test_override_nfs_version_alternate_option(): - nfs_opts = mount_efs.get_nfs_mount_options({"vers": 4.0}) + options = dict(DEFAULT_OPTIONS) + options["vers"] = 4.0 + nfs_opts = mount_efs.get_nfs_mount_options(options, _get_config()) assert "vers=4.0" in nfs_opts assert "nfsvers=4.0" not in nfs_opts @@ -49,21 +93,27 @@ def test_override_nfs_version_alternate_option(): def test_override_rsize(): - nfs_opts = mount_efs.get_nfs_mount_options({"rsize": 1}) + options = dict(DEFAULT_OPTIONS) + options["rsize"] = 1 + nfs_opts = mount_efs.get_nfs_mount_options(options, _get_config()) assert "rsize=1" in nfs_opts assert "rsize=1048576" not in nfs_opts def test_override_wsize(): - nfs_opts = mount_efs.get_nfs_mount_options({"wsize": 1}) + options = dict(DEFAULT_OPTIONS) + options["wsize"] = 1 + nfs_opts = mount_efs.get_nfs_mount_options(options, _get_config()) assert "wsize=1" in nfs_opts assert "wsize=1048576" not in nfs_opts def test_override_recovery_soft(): - nfs_opts = mount_efs.get_nfs_mount_options({"soft": None}) + options = dict(DEFAULT_OPTIONS) + options["soft"] = None + nfs_opts = mount_efs.get_nfs_mount_options(options, _get_config()) assert "soft" in nfs_opts assert "soft=" not in nfs_opts @@ -71,35 +121,43 @@ def test_override_recovery_soft(): def test_override_timeo(): - nfs_opts = mount_efs.get_nfs_mount_options({"timeo": 1}) + options = dict(DEFAULT_OPTIONS) + options["timeo"] = 1 + nfs_opts = mount_efs.get_nfs_mount_options(options, _get_config()) assert "timeo=1" in nfs_opts assert "timeo=600" not in nfs_opts def test_override_retrans(): - nfs_opts = mount_efs.get_nfs_mount_options({"retrans": 1}) + options = dict(DEFAULT_OPTIONS) + options["retrans"] = 1 + nfs_opts = mount_efs.get_nfs_mount_options(options, _get_config()) assert "retrans=1" in nfs_opts assert "retrans=2" not in nfs_opts def test_tlsport(): - nfs_opts = mount_efs.get_nfs_mount_options({"tls": None, "tlsport": 3030}) + options = dict(DEFAULT_OPTIONS) + options["tls"] = None + nfs_opts = mount_efs.get_nfs_mount_options(options, _get_config()) assert "port=3030" in nfs_opts assert "tls" not in nfs_opts def test_fsap_efs_only(): - nfs_opts = mount_efs.get_nfs_mount_options({"fsap": None}) + options = dict(DEFAULT_OPTIONS) + options["fsap"] = None + nfs_opts = mount_efs.get_nfs_mount_options(options, _get_config()) assert "fsap" not in nfs_opts def test_get_default_nfs_mount_options_macos(mocker): mocker.patch("mount_efs.check_if_platform_is_mac", return_value=True) - nfs_opts = mount_efs.get_nfs_mount_options({}) + nfs_opts = mount_efs.get_nfs_mount_options(dict(DEFAULT_OPTIONS), _get_config()) assert "nfsvers=4.0" in nfs_opts assert "rsize=1048576" in nfs_opts @@ -108,13 +166,14 @@ def test_get_default_nfs_mount_options_macos(mocker): assert "timeo=600" in nfs_opts assert "retrans=2" in nfs_opts assert "mountport=2049" in nfs_opts + assert not "port=3030" in nfs_opts def _test_unsupported_mount_options_macos(mocker, capsys, options={}): mocker.patch("mount_efs.check_if_platform_is_mac", return_value=True) _mock_popen(mocker, stdout="nfs") with pytest.raises(SystemExit) as ex: - mount_efs.get_nfs_mount_options(options) + mount_efs.get_nfs_mount_options(options, _get_config()) assert 0 != ex.value.code diff --git a/test/mount_efs_test/test_main.py b/test/mount_efs_test/test_main.py index 73321ab3..789c0b68 100644 --- a/test/mount_efs_test/test_main.py +++ b/test/mount_efs_test/test_main.py @@ -44,6 +44,8 @@ def _test_main( awscredsuri=None, notls=False, crossaccount=False, + stunnel=False, + macos=False, ): options = {} @@ -69,6 +71,8 @@ def _test_main( options["awscredsuri"] = awscredsuri if crossaccount: options["crossaccount"] = None + if stunnel: + options["stunnel"] = None if root: mocker.patch("os.geteuid", return_value=0) @@ -90,8 +94,8 @@ def _test_main( parse_arguments_mock = mocker.patch( "mount_efs.parse_arguments", return_value=("fs-deadbeef", "/", "/mnt", options) ) - bootstrap_tls_mock = mocker.patch( - "mount_efs.bootstrap_tls", side_effect=dummy_contextmanager + bootstrap_proxy_mock = mocker.patch( + "mount_efs.bootstrap_proxy", side_effect=dummy_contextmanager ) if tls: @@ -106,10 +110,19 @@ def _test_main( utils.assert_called_once(parse_arguments_mock) utils.assert_called_once(mount_mock) - if tls: - utils.assert_called_once(bootstrap_tls_mock) + stunnel_mode_enabled = stunnel or macos or ocsp + + if stunnel_mode_enabled: + if tls: + utils.assert_called_once(bootstrap_proxy_mock) + kwargs = bootstrap_proxy_mock.call_args[1] + assert kwargs["efs_proxy_enabled"] is False + else: + utils.assert_not_called(bootstrap_proxy_mock) else: - utils.assert_not_called(bootstrap_tls_mock) + utils.assert_called_once(bootstrap_proxy_mock) + kwargs = bootstrap_proxy_mock.call_args[1] + assert kwargs["efs_proxy_enabled"] is True def _test_main_assert_error(mocker, capsys, expected_err, **kwargs): @@ -128,7 +141,7 @@ def _test_main_macos(mocker, is_supported_macos_version, **kwargs): "mount_efs.check_if_mac_version_is_supported", return_value=is_supported_macos_version, ) - _test_main(mocker, **kwargs) + _test_main(mocker, macos=True, **kwargs) def _test_main_macos_assert_error( @@ -232,8 +245,12 @@ def test_main_awscredsuri_without_iam(mocker, capsys): ) -def test_main_tls_ocsp_option(mocker): - _test_main(mocker, tls=True, ocsp=True, tlsport=TLS_PORT) +def test_main_tls_ocsp_option_with_stunnel(mocker): + _test_main(mocker, tls=True, ocsp=True, stunnel=True, tlsport=TLS_PORT) + + +def test_main_tls_ocsp_option_should_revert_to_stunnel(mocker): + _test_main(mocker, tls=True, ocsp=True, stunnel=False, tlsport=TLS_PORT) def test_main_tls_noocsp_option(mocker): diff --git a/test/mount_efs_test/test_mount_nfs.py b/test/mount_efs_test/test_mount_nfs.py index 5e2782cf..d9070846 100644 --- a/test/mount_efs_test/test_mount_nfs.py +++ b/test/mount_efs_test/test_mount_nfs.py @@ -54,6 +54,8 @@ NETNS = "/proc/1/net/ns" +LOCAL_HOST = "127.0.0.1" + def _get_config( mount_nfs_command_retry="true", @@ -107,6 +109,30 @@ def test_mount_nfs(mocker): args, _ = mock.call_args args = args[0] + assert "/sbin/mount.nfs4" == args[NFS_BIN_ARG_IDX] + assert LOCAL_HOST in args[NFS_MOUNT_PATH_IDX] + assert "/mnt" == args[NFS_MOUNT_POINT_IDX] + + utils.assert_called_once(optimize_readahead_window_mock) + + +def test_mount_nfs_stunnel_enabled(mocker): + mock = _mock_popen(mocker) + optimize_readahead_window_mock = mocker.patch("mount_efs.optimize_readahead_window") + options = dict(DEFAULT_OPTIONS) + options["stunnel"] = None + + mount_efs.mount_nfs( + _get_config(mount_nfs_command_retry="false"), + DNS_NAME, + "/", + "/mnt", + options, + ) + + args, _ = mock.call_args + args = args[0] + assert "/sbin/mount.nfs4" == args[NFS_BIN_ARG_IDX] assert DNS_NAME in args[NFS_MOUNT_PATH_IDX] assert "/mnt" == args[NFS_MOUNT_POINT_IDX] @@ -114,16 +140,18 @@ def test_mount_nfs(mocker): utils.assert_called_once(optimize_readahead_window_mock) -def test_mount_nfs_with_fallback_ip_address(mocker): +def test_mount_nfs_stunnel_with_fallback_ip_address(mocker): mock = _mock_popen(mocker) optimize_readahead_window_mock = mocker.patch("mount_efs.optimize_readahead_window") + options = dict(DEFAULT_OPTIONS) + options["stunnel"] = None mount_efs.mount_nfs( _get_config(mount_nfs_command_retry="false"), DNS_NAME, "/", "/mnt", - DEFAULT_OPTIONS, + options, fallback_ip_address=FALLBACK_IP_ADDRESS, ) @@ -138,12 +166,13 @@ def test_mount_nfs_with_fallback_ip_address(mocker): utils.assert_called_once(optimize_readahead_window_mock) -def test_mount_nfs_tls(mocker): +def test_mount_nfs_tls_stunnel_enabled(mocker): mock = _mock_popen(mocker) optimize_readahead_window_mock = mocker.patch("mount_efs.optimize_readahead_window") options = dict(DEFAULT_OPTIONS) options["tls"] = None + options["stunnel"] = None mount_efs.mount_nfs( _get_config(mount_nfs_command_retry="false"), DNS_NAME, "/", "/mnt", options @@ -205,11 +234,11 @@ def test_mount_tls_mountpoint_mounted_with_nfs(mocker, capsys): options = dict(DEFAULT_OPTIONS) options["tls"] = None - bootstrap_tls_mock = mocker.patch("mount_efs.bootstrap_tls") + bootstrap_proxy_mock = mocker.patch("mount_efs.bootstrap_proxy") optimize_readahead_window_mock = mocker.patch("mount_efs.optimize_readahead_window") mocker.patch("os.path.ismount", return_value=True) _mock_popen(mocker, stdout="nfs") - mount_efs.mount_tls( + mount_efs.mount_with_proxy( _get_config(mount_nfs_command_retry="false"), INIT_SYSTEM, DNS_NAME, @@ -220,7 +249,7 @@ def test_mount_tls_mountpoint_mounted_with_nfs(mocker, capsys): ) out, err = capsys.readouterr() assert "is already mounted" in out - utils.assert_not_called(bootstrap_tls_mock) + utils.assert_not_called(bootstrap_proxy_mock) utils.assert_not_called(optimize_readahead_window_mock) diff --git a/test/mount_efs_test/test_mount_with_proxy.py b/test/mount_efs_test/test_mount_with_proxy.py new file mode 100644 index 00000000..ad5821c7 --- /dev/null +++ b/test/mount_efs_test/test_mount_with_proxy.py @@ -0,0 +1,221 @@ +import subprocess +from unittest.mock import MagicMock + +import pytest + +import mount_efs + +from .. import common, utils + +try: + import ConfigParser +except ImportError: + from configparser import ConfigParser + +try: + import ConfigParser +except ImportError: + from configparser import ConfigParser + +DNS_NAME = "fs-deadbeef.efs.us-east-1.amazonaws.com" +FS_ID = "fs-deadbeef" +INIT_SYSTEM = "upstart" +FALLBACK_IP_ADDRESS = "192.0.0.1" +MOUNT_POINT = "/mnt" +PATH = "/" + +DEFAULT_OPTIONS = { + "nfsvers": 4.1, + "rsize": 1048576, + "wsize": 1048576, + "hard": None, + "timeo": 600, + "retrans": 2, + "tlsport": 3049, +} + +# indices of different arguments to the NFS call +NFS_BIN_ARG_IDX = 0 +NFS_MOUNT_PATH_IDX = 1 +NFS_MOUNT_POINT_IDX = 2 +NFS_OPTION_FLAG_IDX = 3 +NFS_OPTIONS_IDX = 4 + +# indices of different arguments to the NFS call to certain network namespace +NETNS_NSENTER_ARG_IDX = 0 +NETNS_PATH_ARG_IDX = 1 +NETNS_NFS_OFFSET = 2 + +# indices of different arguments to the NFS call for MACOS +NFS_MOUNT_PATH_IDX_MACOS = -2 +NFS_MOUNT_POINT_IDX_MACOS = -1 + +NETNS = "/proc/1/net/ns" + + +def _get_config(ocsp_enabled=False): + try: + config = ConfigParser.SafeConfigParser() + except AttributeError: + config = ConfigParser() + + mount_nfs_command_retry_count = 4 + mount_nfs_command_retry_timeout = 10 + mount_nfs_command_retry = "false" + config.add_section(mount_efs.CONFIG_SECTION) + config.set( + mount_efs.CONFIG_SECTION, "retry_nfs_mount_command", mount_nfs_command_retry + ) + config.set( + mount_efs.CONFIG_SECTION, + "retry_nfs_mount_command_count", + str(mount_nfs_command_retry_count), + ) + config.set( + mount_efs.CONFIG_SECTION, + "retry_nfs_mount_command_timeout_sec", + str(mount_nfs_command_retry_timeout), + ) + if ocsp_enabled: + config.set( + mount_efs.CONFIG_SECTION, + "stunnel_check_cert_validity", + "true", + ) + return config + + +def _mock_popen(mocker, returncode=0, stdout="stdout", stderr="stderr"): + popen_mock = MagicMock() + popen_mock.communicate.return_value = ( + stdout, + stderr, + ) + popen_mock.returncode = returncode + + return mocker.patch("subprocess.Popen", return_value=popen_mock) + + +def test_mount_with_proxy_efs_proxy_enabled(mocker, capsys): + options = dict(DEFAULT_OPTIONS) + options["tls"] = None + + bootstrap_proxy_mock = mocker.patch("mount_efs.bootstrap_proxy") + mocker.patch("os.path.ismount", return_value=False) + mocker.patch("threading.Thread.start") + mocker.patch("threading.Thread.join") + mocker.patch("mount_efs.mount_nfs") + _mock_popen(mocker, stdout="nfs") + mount_efs.mount_with_proxy( + _get_config(), + INIT_SYSTEM, + DNS_NAME, + PATH, + FS_ID, + MOUNT_POINT, + options, + ) + utils.assert_called_once(bootstrap_proxy_mock) + + kwargs = bootstrap_proxy_mock.call_args[1] + assert kwargs["efs_proxy_enabled"] == True + + +def test_mount_with_proxy_ocsp_config_enabled(mocker, capsys): + options = dict(DEFAULT_OPTIONS) + options["tls"] = None + + bootstrap_proxy_mock = mocker.patch("mount_efs.bootstrap_proxy") + mocker.patch("os.path.ismount", return_value=False) + mocker.patch("threading.Thread.start") + mocker.patch("threading.Thread.join") + mocker.patch("mount_efs.mount_nfs") + _mock_popen(mocker, stdout="nfs") + mount_efs.mount_with_proxy( + _get_config(ocsp_enabled=True), + INIT_SYSTEM, + DNS_NAME, + PATH, + FS_ID, + MOUNT_POINT, + options, + ) + utils.assert_called_once(bootstrap_proxy_mock) + + kwargs = bootstrap_proxy_mock.call_args[1] + assert kwargs["efs_proxy_enabled"] == False + + +def test_mount_with_proxy_ocsp_option_enabled(mocker, capsys): + options = dict(DEFAULT_OPTIONS) + options["tls"] = None + options["ocsp"] = None + + bootstrap_proxy_mock = mocker.patch("mount_efs.bootstrap_proxy") + mocker.patch("os.path.ismount", return_value=False) + mocker.patch("threading.Thread.start") + mocker.patch("threading.Thread.join") + mocker.patch("mount_efs.mount_nfs") + _mock_popen(mocker, stdout="nfs") + mount_efs.mount_with_proxy( + _get_config(), + INIT_SYSTEM, + DNS_NAME, + PATH, + FS_ID, + MOUNT_POINT, + options, + ) + utils.assert_called_once(bootstrap_proxy_mock) + + kwargs = bootstrap_proxy_mock.call_args[1] + assert kwargs["efs_proxy_enabled"] == False + + +def test_mount_with_proxy_efs_proxy_enabled_non_tls_mount(mocker, capsys): + options = dict(DEFAULT_OPTIONS) + + bootstrap_proxy_mock = mocker.patch("mount_efs.bootstrap_proxy") + mocker.patch("os.path.ismount", return_value=False) + mocker.patch("threading.Thread.start") + mocker.patch("threading.Thread.join") + mocker.patch("mount_efs.mount_nfs") + _mock_popen(mocker, stdout="nfs") + mount_efs.mount_with_proxy( + _get_config(), + INIT_SYSTEM, + DNS_NAME, + PATH, + FS_ID, + MOUNT_POINT, + options, + ) + utils.assert_called_once(bootstrap_proxy_mock) + + kwargs = bootstrap_proxy_mock.call_args[1] + assert kwargs["efs_proxy_enabled"] == True + + +def test_mount_with_proxy_stunnel_enabled(mocker, capsys): + options = dict(DEFAULT_OPTIONS) + options["stunnel"] = None + + bootstrap_proxy_mock = mocker.patch("mount_efs.bootstrap_proxy") + mocker.patch("os.path.ismount", return_value=False) + mocker.patch("threading.Thread.start") + mocker.patch("threading.Thread.join") + mocker.patch("mount_efs.mount_nfs") + _mock_popen(mocker, stdout="nfs") + mount_efs.mount_with_proxy( + _get_config(), + INIT_SYSTEM, + DNS_NAME, + PATH, + FS_ID, + MOUNT_POINT, + options, + ) + utils.assert_called_once(bootstrap_proxy_mock) + + kwargs = bootstrap_proxy_mock.call_args[1] + assert kwargs["efs_proxy_enabled"] == False diff --git a/test/mount_efs_test/test_write_stunnel_config_file.py b/test/mount_efs_test/test_write_stunnel_config_file.py index b01f88d7..7bb5e267 100644 --- a/test/mount_efs_test/test_write_stunnel_config_file.py +++ b/test/mount_efs_test/test_write_stunnel_config_file.py @@ -27,6 +27,7 @@ OCSP_ENABLED = False DEFAULT_REGION = "us-east-1" STUNNEL_LOGS_FILE = "/var/log/amazon/efs/%s.stunnel.log" % FS_ID +PROXY_LOGS_FILE = "/var/log/amazon/efs/%s.efs-proxy.log" % FS_ID def _get_config( @@ -98,7 +99,12 @@ def _get_config( return config -def _get_mount_options(port=PORT): +def _get_mount_options_tls(port=PORT): + options = {"tlsport": port, "tls": None} + return options + + +def _get_mount_options_non_tls(port=PORT): options = { "tlsport": port, } @@ -152,7 +158,7 @@ def _validate_config(stunnel_config_file, expected_global_config, expected_efs_c assert expected_efs_config == actual_efs_config -def _get_expected_efs_config( +def _get_expected_efs_config_tls( port=PORT, dns_name=DNS_NAME, verify=mount_efs.DEFAULT_STUNNEL_VERIFY_LEVEL, @@ -161,6 +167,7 @@ def _get_expected_efs_config( check_cert_validity=False, disable_libwrap=True, fallback_ip_address=None, + efs_proxy_enabled=True, ): expected_efs_config = dict(mount_efs.STUNNEL_EFS_CONFIG) expected_efs_config["accept"] = expected_efs_config["accept"] % port @@ -172,26 +179,44 @@ def _get_expected_efs_config( ) expected_efs_config["verify"] = str(verify) - if check_cert_hostname: + if check_cert_hostname or efs_proxy_enabled: expected_efs_config["checkHost"] = dns_name[dns_name.index(FS_ID) :] - if check_cert_validity and ocsp_override: + if check_cert_validity and ocsp_override and (not efs_proxy_enabled): expected_efs_config["OCSPaia"] = "yes" - if disable_libwrap: + if disable_libwrap and (not efs_proxy_enabled): expected_efs_config["libwrap"] = "no" return expected_efs_config -def _test_check_cert_hostname( +def _get_expected_efs_config_non_tls( + port=PORT, + dns_name=DNS_NAME, + fallback_ip_address=None, +): + expected_efs_config = dict(mount_efs.STUNNEL_EFS_CONFIG) + expected_efs_config["accept"] = expected_efs_config["accept"] % port + if not fallback_ip_address: + expected_efs_config["connect"] = expected_efs_config["connect"] % dns_name + else: + expected_efs_config["connect"] = ( + expected_efs_config["connect"] % fallback_ip_address + ) + + return expected_efs_config + + +# Check the hostname behavior when using stunnel instead of efs-proxy. +def _test_check_cert_hostname_stunnel( mocker, tmpdir, stunnel_check_cert_hostname_supported, stunnel_check_cert_hostname, expected_check_cert_hostname_config_value, ): - ca_mocker = mocker.patch("mount_efs.add_stunnel_ca_options") + ca_mocker = mocker.patch("mount_efs.add_tunnel_ca_options") state_file_dir = str(tmpdir) config_file = mount_efs.write_stunnel_config_file( _get_config( @@ -206,8 +231,9 @@ def _test_check_cert_hostname( DNS_NAME, VERIFY_LEVEL, OCSP_ENABLED, - _get_mount_options(), + _get_mount_options_tls(), DEFAULT_REGION, + efs_proxy_enabled=False, ) utils.assert_called_once(ca_mocker) @@ -215,8 +241,9 @@ def _test_check_cert_hostname( _validate_config( config_file, _get_expected_global_config(FS_ID, MOUNT_POINT, PORT, state_file_dir), - _get_expected_efs_config( - check_cert_hostname=expected_check_cert_hostname_config_value + _get_expected_efs_config_tls( + check_cert_hostname=expected_check_cert_hostname_config_value, + efs_proxy_enabled=False, ), ) @@ -228,7 +255,7 @@ def _test_check_cert_validity( stunnel_check_cert_validity, expected_check_cert_validity_config_value, ): - ca_mocker = mocker.patch("mount_efs.add_stunnel_ca_options") + ca_mocker = mocker.patch("mount_efs.add_tunnel_ca_options") state_file_dir = str(tmpdir) config_file = mount_efs.write_stunnel_config_file( _get_config( @@ -242,8 +269,9 @@ def _test_check_cert_validity( DNS_NAME, VERIFY_LEVEL, stunnel_check_cert_validity, - _get_mount_options(), + _get_mount_options_tls(), DEFAULT_REGION, + efs_proxy_enabled=True, ) utils.assert_called_once(ca_mocker) @@ -251,14 +279,14 @@ def _test_check_cert_validity( _validate_config( config_file, _get_expected_global_config(FS_ID, MOUNT_POINT, PORT, state_file_dir), - _get_expected_efs_config( + _get_expected_efs_config_tls( check_cert_validity=expected_check_cert_validity_config_value ), ) def test_write_stunnel_config_file(mocker, tmpdir): - ca_mocker = mocker.patch("mount_efs.add_stunnel_ca_options") + ca_mocker = mocker.patch("mount_efs.add_tunnel_ca_options") state_file_dir = str(tmpdir) config_file = mount_efs.write_stunnel_config_file( @@ -270,20 +298,21 @@ def test_write_stunnel_config_file(mocker, tmpdir): DNS_NAME, VERIFY_LEVEL, OCSP_ENABLED, - _get_mount_options(), + _get_mount_options_tls(), DEFAULT_REGION, + efs_proxy_enabled=True, ) utils.assert_called_once(ca_mocker) _validate_config( config_file, _get_expected_global_config(FS_ID, MOUNT_POINT, PORT, state_file_dir), - _get_expected_efs_config(), + _get_expected_efs_config_tls(), ) def test_write_stunnel_config_file_with_az_as_dns_name(mocker, tmpdir): - ca_mocker = mocker.patch("mount_efs.add_stunnel_ca_options") + ca_mocker = mocker.patch("mount_efs.add_tunnel_ca_options") state_file_dir = str(tmpdir) config_file = mount_efs.write_stunnel_config_file( @@ -295,15 +324,16 @@ def test_write_stunnel_config_file_with_az_as_dns_name(mocker, tmpdir): DNS_NAME_WITH_AZ, VERIFY_LEVEL, OCSP_ENABLED, - _get_mount_options(), + _get_mount_options_tls(), DEFAULT_REGION, + efs_proxy_enabled=True, ) utils.assert_called_once(ca_mocker) _validate_config( config_file, _get_expected_global_config(FS_ID, MOUNT_POINT, PORT, state_file_dir), - _get_expected_efs_config(dns_name=DNS_NAME_WITH_AZ), + _get_expected_efs_config_tls(dns_name=DNS_NAME_WITH_AZ), ) @@ -313,7 +343,7 @@ def _test_enable_disable_libwrap( system_release="unknown", libwrap_supported=True, ): - mocker.patch("mount_efs.add_stunnel_ca_options") + mocker.patch("mount_efs.add_tunnel_ca_options") state_file_dir = str(tmpdir) ver_mocker = mocker.patch( "mount_efs.get_system_release_version", return_value=system_release @@ -328,20 +358,21 @@ def _test_enable_disable_libwrap( DNS_NAME, VERIFY_LEVEL, OCSP_ENABLED, - _get_mount_options(), + _get_mount_options_tls(), DEFAULT_REGION, + efs_proxy_enabled=True, ) utils.assert_called_once(ver_mocker) _validate_config( config_file, _get_expected_global_config(FS_ID, MOUNT_POINT, PORT, state_file_dir), - _get_expected_efs_config(disable_libwrap=libwrap_supported), + _get_expected_efs_config_tls(disable_libwrap=libwrap_supported), ) def test_write_stunnel_config_with_debug(mocker, tmpdir): - ca_mocker = mocker.patch("mount_efs.add_stunnel_ca_options") + ca_mocker = mocker.patch("mount_efs.add_tunnel_ca_options") state_file_dir = str(tmpdir) config_file = mount_efs.write_stunnel_config_file( @@ -353,8 +384,9 @@ def test_write_stunnel_config_with_debug(mocker, tmpdir): DNS_NAME, VERIFY_LEVEL, OCSP_ENABLED, - _get_mount_options(), + _get_mount_options_tls(), DEFAULT_REGION, + efs_proxy_enabled=True, ) utils.assert_called_once(ca_mocker) @@ -364,19 +396,21 @@ def test_write_stunnel_config_with_debug(mocker, tmpdir): expected_global_config["debug"] = "debug" expected_global_config["output"] = os.path.join( mount_efs.LOG_DIR, - "%s.stunnel.log" + "%s.efs-proxy.log" % mount_efs.get_mount_specific_filename(FS_ID, MOUNT_POINT, PORT), ) - _validate_config(config_file, expected_global_config, _get_expected_efs_config()) + _validate_config( + config_file, expected_global_config, _get_expected_efs_config_tls() + ) def test_write_stunnel_config_with_debug_and_logs_file(mocker, tmpdir): - ca_mocker = mocker.patch("mount_efs.add_stunnel_ca_options") + ca_mocker = mocker.patch("mount_efs.add_tunnel_ca_options") state_file_dir = str(tmpdir) config_file = mount_efs.write_stunnel_config_file( _get_config( - mocker, stunnel_debug_enabled=True, stunnel_logs_file=STUNNEL_LOGS_FILE + mocker, stunnel_debug_enabled=True, stunnel_logs_file=PROXY_LOGS_FILE ), state_file_dir, FS_ID, @@ -385,8 +419,9 @@ def test_write_stunnel_config_with_debug_and_logs_file(mocker, tmpdir): DNS_NAME, VERIFY_LEVEL, OCSP_ENABLED, - _get_mount_options(), + _get_mount_options_tls(), DEFAULT_REGION, + efs_proxy_enabled=True, ) utils.assert_called_once(ca_mocker) @@ -394,15 +429,48 @@ def test_write_stunnel_config_with_debug_and_logs_file(mocker, tmpdir): _get_expected_global_config(FS_ID, MOUNT_POINT, PORT, state_file_dir) ) expected_global_config["debug"] = "debug" - expected_global_config["output"] = STUNNEL_LOGS_FILE + expected_global_config["output"] = PROXY_LOGS_FILE + + _validate_config( + config_file, expected_global_config, _get_expected_efs_config_tls() + ) + + +# We should always write "checkHost" into the stunnel config when using efs-proxy for TLS mounts. +def test_write_stunnel_config_efs_proxy_check_cert_hostname_tls(mocker, tmpdir): + ca_mocker = mocker.patch("mount_efs.add_tunnel_ca_options") + supported_opt_mock = mocker.patch("mount_efs.is_stunnel_option_supported") + state_file_dir = str(tmpdir) + config_file = mount_efs.write_stunnel_config_file( + _get_config(mocker), + state_file_dir, + FS_ID, + MOUNT_POINT, + PORT, + DNS_NAME, + VERIFY_LEVEL, + OCSP_ENABLED, + _get_mount_options_tls(), + DEFAULT_REGION, + efs_proxy_enabled=True, + ) + + utils.assert_called_once(ca_mocker) + utils.assert_not_called(supported_opt_mock) - _validate_config(config_file, expected_global_config, _get_expected_efs_config()) + _validate_config( + config_file, + _get_expected_global_config(FS_ID, MOUNT_POINT, PORT, state_file_dir), + _get_expected_efs_config_tls( + efs_proxy_enabled=True, + ), + ) def test_write_stunnel_config_check_cert_hostname_supported_flag_not_set( mocker, tmpdir ): - _test_check_cert_hostname( + _test_check_cert_hostname_stunnel( mocker, tmpdir, stunnel_check_cert_hostname_supported=True, @@ -414,7 +482,7 @@ def test_write_stunnel_config_check_cert_hostname_supported_flag_not_set( def test_write_stunnel_config_check_cert_hostname_supported_flag_set_false( mocker, capsys, tmpdir ): - _test_check_cert_hostname( + _test_check_cert_hostname_stunnel( mocker, tmpdir, stunnel_check_cert_hostname_supported=True, @@ -426,7 +494,7 @@ def test_write_stunnel_config_check_cert_hostname_supported_flag_set_false( def test_write_stunnel_config_check_cert_hostname_supported_flag_set_true( mocker, tmpdir ): - _test_check_cert_hostname( + _test_check_cert_hostname_stunnel( mocker, tmpdir, stunnel_check_cert_hostname_supported=True, @@ -438,7 +506,7 @@ def test_write_stunnel_config_check_cert_hostname_supported_flag_set_true( def test_write_stunnel_config_check_cert_hostname_not_supported_flag_not_specified( mocker, capsys, tmpdir ): - _test_check_cert_hostname( + _test_check_cert_hostname_stunnel( mocker, tmpdir, stunnel_check_cert_hostname_supported=False, @@ -450,7 +518,7 @@ def test_write_stunnel_config_check_cert_hostname_not_supported_flag_not_specifi def test_write_stunnel_config_check_cert_hostname_not_supported_flag_set_false( mocker, capsys, tmpdir ): - _test_check_cert_hostname( + _test_check_cert_hostname_stunnel( mocker, tmpdir, stunnel_check_cert_hostname_supported=False, @@ -462,7 +530,7 @@ def test_write_stunnel_config_check_cert_hostname_not_supported_flag_set_false( def test_write_stunnel_config_check_cert_hostname_not_supported_flag_set_true( mocker, capsys, tmpdir ): - mocker.patch("mount_efs.add_stunnel_ca_options") + mocker.patch("mount_efs.add_tunnel_ca_options") with pytest.raises(SystemExit) as ex: mount_efs.write_stunnel_config_file( @@ -478,8 +546,9 @@ def test_write_stunnel_config_check_cert_hostname_not_supported_flag_set_true( DNS_NAME, VERIFY_LEVEL, OCSP_ENABLED, - _get_mount_options(), + _get_mount_options_tls(), DEFAULT_REGION, + efs_proxy_enabled=False, ) assert 0 != ex.value.code @@ -528,7 +597,7 @@ def test_write_stunnel_config_check_cert_validity_not_supported_ocsp_disabled( def test_write_stunnel_config_check_cert_validity_not_supported_ocsp_enabled( mocker, capsys, tmpdir ): - mocker.patch("mount_efs.add_stunnel_ca_options") + mocker.patch("mount_efs.add_tunnel_ca_options") with pytest.raises(SystemExit) as ex: mount_efs.write_stunnel_config_file( @@ -544,8 +613,9 @@ def test_write_stunnel_config_check_cert_validity_not_supported_ocsp_enabled( DNS_NAME, VERIFY_LEVEL, True, - _get_mount_options(), + _get_mount_options_tls(), DEFAULT_REGION, + efs_proxy_enabled=False, ) assert 0 != ex.value.code @@ -556,7 +626,7 @@ def test_write_stunnel_config_check_cert_validity_not_supported_ocsp_enabled( def test_write_stunnel_config_with_verify_level(mocker, tmpdir): - ca_mocker = mocker.patch("mount_efs.add_stunnel_ca_options") + ca_mocker = mocker.patch("mount_efs.add_tunnel_ca_options") state_file_dir = str(tmpdir) verify = 0 config_file = mount_efs.write_stunnel_config_file( @@ -568,7 +638,7 @@ def test_write_stunnel_config_with_verify_level(mocker, tmpdir): DNS_NAME, verify, OCSP_ENABLED, - _get_mount_options(), + _get_mount_options_tls(), DEFAULT_REGION, ) utils.assert_not_called(ca_mocker) @@ -576,7 +646,7 @@ def test_write_stunnel_config_with_verify_level(mocker, tmpdir): _validate_config( config_file, _get_expected_global_config(FS_ID, MOUNT_POINT, PORT, state_file_dir), - _get_expected_efs_config(check_cert_validity=False, verify=verify), + _get_expected_efs_config_tls(check_cert_validity=False, verify=verify), ) @@ -589,7 +659,7 @@ def test_write_stunnel_config_libwrap_supported(mocker, tmpdir): def test_write_stunnel_config_with_fall_back_ip_address(mocker, tmpdir): - ca_mocker = mocker.patch("mount_efs.add_stunnel_ca_options") + ca_mocker = mocker.patch("mount_efs.add_tunnel_ca_options") state_file_dir = str(tmpdir) config_file = mount_efs.write_stunnel_config_file( @@ -601,7 +671,7 @@ def test_write_stunnel_config_with_fall_back_ip_address(mocker, tmpdir): DNS_NAME, VERIFY_LEVEL, OCSP_ENABLED, - _get_mount_options(), + _get_mount_options_tls(), DEFAULT_REGION, fallback_ip_address=FALLBACK_IP_ADDRESS, ) @@ -611,7 +681,7 @@ def test_write_stunnel_config_with_fall_back_ip_address(mocker, tmpdir): _validate_config( config_file, _get_expected_global_config(FS_ID, MOUNT_POINT, PORT, state_file_dir), - _get_expected_efs_config(fallback_ip_address=FALLBACK_IP_ADDRESS), + _get_expected_efs_config_tls(fallback_ip_address=FALLBACK_IP_ADDRESS), ) @@ -633,10 +703,16 @@ def test_write_stunnel_config_foreground_quiet_supported_debug_enabled(mocker, t ) +def test_write_stunnel_config_foreground_quiet_supported_debug_enabled(mocker, tmpdir): + _test_stunnel_config_foreground_quiet_helper( + mocker, tmpdir, foreground_quiet_supported=True, stunnel_debug_enabled=True + ) + + def _test_stunnel_config_foreground_quiet_helper( mocker, tmpdir, foreground_quiet_supported, stunnel_debug_enabled ): - ca_mocker = mocker.patch("mount_efs.add_stunnel_ca_options") + ca_mocker = mocker.patch("mount_efs.add_tunnel_ca_options") state_file_dir = str(tmpdir) config_file = mount_efs.write_stunnel_config_file( @@ -652,8 +728,9 @@ def _test_stunnel_config_foreground_quiet_helper( DNS_NAME, VERIFY_LEVEL, OCSP_ENABLED, - _get_mount_options(), + _get_mount_options_tls(), DEFAULT_REGION, + efs_proxy_enabled=False, ) utils.assert_called_once(ca_mocker) @@ -670,11 +747,15 @@ def _test_stunnel_config_foreground_quiet_helper( "%s.stunnel.log" % mount_efs.get_mount_specific_filename(FS_ID, MOUNT_POINT, PORT), ) - _validate_config(config_file, expected_global_config, _get_expected_efs_config()) + _validate_config( + config_file, + expected_global_config, + _get_expected_efs_config_tls(efs_proxy_enabled=False), + ) def test_write_stunnel_config_fips_enabled(mocker, tmpdir): - ca_mocker = mocker.patch("mount_efs.add_stunnel_ca_options") + ca_mocker = mocker.patch("mount_efs.add_tunnel_ca_options") state_file_dir = str(tmpdir) config_file = mount_efs.write_stunnel_config_file( @@ -686,7 +767,7 @@ def test_write_stunnel_config_fips_enabled(mocker, tmpdir): DNS_NAME, VERIFY_LEVEL, OCSP_ENABLED, - _get_mount_options(), + _get_mount_options_tls(), DEFAULT_REGION, ) utils.assert_called_once(ca_mocker) @@ -699,5 +780,35 @@ def test_write_stunnel_config_fips_enabled(mocker, tmpdir): _validate_config( config_file, expected_global_config, - _get_expected_efs_config(), + _get_expected_efs_config_tls(), + ) + + +def test_non_tls_mount_with_proxy(mocker, tmpdir): + ca_mocker = mocker.patch("mount_efs.add_tunnel_ca_options") + state_file_dir = str(tmpdir) + + config_file = mount_efs.write_stunnel_config_file( + _get_config(mocker), + state_file_dir, + FS_ID, + MOUNT_POINT, + PORT, + DNS_NAME, + VERIFY_LEVEL, + OCSP_ENABLED, + _get_mount_options_non_tls(), + DEFAULT_REGION, + efs_proxy_enabled=True, + ) + utils.assert_not_called(ca_mocker) + + expected_global_config = dict( + _get_expected_global_config(FS_ID, MOUNT_POINT, PORT, state_file_dir) + ) + + _validate_config( + config_file, + expected_global_config, + _get_expected_efs_config_non_tls(), ) diff --git a/test/mount_efs_test/test_write_tls_tunnel_state_file.py b/test/mount_efs_test/test_write_tls_tunnel_state_file.py index ac6426e2..57d7915e 100644 --- a/test/mount_efs_test/test_write_tls_tunnel_state_file.py +++ b/test/mount_efs_test/test_write_tls_tunnel_state_file.py @@ -22,7 +22,7 @@ DATETIME_FORMAT = "%y%m%d%H%M%SZ" -def test_write_tls_tunnel_state_file_netns(tmpdir): +def test_write_tunnel_state_file_netns(tmpdir): state_file_dir = str(tmpdir) mount_point = "/home/user/foo/mount" @@ -42,7 +42,7 @@ def test_write_tls_tunnel_state_file_netns(tmpdir): "useIam": True, } - state_file = mount_efs.write_tls_tunnel_state_file( + state_file = mount_efs.write_tunnel_state_file( FS_ID, mount_point, PORT, @@ -80,7 +80,7 @@ def test_write_tls_tunnel_state_file_netns(tmpdir): assert cert_details["useIam"] == state.get("useIam") -def test_write_tls_tunnel_state_file(tmpdir): +def test_write_tunnel_state_file(tmpdir): state_file_dir = str(tmpdir) mount_point = "/home/user/foo/mount" @@ -100,7 +100,7 @@ def test_write_tls_tunnel_state_file(tmpdir): "useIam": True, } - state_file = mount_efs.write_tls_tunnel_state_file( + state_file = mount_efs.write_tunnel_state_file( FS_ID, mount_point, PORT, PID, COMMAND, FILES, state_file_dir, cert_details ) @@ -131,12 +131,12 @@ def test_write_tls_tunnel_state_file(tmpdir): assert cert_details["useIam"] == state.get("useIam") -def test_write_tls_tunnel_state_file_no_cert(tmpdir): +def test_write_tunnel_state_file_no_cert(tmpdir): state_file_dir = str(tmpdir) mount_point = "/home/user/foo/mount" - state_file = mount_efs.write_tls_tunnel_state_file( + state_file = mount_efs.write_tunnel_state_file( FS_ID, mount_point, PORT, PID, COMMAND, FILES, state_file_dir ) diff --git a/test/watchdog_test/test_send_signal_to_stunnel_processes.py b/test/watchdog_test/test_send_signal_to_stunnel_processes.py index b0573fe1..77cff775 100644 --- a/test/watchdog_test/test_send_signal_to_stunnel_processes.py +++ b/test/watchdog_test/test_send_signal_to_stunnel_processes.py @@ -93,7 +93,7 @@ def test_is_mount_stunnel_proc_running_process_not_stunnel(mocker, tmpdir): assert False == watchdog.is_mount_stunnel_proc_running(PID, STATE_FILE, tmpdir) debug_log = mock_log_debug.call_args[0][0] - assert "is not a stunnel process" in debug_log + assert "is not an efs-proxy or stunnel process" in debug_log def test_is_mount_stunnel_proc_running_process_not_running(mocker, tmpdir): diff --git a/test/watchdog_test/test_start_tls_tunnel.py b/test/watchdog_test/test_start_tls_tunnel.py index cfa6af7c..3b546fee 100644 --- a/test/watchdog_test/test_start_tls_tunnel.py +++ b/test/watchdog_test/test_start_tls_tunnel.py @@ -29,13 +29,17 @@ def _mock_popen(mocker): return mocker.patch("subprocess.Popen", return_value=_get_popen_mock()) -def _initiate_state_file(tmpdir, cmd=None): +def _initiate_state_file(tmpdir, cmd=None, efs_proxy_enabled=False): + tunnel_executable = "/usr/bin/stunnel" + if efs_proxy_enabled: + tunnel_executable = "/usr/bin/efs-proxy" + state = { "pid": PID - 1, "cmd": cmd if cmd else [ - "/usr/bin/stunnel", + tunnel_executable, "/var/run/efs/stunnel-config.fs-deadbeef.mnt.21007", ], } @@ -57,6 +61,18 @@ def test_start_tls_tunnel(mocker, tmpdir): assert 1 == len(procs) +def test_start_tls_tunnel_efs_proxy(mocker, tmpdir): + _mock_popen(mocker) + mocker.patch("watchdog.is_pid_running", return_value=True) + + state, state_file = _initiate_state_file(tmpdir, efs_proxy_enabled=True) + procs = [] + pid = watchdog.start_tls_tunnel(procs, state, str(tmpdir), state_file) + + assert PID == pid + assert 1 == len(procs) + + def test_start_tls_tunnel_fails(mocker, capsys, tmpdir): _mock_popen(mocker) mocker.patch("watchdog.is_pid_running", return_value=False) @@ -70,7 +86,23 @@ def test_start_tls_tunnel_fails(mocker, capsys, tmpdir): assert 0 != ex.value.code out, err = capsys.readouterr() - assert "Failed to initialize TLS tunnel" in err + assert "Failed to initialize stunnel" in err + + +def test_start_tls_tunnel_fails_proxy_enabled(mocker, capsys, tmpdir): + _mock_popen(mocker) + mocker.patch("watchdog.is_pid_running", return_value=False) + + state, state_file = _initiate_state_file(tmpdir, efs_proxy_enabled=True) + procs = [] + with pytest.raises(SystemExit) as ex: + watchdog.start_tls_tunnel(procs, state, str(tmpdir), state_file) + + assert 0 == len(procs) + assert 0 != ex.value.code + + out, err = capsys.readouterr() + assert "Failed to initialize efs-proxy" in err # /~https://github.com/kubernetes-sigs/aws-efs-csi-driver/issues/812 The watchdog is trying to launch stunnel on AL2 for @@ -159,3 +191,39 @@ def test_start_tls_tunnel_for_mount_via_older_version_of_efs_utils_on_ecs_amazon assert " ".join(["nsenter", namespace, "/usr/sbin/stunnel5"]) in " ".join( state["cmd"] ) + + +def test_start_tls_tunnel_efs_proxy_enabled(mocker, tmpdir): + """ + This test makes sure that when efs_proxy is enabled, we will start efs_proxy and not stunnel, + even if the existing command used stunnel. + """ + popen_mock = _mock_popen(mocker) + mocker.patch("watchdog.is_pid_running", return_value=True) + mocker.patch("watchdog.find_command_path", return_value="/usr/bin/efs-proxy") + + proxy_command = [ + "/usr/bin/efs-proxy", + "/var/run/efs/stunnel-config.fs-deadbeef.mnt.21007", + ] + state, state_file = _initiate_state_file(tmpdir, proxy_command) + procs = [] + pid = watchdog.start_tls_tunnel(procs, state, str(tmpdir), state_file) + + args, _ = popen_mock.call_args + args = args[0] + assert "/usr/bin/efs-proxy" == args[0] + assert "/var/run/efs/stunnel-config.fs-deadbeef.mnt.21007" == args[1] + + assert PID == pid + assert 1 == len(procs) + + +def test_command_uses_efs_proxy(): + cmd = [ + "/usr/bin/stunnel", + "/var/run/efs/stunnel-config.fs-deadbeef.mnt.21007", + ] + assert watchdog.command_uses_efs_proxy(cmd) == False + cmd[0] = "/usr/bin/efs-proxy" + assert watchdog.command_uses_efs_proxy(cmd) == True