diff --git a/src/_nebari/stages/infrastructure/__init__.py b/src/_nebari/stages/infrastructure/__init__.py index e046712334..1e34cb05ef 100644 --- a/src/_nebari/stages/infrastructure/__init__.py +++ b/src/_nebari/stages/infrastructure/__init__.py @@ -5,8 +5,7 @@ import re import sys import tempfile -import typing -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type, Union import pydantic @@ -52,9 +51,9 @@ class DigitalOceanInputVars(schema.Base): name: str environment: str region: str - tags: typing.List[str] + tags: List[str] kubernetes_version: str - node_groups: typing.Dict[str, DigitalOceanNodeGroup] + node_groups: Dict[str, DigitalOceanNodeGroup] kubeconfig_filename: str = get_kubeconfig_filename() @@ -143,6 +142,7 @@ class AWSInputVars(schema.Base): vpc_cidr_block: str permissions_boundary: Optional[str] = None kubeconfig_filename: str = get_kubeconfig_filename() + tags: Dict[str, str] = {} def _calculate_node_groups(config: schema.Main): @@ -216,7 +216,7 @@ class DigitalOceanProvider(schema.Base): region: str kubernetes_version: str # Digital Ocean image slugs are listed here https://slugs.do-api.dev/ - node_groups: typing.Dict[str, DigitalOceanNodeGroup] = { + node_groups: Dict[str, DigitalOceanNodeGroup] = { "general": DigitalOceanNodeGroup( instance="g-8vcpu-32gb", min_nodes=1, max_nodes=1 ), @@ -227,7 +227,7 @@ class DigitalOceanProvider(schema.Base): instance="g-4vcpu-16gb", min_nodes=1, max_nodes=5 ), } - tags: typing.Optional[typing.List[str]] = [] + tags: Optional[List[str]] = [] @pydantic.validator("region") def _validate_region(cls, value): @@ -289,7 +289,7 @@ class GCPCIDRBlock(schema.Base): class GCPMasterAuthorizedNetworksConfig(schema.Base): - cidr_blocks: typing.List[GCPCIDRBlock] + cidr_blocks: List[GCPCIDRBlock] class GCPPrivateClusterConfig(schema.Base): @@ -314,34 +314,28 @@ class GCPNodeGroup(schema.Base): min_nodes: pydantic.conint(ge=0) = 0 max_nodes: pydantic.conint(ge=1) = 1 preemptible: bool = False - labels: typing.Dict[str, str] = {} - guest_accelerators: typing.List[GCPGuestAccelerator] = [] + labels: Dict[str, str] = {} + guest_accelerators: List[GCPGuestAccelerator] = [] class GoogleCloudPlatformProvider(schema.Base): region: str project: str kubernetes_version: str - availability_zones: typing.Optional[typing.List[str]] = [] + availability_zones: Optional[List[str]] = [] release_channel: str = constants.DEFAULT_GKE_RELEASE_CHANNEL - node_groups: typing.Dict[str, GCPNodeGroup] = { + node_groups: Dict[str, GCPNodeGroup] = { "general": GCPNodeGroup(instance="n1-standard-8", min_nodes=1, max_nodes=1), "user": GCPNodeGroup(instance="n1-standard-4", min_nodes=0, max_nodes=5), "worker": GCPNodeGroup(instance="n1-standard-4", min_nodes=0, max_nodes=5), } - tags: typing.Optional[typing.List[str]] = [] + tags: Optional[List[str]] = [] networking_mode: str = "ROUTE" network: str = "default" - subnetwork: typing.Optional[typing.Union[str, None]] = None - ip_allocation_policy: typing.Optional[ - typing.Union[GCPIPAllocationPolicy, None] - ] = None - master_authorized_networks_config: typing.Optional[ - typing.Union[GCPCIDRBlock, None] - ] = None - private_cluster_config: typing.Optional[ - typing.Union[GCPPrivateClusterConfig, None] - ] = None + subnetwork: Optional[Union[str, None]] = None + ip_allocation_policy: Optional[Union[GCPIPAllocationPolicy, None]] = None + master_authorized_networks_config: Optional[Union[GCPCIDRBlock, None]] = None + private_cluster_config: Optional[Union[GCPPrivateClusterConfig, None]] = None @pydantic.root_validator def validate_all(cls, values): @@ -381,18 +375,18 @@ class AzureProvider(schema.Base): kubernetes_version: str storage_account_postfix: str resource_group_name: str = None - node_groups: typing.Dict[str, AzureNodeGroup] = { + node_groups: Dict[str, AzureNodeGroup] = { "general": AzureNodeGroup(instance="Standard_D8_v3", min_nodes=1, max_nodes=1), "user": AzureNodeGroup(instance="Standard_D4_v3", min_nodes=0, max_nodes=5), "worker": AzureNodeGroup(instance="Standard_D4_v3", min_nodes=0, max_nodes=5), } storage_account_postfix: str - vnet_subnet_id: typing.Optional[typing.Union[str, None]] = None + vnet_subnet_id: Optional[Union[str, None]] = None private_cluster_enabled: bool = False - resource_group_name: typing.Optional[str] = None - tags: typing.Optional[typing.Dict[str, str]] = {} - network_profile: typing.Optional[typing.Dict[str, str]] = None - max_pods: typing.Optional[int] = None + resource_group_name: Optional[str] = None + tags: Optional[Dict[str, str]] = {} + network_profile: Optional[Dict[str, str]] = None + max_pods: Optional[int] = None @pydantic.validator("kubernetes_version") def _validate_kubernetes_version(cls, value): @@ -440,8 +434,8 @@ class AWSNodeGroup(schema.Base): class AmazonWebServicesProvider(schema.Base): region: str kubernetes_version: str - availability_zones: typing.Optional[typing.List[str]] - node_groups: typing.Dict[str, AWSNodeGroup] = { + availability_zones: Optional[List[str]] + node_groups: Dict[str, AWSNodeGroup] = { "general": AWSNodeGroup(instance="m5.2xlarge", min_nodes=1, max_nodes=1), "user": AWSNodeGroup( instance="m5.xlarge", min_nodes=1, max_nodes=5, single_subnet=False @@ -450,10 +444,11 @@ class AmazonWebServicesProvider(schema.Base): instance="m5.xlarge", min_nodes=1, max_nodes=5, single_subnet=False ), } - existing_subnet_ids: typing.List[str] = None - existing_security_group_ids: str = None + existing_subnet_ids: List[str] = None + existing_security_group_id: str = None vpc_cidr_block: str = "10.10.0.0/16" permissions_boundary: Optional[str] = None + tags: Optional[Dict[str, str]] = {} @pydantic.root_validator def validate_all(cls, values): @@ -491,8 +486,8 @@ def validate_all(cls, values): class LocalProvider(schema.Base): - kube_context: typing.Optional[str] - node_selectors: typing.Dict[str, KeyValueDict] = { + kube_context: Optional[str] + node_selectors: Dict[str, KeyValueDict] = { "general": KeyValueDict(key="kubernetes.io/os", value="linux"), "user": KeyValueDict(key="kubernetes.io/os", value="linux"), "worker": KeyValueDict(key="kubernetes.io/os", value="linux"), @@ -500,8 +495,8 @@ class LocalProvider(schema.Base): class ExistingProvider(schema.Base): - kube_context: typing.Optional[str] - node_selectors: typing.Dict[str, KeyValueDict] = { + kube_context: Optional[str] + node_selectors: Dict[str, KeyValueDict] = { "general": KeyValueDict(key="kubernetes.io/os", value="linux"), "user": KeyValueDict(key="kubernetes.io/os", value="linux"), "worker": KeyValueDict(key="kubernetes.io/os", value="linux"), @@ -532,12 +527,12 @@ class ExistingProvider(schema.Base): class InputSchema(schema.Base): - local: typing.Optional[LocalProvider] - existing: typing.Optional[ExistingProvider] - google_cloud_platform: typing.Optional[GoogleCloudPlatformProvider] - amazon_web_services: typing.Optional[AmazonWebServicesProvider] - azure: typing.Optional[AzureProvider] - digital_ocean: typing.Optional[DigitalOceanProvider] + local: Optional[LocalProvider] + existing: Optional[ExistingProvider] + google_cloud_platform: Optional[GoogleCloudPlatformProvider] + amazon_web_services: Optional[AmazonWebServicesProvider] + azure: Optional[AzureProvider] + digital_ocean: Optional[DigitalOceanProvider] @pydantic.root_validator(pre=True) def check_provider(cls, values): @@ -580,20 +575,20 @@ class NodeSelectorKeyValue(schema.Base): class KubernetesCredentials(schema.Base): host: str cluster_ca_certifiate: str - token: typing.Optional[str] - username: typing.Optional[str] - password: typing.Optional[str] - client_certificate: typing.Optional[str] - client_key: typing.Optional[str] - config_path: typing.Optional[str] - config_context: typing.Optional[str] + token: Optional[str] + username: Optional[str] + password: Optional[str] + client_certificate: Optional[str] + client_key: Optional[str] + config_path: Optional[str] + config_context: Optional[str] class OutputSchema(schema.Base): node_selectors: Dict[str, NodeSelectorKeyValue] kubernetes_credentials: KubernetesCredentials kubeconfig_filename: str - nfs_endpoint: typing.Optional[str] + nfs_endpoint: Optional[str] class KubernetesInfrastructureStage(NebariTerraformStage): @@ -760,7 +755,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): name=self.config.escaped_project_name, environment=self.config.namespace, existing_subnet_ids=self.config.amazon_web_services.existing_subnet_ids, - existing_security_group_id=self.config.amazon_web_services.existing_security_group_ids, + existing_security_group_id=self.config.amazon_web_services.existing_security_group_id, region=self.config.amazon_web_services.region, kubernetes_version=self.config.amazon_web_services.kubernetes_version, node_groups=[ @@ -779,6 +774,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): availability_zones=self.config.amazon_web_services.availability_zones, vpc_cidr_block=self.config.amazon_web_services.vpc_cidr_block, permissions_boundary=self.config.amazon_web_services.permissions_boundary, + tags=self.config.amazon_web_services.tags, ).dict() else: raise ValueError(f"Unknown provider: {self.config.provider}") diff --git a/src/_nebari/stages/infrastructure/template/aws/locals.tf b/src/_nebari/stages/infrastructure/template/aws/locals.tf index d2a065dd75..c414a4b5a0 100644 --- a/src/_nebari/stages/infrastructure/template/aws/locals.tf +++ b/src/_nebari/stages/infrastructure/template/aws/locals.tf @@ -1,9 +1,11 @@ locals { - additional_tags = { - Project = var.name - Owner = "terraform" - Environment = var.environment - } - + additional_tags = merge( + { + Project = var.name + Owner = "terraform" + Environment = var.environment + }, + var.tags, + ) cluster_name = "${var.name}-${var.environment}" } diff --git a/src/_nebari/stages/infrastructure/template/aws/variables.tf b/src/_nebari/stages/infrastructure/template/aws/variables.tf index 0510dec5ac..c07c8f60f2 100644 --- a/src/_nebari/stages/infrastructure/template/aws/variables.tf +++ b/src/_nebari/stages/infrastructure/template/aws/variables.tf @@ -71,3 +71,9 @@ variable "permissions_boundary" { type = string default = null } + +variable "tags" { + description = "Additional tags to add to resources" + type = map(string) + default = {} +}