diff --git a/airflow/www/views.py b/airflow/www/views.py index a1e61721f28ae..907c103fad509 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -22,6 +22,7 @@ import json import logging import math +import re import socket import sys import traceback @@ -78,6 +79,7 @@ from pygments import highlight, lexers from pygments.formatters import HtmlFormatter # noqa pylint: disable=no-name-in-module from sqlalchemy import Date, and_, desc, func, or_, union_all +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import joinedload from wtforms import SelectField, validators from wtforms.validators import InputRequired @@ -3124,6 +3126,7 @@ class ConnectionModelView(AirflowModelView): 'edit': 'edit', 'delete': 'delete', 'action_muldelete': 'delete', + 'action_mulduplicate': 'create', } base_permissions = [ @@ -3177,6 +3180,56 @@ def action_muldelete(self, items): self.update_redirect() return redirect(self.get_redirect()) + @action( + 'mulduplicate', + 'Duplicate', + 'Are you sure you want to duplicate the selected connections?', + single=False, + ) + @provide_session + @auth.has_access( + [ + (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION), + ] + ) + def action_mulduplicate(self, connections, session=None): + """Duplicate Multiple connections""" + for selected_conn in connections: + new_conn_id = selected_conn.conn_id + match = re.search(r"_copy(\d+)$", selected_conn.conn_id) + if match: + conn_id_prefix = selected_conn.conn_id[: match.start()] + new_conn_id = f"{conn_id_prefix}_copy{int(match.group(1)) + 1}" + else: + new_conn_id += '_copy1' + + dup_conn = Connection( + new_conn_id, + selected_conn.conn_type, + selected_conn.description, + selected_conn.host, + selected_conn.login, + selected_conn.password, + selected_conn.schema, + selected_conn.port, + selected_conn.extra, + ) + + try: + session.add(dup_conn) + session.commit() + flash(f"Connection {new_conn_id} added successfully.", "success") + except IntegrityError: + flash( + f"Connection {new_conn_id} can't be added. Integrity error, probably unique constraint.", + "warning", + ) + session.rollback() + + self.update_redirect() + return redirect(self.get_redirect()) + def process_form(self, form, is_created): """Process form data.""" conn_type = form.data['conn_type'] diff --git a/tests/www/views/test_views_connection.py b/tests/www/views/test_views_connection.py index f22645e42e012..557697729a1be 100644 --- a/tests/www/views/test_views_connection.py +++ b/tests/www/views/test_views_connection.py @@ -54,3 +54,45 @@ def test_prefill_form_null_extra(): cmv = ConnectionModelView() cmv.prefill_form(form=mock_form, pk=1) + + +def test_duplicate_connection(admin_client): + """Test Duplicate multiple connection with suffix""" + conn1 = Connection( + conn_id='test_duplicate_gcp_connection', + conn_type='Google Cloud', + description='Google Cloud Connection', + ) + conn2 = Connection( + conn_id='test_duplicate_mysql_connection', + conn_type='FTP', + description='MongoDB2', + host='localhost', + schema='airflow', + port=3306, + ) + conn3 = Connection( + conn_id='test_duplicate_postgres_connection_copy1', + conn_type='FTP', + description='Postgres', + host='localhost', + schema='airflow', + port=3306, + ) + with create_session() as session: + session.query(Connection).delete() + session.add_all([conn1, conn2, conn3]) + session.commit() + + data = {"action": "mulduplicate", "rowid": [conn1.id, conn3.id]} + resp = admin_client.post('/connection/action_post', data=data, follow_redirects=True) + expected_result = { + 'test_duplicate_gcp_connection', + 'test_duplicate_gcp_connection_copy1', + 'test_duplicate_mysql_connection', + 'test_duplicate_postgres_connection_copy1', + 'test_duplicate_postgres_connection_copy2', + } + response = {conn[0] for conn in session.query(Connection.conn_id).all()} + assert resp.status_code == 200 + assert expected_result == response