Skip to content

Commit

Permalink
Make inside-module functions private
Browse files Browse the repository at this point in the history
  • Loading branch information
dandansamax committed Oct 29, 2024
1 parent 23d4ad6 commit 2d8b87e
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 33 deletions.
20 changes: 10 additions & 10 deletions crab/agents/backend_models/claude_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ def reset(self, system_message: str, action_space: list[Action] | None) -> None:
def chat(self, message: list[Message] | Message) -> BackendOutput:
if isinstance(message, tuple):
message = [message]
request = self.fetch_from_memory()
new_message = self.construct_new_message(message)
request = self._fetch_from_memory()
new_message = self._construct_new_message(message)
request.append(new_message)
response_message = self.call_api(request)
self.record_message(new_message, response_message)
return self.generate_backend_output(response_message)
response_message = self._call_api(request)
self._record_message(new_message, response_message)
return self._generate_backend_output(response_message)

def construct_new_message(self, message: list[Message]) -> dict[str, Any]:
def _construct_new_message(self, message: list[Message]) -> dict[str, Any]:
parts: list[dict] = []
for content, msg_type in message:
match msg_type:
Expand All @@ -96,7 +96,7 @@ def construct_new_message(self, message: list[Message]) -> dict[str, Any]:
"content": parts,
}

def fetch_from_memory(self) -> list[dict]:
def _fetch_from_memory(self) -> list[dict]:
request: list[dict] = []
if self.history_messages_len > 0:
fetch_history_len = min(self.history_messages_len, len(self.chat_history))
Expand All @@ -107,7 +107,7 @@ def fetch_from_memory(self) -> list[dict]:
def get_token_usage(self):
return self.token_usage

def record_message(
def _record_message(
self, new_message: dict, response_message: anthropic.types.Message
) -> None:
self.chat_history.append([new_message])
Expand Down Expand Up @@ -145,7 +145,7 @@ def record_message(
)
),
)
def call_api(self, request_messages: list[dict]) -> anthropic.types.Message:
def _call_api(self, request_messages: list[dict]) -> anthropic.types.Message:
request_messages = _merge_request(request_messages)
if self.action_schema is not None:
response = self.client.messages.create(
Expand All @@ -169,7 +169,7 @@ def call_api(self, request_messages: list[dict]) -> anthropic.types.Message:
self.token_usage += response.usage.input_tokens + response.usage.output_tokens
return response

def generate_backend_output(
def _generate_backend_output(
self, response_message: anthropic.types.Message
) -> BackendOutput:
message = ""
Expand Down
20 changes: 10 additions & 10 deletions crab/agents/backend_models/gemini_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ def reset(self, system_message: str, action_space: list[Action] | None) -> None:
def chat(self, message: list[Message] | Message) -> BackendOutput:
if isinstance(message, tuple):
message = [message]
request = self.fetch_from_memory()
new_message = self.construct_new_message(message)
request = self._fetch_from_memory()
new_message = self._construct_new_message(message)
request.append(new_message)
response_message = self.call_api(request)
self.record_message(new_message, response_message)
return self.generate_backend_output(response_message)
response_message = self._call_api(request)
self._record_message(new_message, response_message)
return self._generate_backend_output(response_message)

def construct_new_message(self, message: list[Message]) -> dict[str, Any]:
def _construct_new_message(self, message: list[Message]) -> dict[str, Any]:
parts: list[str | Image] = []
for content, msg_type in message:
match msg_type:
Expand All @@ -91,7 +91,7 @@ def construct_new_message(self, message: list[Message]) -> dict[str, Any]:
"parts": parts,
}

def generate_backend_output(self, response_message: Content) -> BackendOutput:
def _generate_backend_output(self, response_message: Content) -> BackendOutput:
tool_calls: list[ActionOutput] = []
for part in response_message.parts:
if "function_call" in Part.to_dict(part):
Expand All @@ -108,7 +108,7 @@ def generate_backend_output(self, response_message: Content) -> BackendOutput:
action_list=tool_calls or None,
)

