Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support disallowed nebari config changes #2660

Merged
merged 12 commits into from
Aug 30, 2024
21 changes: 20 additions & 1 deletion src/_nebari/provider/terraform.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,10 @@ def download_terraform_binary(version=constants.TERRAFORM_VERSION):
def run_terraform_subprocess(processargs, **kwargs):
terraform_path = download_terraform_binary()
logger.info(f" terraform at {terraform_path}")
if run_subprocess_cmd([terraform_path] + processargs, **kwargs):
exit_code, output = run_subprocess_cmd([terraform_path] + processargs, **kwargs)
if exit_code != 0:
raise TerraformException("Terraform returned an error")
return output


def version():
Expand Down Expand Up @@ -183,6 +185,23 @@ def tfimport(addr, id, directory=None, var_files=None, exist_ok=False):
raise e


def show(directory=None) -> dict:
logger.info(f"terraform show directory={directory}")
command = ["show", "-json"]
with timer(logger, "terraform show"):
try:
output = json.loads(
run_terraform_subprocess(
command,
cwd=directory,
capture_output=True,
)
)
return output
except TerraformException as e:
raise e


def refresh(directory=None, var_files=None):
var_files = var_files or []

Expand Down
5 changes: 4 additions & 1 deletion src/_nebari/stages/kubernetes_services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ class DefaultImages(schema.Base):


class Storage(schema.Base):
type: SharedFsEnum = None
type: SharedFsEnum = Field(
default=None,
json_schema_extra={"immutable": True},
)
conda_store: str = "200Gi"
shared_filesystem: str = "200Gi"

Expand Down
56 changes: 53 additions & 3 deletions src/_nebari/stages/terraform_state/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import enum
import functools
import inspect
import os
import pathlib
Expand All @@ -8,9 +9,11 @@

from pydantic import field_validator

