Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Immediate device updates & limit attributes to be updated #34

Merged
20 changes: 4 additions & 16 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ async def async_client(test_app):
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
login_data = {"username": "admin@goosebit.local", "password": "admin"}
response = await client.post("/login", data=login_data, follow_redirects=True)
assert response.status_code == 200

yield client


Expand Down Expand Up @@ -68,20 +72,6 @@ async def test_data(db):
hardware=compatibility,
)

device_latest = await Device.create(
uuid="device2",
last_state=UpdateStateEnum.REGISTERED,
update_mode=UpdateModeEnum.LATEST,
hardware=compatibility,
)

device_pinned = await Device.create(
uuid="device3",
last_state=UpdateStateEnum.REGISTERED,
update_mode=UpdateModeEnum.PINNED,
hardware=compatibility,
)

temp_file_path = os.path.join(temp_dir, "firmware")
with open(temp_file_path, "w") as temp_file:
temp_file.write("Fake SWUpdate image")
Expand Down Expand Up @@ -115,8 +105,6 @@ async def test_data(db):

yield dict(
device_rollout=device_rollout,
device_latest=device_latest,
device_pinned=device_pinned,
firmware_latest=firmware_latest,
rollout_default=rollout_default,
)
2 changes: 1 addition & 1 deletion goosebit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def root_redirect(request: Request):

@app.get("/login", dependencies=[Depends(auto_redirect)], include_in_schema=False)
async def login_ui(request: Request):
return templates.TemplateResponse("login.html", context={"request": request})
return templates.TemplateResponse(request, "login.html")


@app.post("/login", include_in_schema=False, dependencies=[Depends(authenticate_user)])
Expand Down
19 changes: 7 additions & 12 deletions goosebit/api/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import BaseModel

from goosebit.auth import validate_user_permissions
from goosebit.models import Device, UpdateModeEnum
from goosebit.models import Device, Firmware, UpdateModeEnum
from goosebit.permissions import Permissions
from goosebit.updater.manager import delete_device, get_update_manager

Expand Down Expand Up @@ -68,23 +68,18 @@ class UpdateDevicesModel(BaseModel):
async def devices_update(request: Request, config: UpdateDevicesModel) -> dict:
for uuid in config.devices:
updater = await get_update_manager(uuid)
device = await updater.get_device()
if config.firmware is not None:
if config.firmware == "rollout":
device.update_mode = UpdateModeEnum.ROLLOUT
device.assigned_firmware_id = None
await updater.update_update(UpdateModeEnum.ROLLOUT, None)
elif config.firmware == "latest":
device.update_mode = UpdateModeEnum.LATEST
device.assigned_firmware_id = None
await updater.update_update(UpdateModeEnum.LATEST, None)
else:
device.update_mode = UpdateModeEnum.ASSIGNED
device.assigned_firmware_id = config.firmware
firmware = await Firmware.get_or_none(id=config.firmware)
await updater.update_update(UpdateModeEnum.ASSIGNED, firmware)
if config.pinned:
device.update_mode = UpdateModeEnum.PINNED
device.assigned_firmware_id = None
await updater.update_update(UpdateModeEnum.PINNED, None)
if config.name is not None:
device.name = config.name
await updater.save()
await updater.update_name(config.name)
return {"success": True}


Expand Down
3 changes: 2 additions & 1 deletion goosebit/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import yaml
from argon2 import PasswordHasher
from joserfc.rfc7518.oct_key import OctKey

from goosebit.permissions import Permissions

Expand All @@ -13,7 +14,7 @@
UPDATES_DIR = BASE_DIR.joinpath("updates")
DB_MIGRATIONS_LOC = BASE_DIR.joinpath("migrations")

SECRET = secrets.token_hex(16)
SECRET = OctKey.import_key(secrets.token_hex(16))
PWD_CXT = PasswordHasher()

