Skip to content

Commit

Permalink
add start
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede committed Feb 28, 2025
1 parent 51b4671 commit 0e778e6
Showing 1 changed file with 111 additions and 13 deletions.
124 changes: 111 additions & 13 deletions fedn/cli/session_cmd.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import click
import requests

from .main import main
from .shared import CONTROLLER_DEFAULTS, get_response, print_response
from .shared import CONTROLLER_DEFAULTS, get_api_url, get_response, get_token, print_response


@main.group("session")
@click.pass_context
def session_cmd(ctx):
""":param ctx:"""
"""Session commands."""
pass


Expand All @@ -19,12 +20,7 @@ def session_cmd(ctx):
@session_cmd.command("list")
@click.pass_context
def list_sessions(ctx, protocol: str, host: str, port: str, token: str = None, n_max: int = None):
"""Return:
------
- count: number of sessions
- result: list of sessions
"""
"""List sessions."""
headers = {}

if n_max:
Expand All @@ -42,10 +38,112 @@ def list_sessions(ctx, protocol: str, host: str, port: str, token: str = None, n
@session_cmd.command("get")
@click.pass_context
def get_session(ctx, protocol: str, host: str, port: str, token: str = None, id: str = None):
"""Return:
------
- result: session with given session id
"""
"""Get session by id."""
response = get_response(protocol=protocol, host=host, port=port, endpoint=f"sessions/{id}", token=token, headers={}, usr_api=False, usr_token=False)
print_response(response, "session", id)


@click.option("-p", "--protocol", required=False, default=CONTROLLER_DEFAULTS["protocol"], help="Communication protocol of controller (api)")
@click.option("-H", "--host", required=False, default=CONTROLLER_DEFAULTS["host"], help="Hostname of controller (api)")
@click.option("-P", "--port", required=False, default=CONTROLLER_DEFAULTS["port"], help="Port of controller (api)")
@click.option("-t", "--token", required=False, help="Authentication token")
@click.option("-n", "--name", required=False, help="Name of the session")
@click.option("-a", "--aggregator", required=False, default="fedavg", help="The aggregator plugin to use")
@click.option("-ak", "--aggregator_kwargs", required=False, type=dict, help="Aggregator keyword arguments")
@click.option("-m", "--model_id", required=False, help="The id of the initial model")
@click.option("-rt", "--round_timeout", required=False, default=180, type=int, help="The round timeout to use in seconds")
@click.option("-r", "--rounds", required=False, default=5, type=int, help="The number of rounds to perform")
@click.option("-rb", "--round_buffer_size", required=False, default=-1, type=int, help="The round buffer size to use")
@click.option("-d", "--delete_models", required=False, default=True, type=bool, help="Whether to delete models after each round at combiner (save storage)")
@click.option("-v", "--validate", required=False, default=True, type=bool, help="Whether to validate the model after each round")
@click.option("-hp", "--helper", required=False, help="The helper type to use")
@click.option("-mc", "--min_clients", required=False, default=1, type=int, help="The minimum number of clients required")
@click.option("-rc", "--requested_clients", required=False, default=8, type=int, help="The requested number of clients")
@session_cmd.command("start")
@click.pass_context
def start_session(
ctx,
protocol: str,
host: str,
port: str,
token: str,
name: str = None,
aggregator: str = "fedavg",
aggregator_kwargs: dict = None,
model_id: str = None,
round_timeout: int = 180,
rounds: int = 5,
round_buffer_size: int = -1,
delete_models: bool = True,
validate: bool = True,
helper: str = None,
min_clients: int = 1,
requested_clients: int = 8,
):
"""Start a new session."""
headers = {}
_token = get_token(token=token, usr_token=False)
if _token:
headers = {"Authorization": _token}

if model_id is None:
url = get_api_url(protocol, host, port, "models/active", usr_api=False)
response = requests.get(url, headers=headers)
if response.status_code == 200:
model_id = response.json()
else:
click.secho(f"Failed to get active model: {response.json()}", fg="red")
return

if helper is None:
url = get_api_url(protocol, host, port, "helpers/active", usr_api=False)
response = requests.get(url, headers=headers)
if response.status_code == 400:
helper = "numpyhelper"
elif response.status_code == 200:
helper = response.json()
else:
click.secho("An unexpected error occurred when getting the active helper", fg="red")
return

url = get_api_url(protocol, host, port, "sessions/", usr_api=False)
response = requests.post(
url,
json={
"name": name,
"session_config": {
"aggregator": aggregator,
"aggregator_kwargs": aggregator_kwargs,
"round_timeout": round_timeout,
"buffer_size": round_buffer_size,
"model_id": model_id,
"delete_models_storage": delete_models,
"clients_required": min_clients,
"requested_clients": requested_clients,
"validate": validate,
"helper_type": helper,
"server_functions": None,
},
},
headers=headers,
verify=False,
)

if response.status_code == 201:
session_id = response.json()["session_id"]
url = get_api_url(protocol, host, port, "sessions/start", usr_api=False)
response = requests.post(
url,
json={
"session_id": session_id,
"rounds": rounds,
"round_timeout": round_timeout,
},
headers=headers,
verify=False,
)
response_json = response.json()
response_json["session_id"] = session_id
click.secho(f"Session started successfully: {response_json}", fg="green")
else:
click.secho(f"Failed to start session: {response.json()}", fg="red")

0 comments on commit 0e778e6

Please sign in to comment.