Skip to content

Commit

Permalink
Extract device info function and test error checking
Browse files Browse the repository at this point in the history
  • Loading branch information
balloob committed Jul 6, 2023
1 parent 775555f commit 34bfdb1
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 82 deletions.
164 changes: 82 additions & 82 deletions homeassistant/helpers/entity_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from homeassistant.exceptions import (
HomeAssistantError,
PlatformNotReady,
RequiredParameterMissing,
)
from homeassistant.generated import languages
from homeassistant.setup import async_start_setup
Expand All @@ -42,14 +41,13 @@
service,
translation,
)
from .device_registry import DeviceRegistry
from .entity_registry import EntityRegistry, RegistryEntryDisabler, RegistryEntryHider
from .event import async_call_later, async_track_time_interval
from .issue_registry import IssueSeverity, async_create_issue
from .typing import UNDEFINED, ConfigType, DiscoveryInfoType

if TYPE_CHECKING:
from .entity import Entity
from .entity import DeviceInfo, Entity


SLOW_SETUP_WARNING = 10
Expand Down Expand Up @@ -513,12 +511,9 @@ async def async_add_entities(

hass = self.hass

device_registry = dev_reg.async_get(hass)
entity_registry = ent_reg.async_get(hass)
tasks = [
self._async_add_entity(
entity, update_before_add, entity_registry, device_registry
)
self._async_add_entity(entity, update_before_add, entity_registry)
for entity in new_entities
]

Expand Down Expand Up @@ -580,7 +575,6 @@ async def _async_add_entity( # noqa: C901
entity: Entity,
update_before_add: bool,
entity_registry: EntityRegistry,
device_registry: DeviceRegistry,
) -> None:
"""Add an entity to the platform."""
if entity is None:
Expand Down Expand Up @@ -637,81 +631,10 @@ async def _async_add_entity( # noqa: C901
return

device_info = entity.device_info
device_id = None
device = None
device: dev_reg.DeviceEntry | None = None

if self.config_entry and device_info is not None:
processed_dev_info: dict[str, str | None] = {}
for key in (
"connections",
"default_manufacturer",
"default_model",
"default_name",
"entry_type",
"identifiers",
"manufacturer",
"model",
"name",
"suggested_area",
"sw_version",
"hw_version",
"via_device",
):
if key in device_info:
processed_dev_info[key] = device_info[
key # type: ignore[literal-required]
]

keys = set(processed_dev_info)
entity_type: str | None = None

for possible_type, allowed_keys in DEVICE_INFO_TYPES.items():
if keys <= allowed_keys:
entity_type = possible_type
break

if entity_type is None:
raise HomeAssistantError(
"Device info needs to either describe a device, "
"link to existing device or provide extra information."
)

if (
# device info that is purely meant for linking doesn't need default name
any(
key not in {"identifiers", "connections"}
for key in (processed_dev_info)
)
and "default_name" not in processed_dev_info
and not processed_dev_info.get("name")
):
processed_dev_info["name"] = self.config_entry.title

if "configuration_url" in device_info:
if device_info["configuration_url"] is None:
processed_dev_info["configuration_url"] = None
else:
configuration_url = str(device_info["configuration_url"])
if urlparse(configuration_url).scheme in [
"http",
"https",
"homeassistant",
]:
processed_dev_info["configuration_url"] = configuration_url
else:
_LOGGER.warning(
"Ignoring invalid device configuration_url '%s'",
configuration_url,
)

try:
device = device_registry.async_get_or_create(
config_entry_id=self.config_entry.entry_id,
**processed_dev_info, # type: ignore[arg-type]
)
device_id = device.id
except RequiredParameterMissing:
pass
device = self._async_process_device_info(device_info)

# An entity may suggest the entity_id by setting entity_id itself
suggested_entity_id: str | None = entity.entity_id
Expand Down Expand Up @@ -746,7 +669,7 @@ async def _async_add_entity( # noqa: C901
entity.unique_id,
capabilities=entity.capability_attributes,
config_entry=self.config_entry,
device_id=device_id,
device_id=device.id if device else None,
disabled_by=disabled_by,
entity_category=entity.entity_category,
get_initial_options=entity.get_initial_entity_options,
Expand Down Expand Up @@ -834,6 +757,83 @@ def remove_entity_cb() -> None:

await entity.add_to_platform_finish()

@callback
def _async_process_device_info(
self, device_info: DeviceInfo
) -> dev_reg.DeviceEntry | None:
"""Process a device info."""
processed_dev_info: DeviceInfo = {}
for key in (
"connections",
"default_manufacturer",
"default_model",
"default_name",
"entry_type",
"hw_version",
"identifiers",
"manufacturer",
"model",
"name",
"suggested_area",
"sw_version",
"via_device",
):
if key in device_info:
processed_dev_info[key] = device_info[key] # type: ignore [literal-required]

if "configuration_url" in device_info:
if device_info["configuration_url"] is None:
processed_dev_info["configuration_url"] = None
else:
configuration_url = str(device_info["configuration_url"])
if urlparse(configuration_url).scheme in [
"http",
"https",
"homeassistant",
]:
processed_dev_info["configuration_url"] = configuration_url
else:
_LOGGER.warning(
"Ignoring invalid device configuration_url '%s'",
configuration_url,
)

keys = set(processed_dev_info)

# If no keys or not enough info to match up, abort
if not keys or len(keys & {"connections", "identifiers"}) == 0:
return None

device_info_type: str | None = None

for possible_type, allowed_keys in DEVICE_INFO_TYPES.items():
if keys <= allowed_keys:
device_info_type = possible_type
break

if device_info_type is None:
self.logger.error(
"Device info for %s needs to either describe a device, "
"link to existing device or provide extra information.",
device_info,
)
return None

assert self.config_entry is not None

if (
# device info that is purely meant for linking doesn't need default name
device_info_type != "link"
and "default_name" not in processed_dev_info
and not processed_dev_info.get("name")
):
processed_dev_info["name"] = self.config_entry.title

return dev_reg.async_get(self.hass).async_get_or_create(
config_entry_id=self.config_entry.entry_id,
**processed_dev_info,
)

async def async_reset(self) -> None:
"""Remove all entities and reset data.
Expand Down
41 changes: 41 additions & 0 deletions tests/helpers/test_entity_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,3 +1878,44 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
device = dev_reg.async_get_device(set(), {(dr.CONNECTION_NETWORK_MAC, "1234")})
assert device is not None
assert device.name == expected_device_name


@pytest.mark.parametrize(
("device_info"),
[
{},
{"name": "bla"},
{"default_name": "bla"},
{
"name": "bla",
"default_name": "yo",
},
],
)
async def test_device_type_checking(
hass: HomeAssistant,
device_info: dict,
) -> None:
"""Test catching invalid device info."""

class DeviceNameEntity(Entity):
_attr_unique_id = "qwer"
_attr_device_info = device_info

async def async_setup_entry(hass, config_entry, async_add_entities):
"""Mock setup entry method."""
async_add_entities([DeviceNameEntity()])
return True

platform = MockPlatform(async_setup_entry=async_setup_entry)
config_entry = MockConfigEntry(
title="Mock Config Entry Title", entry_id="super-mock-id"
)
entity_platform = MockEntityPlatform(
hass, platform_name=config_entry.domain, platform=platform
)

assert await entity_platform.async_setup_entry(config_entry)

dev_reg = dr.async_get(hass)
assert len(dev_reg.devices) == 0

0 comments on commit 34bfdb1

Please sign in to comment.