with open(BASE_DIR.joinpath("settings.yaml"), "r") as f:
Expand Down
12 changes: 5 additions & 7 deletions goosebit/ui/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def ui_root(request: Request):
)
async def firmware_ui(request: Request):
return templates.TemplateResponse(
"firmware.html", context={"request": request, "title": "Firmware"}
request, "firmware.html", context={"title": "Firmware"}
)


Expand Down Expand Up @@ -84,9 +84,7 @@ async def upload_update_remote(request: Request, url: str = Form(...)):
dependencies=[Security(validate_user_permissions, scopes=[Permissions.HOME.READ])],
)
async def home_ui(request: Request):
return templates.TemplateResponse(
"index.html", context={"request": request, "title": "Home"}
)
return templates.TemplateResponse(request, "index.html", context={"title": "Home"})


@router.get(
Expand All @@ -97,7 +95,7 @@ async def home_ui(request: Request):
)
async def devices_ui(request: Request):
return templates.TemplateResponse(
"devices.html", context={"request": request, "title": "Devices"}
request, "devices.html", context={"title": "Devices"}
)


Expand All @@ -109,7 +107,7 @@ async def devices_ui(request: Request):
)
async def rollouts_ui(request: Request):
return templates.TemplateResponse(
"rollouts.html", context={"request": request, "title": "Rollouts"}
request, "rollouts.html", context={"title": "Rollouts"}
)


