diff --git a/conftest.py b/conftest.py index 5f4e570c..e72673bd 100644 --- a/conftest.py +++ b/conftest.py @@ -39,6 +39,10 @@ async def async_client(test_app): async with AsyncClient( transport=ASGITransport(app=test_app), base_url="http://test" ) as client: + login_data = {"username": "admin@goosebit.local", "password": "admin"} + response = await client.post("/login", data=login_data, follow_redirects=True) + assert response.status_code == 200 + yield client @@ -68,20 +72,6 @@ async def test_data(db): hardware=compatibility, ) - device_latest = await Device.create( - uuid="device2", - last_state=UpdateStateEnum.REGISTERED, - update_mode=UpdateModeEnum.LATEST, - hardware=compatibility, - ) - - device_pinned = await Device.create( - uuid="device3", - last_state=UpdateStateEnum.REGISTERED, - update_mode=UpdateModeEnum.PINNED, - hardware=compatibility, - ) - temp_file_path = os.path.join(temp_dir, "firmware") with open(temp_file_path, "w") as temp_file: temp_file.write("Fake SWUpdate image") @@ -115,8 +105,6 @@ async def test_data(db): yield dict( device_rollout=device_rollout, - device_latest=device_latest, - device_pinned=device_pinned, firmware_latest=firmware_latest, rollout_default=rollout_default, ) diff --git a/goosebit/__init__.py b/goosebit/__init__.py index 8f840ad3..084a2a98 100644 --- a/goosebit/__init__.py +++ b/goosebit/__init__.py @@ -48,7 +48,7 @@ def root_redirect(request: Request): @app.get("/login", dependencies=[Depends(auto_redirect)], include_in_schema=False) async def login_ui(request: Request): - return templates.TemplateResponse("login.html", context={"request": request}) + return templates.TemplateResponse(request, "login.html") @app.post("/login", include_in_schema=False, dependencies=[Depends(authenticate_user)]) diff --git a/goosebit/api/devices.py b/goosebit/api/devices.py index 7047a4b8..2362b80f 100644 --- a/goosebit/api/devices.py +++ b/goosebit/api/devices.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from goosebit.auth import validate_user_permissions -from goosebit.models import Device, UpdateModeEnum +from goosebit.models import Device, Firmware, UpdateModeEnum from goosebit.permissions import Permissions from goosebit.updater.manager import delete_device, get_update_manager @@ -68,23 +68,18 @@ class UpdateDevicesModel(BaseModel): async def devices_update(request: Request, config: UpdateDevicesModel) -> dict: for uuid in config.devices: updater = await get_update_manager(uuid) - device = await updater.get_device() if config.firmware is not None: if config.firmware == "rollout": - device.update_mode = UpdateModeEnum.ROLLOUT - device.assigned_firmware_id = None + await updater.update_update(UpdateModeEnum.ROLLOUT, None) elif config.firmware == "latest": - device.update_mode = UpdateModeEnum.LATEST - device.assigned_firmware_id = None + await updater.update_update(UpdateModeEnum.LATEST, None) else: - device.update_mode = UpdateModeEnum.ASSIGNED - device.assigned_firmware_id = config.firmware + firmware = await Firmware.get_or_none(id=config.firmware) + await updater.update_update(UpdateModeEnum.ASSIGNED, firmware) if config.pinned: - device.update_mode = UpdateModeEnum.PINNED - device.assigned_firmware_id = None + await updater.update_update(UpdateModeEnum.PINNED, None) if config.name is not None: - device.name = config.name - await updater.save() + await updater.update_name(config.name) return {"success": True} diff --git a/goosebit/settings.py b/goosebit/settings.py index 7c3e5373..590cdae8 100644 --- a/goosebit/settings.py +++ b/goosebit/settings.py @@ -4,6 +4,7 @@ import yaml from argon2 import PasswordHasher +from joserfc.rfc7518.oct_key import OctKey from goosebit.permissions import Permissions @@ -13,7 +14,7 @@ UPDATES_DIR = BASE_DIR.joinpath("updates") DB_MIGRATIONS_LOC = BASE_DIR.joinpath("migrations") -SECRET = secrets.token_hex(16) +SECRET = OctKey.import_key(secrets.token_hex(16)) PWD_CXT = PasswordHasher() with open(BASE_DIR.joinpath("settings.yaml"), "r") as f: diff --git a/goosebit/ui/routes.py b/goosebit/ui/routes.py index 955fba89..5f003b39 100644 --- a/goosebit/ui/routes.py +++ b/goosebit/ui/routes.py @@ -31,7 +31,7 @@ async def ui_root(request: Request): ) async def firmware_ui(request: Request): return templates.TemplateResponse( - "firmware.html", context={"request": request, "title": "Firmware"} + request, "firmware.html", context={"title": "Firmware"} ) @@ -84,9 +84,7 @@ async def upload_update_remote(request: Request, url: str = Form(...)): dependencies=[Security(validate_user_permissions, scopes=[Permissions.HOME.READ])], ) async def home_ui(request: Request): - return templates.TemplateResponse( - "index.html", context={"request": request, "title": "Home"} - ) + return templates.TemplateResponse(request, "index.html", context={"title": "Home"}) @router.get( @@ -97,7 +95,7 @@ async def home_ui(request: Request): ) async def devices_ui(request: Request): return templates.TemplateResponse( - "devices.html", context={"request": request, "title": "Devices"} + request, "devices.html", context={"title": "Devices"} ) @@ -109,7 +107,7 @@ async def devices_ui(request: Request): ) async def rollouts_ui(request: Request): return templates.TemplateResponse( - "rollouts.html", context={"request": request, "title": "Rollouts"} + request, "rollouts.html", context={"title": "Rollouts"} ) @@ -121,5 +119,5 @@ async def rollouts_ui(request: Request): ) async def logs_ui(request: Request, dev_id: str): return templates.TemplateResponse( - "logs.html", context={"request": request, "title": "Log", "device": dev_id} + request, "logs.html", context={"title": "Log", "device": dev_id} ) diff --git a/goosebit/updater/controller/v1/routes.py b/goosebit/updater/controller/v1/routes.py index abe54171..ed4c54da 100644 --- a/goosebit/updater/controller/v1/routes.py +++ b/goosebit/updater/controller/v1/routes.py @@ -174,5 +174,4 @@ async def deployment_feedback( except KeyError: logging.warning(f"No details to update update log, device={dev_id}") - await updater.save() return {"id": str(action_id)} diff --git a/goosebit/updater/manager.py b/goosebit/updater/manager.py index 9439394a..a9e86e9e 100644 --- a/goosebit/updater/manager.py +++ b/goosebit/updater/manager.py @@ -30,7 +30,6 @@ class HandlingType(StrEnum): class UpdateManager(ABC): def __init__(self, dev_id: str): self.dev_id = dev_id - self.config_data = {} self.device = None self.force_update = False self.update_complete = False @@ -40,9 +39,6 @@ def __init__(self, dev_id: str): async def get_device(self) -> Device | None: return - async def save(self) -> None: - return - async def update_fw_version(self, version: str) -> None: return @@ -52,28 +48,20 @@ async def update_hardware(self, hardware: Hardware) -> None: async def update_device_state(self, state: UpdateStateEnum) -> None: return - async def update_last_seen(self, last_seen: int) -> None: + async def update_last_connection(self, last_seen: int, last_ip: str) -> None: return - async def update_last_ip(self, last_ip: str) -> None: + async def update_update(self, update_mode: UpdateModeEnum, firmware: Firmware): return - async def get_rollout(self) -> Optional[Rollout]: - return None + async def update_name(self, name: str): + return async def update_config_data(self, **kwargs): - model = kwargs.get("hw_model") or "default" - revision = kwargs.get("hw_revision") or "default" - hardware = (await Hardware.get_or_create(model=model, revision=revision))[0] - - await self.update_hardware(hardware) - - device = await self.get_device() - if device.last_state == UpdateStateEnum.UNKNOWN: - await self.update_device_state(UpdateStateEnum.REGISTERED) - await self.save() + return - self.config_data.update(kwargs) + async def get_rollout(self) -> Optional[Rollout]: + return None @asynccontextmanager async def subscribe_log(self, callback: Callable): @@ -133,31 +121,59 @@ async def get_device(self) -> Device: return self.device - async def save(self) -> None: - await self.device.save() - async def update_fw_version(self, version: str) -> None: device = await self.get_device() device.fw_version = version + await device.save(update_fields=["fw_version"]) async def update_hardware(self, hardware: Hardware) -> None: device = await self.get_device() device.hardware = hardware + await device.save(update_fields=["hardware"]) async def update_device_state(self, state: UpdateStateEnum) -> None: device = await self.get_device() device.last_state = state + await device.save(update_fields=["last_state"]) - async def update_last_seen(self, last_seen: int) -> None: + async def update_last_connection(self, last_seen: int, last_ip: str) -> None: device = await self.get_device() device.last_seen = last_seen - - async def update_last_ip(self, last_ip: str) -> None: - device = await self.get_device() if ":" in last_ip: device.last_ipv6 = last_ip + await device.save(update_fields=["last_seen", "last_ipv6"]) else: device.last_ip = last_ip + await device.save(update_fields=["last_seen", "last_ip"]) + + async def update_update(self, update_mode: UpdateModeEnum, firmware: Firmware): + device = await self.get_device() + device.assigned_firmware = firmware + device.update_mode = update_mode + await device.save(update_fields=["assigned_firmware_id", "update_mode"]) + + async def update_name(self, name: str): + device = await self.get_device() + device.name = name + await device.save(update_fields=["name"]) + + async def update_config_data(self, **kwargs): + model = kwargs.get("hw_model") or "default" + revision = kwargs.get("hw_revision") or "default" + hardware = (await Hardware.get_or_create(model=model, revision=revision))[0] + device = await self.get_device() + modified = False + + if device.hardware != hardware: + device.hardware = hardware + modified = True + + if device.last_state == UpdateStateEnum.UNKNOWN: + device.last_state = UpdateStateEnum.REGISTERED + modified = True + + if modified: + await device.save(update_fields=["hardware_id", "last_state"]) async def get_rollout(self) -> Optional[Rollout]: device = await self.get_device() @@ -225,25 +241,32 @@ async def update_log(self, log_data: str) -> None: if log_data is None: return device = await self.get_device() + + if device.last_log is None: + device.last_log = "" + matches = re.findall(r"Downloaded (\d+)%", log_data) if matches: device.progress = matches[-1] - if device.last_log is None: - device.last_log = "" + if log_data.startswith("Installing Update Chunk Artifacts."): - await self.clear_log() - if log_data == "All Chunks Installed.": + # clear log + device.last_log = "" + await self.publish_log(None) + elif log_data == "All Chunks Installed.": self.force_update = False self.update_complete = True + if not log_data == "Skipped Update.": device.last_log += f"{log_data}\n" await self.publish_log(f"{log_data}\n") - await device.save() + + await device.save(update_fields=["progress", "last_log"]) async def clear_log(self) -> None: device = await self.get_device() device.last_log = "" - await device.save() + await device.save(update_fields=["last_log"]) await self.publish_log(None) @@ -258,17 +281,10 @@ async def get_update_manager(dev_id: str) -> UpdateManager: return device_managers[dev_id] -def get_update_manager_sync(dev_id: str) -> UpdateManager: - global device_managers - if device_managers.get(dev_id) is None: - device_managers[dev_id] = DeviceUpdateManager(dev_id) - return device_managers[dev_id] - - async def delete_device(dev_id: str) -> None: global device_managers try: - updater = get_update_manager_sync(dev_id) + updater = await get_update_manager(dev_id) await (await updater.get_device()).delete() del device_managers[dev_id] except KeyError as e: diff --git a/goosebit/updater/routes.py b/goosebit/updater/routes.py index fc3cf7df..9d8eca20 100644 --- a/goosebit/updater/routes.py +++ b/goosebit/updater/routes.py @@ -6,7 +6,7 @@ from goosebit.settings import TENANT from . import controller -from .manager import get_update_manager_sync +from .manager import get_update_manager async def verify_tenant(tenant: str): @@ -17,10 +17,8 @@ async def verify_tenant(tenant: str): async def log_last_connection(request: Request, dev_id: str): host = request.client.host - updater = get_update_manager_sync(dev_id) - await updater.update_last_ip(host) - await updater.update_last_seen(round(time.time())) - await updater.save() + updater = await get_update_manager(dev_id) + await updater.update_last_connection(round(time.time()), host) router = APIRouter( diff --git a/tests/updater/controller/v1/test_routes.py b/tests/updater/controller/v1/test_routes.py index 0d083f1d..6d8f37d9 100644 --- a/tests/updater/controller/v1/test_routes.py +++ b/tests/updater/controller/v1/test_routes.py @@ -1,11 +1,31 @@ import pytest -from goosebit.models import Firmware, Hardware, UpdateStateEnum +from goosebit.models import Firmware, Hardware from goosebit.updater.manager import get_update_manager UUID = "221326d9-7873-418e-960c-c074026a3b7c" +async def _api_device_update(async_client, device, update_attribute, update_value): + response = await async_client.post( + f"/api/devices/update", + json={"devices": [f"{device.uuid}"], update_attribute: update_value}, + ) + assert response.status_code == 200 + + +async def _api_devices_get(async_client): + response = await async_client.get("/api/devices/all") + assert response.status_code == 200 + return response.json() + + +async def _api_rollouts_get(async_client): + response = await async_client.get("/api/rollouts/all") + assert response.status_code == 200 + return response.json() + + async def _poll_first_time(async_client): response = await async_client.get(f"/DEFAULT/controller/v1/{UUID}") assert response.status_code == 200 @@ -102,6 +122,9 @@ async def test_register_device(async_client, test_data): await _poll(async_client, UUID, None, False) + devices = await _api_devices_get(async_client) + assert devices[0]["state"] == "Registered" + @pytest.mark.asyncio async def test_rollout_full(async_client, test_data): @@ -115,17 +138,19 @@ async def test_rollout_full(async_client, test_data): # confirm installation start (in reality: several of similar posts) await _feedback(async_client, device.uuid, firmware, "none", "proceeding") - await device.refresh_from_db() - assert device.last_state == UpdateStateEnum.RUNNING + devices = await _api_devices_get(async_client) + assert devices[0]["state"] == "Running" # report finished installation await _feedback(async_client, device.uuid, firmware, "success", "closed") - await device.refresh_from_db() - assert device.last_state == UpdateStateEnum.FINISHED + devices = await _api_devices_get(async_client) + assert devices[0]["state"] == "Finished" + assert devices[0]["fw"] == firmware.version await rollout.refresh_from_db() - assert rollout.success_count == 1 - assert rollout.failure_count == 0 + rollouts = await _api_rollouts_get(async_client) + assert rollouts[0]["success_count"] == 1 + assert rollouts[0]["failure_count"] == 0 @pytest.mark.asyncio @@ -139,8 +164,8 @@ async def test_rollout_signalling_download_failure(async_client, test_data): # confirm installation start (in reality: several of similar posts) await _feedback(async_client, device.uuid, firmware, "none", "proceeding") - await device.refresh_from_db() - assert device.last_state == UpdateStateEnum.RUNNING + devices = await _api_devices_get(async_client) + assert devices[0]["state"] == "Running" # HEAD /api/download/1 HTTP/1.1 (reason not clear) response = await async_client.head(firmware_url) @@ -153,21 +178,38 @@ async def test_rollout_signalling_download_failure(async_client, test_data): # report failure await _feedback(async_client, device.uuid, firmware, "failure", "closed") - await device.refresh_from_db() - assert device.last_state == UpdateStateEnum.ERROR + devices = await _api_devices_get(async_client) + assert devices[0]["state"] == "Error" @pytest.mark.asyncio async def test_latest(async_client, test_data): - device = test_data["device_latest"] + device = test_data["device_rollout"] firmware = test_data["firmware_latest"] - await _poll(async_client, device.uuid, firmware) + await _api_device_update(async_client, device, "firmware", "latest") + + deployment_base = await _poll(async_client, device.uuid, firmware) + + await _retrieve_firmware_url(async_client, deployment_base, firmware) + + # confirm installation start (in reality: several of similar posts) + await _feedback(async_client, device.uuid, firmware, "none", "proceeding") + devices = await _api_devices_get(async_client) + assert devices[0]["state"] == "Running" + + # report finished installation + await _feedback(async_client, device.uuid, firmware, "success", "closed") + devices = await _api_devices_get(async_client) + assert devices[0]["state"] == "Finished" + assert devices[0]["fw"] == firmware.version @pytest.mark.asyncio async def test_latest_with_no_firmware_available(async_client, test_data): - device = test_data["device_latest"] + device = test_data["device_rollout"] + + await _api_device_update(async_client, device, "firmware", "latest") fake_hardware = await Hardware.create(model="does-not-exist", revision="default") device.hardware_id = fake_hardware.id @@ -178,15 +220,20 @@ async def test_latest_with_no_firmware_available(async_client, test_data): @pytest.mark.asyncio async def test_pinned(async_client, test_data): - device = test_data["device_pinned"] + device = test_data["device_rollout"] + + await _api_device_update(async_client, device, "pinned", True) await _poll(async_client, device.uuid, None, False) @pytest.mark.asyncio async def test_up_to_date(async_client, test_data): - device = test_data["device_latest"] + device = test_data["device_rollout"] firmware = test_data["firmware_latest"] + + await _api_device_update(async_client, device, "firmware", "latest") + manager = await get_update_manager(dev_id=device.uuid) await manager.update_fw_version(firmware.version)