Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuan325 committed Apr 4, 2024
1 parent 523ecfb commit 7e4609c
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 7 deletions.
Binary file removed llm_demo/orchestrator/.orchestrator.py.swp
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,12 @@ def get_base_history(self, session: dict[str, Any]):
return base_history
return BASE_HISTORY

async def user_session_signout(self, uuid: str):
user_session = self.get_user_session(uuid)
if user_session:
await user_session.close()
del self._user_sessions[uuid]

def close_clients(self):
close_client_tasks = [
asyncio.create_task(a.close()) for a in self._user_sessions.values()
Expand Down
12 changes: 5 additions & 7 deletions llm_demo/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def get_user_session(self, uuid: str) -> Any:
async def user_session_insert_ticket(self, uuid: str, params: str) -> Any:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def user_session_signout(self, uuid: str):
"""Sign out from user session. Clear and restart session."""
raise NotImplementedError("Subclass should implement this!")

def set_user_session_header(self, uuid: str, user_id_token: str):
user_session = self.get_user_session(uuid)
user_session.client.headers["User-Id-Token"] = f"Bearer {user_id_token}"
Expand All @@ -76,13 +81,6 @@ def get_user_id_token(self, uuid: str) -> Optional[str]:
return parts[1]
return None

async def user_session_signout(self, uuid: str):
"""Sign out from user session. Clear and restart session."""
user_session = self.get_user_session(uuid)
if user_session:
await user_session.close()
del user_session


def createOrchestrator(orchestration_type: str) -> "BaseOrchestrator":
for cls in BaseOrchestrator.__subclasses__():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,12 @@ def get_base_history(self, session: dict[str, Any]):
return base_history
return BASE_HISTORY

async def user_session_signout(self, uuid: str):
user_session = self.get_user_session(uuid)
if user_session:
await user_session.close()
del self._user_sessions[uuid]

def close_clients(self):
close_client_tasks = [
asyncio.create_task(a.close()) for a in self._user_sessions.values()
Expand Down

0 comments on commit 7e4609c

Please sign in to comment.