Skip to content

Commit

Permalink
Merge branch 'main' into azure-entra-ID-rbac
Browse files Browse the repository at this point in the history
  • Loading branch information
viniciusdc authored Jan 15, 2025
2 parents 81bb71b + 062529b commit 10bfa09
Show file tree
Hide file tree
Showing 25 changed files with 963 additions and 217 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ repos:
exclude: "^src/_nebari/template/"

- repo: /~https://github.com/crate-ci/typos
rev: typos-dict-v0.11.37
rev: dictgen-v0.3.1
hooks:
- id: typos

Expand All @@ -61,7 +61,7 @@ repos:
args: ["--line-length=88", "--exclude=/src/_nebari/template/"]

- repo: /~https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.1
rev: v0.8.6
hooks:
- id: ruff
args: ["--fix"]
Expand All @@ -77,7 +77,7 @@ repos:

# terraform
- repo: /~https://github.com/antonbabenko/pre-commit-terraform
rev: v1.96.2
rev: v1.96.3
hooks:
- id: terraform_fmt
args:
Expand Down
262 changes: 148 additions & 114 deletions docs-sphinx/cli.html

Large diffs are not rendered by default.

54 changes: 54 additions & 0 deletions src/_nebari/config_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import logging
import pathlib
from typing import Optional

from packaging.requirements import SpecifierSet
from pydantic import BaseModel, ConfigDict, field_validator

from _nebari._version import __version__
from _nebari.utils import yaml

logger = logging.getLogger(__name__)


class ConfigSetMetadata(BaseModel):
model_config: ConfigDict = ConfigDict(extra="allow", arbitrary_types_allowed=True)
name: str # for use with guided init
description: Optional[str] = None
nebari_version: str | SpecifierSet

@field_validator("nebari_version")
@classmethod
def validate_version_requirement(cls, version_req):
if isinstance(version_req, str):
version_req = SpecifierSet(version_req, prereleases=True)

return version_req

def check_version(self, version):
if not self.nebari_version.contains(version, prereleases=True):
raise ValueError(
f'Nebari version "{version}" is not compatible with '
f'version requirement {self.nebari_version} for "{self.name}" config set.'
)


class ConfigSet(BaseModel):
metadata: ConfigSetMetadata
config: dict


def read_config_set(config_set_filepath: str):
"""Read a config set from a config file."""

filename = pathlib.Path(config_set_filepath)

with filename.open() as f:
config_set_yaml = yaml.load(f)

config_set = ConfigSet(**config_set_yaml)

# validation
config_set.metadata.check_version(__version__)

return config_set
2 changes: 1 addition & 1 deletion src/_nebari/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
DEFAULT_NEBARI_IMAGE_TAG = CURRENT_RELEASE
DEFAULT_NEBARI_WORKFLOW_CONTROLLER_IMAGE_TAG = CURRENT_RELEASE

DEFAULT_CONDA_STORE_IMAGE_TAG = "2024.3.1"
DEFAULT_CONDA_STORE_IMAGE_TAG = "2024.11.2"

LATEST_SUPPORTED_PYTHON_VERSION = "3.10"

Expand Down
10 changes: 8 additions & 2 deletions src/_nebari/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import pydantic
import requests

