Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/SK-1462 | Add Combiner DTO #832

Open
wants to merge 2 commits into
base: feature/SK-1434
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions fedn/network/api/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def get_combiners(self):
:return: list of combiners objects
:rtype: list(:class:`fedn.network.combiner.interfaces.CombinerInterface`)
"""
data = self.combiner_store.list(limit=0, skip=0, sort_key=None)
result = self.combiner_store.select(limit=0, skip=0, sort_key=None)
combiners = []
for c in data["result"]:
name = c["name"].upper()
for combiner in result:
name = combiner.name.upper()
# General certificate handling, same for all combiners.
if os.environ.get("FEDN_GRPC_CERT_PATH"):
with open(os.environ.get("FEDN_GRPC_CERT_PATH"), "rb") as f:
Expand All @@ -63,7 +63,9 @@ def get_combiners(self):
cert = f.read()
else:
cert = None
combiners.append(CombinerInterface(c["parent"], c["name"], c["address"], c["fqdn"], c["port"], certificate=cert, ip=c["ip"]))
combiners.append(
CombinerInterface(combiner.parent, combiner.name, combiner.address, combiner.fqdn, combiner.port, certificate=cert, ip=combiner.ip)
)

return combiners

Expand Down
14 changes: 9 additions & 5 deletions fedn/network/api/v1/combiner_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def get_combiners():

kwargs = request.args.to_dict()

response = combiner_store.list(limit, skip, sort_key, sort_order, **kwargs)
combiners = combiner_store.select(limit, skip, sort_key, sort_order, **kwargs)
count = combiner_store.count(**kwargs)
response = {"count": count, "result": [combiner.to_dict() for combiner in combiners]}

return jsonify(response), 200
except Exception as e:
Expand Down Expand Up @@ -184,7 +186,9 @@ def list_combiners():

kwargs = get_post_data_to_kwargs(request)

response = combiner_store.list(limit, skip, sort_key, sort_order, **kwargs)
combiners = combiner_store.select(limit, skip, sort_key, sort_order, **kwargs)
count = combiner_store.count(**kwargs)
response = {"count": count, "result": [combiner.to_dict() for combiner in combiners]}

return jsonify(response), 200
except Exception as e:
Expand Down Expand Up @@ -327,8 +331,8 @@ def get_combiner(id: str):
try:
response = combiner_store.get(id)
if response is None:
return jsonify({"message": f"Entity with id: {id} not found"}), 404
return jsonify(response), 200
return jsonify({"message": f"Entity with id: {id} not found"}), 404
return jsonify(response.to_dict()), 200
except Exception as e:
logger.error(f"An unexpected error occurred: {e}")
return jsonify({"message": "An unexpected error occurred"}), 500
Expand Down Expand Up @@ -369,7 +373,7 @@ def delete_combiner(id: str):
try:
result: bool = combiner_store.delete(id)
if not result:
return jsonify({"message": f"Entity with id: {id} not found"}), 404
return jsonify({"message": f"Entity with id: {id} not found"}), 404
msg = "Combiner deleted" if result else "Combiner not deleted"
return jsonify({"message": msg}), 200
except Exception as e:
Expand Down
23 changes: 11 additions & 12 deletions fedn/network/combiner/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from fedn.network.combiner.shared import client_store, combiner_store, prediction_store, repository, round_store, status_store, validation_store
from fedn.network.grpc.server import Server, ServerConfig
from fedn.network.storage.statestore.stores.dto import ClientDTO
from fedn.network.storage.statestore.stores.dto.combiner import CombinerDTO

VALID_NAME_REGEX = "^[a-zA-Z0-9_-]*$"

Expand Down Expand Up @@ -109,19 +110,17 @@ def __init__(self, config):

self.round_store = round_store

# Add combiner to statestore
interface_config = {
"port": config["port"],
"fqdn": config["fqdn"],
"name": config["name"],
"address": config["host"],
"parent": "localhost",
"ip": "",
"updated_at": str(datetime.now()),
}
# Check if combiner already exists in statestore
if combiner_store.get(config["name"]) is None:
combiner_store.add(interface_config)
if combiner_store.get_by_name(config["name"]) is None:
new_combiner = CombinerDTO()
new_combiner.port = config["port"]
new_combiner.fqdn = config["fqdn"]
new_combiner.name = config["name"]
new_combiner.address = config["host"]
new_combiner.parent = "localhost"
new_combiner.ip = ""
new_combiner.updated_at = str(datetime.now())
combiner_store.add(new_combiner)

# Fetch all clients previously connected to the combiner
# If a client and a combiner goes down at the same time,
Expand Down
254 changes: 72 additions & 182 deletions fedn/network/storage/statestore/stores/combiner_store.py
Original file line number Diff line number Diff line change
@@ -1,217 +1,107 @@
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from abc import abstractmethod
from typing import Dict, List

import pymongo
from bson import ObjectId
from pymongo.database import Database
from sqlalchemy import String, func, or_, select
from sqlalchemy.orm import Mapped, mapped_column

from fedn.network.storage.statestore.stores.sql.shared import MyAbstractBase
from fedn.network.storage.statestore.stores.store import MongoDBStore, SQLStore, Store

from .shared import from_document


class Combiner:
def __init__(
self,
id: str,
name: str,
address: str,
certificate: str,
config: dict,
fqdn: str,
ip: str,
key: str,
parent: dict,
port: int,
status: str,
updated_at: str,
):
self.id = id
self.name = name
self.address = address
self.certificate = certificate
self.config = config
self.fqdn = fqdn
self.ip = ip
self.key = key
self.parent = parent
self.port = port
self.status = status
self.updated_at = updated_at


class CombinerStore(Store[Combiner]):
pass


class MongoDBCombinerStore(MongoDBStore[Combiner]):
def __init__(self, database: Database, collection: str):
super().__init__(database, collection)

def get(self, id: str) -> Combiner:
"""Get an entity by id
param id: The id of the entity
type: str
description: The id of the entity, can be either the id or the name (property)
return: The entity
"""
if ObjectId.is_valid(id):
id_obj = ObjectId(id)
document = self.database[self.collection].find_one({"_id": id_obj})
else:
document = self.database[self.collection].find_one({"name": id})

if document is None:
return None
from fedn.network.storage.statestore.stores.dto import CombinerDTO
from fedn.network.storage.statestore.stores.new_store import MongoDBStore, SQLStore, Store, from_document
from fedn.network.storage.statestore.stores.sql.shared import CombinerModel, from_orm_model

return from_document(document)

def update(self, id: str, item: Combiner) -> bool:
raise NotImplementedError("Update not implemented for CombinerStore")
class CombinerStore(Store[CombinerDTO]):
@abstractmethod
def get_by_name(name: str) -> CombinerDTO:
pass

def add(self, item: Combiner) -> Tuple[bool, Any]:
return super().add(item)

def delete(self, id: str) -> bool:
if ObjectId.is_valid(id):
kwargs = {"_id": ObjectId(id)}
else:
return False
class MongoDBCombinerStore(CombinerStore, MongoDBStore):
def __init__(self, database: Database, collection: str):
super().__init__(database, collection, "id")

document = self.database[self.collection].find_one(kwargs)
def get(self, id: str) -> CombinerDTO:
obj = self.mongo_get(id)
if obj is None:
return None
return self._dto_from_document(obj)

if document is None:
return False

return super().delete(document["_id"])

def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs) -> Dict[int, List[Combiner]]:
"""List entities
param limit: The maximum number of entities to return
type: int
param skip: The number of entities to skip
type: int
param sort_key: The key to sort by
type: str
param sort_order: The order to sort by
type: pymongo.DESCENDING | pymongo.ASCENDING
param kwargs: Additional query parameters
type: dict
example: {"key": "models"}
return: A dictionary with the count and the result
"""
response = super().list(limit, skip, sort_key or "updated_at", sort_order, **kwargs)

return response
def update(self, item: CombinerDTO):
raise NotImplementedError("Update not implemented for CombinerStore")

def count(self, **kwargs) -> int:
return super().count(**kwargs)
def add(self, item: CombinerDTO):
item_dict = item.to_db(exclude_unset=False)
success, obj = self.mongo_add(item_dict)
if success:
return success, self._dto_from_document(obj)
return success, obj

def delete(self, id: str) -> bool:
return self.mongo_delete(id)

class CombinerModel(MyAbstractBase):
__tablename__ = "combiners"
def select(self, limit: int = 0, skip: int = 0, sort_key: str = None, sort_order=pymongo.DESCENDING, **filter_kwargs) -> List[CombinerDTO]:
entities = self.mongo_select(limit, skip, sort_key, sort_order, **filter_kwargs)
result = []
for entity in entities:
result.append(self._dto_from_document(entity))
return result

address: Mapped[str] = mapped_column(String(255))
fqdn: Mapped[Optional[str]] = mapped_column(String(255))
ip: Mapped[Optional[str]] = mapped_column(String(255))
name: Mapped[str] = mapped_column(String(255))
parent: Mapped[Optional[str]] = mapped_column(String(255))
port: Mapped[int]
updated_at: Mapped[datetime] = mapped_column(default=datetime.now())
def count(self, **kwargs) -> int:
return self.mongo_count(**kwargs)

def get_by_name(self, name: str) -> CombinerDTO:
document = self.database[self.collection].find_one({"name": name})
if document is None:
return None
return self._dto_from_document(document)

def from_row(row: CombinerModel) -> Combiner:
return {
"id": row.id,
"committed_at": row.committed_at,
"address": row.address,
"ip": row.ip,
"name": row.name,
"parent": row.parent,
"fqdn": row.fqdn,
"port": row.port,
"updated_at": row.updated_at,
}
def _dto_from_document(self, document: Dict) -> CombinerDTO:
return CombinerDTO().populate_with(from_document(document))


class SQLCombinerStore(CombinerStore, SQLStore[Combiner]):
class SQLCombinerStore(CombinerStore, SQLStore[CombinerDTO]):
def __init__(self, Session):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing the get_by_name method? Is this solved in another way? If so great, we should remove get_by_name from the interface...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be there? It is a new method, so added it to both mongo and sql. I believe it is required but could be discussed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Yes, I think in the future we can think of removing it... Name should not be used as key I think...

super().__init__(Session)
super().__init__(Session, CombinerModel)

def get(self, id: str) -> Combiner:
def get(self, id: str) -> CombinerDTO:
with self.Session() as session:
stmt = select(CombinerModel).where(or_(CombinerModel.id == id, CombinerModel.name == id))
item = session.scalars(stmt).first()
if item is None:
entity = self.sql_get(session, id)
if entity is None:
return None
return from_row(item)
return self._dto_from_orm_model(entity)

def update(self, id, item):
def update(self, item):
raise NotImplementedError

def add(self, item):
with self.Session() as session:
entity = CombinerModel(
address=item["address"],
fqdn=item["fqdn"],
ip=item["ip"],
name=item["name"],
parent=item["parent"],
port=item["port"],
)
session.add(entity)
session.commit()
return True, from_row(entity)
item_dict = item.to_db(exclude_unset=False)
item_dict = self._to_orm_dict(item_dict)
entity = CombinerModel(**item_dict)
success, obj = self.sql_add(session, entity)
if success:
return success, self._dto_from_orm_model(obj)
return success, obj

def delete(self, id: str) -> bool:
with self.Session() as session:
stmt = select(CombinerModel).where(CombinerModel.id == id)
item = session.scalars(stmt).first()
if item is None:
return False
session.delete(item)
return True

def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs):
with self.Session() as session:
stmt = select(CombinerModel)

for key, value in kwargs.items():
stmt = stmt.where(getattr(CombinerModel, key) == value)

_sort_order: str = "DESC" if sort_order == pymongo.DESCENDING else "ASC"
_sort_key: str = sort_key or "committed_at"

if _sort_key in CombinerModel.__table__.columns:
sort_obj = CombinerModel.__table__.columns.get(_sort_key) if _sort_order == "ASC" else CombinerModel.__table__.columns.get(_sort_key).desc()

stmt = stmt.order_by(sort_obj)

if limit:
stmt = stmt.offset(skip or 0).limit(limit)
elif skip:
stmt = stmt.offset(skip)
return self.sql_delete(id)

items = session.scalars(stmt).all()

result = []
for i in items:
result.append(from_row(i))

count = session.scalar(select(func.count()).select_from(CombinerModel))

return {"count": count, "result": result}
def select(self, limit=0, skip=0, sort_key=None, sort_order=pymongo.DESCENDING, **kwargs):
with self.Session() as session:
entities = self.sql_select(session, limit, skip, sort_key, sort_order, **kwargs)
return [self._dto_from_orm_model(item) for item in entities]

def count(self, **kwargs):
with self.Session() as session:
stmt = select(func.count()).select_from(CombinerModel)
return self.sql_count(**kwargs)

for key, value in kwargs.items():
stmt = stmt.where(getattr(CombinerModel, key) == value)
def get_by_name(self, name: str) -> CombinerDTO:
with self.Session() as session:
entity = session.query(CombinerModel).filter(CombinerModel.name == name).first()
if entity is None:
return None
return self._dto_from_orm_model(entity)

count = session.scalar(stmt)
def _to_orm_dict(self, item_dict: Dict) -> Dict:
return item_dict

return count
def _dto_from_orm_model(self, item: CombinerModel) -> CombinerDTO:
return CombinerDTO().populate_with(from_orm_model(item, CombinerModel))
Loading
Loading