Skip to content

Commit

Permalink
Added retry to ECS Operator (apache#14263)
Browse files Browse the repository at this point in the history
* Added retry to ECS Operator

* ...

* Remove airflow/www/yarn-error.log

* Update decorator to not accept any params

* ...

* ...

* ...

* lint

* Add predicate argument in retry decorator

* Add wraps and fixed test

* ...

* Remove unnecessary retry_if_permissible_error and fix lint errors

* Static check fixes

* Fix TestECSOperator.test_execute_with_failures
  • Loading branch information
markhopson authored Mar 26, 2021
1 parent a7f2cc2 commit 614be87
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 4 deletions.
29 changes: 29 additions & 0 deletions airflow/providers/amazon/aws/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# Note: Any AirflowException raised is expected to cause the TaskInstance
# to be marked in an ERROR state


class ECSOperatorError(Exception):
"""Raise when ECS cannot handle the request."""

def __init__(self, failures: list, message: str):
self.failures = failures
self.message = message
super().__init__(message)
35 changes: 34 additions & 1 deletion airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
import configparser
import datetime
import logging
from typing import Any, Dict, Optional, Tuple, Union
from functools import wraps
from typing import Any, Callable, Dict, Optional, Tuple, Union

import boto3
import botocore
import botocore.session
import tenacity
from botocore.config import Config
from botocore.credentials import ReadOnlyCredentials

Expand Down Expand Up @@ -488,6 +490,37 @@ def expand_role(self, role: str) -> str:
else:
return self.get_client_type("iam").get_role(RoleName=role)["Role"]["Arn"]

@staticmethod
def retry(should_retry: Callable[[Exception], bool]):
"""
A decorator that provides a mechanism to repeat requests in response to exceeding a temporary quote
limit.
"""

def retry_decorator(fun: Callable):
@wraps(fun)
def decorator_f(self, *args, **kwargs):
retry_args = getattr(self, 'retry_args', None)
if retry_args is None:
return fun(self)
multiplier = retry_args.get('multiplier', 1)
min_limit = retry_args.get('min', 1)
max_limit = retry_args.get('max', 1)
stop_after_delay = retry_args.get('stop_after_delay', 10)
tenacity_logger = tenacity.before_log(self.log, logging.DEBUG) if self.log else None
default_kwargs = {
'wait': tenacity.wait_exponential(multiplier=multiplier, max=max_limit, min=min_limit),
'retry': tenacity.retry_if_exception(should_retry),
'stop': tenacity.stop_after_delay(stop_after_delay),
'before': tenacity_logger,
'after': tenacity_logger,
}
return tenacity.retry(**default_kwargs)(fun)(self)

return decorator_f

return retry_decorator


def _parse_s3_config(
config_file_name: str, config_format: Optional[str] = "boto", profile: Optional[str] = None
Expand Down
19 changes: 18 additions & 1 deletion airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,24 @@

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.exceptions import ECSOperatorError
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.typing_compat import Protocol, runtime_checkable
from airflow.utils.decorators import apply_defaults


def should_retry(exception: Exception):
"""Check if exception is related to ECS resource quota (CPU, MEM)."""
if isinstance(exception, ECSOperatorError):
return any(
quota_reason in failure['reason']
for quota_reason in ['RESOURCE:MEMORY', 'RESOURCE:CPU']
for failure in exception.failures
)
return False


@runtime_checkable
class ECSProtocol(Protocol):
"""
Expand Down Expand Up @@ -125,6 +137,8 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
:param reattach: If set to True, will check if a task from the same family is already running.
If so, the operator will attach to it instead of starting a new task.
:type reattach: bool
:param quota_retry: Config if and how to retry _start_task() for transient errors.
:type quota_retry: dict
"""

ui_color = '#f0ede4'
Expand All @@ -150,6 +164,7 @@ def __init__(
awslogs_region: Optional[str] = None,
awslogs_stream_prefix: Optional[str] = None,
propagate_tags: Optional[str] = None,
quota_retry: Optional[dict] = None,
reattach: bool = False,
**kwargs,
):
Expand Down Expand Up @@ -180,6 +195,7 @@ def __init__(
self.hook: Optional[AwsBaseHook] = None
self.client: Optional[ECSProtocol] = None
self.arn: Optional[str] = None
self.retry_args = quota_retry

def execute(self, context):
self.log.info(
Expand All @@ -206,6 +222,7 @@ def execute(self, context):

return None

@AwsBaseHook.retry(should_retry)
def _start_task(self):
run_opts = {
'cluster': self.cluster,
Expand Down Expand Up @@ -235,7 +252,7 @@ def _start_task(self):

failures = response['failures']
if len(failures) > 0:
raise AirflowException(response)
raise ECSOperatorError(failures, response)
self.log.info('ECS Task started: %s', response)

self.arn = response['tasks'][0]['taskArn']
Expand Down
80 changes: 80 additions & 0 deletions tests/providers/amazon/aws/hooks/test_base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from unittest import mock

import boto3
import pytest

from airflow.models import Connection
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
Expand Down Expand Up @@ -266,3 +267,82 @@ def test_use_default_boto3_behaviour_without_conn_id(self):
hook = AwsBaseHook(aws_conn_id=conn_id, client_type='s3')
# should cause no exception
hook.get_client_type('s3')


class ThrowErrorUntilCount:
"""Holds counter state for invoking a method several times in a row."""

def __init__(self, count, quota_retry, **kwargs):
self.counter = 0
self.count = count
self.retry_args = quota_retry
self.kwargs = kwargs
self.log = None

def __call__(self):
"""
Raise an Forbidden until after count threshold has been crossed.
Then return True.
"""
if self.counter < self.count:
self.counter += 1
raise Exception()
return True


def _always_true_predicate(e: Exception): # pylint: disable=unused-argument
return True


@AwsBaseHook.retry(_always_true_predicate)
def _retryable_test(thing):
return thing()


def _always_false_predicate(e: Exception): # pylint: disable=unused-argument
return False


@AwsBaseHook.retry(_always_false_predicate)
def _non_retryable_test(thing):
return thing()


class TestRetryDecorator(unittest.TestCase): # ptlint: disable=invalid-name
def test_do_nothing_on_non_exception(self):
result = _retryable_test(lambda: 42)
assert result, 42

def test_retry_on_exception(self):
quota_retry = {
'stop_after_delay': 2,
'multiplier': 1,
'min': 1,
'max': 10,
}
custom_fn = ThrowErrorUntilCount(
count=2,
quota_retry=quota_retry,
)
result = _retryable_test(custom_fn)
assert custom_fn.counter == 2
assert result

def test_no_retry_on_exception(self):
quota_retry = {
'stop_after_delay': 2,
'multiplier': 1,
'min': 1,
'max': 10,
}
custom_fn = ThrowErrorUntilCount(
count=2,
quota_retry=quota_retry,
)
with pytest.raises(Exception):
_non_retryable_test(custom_fn)

def test_raise_exception_when_no_retry_args(self):
custom_fn = ThrowErrorUntilCount(count=2, quota_retry=None)
with pytest.raises(Exception):
_retryable_test(custom_fn)
13 changes: 11 additions & 2 deletions tests/providers/amazon/aws/operators/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from parameterized import parameterized

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.ecs import ECSOperator
from airflow.providers.amazon.aws.exceptions import ECSOperatorError
from airflow.providers.amazon.aws.operators.ecs import ECSOperator, should_retry

# fmt: off
RESPONSE_WITHOUT_FAILURES = {
Expand Down Expand Up @@ -145,7 +146,7 @@ def test_execute_with_failures(self):
resp_failures['failures'].append('dummy error')
client_mock.run_task.return_value = resp_failures

with pytest.raises(AirflowException):
with pytest.raises(ECSOperatorError):
self.ecs.execute(None)

self.aws_hook_mock.return_value.get_conn.assert_called_once()
Expand Down Expand Up @@ -326,3 +327,11 @@ def test_execute_xcom_with_no_log(self, mock_cloudwatch_log_message):
def test_execute_xcom_disabled(self, mock_cloudwatch_log_message):
self.ecs.do_xcom_push = False
assert self.ecs.execute(None) is None


class TestShouldRetry(unittest.TestCase):
def test_return_true_on_valid_reason(self):
self.assertTrue(should_retry(ECSOperatorError([{'reason': 'RESOURCE:MEMORY'}], 'Foo')))

def test_return_false_on_invalid_reason(self):
self.assertFalse(should_retry(ECSOperatorError([{'reason': 'CLUSTER_NOT_FOUND'}], 'Foo')))

0 comments on commit 614be87

Please sign in to comment.