def fetch_from_memory(self) -> list[dict]:
def _fetch_from_memory(self) -> list[dict]:
request: list[dict] = []
if self.history_messages_len > 0:
fetch_history_len = min(self.history_messages_len, len(self.chat_history))
Expand All @@ -119,7 +119,7 @@ def fetch_from_memory(self) -> list[dict]:
def get_token_usage(self):
return self.token_usage

def record_message(
def _record_message(
self, new_message: dict[str, Any], response_message: Content
) -> None:
self.chat_history.append([new_message])
Expand All @@ -132,7 +132,7 @@ def record_message(
stop=stop_after_attempt(7),
retry=retry_if_exception_type(ResourceExhausted),
)
def call_api(self, request_messages: list) -> Content:
def _call_api(self, request_messages: list) -> Content:
if self.action_schema is not None:
tool_config = content_types.to_tool_config(
{
Expand Down
26 changes: 13 additions & 13 deletions crab/agents/backend_models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,17 @@ def reset(self, system_message: str, action_space: list[Action] | None) -> None:
def chat(self, message: list[Message] | Message) -> BackendOutput:
if isinstance(message, tuple):
message = [message]
request = self.fetch_from_memory()
new_message = self.construct_new_message(message)
request = self._fetch_from_memory()
new_message = self._construct_new_message(message)
request.append(new_message)
response_message = self.call_api(request)
self.record_message(new_message, response_message)
return self.generate_backend_output(response_message)
response_message = self._call_api(request)
self._record_message(new_message, response_message)
return self._generate_backend_output(response_message)

def get_token_usage(self):
return self.token_usage

def record_message(
def _record_message(
self, new_message: dict, response_message: ChatCompletionMessage
) -> None:
self.chat_history.append([new_message])
Expand All @@ -99,7 +99,7 @@ def record_message(
}
) # extend conversation with function response

def call_api(
def _call_api(
self, request_messages: list[ChatCompletionMessage | dict]
) -> ChatCompletionMessage:
if self.action_schema is not None:
Expand All @@ -120,15 +120,15 @@ def call_api(
self.token_usage += response.usage.total_tokens
return response.choices[0].message

def fetch_from_memory(self) -> list[ChatCompletionMessage | dict]:
def _fetch_from_memory(self) -> list[ChatCompletionMessage | dict]:
request: list[ChatCompletionMessage | dict] = [self.openai_system_message]
if self.history_messages_len > 0:
fetch_history_len = min(self.history_messages_len, len(self.chat_history))
for history_message in self.chat_history[-fetch_history_len:]:
request = request + history_message
return request

def construct_new_message(self, message: list[Message]) -> dict[str, Any]:
def _construct_new_message(self, message: list[Message]) -> dict[str, Any]:
new_message_content: list[dict[str, Any]] = []
for content, msg_type in message:
match msg_type:
Expand All @@ -152,7 +152,7 @@ def construct_new_message(self, message: list[Message]) -> dict[str, Any]:

return {"role": "user", "content": new_message_content}

def generate_backend_output(
def _generate_backend_output(
self, response_message: ChatCompletionMessage
) -> BackendOutput:
if response_message.tool_calls is None:
Expand Down Expand Up @@ -205,15 +205,15 @@ def reset(self, system_message: str, action_space: list[Action] | None) -> None:
super().reset(system_message, action_space)
self.action_schema = None

def record_message(
def _record_message(
self, new_message: dict, response_message: ChatCompletionMessage
) -> None:
self.chat_history.append([new_message])
self.chat_history[-1].append(
{"role": "assistant", "content": response_message.content}
)

def generate_backend_output(
def _generate_backend_output(
self, response_message: ChatCompletionMessage
) -> BackendOutput:
content = response_message.content
Expand All @@ -240,7 +240,7 @@ def generate_backend_output(


class SGlangOpenAIModelJSON(OpenAIModelJSON):
def construct_new_message(self, message: list[Message]) -> dict[str, Any]:
def _construct_new_message(self, message: list[Message]) -> dict[str, Any]:
new_message_content: list[dict[str, Any]] = []
image_count = 0
for _, msg_type in message:
Expand Down

0 comments on commit 2d8b87e

Please sign in to comment.