Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reasoning (extended thinking) for claude 3.7 #750

Open
wants to merge 27 commits into
base: v2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d484103
WIP
fsatsuki Feb 19, 2025
1137b11
Merge branch 'v2' into issue-714
fsatsuki Feb 19, 2025
753d117
CI対応
fsatsuki Feb 19, 2025
cb0b31a
CI対応
fsatsuki Feb 19, 2025
a5e6c2d
discriminatedの対応
fsatsuki Feb 20, 2025
adcde72
model_validate
fsatsuki Feb 20, 2025
a5735b1
CIのビルドエラー修正
fsatsuki Feb 21, 2025
e0dfcca
レビューコメント反映
fsatsuki Feb 26, 2025
615c09e
CIエラー修正
fsatsuki Feb 26, 2025
2c2b0ab
UIのインデント調整
fsatsuki Feb 27, 2025
6eba78d
reformat
fsatsuki Feb 27, 2025
f15da66
feat: add Pyright configuration to exclude specific directories
statefb Feb 27, 2025
a8aeea3
chore: update boto3 to support reasoning
statefb Feb 27, 2025
b646edb
feat: add core functionality for reasoning
statefb Feb 27, 2025
95be571
feat: bot feature
statefb Feb 27, 2025
5fa3a9f
feat: frontend
statefb Feb 27, 2025
3114bf8
add test suit
statefb Feb 27, 2025
9a3b288
chore: fix lint err
statefb Feb 27, 2025
bc68257
delete_secret_managerの実施場所変更.
fsatsuki Feb 27, 2025
9f228f9
search_engineがfirecrawlなのに、firecrawl_configが未入力の場合はエラーにする
fsatsuki Feb 27, 2025
77dc964
apikeyのvalidation
fsatsuki Feb 27, 2025
d5f4f15
Refactor API key handling and update tool models in the bot framework
statefb Feb 28, 2025
e7321dd
Rename delete_secret_manager to delete_api_key_from_secret_manager
statefb Feb 28, 2025
ca06bc9
fix: raise error when failed for internet search tool
statefb Feb 28, 2025
d2ad5f1
feat: enhance Firecrawl integration with improved validation and lega…
statefb Feb 28, 2025
0eec2df
chore: lint
statefb Feb 28, 2025
f36b9f5
resolve conflicts
statefb Feb 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 117 additions & 16 deletions backend/app/agents/tools/internet_search.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import logging

from app.agents.tools.agent_tool import AgentTool
from app.repositories.models.custom_bot import BotModel
from app.repositories.models.custom_bot import BotModel, InternetToolModel
from app.routes.schemas.conversation import type_model_name
from duckduckgo_search import DDGS
from firecrawl.firecrawl import FirecrawlApp
from pydantic import BaseModel, Field, root_validator

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class InternetSearchInput(BaseModel):
query: str = Field(description="The query to search for on the internet.")
Expand Down Expand Up @@ -33,38 +39,133 @@ def validate_country(cls, values):
return values


def internet_search(
tool_input: InternetSearchInput, bot: BotModel | None, model: type_model_name | None
) -> list:
query = tool_input.query
time_limit = tool_input.time_limit
country = tool_input.country

