Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support loading default_args from shared defaults.yml #330

Merged
merged 14 commits into from
Jan 3, 2025
30 changes: 26 additions & 4 deletions dagfactory/dagfactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,27 @@ class DagFactory:
:type config: dict
"""

def __init__(self, config_filepath: Optional[str] = None, config: Optional[dict] = None) -> None:
def __init__(
self,
config_filepath: Optional[str] = None,
config: Optional[dict] = None,
default_args_config_path: str = airflow_conf.get("core", "dags_folder"),
) -> None:
assert bool(config_filepath) ^ bool(config), "Either `config_filepath` or `config` should be provided"
self.default_args_config_path = default_args_config_path
if config_filepath:
DagFactory._validate_config_filepath(config_filepath=config_filepath)
self.config: Dict[str, Any] = DagFactory._load_config(config_filepath=config_filepath)
if config:
self.config: Dict[str, Any] = config

def _global_default_args(self):
default_args_yml = Path(self.default_args_config_path) / "defaults.yml"

if default_args_yml.exists():
with open(default_args_yml, "r") as file:
return yaml.safe_load(file)

@staticmethod
def _serialise_config_md(dag_name, dag_config, default_config):
# Remove empty task_groups if it exists
Expand Down Expand Up @@ -111,8 +124,15 @@ def get_default_config(self) -> Dict[str, Any]:
def build_dags(self) -> Dict[str, DAG]:
"""Build DAGs using the config file."""
dag_configs: Dict[str, Dict[str, Any]] = self.get_dag_configs()
global_default_args = self._global_default_args()
default_config: Dict[str, Any] = self.get_default_config()

if global_default_args is not None:
if "default_args" in default_config and "default_args" in global_default_args:
default_config = {
"default_args": {**global_default_args["default_args"], **default_config["default_args"]}
}

dags: Dict[str, Any] = {}

for dag_name, dag_config in dag_configs.items():
Expand Down Expand Up @@ -179,6 +199,7 @@ def clean_dags(self, globals: Dict[str, Any]) -> None:
def load_yaml_dags(
globals_dict: Dict[str, Any],
dags_folder: str = airflow_conf.get("core", "dags_folder"),
default_args_config_path: str = airflow_conf.get("core", "dags_folder"),
suffix=None,
):
"""
Expand All @@ -189,8 +210,9 @@ def load_yaml_dags(
interesting to load only a subset by setting a different suffix.

:param globals_dict: The globals() from the file used to generate DAGs
:dags_folder: Path to the folder you want to get recursively scanned
:suffix: file suffix to filter `in` what files to scan for dags
:param dags_folder: Path to the folder you want to get recursively scanned
:param default_args_config_path: The Folder path where defaults.yml exist.
:param suffix: file suffix to filter `in` what files to scan for dags
"""
# chain all file suffixes in a single iterator
logging.info("Loading DAGs from %s", dags_folder)
Expand All @@ -203,7 +225,7 @@ def load_yaml_dags(
config_file_abs_path = str(config_file_path.absolute())
logging.info("Loading %s", config_file_abs_path)
try:
factory = DagFactory(config_file_abs_path)
factory = DagFactory(config_file_abs_path, default_args_config_path=default_args_config_path)
factory.generate_dags(globals_dict)
except Exception: # pylint: disable=broad-except
logging.exception("Failed to load dag from %s", config_file_path)
Expand Down
3 changes: 3 additions & 0 deletions dev/dags/defaults.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
default_args:
start_date: "2025-01-01"
owner: "global_owner"
4 changes: 4 additions & 0 deletions tests/fixtures/defaults.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
default_args:
start_date: "2025-01-01"
owner: "global_owner"
depends_on_past: true
9 changes: 9 additions & 0 deletions tests/test_dagfactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
TEST_DAG_FACTORY = os.path.join(here, "fixtures/dag_factory.yml")
INVALID_YAML = os.path.join(here, "fixtures/invalid_yaml.yml")
INVALID_DAG_FACTORY = os.path.join(here, "fixtures/invalid_dag_factory.yml")
DEFAULT_ARGS_CONFIG_ROOT = os.path.join(here, "fixtures/")
DAG_FACTORY_KUBERNETES_POD_OPERATOR = os.path.join(here, "fixtures/dag_factory_kubernetes_pod_operator.yml")
DAG_FACTORY_VARIABLES_AS_ARGUMENTS = os.path.join(here, "fixtures/dag_factory_variables_as_arguments.yml")

Expand Down Expand Up @@ -448,6 +449,14 @@ def test_set_callback_after_loading_config():
td.generate_dags(globals())


def test_build_dag_with_global_default():
dags = dagfactory.DagFactory(
config=DAG_FACTORY_CONFIG, default_args_config_path=DEFAULT_ARGS_CONFIG_ROOT
).build_dags()

assert dags.get("example_dag").tasks[0].depends_on_past == True


def test_load_invalid_yaml_logs_error(caplog):
caplog.set_level(logging.ERROR)
load_yaml_dags(
Expand Down
Loading