from _nebari import utils
from _nebari.provider import terraform
from _nebari.provider.cloud import azure_cloud
from _nebari.stages.base import NebariTerraformStage
from _nebari.stages.tf_objects import NebariConfig
from _nebari.utils import (
AZURE_TF_STATE_RESOURCE_GROUP_SUFFIX,
construct_azure_resource_group_name,
Expand Down Expand Up @@ -170,22 +173,23 @@ def state_imports(self) -> List[Tuple[str, str]]:
return []

def tf_objects(self) -> List[Dict]:
resources = [NebariConfig(self.config)]
if self.config.provider == schema.ProviderEnum.gcp:
return [
return resources + [
terraform.Provider(
"google",
project=self.config.google_cloud_platform.project,
region=self.config.google_cloud_platform.region,
),
]
elif self.config.provider == schema.ProviderEnum.aws:
return [
return resources + [
terraform.Provider(
"aws", region=self.config.amazon_web_services.region
),
]
else:
return []
return resources

def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
if self.config.provider == schema.ProviderEnum.do:
Expand Down Expand Up @@ -231,6 +235,8 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
def deploy(
self, stage_outputs: Dict[str, Dict[str, Any]], disable_prompt: bool = False
):
self.check_immutable_fields()

with super().deploy(stage_outputs, disable_prompt):
env_mapping = {}
# DigitalOcean terraform remote state using Spaces Bucket
Expand All @@ -246,6 +252,50 @@ def deploy(
with modified_environ(**env_mapping):
yield

def check_immutable_fields(self):
nebari_config_state = self.get_nebari_config_state()
if not nebari_config_state:
return

# compute diff of remote/prior and current nebari config
nebari_config_diff = utils.JsonDiff(
nebari_config_state.model_dump(), self.config.model_dump()
)

# check if any changed fields are immutable
for keys, old, new in nebari_config_diff.modified():
bottom_level_schema = self.config
if len(keys) > 1:
bottom_level_schema = functools.reduce(
lambda m, k: getattr(m, k), keys[:-1], self.config
)
extra_field_schema = schema.ExtraFieldSchema(
**bottom_level_schema.model_fields[keys[-1]].json_schema_extra or {}
)
if extra_field_schema.immutable:
key_path = ".".join(keys)
raise ValueError(
f'Attempting to change immutable field "{key_path}" ("{old}"->"{new}") in Nebari config file. Immutable fields cannot be changed after initial deployment.'
)

def get_nebari_config_state(self):
directory = str(self.output_directory / self.stage_prefix)
tf_state = terraform.show(directory)
nebari_config_state = None

# get nebari config from state
for resource in (
tf_state.get("values", {}).get("root_module", {}).get("resources", [])
):
if resource["address"] == "terraform_data.nebari_config":
from nebari.plugins import nebari_plugin_manager

nebari_config_state = nebari_plugin_manager.config_schema(
**resource["values"]["input"]
)
break
return nebari_config_state

@contextlib.contextmanager
def destroy(
self, stage_outputs: Dict[str, Dict[str, Any]], status: Dict[str, bool]
Expand Down
6 changes: 5 additions & 1 deletion src/_nebari/stages/tf_objects.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from _nebari.provider.terraform import Data, Provider, TerraformBackend
from _nebari.provider.terraform import Data, Provider, Resource, TerraformBackend
from _nebari.utils import (
AZURE_TF_STATE_RESOURCE_GROUP_SUFFIX,
construct_azure_resource_group_name,
Expand Down Expand Up @@ -115,3 +115,7 @@ def NebariTerraformState(directory: str, nebari_config: schema.Main):
)
else:
raise NotImplementedError("state not implemented")


def NebariConfig(nebari_config: schema.Main):
return Resource("terraform_data", "nebari_config", input=nebari_config.model_dump())
75 changes: 70 additions & 5 deletions src/_nebari/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import contextlib
import enum
import functools
import json
import os
import re
import secrets
Expand All @@ -11,7 +13,7 @@
import time
import warnings
from pathlib import Path
from typing import Dict, List, Set
from typing import Any, Dict, List, Set

from ruamel.yaml import YAML

Expand Down Expand Up @@ -44,7 +46,7 @@ def change_directory(directory):
os.chdir(current_directory)


def run_subprocess_cmd(processargs, **kwargs):
def run_subprocess_cmd(processargs, capture_output=False, **kwargs):
"""Runs subprocess command with realtime stdout logging with optional line prefix."""
if "prefix" in kwargs:
line_prefix = f"[{kwargs['prefix']}]: ".encode("utf-8")
Expand Down Expand Up @@ -78,6 +80,7 @@ def kill_process():
timeout_timer = threading.Timer(timeout, kill_process)
timeout_timer.start()

output = []
for line in iter(lambda: process.stdout.readline(), b""):
full_line = line_prefix + line
if strip_errors:
Expand All @@ -87,17 +90,25 @@ def kill_process():
) # Remove red ANSI escape code
full_line = full_line.encode("utf-8")

sys.stdout.buffer.write(full_line)
sys.stdout.flush()
if capture_output:
output.append(full_line)
else:
sys.stdout.buffer.write(full_line)
sys.stdout.flush()

if timeout_timer is not None:
timeout_timer.cancel()

process.stdout.close()
return process.wait(
exit_code = process.wait(
timeout=10
) # Should already have finished because we have drained stdout

if capture_output:
return exit_code, b"".join(output)
else:
return exit_code, None


def load_yaml(config_filename: Path):
"""
Expand Down Expand Up @@ -406,3 +417,57 @@ def byte_unit_conversion(byte_size_str: str, output_unit: str = "B") -> float:
)

return value * units_multiplier[input_unit] / units_multiplier[output_unit]


class JsonDiffEnum(str, enum.Enum):
ADDED = "+"
REMOVED = "-"
MODIFIED = "!"


class JsonDiff:
def __init__(self, obj1: Dict[str, Any], obj2: Dict[str, Any]):
self.diff = self.json_diff(obj1, obj2)

@staticmethod
def json_diff(obj1: Dict[str, Any], obj2: Dict[str, Any]) -> Dict[str, Any]:
"""Calculates the diff between two json-like objects

# Example usage
obj1 = {"a": 1, "b": {"c": 2, "d": 3}}
obj2 = {"a": 1, "b": {"c": 2, "e": 4}, "f": 5}

result = json_diff(obj1, obj2)
"""
diff = {}
for key in set(obj1.keys()) | set(obj2.keys()):
if key not in obj1:
diff[key] = {JsonDiffEnum.ADDED: obj2[key]}
elif key not in obj2:
diff[key] = {JsonDiffEnum.REMOVED: obj1[key]}
elif obj1[key] != obj2[key]:
if isinstance(obj1[key], dict) and isinstance(obj2[key], dict):
nested_diff = JsonDiff.json_diff(obj1[key], obj2[key])
if nested_diff:
diff[key] = nested_diff
else:
diff[key] = {JsonDiffEnum.MODIFIED: (obj1[key], obj2[key])}
return diff

@staticmethod
def walk_dict(d, path, sentinel):
for key, value in d.items():
if key is not sentinel:
if not isinstance(value, dict):
continue
yield from JsonDiff.walk_dict(value, path + [key], sentinel)
else:
yield path, value

def modified(self):
"""Generator that yields the path, old value, and new value of changed items"""
for path, (old, new) in self.walk_dict(self.diff, [], JsonDiffEnum.MODIFIED):
yield path, old, new

def __repr__(self):
return f"{self.__class__.__name__}(diff={json.dumps(self.diff)})"
18 changes: 16 additions & 2 deletions src/nebari/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,24 @@ def to_yaml(cls, representer, node):
return representer.represent_str(node.value)


class ExtraFieldSchema(Base):
model_config = ConfigDict(
extra="allow",
validate_assignment=True,
populate_by_name=True,
)
immutable: bool = (
False # Whether field supports being changed after initial deployment
)


class Main(Base):
project_name: project_name_pydantic
project_name: project_name_pydantic = Field(json_schema_extra={"immutable": True})
namespace: namespace_pydantic = "dev"
provider: ProviderEnum = ProviderEnum.local
provider: ProviderEnum = Field(
default=ProviderEnum.local,
json_schema_extra={"immutable": True},
)
# In nebari_version only use major.minor.patch version - drop any pre/post/dev suffixes
nebari_version: Annotated[str, Field(validate_default=True)] = __version__

Expand Down
73 changes: 73 additions & 0 deletions tests/tests_unit/test_stages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import pathlib
from unittest.mock import patch

import pytest

from _nebari.stages.terraform_state import TerraformStateStage
from _nebari.utils import yaml
from _nebari.version import __version__
from nebari import schema
from nebari.plugins import nebari_plugin_manager

HERE = pathlib.Path(__file__).parent


@pytest.fixture
def mock_config():
with open(HERE / "./cli_validate/local.happy.yaml", "r") as f:
mock_config_file = yaml.load(f)
mock_config_file["nebari_version"] = __version__

config = nebari_plugin_manager.config_schema.model_validate(mock_config_file)
return config


@pytest.fixture
def terraform_state_stage(mock_config, tmp_path):
return TerraformStateStage(tmp_path, mock_config)


@patch.object(TerraformStateStage, "get_nebari_config_state")
def test_check_immutable_fields_no_changes(mock_get_state, terraform_state_stage):
mock_get_state.return_value = terraform_state_stage.config

# This should not raise an exception
terraform_state_stage.check_immutable_fields()


@patch.object(TerraformStateStage, "get_nebari_config_state")
def test_check_immutable_fields_mutable_change(
mock_get_state, terraform_state_stage, mock_config
):
old_config = mock_config.model_copy()
old_config.namespace = "old-namespace"
mock_get_state.return_value = old_config

# This should not raise an exception (namespace is mutable)
terraform_state_stage.check_immutable_fields()


@patch.object(TerraformStateStage, "get_nebari_config_state")
@patch.object(schema.Main, "model_fields")
def test_check_immutable_fields_immutable_change(
mock_model_fields, mock_get_state, terraform_state_stage, mock_config
):
old_config = mock_config.model_copy()
old_config.provider = schema.ProviderEnum.gcp
mock_get_state.return_value = old_config

# Mock the provider field to be immutable
mock_model_fields.__getitem__.return_value.json_schema_extra = {"immutable": True}

with pytest.raises(ValueError) as exc_info:
terraform_state_stage.check_immutable_fields()

assert 'Attempting to change immutable field "provider"' in str(exc_info.value)


@patch.object(TerraformStateStage, "get_nebari_config_state")
def test_check_immutable_fields_no_prior_state(mock_get_state, terraform_state_stage):
mock_get_state.return_value = None

# This should not raise an exception
terraform_state_stage.check_immutable_fields()
Loading
Loading