def _search_with_duckduckgo(query: str, time_limit: str, country: str) -> list:
REGION = country
SAFE_SEARCH = "moderate"
MAX_RESULTS = 20
BACKEND = "api"
logger.info(
f"Executing DuckDuckGo search with query: {query}, region: {REGION}, time_limit: {time_limit}"
)
with DDGS() as ddgs:
return [
{
"content": result["body"],
"source_name": result["title"],
"source_link": result["href"],
}
for result in ddgs.text(
results = list(
ddgs.text(
keywords=query,
region=REGION,
safesearch=SAFE_SEARCH,
timelimit=time_limit,
max_results=MAX_RESULTS,
backend=BACKEND,
)
)
logger.info(f"DuckDuckGo search completed. Found {len(results)} results")
return [
{
"content": result["body"],
"source_name": result["title"],
"source_link": result["href"],
}
for result in results
]


def _search_with_firecrawl(
query: str, api_key: str, country: str, max_results: int = 10
) -> list:
logger.info(f"Searching with Firecrawl. Query: {query}, Max Results: {max_results}")

try:
app = FirecrawlApp(api_key=api_key)

# Search using Firecrawl
# SearchParams: /~https://github.com/mendableai/firecrawl/blob/main/apps/python-sdk/firecrawl/firecrawl.py#L24
results = app.search(
query,
{
"limit": max_results,
"lang": country,
"scrapeOptions": {"formats": ["markdown"], "onlyMainContent": True},
},
)

if not results:
logger.warning("No results found")
return []
logger.info(f"results of firecrawl: {results}")

# Format search results
search_results = [
{
"content": data.get("markdown", {}),
"source_name": data.get("title", ""),
"source_link": data.get("metadata", {}).get("sourceURL", ""),
}
for data in results.get("data", [])
if isinstance(data, dict)
]

logger.info(f"Found {len(search_results)} results from Firecrawl")
return search_results

except Exception as e:
logger.error(f"Error searching with Firecrawl: {e}")
raise e


def _internet_search(
tool_input: InternetSearchInput, bot: BotModel | None, model: type_model_name | None
) -> list:
query = tool_input.query
time_limit = tool_input.time_limit
country = tool_input.country

logger.info(
f"Internet search request - Query: {query}, Time Limit: {time_limit}, Country: {country}"
)

if bot is None:
logger.warning("Bot is None, defaulting to DuckDuckGo search")
return _search_with_duckduckgo(query, time_limit, country)

# Find internet search tool
internet_tool = next(
(tool for tool in bot.agent.tools if isinstance(tool, InternetToolModel)),
None,
)

# If no internet tool found or search engine is duckduckgo, use DuckDuckGo
if not internet_tool or internet_tool.search_engine == "duckduckgo":
logger.info("No internet tool found or search engine is DuckDuckGo")
return _search_with_duckduckgo(query, time_limit, country)

# Handle Firecrawl search
if internet_tool.search_engine == "firecrawl":
if not internet_tool.firecrawl_config:
raise ValueError("Firecrawl configuration is not set in the bot.")

try:
api_key = internet_tool.firecrawl_config.api_key
if not api_key:
raise ValueError("Firecrawl API key is empty")

return _search_with_firecrawl(
query=query,
api_key=api_key,
country=country,
max_results=internet_tool.firecrawl_config.max_results,
)
except Exception as e:
logger.error(f"Error with Firecrawl search: {e}")
raise e

# Fallback to DuckDuckGo for any unexpected cases
logger.warning("Unexpected search engine configuration, falling back to DuckDuckGo")
return _search_with_duckduckgo(query, time_limit, country)


internet_search_tool = AgentTool(
name="internet_search",
description="Search the internet for information.",
args_schema=InternetSearchInput,
function=internet_search,
function=_internet_search,
)
101 changes: 72 additions & 29 deletions backend/app/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from mypy_boto3_bedrock_runtime.type_defs import (
ContentBlockTypeDef,
ConverseResponseTypeDef,
ConverseStreamRequestRequestTypeDef,
ConverseStreamRequestTypeDef,
GuardrailConverseContentBlockTypeDef,
InferenceConfigurationTypeDef,
MessageTypeDef,
Expand Down Expand Up @@ -105,7 +105,8 @@ def compose_args_for_converse_api(
grounding_source: GuardrailConverseContentBlockTypeDef | None = None,
tools: dict[str, AgentTool] | None = None,
stream: bool = True,
) -> ConverseStreamRequestRequestTypeDef:
enable_reasoning: bool = False,
) -> ConverseStreamRequestTypeDef:
def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:
if c.content_type == "text":
if (
Expand Down Expand Up @@ -142,6 +143,7 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:
inference_config: InferenceConfigurationTypeDef
additional_model_request_fields: dict[str, Any]
system_prompts: list[SystemContentBlockTypeDef]

if is_nova_model(model):
# Special handling for Nova models
inference_config, additional_model_request_fields = _prepare_nova_model_params(
Expand All @@ -159,35 +161,76 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:

else:
# Standard handling for non-Nova models
inference_config = {
"maxTokens": (
if enable_reasoning:
budget_tokens = (
generation_params.reasoning_params.budget_tokens
if generation_params and generation_params.reasoning_params
else DEFAULT_GENERATION_CONFIG["reasoning_params"]["budget_tokens"] # type: ignore
)
max_tokens = (
generation_params.max_tokens
if generation_params
else DEFAULT_GENERATION_CONFIG["max_tokens"]
),
"temperature": (
generation_params.temperature
if generation_params
else DEFAULT_GENERATION_CONFIG["temperature"]
),
"topP": (
generation_params.top_p
if generation_params
else DEFAULT_GENERATION_CONFIG["top_p"]
),
"stopSequences": (
generation_params.stop_sequences
if generation_params
else DEFAULT_GENERATION_CONFIG.get("stop_sequences", [])
),
}
additional_model_request_fields = {
"top_k": (
generation_params.top_k
if generation_params
else DEFAULT_GENERATION_CONFIG["top_k"]
)
}

if max_tokens <= budget_tokens:
logger.warning(
f"max_tokens ({max_tokens}) must be greater than budget_tokens ({budget_tokens}). "
f"Setting max_tokens to {budget_tokens + 1024}"
)
max_tokens = budget_tokens + 1024

inference_config = {
"maxTokens": max_tokens,
"temperature": 1.0, # Force temperature to 1.0 when reasoning is enabled
"topP": (
generation_params.top_p
if generation_params
else DEFAULT_GENERATION_CONFIG["top_p"]
),
"stopSequences": (
generation_params.stop_sequences
if generation_params
else DEFAULT_GENERATION_CONFIG.get("stop_sequences", [])
),
}
additional_model_request_fields = {
# top_k cannot be used with reasoning
"thinking": {
"type": "enabled",
"budget_tokens": budget_tokens,
},
}
else:
inference_config = {
"maxTokens": (
generation_params.max_tokens
if generation_params
else DEFAULT_GENERATION_CONFIG["max_tokens"]
),
"temperature": (
generation_params.temperature
if generation_params
else DEFAULT_GENERATION_CONFIG["temperature"]
),
"topP": (
generation_params.top_p
if generation_params
else DEFAULT_GENERATION_CONFIG["top_p"]
),
"stopSequences": (
generation_params.stop_sequences
if generation_params
else DEFAULT_GENERATION_CONFIG.get("stop_sequences", [])
),
}
additional_model_request_fields = {
"top_k": (
generation_params.top_k
if generation_params
else DEFAULT_GENERATION_CONFIG["top_k"]
),
}
system_prompts = [
{
"text": instruction,
Expand All @@ -197,7 +240,7 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:
]

# Construct the base arguments
args: ConverseStreamRequestRequestTypeDef = {
args: ConverseStreamRequestTypeDef = {
"inferenceConfig": inference_config,
"modelId": get_model_id(model),
"messages": arg_messages,
Expand Down Expand Up @@ -230,7 +273,7 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:


def call_converse_api(
args: ConverseStreamRequestRequestTypeDef,
args: ConverseStreamRequestTypeDef,
) -> ConverseResponseTypeDef:
client = get_bedrock_runtime_client()

Expand Down
2 changes: 2 additions & 0 deletions backend/app/bot_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
find_usage_plan_by_id,
)
from app.repositories.common import RecordNotFoundError, decompose_bot_id
from app.utils import delete_api_key_from_secret_manager

DOCUMENT_BUCKET = os.environ.get("DOCUMENT_BUCKET", "documents")
BEDROCK_REGION = os.environ.get("BEDROCK_REGION", "us-east-1")
Expand Down Expand Up @@ -75,6 +76,7 @@ def handler(event: dict, context: Any) -> None:

delete_from_s3(user_id, bot_id)
delete_custom_bot_stack_by_bot_id(bot_id)
delete_api_key_from_secret_manager(user_id, bot_id, "firecrawl")

# Check if api published stack exists
try:
Expand Down
11 changes: 8 additions & 3 deletions backend/app/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing_extensions import TypedDict
from typing_extensions import NotRequired, TypedDict


class GenerationParams(TypedDict):
Expand All @@ -7,6 +7,7 @@ class GenerationParams(TypedDict):
top_p: float
temperature: float
stop_sequences: list[str]
reasoning_params: NotRequired[dict[str, int]]


class EmbeddingConfig(TypedDict):
Expand All @@ -20,11 +21,15 @@ class EmbeddingConfig(TypedDict):
# Adjust the values according to your application.
# See: https://docs.anthropic.com/claude/reference/complete_post
DEFAULT_GENERATION_CONFIG: GenerationParams = {
"max_tokens": 2000,
# Minimum (Haiku) is 4096
# Ref: https://docs.anthropic.com/en/docs/about-claude/models/all-models#model-comparison
"max_tokens": 4096,
"top_k": 250,
"top_p": 0.999,
"temperature": 0.6,
"temperature": 1.0,
"stop_sequences": ["Human: ", "Assistant: "],
# Budget tokens must NOT exceeds max_tokens
"reasoning_params": {"budget_tokens": 1024},
}

# Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html#model-parameters-mistral-request-response
Expand Down
Loading