Expand All @@ -121,5 +119,5 @@ async def rollouts_ui(request: Request):
)
async def logs_ui(request: Request, dev_id: str):
return templates.TemplateResponse(
"logs.html", context={"request": request, "title": "Log", "device": dev_id}
request, "logs.html", context={"title": "Log", "device": dev_id}
)
1 change: 0 additions & 1 deletion goosebit/updater/controller/v1/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,5 +174,4 @@ async def deployment_feedback(
except KeyError:
logging.warning(f"No details to update update log, device={dev_id}")

await updater.save()
return {"id": str(action_id)}
96 changes: 56 additions & 40 deletions goosebit/updater/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class HandlingType(StrEnum):
class UpdateManager(ABC):
def __init__(self, dev_id: str):
self.dev_id = dev_id
self.config_data = {}
self.device = None
self.force_update = False
self.update_complete = False
Expand All @@ -40,9 +39,6 @@ def __init__(self, dev_id: str):
async def get_device(self) -> Device | None:
return

async def save(self) -> None:
return

async def update_fw_version(self, version: str) -> None:
return

Expand All @@ -52,28 +48,20 @@ async def update_hardware(self, hardware: Hardware) -> None:
async def update_device_state(self, state: UpdateStateEnum) -> None:
return

async def update_last_seen(self, last_seen: int) -> None:
async def update_last_connection(self, last_seen: int, last_ip: str) -> None:
return

async def update_last_ip(self, last_ip: str) -> None:
async def update_update(self, update_mode: UpdateModeEnum, firmware: Firmware):
return

async def get_rollout(self) -> Optional[Rollout]:
return None
async def update_name(self, name: str):
return

async def update_config_data(self, **kwargs):
model = kwargs.get("hw_model") or "default"
revision = kwargs.get("hw_revision") or "default"
hardware = (await Hardware.get_or_create(model=model, revision=revision))[0]

await self.update_hardware(hardware)

device = await self.get_device()
if device.last_state == UpdateStateEnum.UNKNOWN:
await self.update_device_state(UpdateStateEnum.REGISTERED)
await self.save()
return

self.config_data.update(kwargs)
async def get_rollout(self) -> Optional[Rollout]:
return None

@asynccontextmanager
async def subscribe_log(self, callback: Callable):
Expand Down Expand Up @@ -133,31 +121,59 @@ async def get_device(self) -> Device:

return self.device

async def save(self) -> None:
await self.device.save()

async def update_fw_version(self, version: str) -> None:
device = await self.get_device()
device.fw_version = version
await device.save(update_fields=["fw_version"])

async def update_hardware(self, hardware: Hardware) -> None:
device = await self.get_device()
device.hardware = hardware
await device.save(update_fields=["hardware"])

async def update_device_state(self, state: UpdateStateEnum) -> None:
device = await self.get_device()
device.last_state = state
await device.save(update_fields=["last_state"])

async def update_last_seen(self, last_seen: int) -> None:
async def update_last_connection(self, last_seen: int, last_ip: str) -> None:
device = await self.get_device()
device.last_seen = last_seen

async def update_last_ip(self, last_ip: str) -> None:
device = await self.get_device()
if ":" in last_ip:
device.last_ipv6 = last_ip
await device.save(update_fields=["last_seen", "last_ipv6"])
else:
device.last_ip = last_ip
await device.save(update_fields=["last_seen", "last_ip"])

async def update_update(self, update_mode: UpdateModeEnum, firmware: Firmware):
device = await self.get_device()
device.assigned_firmware = firmware
device.update_mode = update_mode
await device.save(update_fields=["assigned_firmware_id", "update_mode"])
tsagadar marked this conversation as resolved.
Show resolved Hide resolved

async def update_name(self, name: str):
device = await self.get_device()
device.name = name
await device.save(update_fields=["name"])

async def update_config_data(self, **kwargs):
model = kwargs.get("hw_model") or "default"
revision = kwargs.get("hw_revision") or "default"
hardware = (await Hardware.get_or_create(model=model, revision=revision))[0]
device = await self.get_device()
modified = False

if device.hardware != hardware:
device.hardware = hardware
modified = True

if device.last_state == UpdateStateEnum.UNKNOWN:
device.last_state = UpdateStateEnum.REGISTERED
modified = True

if modified:
await device.save(update_fields=["hardware_id", "last_state"])

async def get_rollout(self) -> Optional[Rollout]:
device = await self.get_device()
Expand Down Expand Up @@ -225,25 +241,32 @@ async def update_log(self, log_data: str) -> None:
if log_data is None:
return
device = await self.get_device()

if device.last_log is None:
device.last_log = ""

matches = re.findall(r"Downloaded (\d+)%", log_data)
if matches:
device.progress = matches[-1]
if device.last_log is None:
device.last_log = ""

if log_data.startswith("Installing Update Chunk Artifacts."):
await self.clear_log()
if log_data == "All Chunks Installed.":
# clear log
device.last_log = ""
await self.publish_log(None)
elif log_data == "All Chunks Installed.":
self.force_update = False
self.update_complete = True

if not log_data == "Skipped Update.":
device.last_log += f"{log_data}\n"
await self.publish_log(f"{log_data}\n")
await device.save()

await device.save(update_fields=["progress", "last_log"])

async def clear_log(self) -> None:
device = await self.get_device()
device.last_log = ""
await device.save()
await device.save(update_fields=["last_log"])
await self.publish_log(None)


Expand All @@ -258,17 +281,10 @@ async def get_update_manager(dev_id: str) -> UpdateManager:
return device_managers[dev_id]


def get_update_manager_sync(dev_id: str) -> UpdateManager:
global device_managers
if device_managers.get(dev_id) is None:
device_managers[dev_id] = DeviceUpdateManager(dev_id)
return device_managers[dev_id]


async def delete_device(dev_id: str) -> None:
global device_managers
try:
updater = get_update_manager_sync(dev_id)
updater = await get_update_manager(dev_id)
await (await updater.get_device()).delete()
del device_managers[dev_id]
except KeyError as e:
Expand Down
8 changes: 3 additions & 5 deletions goosebit/updater/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from goosebit.settings import TENANT

from . import controller
from .manager import get_update_manager_sync
from .manager import get_update_manager


async def verify_tenant(tenant: str):
Expand All @@ -17,10 +17,8 @@ async def verify_tenant(tenant: str):

async def log_last_connection(request: Request, dev_id: str):
host = request.client.host
updater = get_update_manager_sync(dev_id)
await updater.update_last_ip(host)
await updater.update_last_seen(round(time.time()))
await updater.save()
updater = await get_update_manager(dev_id)
await updater.update_last_connection(round(time.time()), host)


router = APIRouter(
Expand Down
Loading