diff --git a/src/_nebari/config_set.py b/src/_nebari/config_set.py new file mode 100644 index 0000000000..95413ea1a7 --- /dev/null +++ b/src/_nebari/config_set.py @@ -0,0 +1,54 @@ +import logging +import pathlib +from typing import Optional + +from packaging.requirements import SpecifierSet +from pydantic import BaseModel, ConfigDict, field_validator + +from _nebari._version import __version__ +from _nebari.utils import yaml + +logger = logging.getLogger(__name__) + + +class ConfigSetMetadata(BaseModel): + model_config: ConfigDict = ConfigDict(extra="allow", arbitrary_types_allowed=True) + name: str # for use with guided init + description: Optional[str] = None + nebari_version: str | SpecifierSet + + @field_validator("nebari_version") + @classmethod + def validate_version_requirement(cls, version_req): + if isinstance(version_req, str): + version_req = SpecifierSet(version_req, prereleases=True) + + return version_req + + def check_version(self, version): + if not self.nebari_version.contains(version, prereleases=True): + raise ValueError( + f'Nebari version "{version}" is not compatible with ' + f'version requirement {self.nebari_version} for "{self.name}" config set.' + ) + + +class ConfigSet(BaseModel): + metadata: ConfigSetMetadata + config: dict + + +def read_config_set(config_set_filepath: str): + """Read a config set from a config file.""" + + filename = pathlib.Path(config_set_filepath) + + with filename.open() as f: + config_set_yaml = yaml.load(f) + + config_set = ConfigSet(**config_set_yaml) + + # validation + config_set.metadata.check_version(__version__) + + return config_set diff --git a/src/_nebari/initialize.py b/src/_nebari/initialize.py index 4b41f2c5a1..7566fe7b44 100644 --- a/src/_nebari/initialize.py +++ b/src/_nebari/initialize.py @@ -8,7 +8,8 @@ import pydantic import requests -from _nebari import constants +from _nebari import constants, utils +from _nebari.config_set import read_config_set from _nebari.provider import git from _nebari.provider.cicd import github from _nebari.provider.cloud import amazon_web_services, azure_cloud, google_cloud @@ -47,6 +48,7 @@ def render_config( region: str = None, disable_prompt: bool = False, ssl_cert_email: str = None, + config_set: str = None, ) -> Dict[str, Any]: config = { "provider": cloud_provider, @@ -176,13 +178,17 @@ def render_config( config["certificate"] = {"type": CertificateEnum.letsencrypt.value} config["certificate"]["acme_email"] = ssl_cert_email + if config_set: + config_set = read_config_set(config_set) + config = utils.deep_merge(config, config_set.config) + # validate configuration and convert to model from nebari.plugins import nebari_plugin_manager try: config_model = nebari_plugin_manager.config_schema.model_validate(config) except pydantic.ValidationError as e: - print(str(e)) + raise e if repository_auto_provision: match = re.search(github_url_regex, repository) diff --git a/src/_nebari/stages/infrastructure/__init__.py b/src/_nebari/stages/infrastructure/__init__.py index 243abd1608..b75412bd64 100644 --- a/src/_nebari/stages/infrastructure/__init__.py +++ b/src/_nebari/stages/infrastructure/__init__.py @@ -6,6 +6,7 @@ import re import sys import tempfile +import warnings from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union from pydantic import ConfigDict, Field, field_validator, model_validator @@ -613,11 +614,23 @@ def check_provider(cls, data: Any) -> Any: data[provider] = provider_enum_model_map[provider]() else: # if the provider field is invalid, it won't be set when this validator is called - # so we need to check for it explicitly here, and set the `pre` to True + # so we need to check for it explicitly here, and set mode to "before" # TODO: this is a workaround, check if there is a better way to do this in Pydantic v2 raise ValueError( f"'{provider}' is not a valid enumeration member; permitted: local, existing, aws, gcp, azure" ) + set_providers = { + provider + for provider in provider_name_abbreviation_map.keys() + if provider in data and data[provider] + } + expected_provider_config = provider_enum_name_map[provider] + extra_provider_config = set_providers - {expected_provider_config} + if extra_provider_config: + warnings.warn( + f"Provider is set to {getattr(provider, 'value', provider)}, but configuration defined for other providers: {extra_provider_config}" + ) + else: set_providers = [ provider @@ -631,6 +644,7 @@ def check_provider(cls, data: Any) -> Any: data["provider"] = provider_name_abbreviation_map[set_providers[0]] elif num_providers == 0: data["provider"] = schema.ProviderEnum.local.value + return data diff --git a/src/_nebari/subcommands/init.py b/src/_nebari/subcommands/init.py index e794841ea7..c2f8d416e9 100644 --- a/src/_nebari/subcommands/init.py +++ b/src/_nebari/subcommands/init.py @@ -93,6 +93,7 @@ class InitInputs(schema.Base): region: Optional[str] = None ssl_cert_email: Optional[schema.email_pydantic] = None disable_prompt: bool = False + config_set: Optional[str] = None output: pathlib.Path = pathlib.Path("nebari-config.yaml") explicit: int = 0 @@ -134,6 +135,7 @@ def handle_init(inputs: InitInputs, config_schema: BaseModel): terraform_state=inputs.terraform_state, ssl_cert_email=inputs.ssl_cert_email, disable_prompt=inputs.disable_prompt, + config_set=inputs.config_set, ) try: @@ -496,6 +498,12 @@ def init( False, is_eager=True, ), + config_set: str = typer.Option( + None, + "--config-set", + "-s", + help="Apply a pre-defined set of nebari configuration options.", + ), output: str = typer.Option( pathlib.Path("nebari-config.yaml"), "--output", @@ -554,6 +562,7 @@ def init( inputs.terraform_state = terraform_state inputs.ssl_cert_email = ssl_cert_email inputs.disable_prompt = disable_prompt + inputs.config_set = config_set inputs.output = output inputs.explicit = explicit diff --git a/src/_nebari/utils.py b/src/_nebari/utils.py index f3d62f353d..48b8a91e9b 100644 --- a/src/_nebari/utils.py +++ b/src/_nebari/utils.py @@ -160,7 +160,7 @@ def modified_environ(*remove: List[str], **update: Dict[str, str]): def deep_merge(*args): - """Deep merge multiple dictionaries. + """Deep merge multiple dictionaries. Preserves order in dicts and lists. >>> value_1 = { 'a': [1, 2], @@ -190,7 +190,7 @@ def deep_merge(*args): if isinstance(d1, dict) and isinstance(d2, dict): d3 = {} - for key in d1.keys() | d2.keys(): + for key in tuple(d1.keys()) + tuple(d2.keys()): if key in d1 and key in d2: d3[key] = deep_merge(d1[key], d2[key]) elif key in d1: diff --git a/tests/tests_unit/test_config_set.py b/tests/tests_unit/test_config_set.py new file mode 100644 index 0000000000..81f5a8a11c --- /dev/null +++ b/tests/tests_unit/test_config_set.py @@ -0,0 +1,73 @@ +from unittest.mock import patch + +import pytest +from packaging.requirements import SpecifierSet + +from _nebari.config_set import ConfigSetMetadata, read_config_set + +test_version = "2024.12.2" + + +@pytest.mark.parametrize( + "version_input,test_version,should_pass", + [ + # Standard version tests + (">=2024.12.0,<2025.0.0", "2024.12.2", True), + (SpecifierSet(">=2024.12.0,<2025.0.0"), "2024.12.2", True), + # Pre-release version requirement tests + (">=2024.12.0rc1,<2025.0.0", "2024.12.0rc1", True), + (SpecifierSet(">=2024.12.0rc1"), "2024.12.0rc2", True), + # Pre-release test version against standard requirement + (">=2024.12.0,<2025.0.0", "2024.12.1rc1", True), + (SpecifierSet(">=2024.12.0,<2025.0.0"), "2024.12.1rc1", True), + # Failing cases + (">=2025.0.0", "2024.12.2rc1", False), + (SpecifierSet(">=2025.0.0rc1"), "2024.12.2", False), + ], +) +def test_version_requirement(version_input, test_version, should_pass): + metadata = ConfigSetMetadata(name="test-config", nebari_version=version_input) + + if should_pass: + metadata.check_version(test_version) + else: + with pytest.raises(ValueError) as exc_info: + metadata.check_version(test_version) + assert "Nebari version" in str(exc_info.value) + + +def test_read_config_set_valid(tmp_path): + config_set_yaml = """ + metadata: + name: test-config + nebari_version: ">=2024.12.0" + config: + key: value + """ + config_set_filepath = tmp_path / "config_set.yaml" + config_set_filepath.write_text(config_set_yaml) + with patch("_nebari.config_set.__version__", "2024.12.2"): + config_set = read_config_set(str(config_set_filepath)) + assert config_set.metadata.name == "test-config" + assert config_set.config["key"] == "value" + + +def test_read_config_set_invalid_version(tmp_path): + config_set_yaml = """ + metadata: + name: test-config + nebari_version: ">=2025.0.0" + config: + key: value + """ + config_set_filepath = tmp_path / "config_set.yaml" + config_set_filepath.write_text(config_set_yaml) + + with patch("_nebari.config_set.__version__", "2024.12.2"): + with pytest.raises(ValueError) as exc_info: + read_config_set(str(config_set_filepath)) + assert "Nebari version" in str(exc_info.value) + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/tests_unit/test_schema.py b/tests/tests_unit/test_schema.py index 5c21aef8d6..e445ba37da 100644 --- a/tests/tests_unit/test_schema.py +++ b/tests/tests_unit/test_schema.py @@ -161,3 +161,13 @@ def test_set_provider(config_schema, provider): result_config_dict = config.model_dump() assert provider in result_config_dict assert result_config_dict[provider]["kube_context"] == "some_context" + + +def test_provider_config_mismatch_warning(config_schema): + config_dict = { + "project_name": "test", + "provider": "local", + "existing": {"kube_context": "some_context"}, # <-- Doesn't match the provider + } + with pytest.warns(UserWarning, match="configuration defined for other providers"): + config_schema(**config_dict) diff --git a/tests/tests_unit/test_stages.py b/tests/tests_unit/test_stages.py index c716d93030..c15aa6d9fc 100644 --- a/tests/tests_unit/test_stages.py +++ b/tests/tests_unit/test_stages.py @@ -53,6 +53,7 @@ def test_check_immutable_fields_immutable_change( mock_model_fields, mock_get_state, terraform_state_stage, mock_config ): old_config = mock_config.model_copy(deep=True) + old_config.local = None old_config.provider = schema.ProviderEnum.gcp mock_get_state.return_value = old_config.model_dump() diff --git a/tests/tests_unit/test_utils.py b/tests/tests_unit/test_utils.py index 678cd1f230..88b911ff60 100644 --- a/tests/tests_unit/test_utils.py +++ b/tests/tests_unit/test_utils.py @@ -1,6 +1,6 @@ import pytest -from _nebari.utils import JsonDiff, JsonDiffEnum, byte_unit_conversion +from _nebari.utils import JsonDiff, JsonDiffEnum, byte_unit_conversion, deep_merge @pytest.mark.parametrize( @@ -64,3 +64,75 @@ def test_JsonDiff_modified(): diff = JsonDiff(obj1, obj2) modifieds = diff.modified() assert sorted(modifieds) == sorted([(["b", "!"], 2, 3), (["+"], 4, 5)]) + + +def test_deep_merge_order_preservation_dict(): + value_1 = { + "a": [1, 2], + "b": {"c": 1, "z": [5, 6]}, + "e": {"f": {"g": {}}}, + "m": 1, + } + + value_2 = { + "a": [3, 4], + "b": {"d": 2, "z": [7]}, + "e": {"f": {"h": 1}}, + "m": [1], + } + + expected_result = { + "a": [1, 2, 3, 4], + "b": {"c": 1, "z": [5, 6, 7], "d": 2}, + "e": {"f": {"g": {}, "h": 1}}, + "m": 1, + } + + result = deep_merge(value_1, value_2) + assert result == expected_result + assert list(result.keys()) == list(expected_result.keys()) + assert list(result["b"].keys()) == list(expected_result["b"].keys()) + assert list(result["e"]["f"].keys()) == list(expected_result["e"]["f"].keys()) + + +def test_deep_merge_order_preservation_list(): + value_1 = { + "a": [1, 2], + "b": {"c": 1, "z": [5, 6]}, + } + + value_2 = { + "a": [3, 4], + "b": {"d": 2, "z": [7]}, + } + + expected_result = { + "a": [1, 2, 3, 4], + "b": {"c": 1, "z": [5, 6, 7], "d": 2}, + } + + result = deep_merge(value_1, value_2) + assert result == expected_result + assert result["a"] == expected_result["a"] + assert result["b"]["z"] == expected_result["b"]["z"] + + +def test_deep_merge_single_dict(): + value_1 = { + "a": [1, 2], + "b": {"c": 1, "z": [5, 6]}, + } + + expected_result = value_1 + + result = deep_merge(value_1) + assert result == expected_result + assert list(result.keys()) == list(expected_result.keys()) + assert list(result["b"].keys()) == list(expected_result["b"].keys()) + + +def test_deep_merge_empty(): + expected_result = {} + + result = deep_merge() + assert result == expected_result