Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: invalidate cache on bad connection info and IP lookup #1118

Merged
merged 2 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion google/cloud/sql/connector/connection_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,5 @@ def get_preferred_ip(self, ip_type: IPTypes) -> str:
return self.ip_addrs[ip_type.value]
raise CloudSQLIPTypeError(
"Cloud SQL instance does not have any IP addresses matching "
f"preference: {ip_type.value})"
f"preference: {ip_type.value}"
)
72 changes: 44 additions & 28 deletions google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,41 +340,46 @@ async def connect_async(
kwargs.pop("ssl", None)
kwargs.pop("port", None)

# attempt to make connection to Cloud SQL instance
# attempt to get connection info for Cloud SQL instance
try:
conn_info = await cache.connect_info()
# validate driver matches intended database engine
DriverMapping.validate_engine(driver, conn_info.database_version)
ip_address = conn_info.get_preferred_ip(ip_type)
# resolve DNS name into IP address for PSC
if ip_type.value == "PSC":
addr_info = await self._loop.getaddrinfo(
ip_address, None, family=socket.AF_INET, type=socket.SOCK_STREAM
)
# getaddrinfo returns a list of 5-tuples that contain socket
# connection info in the form
# (family, type, proto, canonname, sockaddr), where sockaddr is a
# 2-tuple in the form (ip_address, port)
try:
ip_address = addr_info[0][4][0]
except IndexError as e:
raise DnsNameResolutionError(
f"['{instance_connection_string}']: DNS name could not be resolved into IP address"
) from e
logger.debug(
f"['{instance_connection_string}']: Connecting to {ip_address}:3307"
except Exception:
# with an error from Cloud SQL Admin API call or IP type, invalidate
# the cache and re-raise the error
await self._remove_cached(instance_connection_string)
raise
# resolve DNS name into IP address for PSC
if ip_type.value == "PSC":
addr_info = await self._loop.getaddrinfo(
ip_address, None, family=socket.AF_INET, type=socket.SOCK_STREAM
)
# getaddrinfo returns a list of 5-tuples that contain socket
# connection info in the form
# (family, type, proto, canonname, sockaddr), where sockaddr is a
# 2-tuple in the form (ip_address, port)
try:
ip_address = addr_info[0][4][0]
except IndexError as e:
raise DnsNameResolutionError(
f"['{instance_connection_string}']: DNS name could not be resolved into IP address"
) from e
logger.debug(
f"['{instance_connection_string}']: Connecting to {ip_address}:3307"
)
# format `user` param for automatic IAM database authn
if enable_iam_auth:
formatted_user = format_database_user(
conn_info.database_version, kwargs["user"]
)
# format `user` param for automatic IAM database authn
if enable_iam_auth:
formatted_user = format_database_user(
conn_info.database_version, kwargs["user"]
if formatted_user != kwargs["user"]:
logger.debug(
f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}"
)
if formatted_user != kwargs["user"]:
logger.debug(
f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}"
)
kwargs["user"] = formatted_user

kwargs["user"] = formatted_user
try:
# async drivers are unblocking and can be awaited directly
if driver in ASYNC_DRIVERS:
return await connector(
Expand All @@ -396,6 +401,17 @@ async def connect_async(
await cache.force_refresh()
raise

async def _remove_cached(self, instance_connection_string: str) -> None:
"""Stops all background refreshes and deletes the connection
info cache from the map of caches.
"""
logger.debug(
f"['{instance_connection_string}']: Removing connection info from cache"
)
# remove cache from stored caches and close it
cache = self._cache.pop(instance_connection_string)
await cache.close()

def __enter__(self) -> Any:
"""Enter context manager by returning Connector object"""
return self
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import asyncio
from typing import Union

from aiohttp import ClientResponseError
from google.auth.credentials import Credentials
from mock import patch
import pytest # noqa F401 Needed to run the tests
Expand All @@ -25,6 +26,7 @@
from google.cloud.sql.connector import create_async_connector
from google.cloud.sql.connector import IPTypes
from google.cloud.sql.connector.client import CloudSQLClient
from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError
from google.cloud.sql.connector.exceptions import ConnectorLoopError
from google.cloud.sql.connector.exceptions import IncompatibleDriverError
from google.cloud.sql.connector.instance import RefreshAheadCache
Expand Down Expand Up @@ -305,6 +307,60 @@ def test_Connector_close_called_multiple_times(fake_credentials: Credentials) ->
connector.close()


async def test_Connector_remove_cached_bad_instance(
fake_credentials: Credentials, fake_client: CloudSQLClient
) -> None:
"""When a Connector attempts to retrieve connection info for a
non-existent instance, it should delete the instance from
the cache and ensure no background refresh happens (which would be
wasted cycles).
"""
async with Connector(
credentials=fake_credentials, loop=asyncio.get_running_loop()
) as connector:
conn_name = "bad-project:bad-region:bad-inst"
# populate cache
cache = RefreshAheadCache(conn_name, fake_client, connector._keys)
connector._cache[conn_name] = cache
# aiohttp client should throw a 404 ClientResponseError
with pytest.raises(ClientResponseError):
await connector.connect_async(
conn_name,
"pg8000",
)
# check that cache has been removed from dict
assert conn_name not in connector._cache


async def test_Connector_remove_cached_no_ip_type(
fake_credentials: Credentials, fake_client: CloudSQLClient
) -> None:
"""When a Connector attempts to connect and preferred IP type is not present,
it should delete the instance from the cache and ensure no background refresh
happens (which would be wasted cycles).
"""
# set instance to only have public IP
fake_client.instance.ip_addrs = {"PRIMARY": "127.0.0.1"}
async with Connector(
credentials=fake_credentials, loop=asyncio.get_running_loop()
) as connector:
conn_name = "test-project:test-region:test-instance"
# populate cache
cache = RefreshAheadCache(conn_name, fake_client, connector._keys)
connector._cache[conn_name] = cache
# test instance does not have Private IP, thus should invalidate cache
with pytest.raises(CloudSQLIPTypeError):
await connector.connect_async(
conn_name,
"pg8000",
user="my-user",
password="my-pass",
ip_type="private",
)
# check that cache has been removed from dict
assert conn_name not in connector._cache


def test_default_universe_domain(fake_credentials: Credentials) -> None:
"""Test that default universe domain and constructed service endpoint are
formatted correctly.
Expand Down
Loading