diff --git a/src/_nebari/provider/terraform.py b/src/_nebari/provider/terraform.py index 6f6ad6930b..e3ff32489f 100644 --- a/src/_nebari/provider/terraform.py +++ b/src/_nebari/provider/terraform.py @@ -114,8 +114,10 @@ def download_terraform_binary(version=constants.TERRAFORM_VERSION): def run_terraform_subprocess(processargs, **kwargs): terraform_path = download_terraform_binary() logger.info(f" terraform at {terraform_path}") - if run_subprocess_cmd([terraform_path] + processargs, **kwargs): + exit_code, output = run_subprocess_cmd([terraform_path] + processargs, **kwargs) + if exit_code != 0: raise TerraformException("Terraform returned an error") + return output def version(): @@ -183,6 +185,23 @@ def tfimport(addr, id, directory=None, var_files=None, exist_ok=False): raise e +def show(directory=None) -> dict: + logger.info(f"terraform show directory={directory}") + command = ["show", "-json"] + with timer(logger, "terraform show"): + try: + output = json.loads( + run_terraform_subprocess( + command, + cwd=directory, + capture_output=True, + ) + ) + return output + except TerraformException as e: + raise e + + def refresh(directory=None, var_files=None): var_files = var_files or [] diff --git a/src/_nebari/stages/kubernetes_services/__init__.py b/src/_nebari/stages/kubernetes_services/__init__.py index ba32ca6186..382fac7a89 100644 --- a/src/_nebari/stages/kubernetes_services/__init__.py +++ b/src/_nebari/stages/kubernetes_services/__init__.py @@ -61,7 +61,10 @@ class DefaultImages(schema.Base): class Storage(schema.Base): - type: SharedFsEnum = None + type: SharedFsEnum = Field( + default=None, + json_schema_extra={"immutable": True}, + ) conda_store: str = "200Gi" shared_filesystem: str = "200Gi" diff --git a/src/_nebari/stages/terraform_state/__init__.py b/src/_nebari/stages/terraform_state/__init__.py index edd4b9ed8a..d9afff36e4 100644 --- a/src/_nebari/stages/terraform_state/__init__.py +++ b/src/_nebari/stages/terraform_state/__init__.py @@ -1,5 +1,6 @@ import contextlib import enum +import functools import inspect import os import pathlib @@ -8,9 +9,11 @@ from pydantic import field_validator +from _nebari import utils from _nebari.provider import terraform from _nebari.provider.cloud import azure_cloud from _nebari.stages.base import NebariTerraformStage +from _nebari.stages.tf_objects import NebariConfig from _nebari.utils import ( AZURE_TF_STATE_RESOURCE_GROUP_SUFFIX, construct_azure_resource_group_name, @@ -170,8 +173,9 @@ def state_imports(self) -> List[Tuple[str, str]]: return [] def tf_objects(self) -> List[Dict]: + resources = [NebariConfig(self.config)] if self.config.provider == schema.ProviderEnum.gcp: - return [ + return resources + [ terraform.Provider( "google", project=self.config.google_cloud_platform.project, @@ -179,13 +183,13 @@ def tf_objects(self) -> List[Dict]: ), ] elif self.config.provider == schema.ProviderEnum.aws: - return [ + return resources + [ terraform.Provider( "aws", region=self.config.amazon_web_services.region ), ] else: - return [] + return resources def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): if self.config.provider == schema.ProviderEnum.do: @@ -231,6 +235,8 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): def deploy( self, stage_outputs: Dict[str, Dict[str, Any]], disable_prompt: bool = False ): + self.check_immutable_fields() + with super().deploy(stage_outputs, disable_prompt): env_mapping = {} # DigitalOcean terraform remote state using Spaces Bucket @@ -246,6 +252,50 @@ def deploy( with modified_environ(**env_mapping): yield + def check_immutable_fields(self): + nebari_config_state = self.get_nebari_config_state() + if not nebari_config_state: + return + + # compute diff of remote/prior and current nebari config + nebari_config_diff = utils.JsonDiff( + nebari_config_state.model_dump(), self.config.model_dump() + ) + + # check if any changed fields are immutable + for keys, old, new in nebari_config_diff.modified(): + bottom_level_schema = self.config + if len(keys) > 1: + bottom_level_schema = functools.reduce( + lambda m, k: getattr(m, k), keys[:-1], self.config + ) + extra_field_schema = schema.ExtraFieldSchema( + **bottom_level_schema.model_fields[keys[-1]].json_schema_extra or {} + ) + if extra_field_schema.immutable: + key_path = ".".join(keys) + raise ValueError( + f'Attempting to change immutable field "{key_path}" ("{old}"->"{new}") in Nebari config file. Immutable fields cannot be changed after initial deployment.' + ) + + def get_nebari_config_state(self): + directory = str(self.output_directory / self.stage_prefix) + tf_state = terraform.show(directory) + nebari_config_state = None + + # get nebari config from state + for resource in ( + tf_state.get("values", {}).get("root_module", {}).get("resources", []) + ): + if resource["address"] == "terraform_data.nebari_config": + from nebari.plugins import nebari_plugin_manager + + nebari_config_state = nebari_plugin_manager.config_schema( + **resource["values"]["input"] + ) + break + return nebari_config_state + @contextlib.contextmanager def destroy( self, stage_outputs: Dict[str, Dict[str, Any]], status: Dict[str, bool] diff --git a/src/_nebari/stages/tf_objects.py b/src/_nebari/stages/tf_objects.py index 76f05e5f9c..04c6d434aa 100644 --- a/src/_nebari/stages/tf_objects.py +++ b/src/_nebari/stages/tf_objects.py @@ -1,4 +1,4 @@ -from _nebari.provider.terraform import Data, Provider, TerraformBackend +from _nebari.provider.terraform import Data, Provider, Resource, TerraformBackend from _nebari.utils import ( AZURE_TF_STATE_RESOURCE_GROUP_SUFFIX, construct_azure_resource_group_name, @@ -115,3 +115,7 @@ def NebariTerraformState(directory: str, nebari_config: schema.Main): ) else: raise NotImplementedError("state not implemented") + + +def NebariConfig(nebari_config: schema.Main): + return Resource("terraform_data", "nebari_config", input=nebari_config.model_dump()) diff --git a/src/_nebari/utils.py b/src/_nebari/utils.py index 6b33b1efbb..84eb376c2a 100644 --- a/src/_nebari/utils.py +++ b/src/_nebari/utils.py @@ -1,5 +1,7 @@ import contextlib +import enum import functools +import json import os import re import secrets @@ -11,7 +13,7 @@ import time import warnings from pathlib import Path -from typing import Dict, List, Set +from typing import Any, Dict, List, Set from ruamel.yaml import YAML @@ -44,7 +46,7 @@ def change_directory(directory): os.chdir(current_directory) -def run_subprocess_cmd(processargs, **kwargs): +def run_subprocess_cmd(processargs, capture_output=False, **kwargs): """Runs subprocess command with realtime stdout logging with optional line prefix.""" if "prefix" in kwargs: line_prefix = f"[{kwargs['prefix']}]: ".encode("utf-8") @@ -78,6 +80,7 @@ def kill_process(): timeout_timer = threading.Timer(timeout, kill_process) timeout_timer.start() + output = [] for line in iter(lambda: process.stdout.readline(), b""): full_line = line_prefix + line if strip_errors: @@ -87,17 +90,25 @@ def kill_process(): ) # Remove red ANSI escape code full_line = full_line.encode("utf-8") - sys.stdout.buffer.write(full_line) - sys.stdout.flush() + if capture_output: + output.append(full_line) + else: + sys.stdout.buffer.write(full_line) + sys.stdout.flush() if timeout_timer is not None: timeout_timer.cancel() process.stdout.close() - return process.wait( + exit_code = process.wait( timeout=10 ) # Should already have finished because we have drained stdout + if capture_output: + return exit_code, b"".join(output) + else: + return exit_code, None + def load_yaml(config_filename: Path): """ @@ -406,3 +417,57 @@ def byte_unit_conversion(byte_size_str: str, output_unit: str = "B") -> float: ) return value * units_multiplier[input_unit] / units_multiplier[output_unit] + + +class JsonDiffEnum(str, enum.Enum): + ADDED = "+" + REMOVED = "-" + MODIFIED = "!" + + +class JsonDiff: + def __init__(self, obj1: Dict[str, Any], obj2: Dict[str, Any]): + self.diff = self.json_diff(obj1, obj2) + + @staticmethod + def json_diff(obj1: Dict[str, Any], obj2: Dict[str, Any]) -> Dict[str, Any]: + """Calculates the diff between two json-like objects + + # Example usage + obj1 = {"a": 1, "b": {"c": 2, "d": 3}} + obj2 = {"a": 1, "b": {"c": 2, "e": 4}, "f": 5} + + result = json_diff(obj1, obj2) + """ + diff = {} + for key in set(obj1.keys()) | set(obj2.keys()): + if key not in obj1: + diff[key] = {JsonDiffEnum.ADDED: obj2[key]} + elif key not in obj2: + diff[key] = {JsonDiffEnum.REMOVED: obj1[key]} + elif obj1[key] != obj2[key]: + if isinstance(obj1[key], dict) and isinstance(obj2[key], dict): + nested_diff = JsonDiff.json_diff(obj1[key], obj2[key]) + if nested_diff: + diff[key] = nested_diff + else: + diff[key] = {JsonDiffEnum.MODIFIED: (obj1[key], obj2[key])} + return diff + + @staticmethod + def walk_dict(d, path, sentinel): + for key, value in d.items(): + if key is not sentinel: + if not isinstance(value, dict): + continue + yield from JsonDiff.walk_dict(value, path + [key], sentinel) + else: + yield path, value + + def modified(self): + """Generator that yields the path, old value, and new value of changed items""" + for path, (old, new) in self.walk_dict(self.diff, [], JsonDiffEnum.MODIFIED): + yield path, old, new + + def __repr__(self): + return f"{self.__class__.__name__}(diff={json.dumps(self.diff)})" diff --git a/src/nebari/schema.py b/src/nebari/schema.py index 2cc1c1ea3f..e138b4c247 100644 --- a/src/nebari/schema.py +++ b/src/nebari/schema.py @@ -45,10 +45,24 @@ def to_yaml(cls, representer, node): return representer.represent_str(node.value) +class ExtraFieldSchema(Base): + model_config = ConfigDict( + extra="allow", + validate_assignment=True, + populate_by_name=True, + ) + immutable: bool = ( + False # Whether field supports being changed after initial deployment + ) + + class Main(Base): - project_name: project_name_pydantic + project_name: project_name_pydantic = Field(json_schema_extra={"immutable": True}) namespace: namespace_pydantic = "dev" - provider: ProviderEnum = ProviderEnum.local + provider: ProviderEnum = Field( + default=ProviderEnum.local, + json_schema_extra={"immutable": True}, + ) # In nebari_version only use major.minor.patch version - drop any pre/post/dev suffixes nebari_version: Annotated[str, Field(validate_default=True)] = __version__ diff --git a/tests/tests_unit/test_stages.py b/tests/tests_unit/test_stages.py new file mode 100644 index 0000000000..e0e254a7d6 --- /dev/null +++ b/tests/tests_unit/test_stages.py @@ -0,0 +1,73 @@ +import pathlib +from unittest.mock import patch + +import pytest + +from _nebari.stages.terraform_state import TerraformStateStage +from _nebari.utils import yaml +from _nebari.version import __version__ +from nebari import schema +from nebari.plugins import nebari_plugin_manager + +HERE = pathlib.Path(__file__).parent + + +@pytest.fixture +def mock_config(): + with open(HERE / "./cli_validate/local.happy.yaml", "r") as f: + mock_config_file = yaml.load(f) + mock_config_file["nebari_version"] = __version__ + + config = nebari_plugin_manager.config_schema.model_validate(mock_config_file) + return config + + +@pytest.fixture +def terraform_state_stage(mock_config, tmp_path): + return TerraformStateStage(tmp_path, mock_config) + + +@patch.object(TerraformStateStage, "get_nebari_config_state") +def test_check_immutable_fields_no_changes(mock_get_state, terraform_state_stage): + mock_get_state.return_value = terraform_state_stage.config + + # This should not raise an exception + terraform_state_stage.check_immutable_fields() + + +@patch.object(TerraformStateStage, "get_nebari_config_state") +def test_check_immutable_fields_mutable_change( + mock_get_state, terraform_state_stage, mock_config +): + old_config = mock_config.model_copy() + old_config.namespace = "old-namespace" + mock_get_state.return_value = old_config + + # This should not raise an exception (namespace is mutable) + terraform_state_stage.check_immutable_fields() + + +@patch.object(TerraformStateStage, "get_nebari_config_state") +@patch.object(schema.Main, "model_fields") +def test_check_immutable_fields_immutable_change( + mock_model_fields, mock_get_state, terraform_state_stage, mock_config +): + old_config = mock_config.model_copy() + old_config.provider = schema.ProviderEnum.gcp + mock_get_state.return_value = old_config + + # Mock the provider field to be immutable + mock_model_fields.__getitem__.return_value.json_schema_extra = {"immutable": True} + + with pytest.raises(ValueError) as exc_info: + terraform_state_stage.check_immutable_fields() + + assert 'Attempting to change immutable field "provider"' in str(exc_info.value) + + +@patch.object(TerraformStateStage, "get_nebari_config_state") +def test_check_immutable_fields_no_prior_state(mock_get_state, terraform_state_stage): + mock_get_state.return_value = None + + # This should not raise an exception + terraform_state_stage.check_immutable_fields() diff --git a/tests/tests_unit/test_utils.py b/tests/tests_unit/test_utils.py index c2ae2d4965..678cd1f230 100644 --- a/tests/tests_unit/test_utils.py +++ b/tests/tests_unit/test_utils.py @@ -1,6 +1,6 @@ import pytest -from _nebari.utils import byte_unit_conversion +from _nebari.utils import JsonDiff, JsonDiffEnum, byte_unit_conversion @pytest.mark.parametrize( @@ -42,3 +42,25 @@ ) def test_byte_unit_conversion(value, from_unit, to_unit, expected): assert byte_unit_conversion(f"{value} {from_unit}", to_unit) == expected + + +def test_JsonDiff_diff(): + obj1 = {"a": 1, "b": {"c": 2, "d": 3}} + obj2 = {"a": 1, "b": {"c": 3, "e": 4}, "f": 5} + diff = JsonDiff(obj1, obj2) + assert diff.diff == { + "b": { + "e": {JsonDiffEnum.ADDED: 4}, + "c": {JsonDiffEnum.MODIFIED: (2, 3)}, + "d": {JsonDiffEnum.REMOVED: 3}, + }, + "f": {JsonDiffEnum.ADDED: 5}, + } + + +def test_JsonDiff_modified(): + obj1 = {"a": 1, "b": {"!": 2, "-": 3}, "+": 4} + obj2 = {"a": 1, "b": {"!": 3, "+": 4}, "+": 5} + diff = JsonDiff(obj1, obj2) + modifieds = diff.modified() + assert sorted(modifieds) == sorted([(["b", "!"], 2, 3), (["+"], 4, 5)])