from _nebari import constants
from _nebari import constants, utils
from _nebari.config_set import read_config_set
from _nebari.provider import git
from _nebari.provider.cicd import github
from _nebari.provider.cloud import amazon_web_services, azure_cloud, google_cloud
Expand Down Expand Up @@ -47,6 +48,7 @@ def render_config(
region: str = None,
disable_prompt: bool = False,
ssl_cert_email: str = None,
config_set: str = None,
) -> Dict[str, Any]:
config = {
"provider": cloud_provider,
Expand Down Expand Up @@ -176,13 +178,17 @@ def render_config(
config["certificate"] = {"type": CertificateEnum.letsencrypt.value}
config["certificate"]["acme_email"] = ssl_cert_email

if config_set:
config_set = read_config_set(config_set)
config = utils.deep_merge(config, config_set.config)

# validate configuration and convert to model
from nebari.plugins import nebari_plugin_manager

try:
config_model = nebari_plugin_manager.config_schema.model_validate(config)
except pydantic.ValidationError as e:
print(str(e))
raise e

if repository_auto_provision:
match = re.search(github_url_regex, repository)
Expand Down
19 changes: 18 additions & 1 deletion src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re
import sys
import tempfile
import warnings
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union

from pydantic import ConfigDict, Field, field_validator, model_validator
Expand Down Expand Up @@ -110,6 +111,7 @@ class AzureInputVars(schema.Base):
name: str
environment: str
region: str
authorized_ip_ranges: List[str] = ["0.0.0.0/0"]
kubeconfig_filename: str = get_kubeconfig_filename()
kubernetes_version: str
node_groups: Dict[str, AzureNodeGroupInputVars]
Expand Down Expand Up @@ -378,6 +380,7 @@ class AzureProvider(schema.Base):
region: str
kubernetes_version: Optional[str] = None
storage_account_postfix: str
authorized_ip_ranges: Optional[List[str]] = ["0.0.0.0/0"]
resource_group_name: Optional[str] = None
node_groups: Dict[str, AzureNodeGroup] = DEFAULT_AZURE_NODE_GROUPS
storage_account_postfix: str
Expand Down Expand Up @@ -633,11 +636,23 @@ def check_provider(cls, data: Any) -> Any:
data[provider] = provider_enum_model_map[provider]()
else:
# if the provider field is invalid, it won't be set when this validator is called
# so we need to check for it explicitly here, and set the `pre` to True
# so we need to check for it explicitly here, and set mode to "before"
# TODO: this is a workaround, check if there is a better way to do this in Pydantic v2
raise ValueError(
f"'{provider}' is not a valid enumeration member; permitted: local, existing, aws, gcp, azure"
)
set_providers = {
provider
for provider in provider_name_abbreviation_map.keys()
if provider in data and data[provider]
}
expected_provider_config = provider_enum_name_map[provider]
extra_provider_config = set_providers - {expected_provider_config}
if extra_provider_config:
warnings.warn(
f"Provider is set to {getattr(provider, 'value', provider)}, but configuration defined for other providers: {extra_provider_config}"
)

else:
set_providers = [
provider
Expand All @@ -651,6 +666,7 @@ def check_provider(cls, data: Any) -> Any:
data["provider"] = provider_name_abbreviation_map[set_providers[0]]
elif num_providers == 0:
data["provider"] = schema.ProviderEnum.local.value

return data


Expand Down Expand Up @@ -804,6 +820,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
environment=self.config.namespace,
region=self.config.azure.region,
kubernetes_version=self.config.azure.kubernetes_version,
authorized_ip_ranges=self.config.azure.authorized_ip_ranges,
node_groups={
name: AzureNodeGroupInputVars(
instance=node_group.instance,
Expand Down
1 change: 1 addition & 0 deletions src/_nebari/stages/infrastructure/template/azure/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ module "kubernetes" {
kubernetes_version = var.kubernetes_version
tags = var.tags
max_pods = var.max_pods
authorized_ip_ranges = var.authorized_ip_ranges

network_profile = var.network_profile

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ resource "azurerm_kubernetes_cluster" "main" {
location = var.location
resource_group_name = var.resource_group_name
tags = var.tags
api_server_access_profile {
authorized_ip_ranges = var.authorized_ip_ranges
}

# To enable Azure AD Workload Identity oidc_issuer_enabled must be set to true.
oidc_issuer_enabled = var.workload_identity_enabled
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ variable "aad_access_control" {
nullable = false
}

variable "authorized_ip_ranges" {
description = "The ip range allowed to access the Kubernetes API server, defaults to 0.0.0.0/0"
type = list(string)
default = ["0.0.0.0/0"]
}

variable "azure_policy_enabled" {
description = "Enable Azure Policy"
type = bool
Expand Down
6 changes: 6 additions & 0 deletions src/_nebari/stages/infrastructure/template/azure/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ variable "aad_access_control" {
}
}

variable "authorized_ip_ranges" {
description = "The ip range allowed to access the Kubernetes API server, defaults to 0.0.0.0/0"
type = list(string)
default = ["0.0.0.0/0"]
}

variable "azure_policy_enabled" {
description = "Enable Azure Policy"
type = bool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from pathlib import Path

import requests
from conda_store_server import api, orm, schema
from conda_store_server import api
from conda_store_server._internal import schema
from conda_store_server._internal.server.dependencies import get_conda_store
from conda_store_server.server.auth import GenericOAuthAuthentication
from conda_store_server.server.dependencies import get_conda_store
from conda_store_server.storage import S3Storage


Expand Down Expand Up @@ -422,8 +423,7 @@ async def authenticate(self, request):
for namespace in namespaces:
_namespace = api.get_namespace(db, name=namespace)
if _namespace is None:
db.add(orm.Namespace(name=namespace))
db.commit()
api.ensure_namespace(db, name=namespace)

return schema.AuthenticationToken(
primary_namespace=username,
Expand Down
15 changes: 13 additions & 2 deletions src/_nebari/subcommands/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@

@hookimpl
def nebari_subcommand(cli: typer.Typer):
EXTERNAL_PLUGIN_STYLE = "cyan"

@cli.command()
def info(ctx: typer.Context):
from nebari.plugins import nebari_plugin_manager

rich.print(f"Nebari version: {__version__}")

external_plugins = nebari_plugin_manager.get_external_plugins()

hooks = collections.defaultdict(list)
for plugin in nebari_plugin_manager.plugin_manager.get_plugins():
for hook in nebari_plugin_manager.plugin_manager.get_hookcallers(plugin):
Expand All @@ -27,7 +31,8 @@ def info(ctx: typer.Context):

for hook_name, modules in hooks.items():
for module in modules:
table.add_row(hook_name, module)
style = EXTERNAL_PLUGIN_STYLE if module in external_plugins else None
table.add_row(hook_name, module, style=style)

rich.print(table)

Expand All @@ -36,8 +41,14 @@ def info(ctx: typer.Context):
table.add_column("priority")
table.add_column("module")
for stage in nebari_plugin_manager.ordered_stages:
style = (
EXTERNAL_PLUGIN_STYLE if stage.__module__ in external_plugins else None
)
table.add_row(
stage.name, str(stage.priority), f"{stage.__module__}.{stage.__name__}"
stage.name,
str(stage.priority),
f"{stage.__module__}.{stage.__name__}",
style=style,
)

rich.print(table)
9 changes: 9 additions & 0 deletions src/_nebari/subcommands/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class InitInputs(schema.Base):
region: Optional[str] = None
ssl_cert_email: Optional[schema.email_pydantic] = None
disable_prompt: bool = False
config_set: Optional[str] = None
output: pathlib.Path = pathlib.Path("nebari-config.yaml")
explicit: int = 0

Expand Down Expand Up @@ -134,6 +135,7 @@ def handle_init(inputs: InitInputs, config_schema: BaseModel):
terraform_state=inputs.terraform_state,
ssl_cert_email=inputs.ssl_cert_email,
disable_prompt=inputs.disable_prompt,
config_set=inputs.config_set,
)

try:
Expand Down Expand Up @@ -496,6 +498,12 @@ def init(
False,
is_eager=True,
),
config_set: str = typer.Option(
None,
"--config-set",
"-s",
help="Apply a pre-defined set of nebari configuration options.",
),
output: str = typer.Option(
pathlib.Path("nebari-config.yaml"),
"--output",
Expand Down Expand Up @@ -554,6 +562,7 @@ def init(
inputs.terraform_state = terraform_state
inputs.ssl_cert_email = ssl_cert_email
inputs.disable_prompt = disable_prompt
inputs.config_set = config_set
inputs.output = output
inputs.explicit = explicit

Expand Down
42 changes: 42 additions & 0 deletions src/_nebari/subcommands/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from importlib.metadata import version

import rich
import typer
from rich.table import Table

from nebari.hookspecs import hookimpl


@hookimpl
def nebari_subcommand(cli: typer.Typer):
plugin_cmd = typer.Typer(
add_completion=False,
no_args_is_help=True,
rich_markup_mode="rich",
context_settings={"help_option_names": ["-h", "--help"]},
)

cli.add_typer(
plugin_cmd,
name="plugin",
help="Interact with nebari plugins",
rich_help_panel="Additional Commands",
)

@plugin_cmd.command()
def list(ctx: typer.Context):
"""
List installed plugins
"""
from nebari.plugins import nebari_plugin_manager

external_plugins = nebari_plugin_manager.get_external_plugins()

table = Table(title="Plugins")
table.add_column("name", justify="left", no_wrap=True)
table.add_column("version", justify="left", no_wrap=True)

for plugin in external_plugins:
table.add_row(plugin, version(plugin))

rich.print(table)
Loading

0 comments on commit 10bfa09

Please sign in to comment.