Skip to content

Commit

Permalink
chore: sync the change on main branch aws (#326)
Browse files Browse the repository at this point in the history
* chore: sync the change on main branch aws

* chore: change the author's name

* chore: add missing dependencies
  • Loading branch information
Nov1c444 authored Feb 28, 2025
1 parent a58259a commit 7923aae
Show file tree
Hide file tree
Showing 23 changed files with 2,831 additions and 82 deletions.
4 changes: 2 additions & 2 deletions tools/aws/manifest.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
version: 0.0.2
type: plugin
author: "aws"
author: langgenius
name: "aws_tools"
label:
en_US: "AWS"
Expand All @@ -20,7 +20,7 @@ plugins:
tools:
- "provider/aws.yaml"
meta:
version: 0.0.1
version: 0.0.2
arch:
- "amd64"
- "arm64"
Expand Down
7 changes: 7 additions & 0 deletions tools/aws/provider/aws.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,17 @@ identity:
credentials_for_provider: {}
tools:
- tools/apply_guardrail.yaml
- tools/bedrock_retrieve.yaml
- tools/bedrock_retrieve_and_generate.yaml
- tools/lambda_translate_utils.yaml
- tools/lambda_yaml_to_json.yaml
- tools/nova_canvas.yaml
- tools/nova_reel.yaml
- tools/s3_operator.yaml
- tools/sagemaker_chinese_toxicity_detector.yaml
- tools/sagemaker_text_rerank.yaml
- tools/sagemaker_tts.yaml
- tools/transcribe_asr.yaml
extra:
python:
source: provider/aws.py
5 changes: 3 additions & 2 deletions tools/aws/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
dify_plugin==0.0.1b65
boto3==1.35.26
dify_plugin==0.0.1b68
boto3==1.36.12
pillow==11.0.0
25 changes: 13 additions & 12 deletions tools/aws/tools/apply_guardrail.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import json
import logging
from typing import Any, Union

import boto3
from botocore.exceptions import BotoCoreError
from pydantic import BaseModel, Field
from typing import Any, Generator

import boto3 # type: ignore
from botocore.exceptions import BotoCoreError # type: ignore
from dify_plugin import Tool
from dify_plugin.entities.tool import ToolInvokeMessage
from pydantic import BaseModel, Field

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand All @@ -23,8 +22,8 @@ class GuardrailParameters(BaseModel):

class ApplyGuardrailTool(Tool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
self, tool_parameters: dict[str, Any]
) -> Generator[ToolInvokeMessage, None, None]:
"""
Invoke the ApplyGuardrail tool
"""
Expand All @@ -49,7 +48,7 @@ def _invoke(

# Check for empty response
if not response:
return self.create_text_message(
yield self.create_text_message(
text="Received empty response from AWS Bedrock."
)

Expand Down Expand Up @@ -84,17 +83,19 @@ def _invoke(
result += "Assessments:\n " + "\n ".join(formatted_assessments) + "\n "
# result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}"

return self.create_text_message(text=result)
yield self.create_text_message(text=result)

except BotoCoreError as e:
error_message = f"AWS service error: {str(e)}"
logger.error(error_message, exc_info=True)
return self.create_text_message(text=error_message)
yield self.create_text_message(text=error_message)
return
except json.JSONDecodeError as e:
error_message = f"JSON parsing error: {str(e)}"
logger.error(error_message, exc_info=True)
return self.create_text_message(text=error_message)
yield self.create_text_message(text=error_message)
return
except Exception as e:
error_message = f"An unexpected error occurred: {str(e)}"
logger.error(error_message, exc_info=True)
return self.create_text_message(text=error_message)
yield self.create_text_message(text=error_message)
187 changes: 187 additions & 0 deletions tools/aws/tools/bedrock_retrieve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import json
import operator
from typing import Any, Generator, Optional

import boto3
from dify_plugin import Tool
from dify_plugin.entities.tool import ToolInvokeMessage


class BedrockRetrieveTool(Tool):
bedrock_client: Any = None
knowledge_base_id: str = ""
topk: int = 0

def _bedrock_retrieve(
self,
query_input: str,
knowledge_base_id: str,
num_results: int,
search_type: str,
rerank_model_id: str,
metadata_filter: Optional[dict] = None,
):
try:
retrieval_query = {"text": query_input}

if search_type not in ["HYBRID", "SEMANTIC"]:
raise ValueError("search_type should be HYBRID or SEMANTIC")

retrieval_configuration = {
"vectorSearchConfiguration": {
"numberOfResults": num_results,
"overrideSearchType": search_type,
}
}

if rerank_model_id != "default":
model_for_rerank_arn = (
f"arn:aws:bedrock:us-west-2::foundation-model/{rerank_model_id}"
)
rerankingConfiguration = {
"bedrockRerankingConfiguration": {
"numberOfRerankedResults": num_results,
"modelConfiguration": {"modelArn": model_for_rerank_arn},
},
"type": "BEDROCK_RERANKING_MODEL",
}

retrieval_configuration["vectorSearchConfiguration"][
"rerankingConfiguration"
] = rerankingConfiguration
retrieval_configuration["vectorSearchConfiguration"][
"numberOfResults"
] = num_results * 5

# 如果有元数据过滤条件,则添加到检索配置中
if metadata_filter:
retrieval_configuration["vectorSearchConfiguration"]["filter"] = (
metadata_filter
)

response = self.bedrock_client.retrieve(
knowledgeBaseId=knowledge_base_id,
retrievalQuery=retrieval_query,
retrievalConfiguration=retrieval_configuration,
)

results = []
for result in response.get("retrievalResults", []):
results.append(
{
"content": result.get("content", {}).get("text", ""),
"score": result.get("score", 0.0),
"metadata": result.get("metadata", {}),
}
)

return results
except Exception as e:
raise Exception(f"Error retrieving from knowledge base: {str(e)}")

def _invoke(
self,
tool_parameters: dict[str, Any],
) -> Generator[ToolInvokeMessage, None, None]:
"""
invoke tools
"""
try:
line = 0
# Initialize Bedrock client if not already initialized
if not self.bedrock_client:
aws_region = tool_parameters.get("aws_region")
aws_access_key_id = tool_parameters.get("aws_access_key_id")
aws_secret_access_key = tool_parameters.get("aws_secret_access_key")

client_kwargs = {
"service_name": "bedrock-agent-runtime",
"region_name": aws_region or None,
}

# Only add credentials if both access key and secret key are provided
if aws_access_key_id and aws_secret_access_key:
client_kwargs.update(
{
"aws_access_key_id": aws_access_key_id,
"aws_secret_access_key": aws_secret_access_key,
}
)

self.bedrock_client = boto3.client(**client_kwargs)
except Exception as e:
yield self.create_text_message(
f"Failed to initialize Bedrock client: {str(e)}"
)
return
line = 0
try:
line = 1
if not self.knowledge_base_id:
self.knowledge_base_id = tool_parameters.get("knowledge_base_id", "")
if not self.knowledge_base_id:
yield self.create_text_message("Please provide knowledge_base_id")

line = 2
if not self.topk:
self.topk = tool_parameters.get("topk", 5)

line = 3
query = tool_parameters.get("query", "")
if not query:
yield self.create_text_message("Please input query")

# 获取元数据过滤条件(如果存在)
metadata_filter_str = tool_parameters.get("metadata_filter")
metadata_filter = (
json.loads(metadata_filter_str) if metadata_filter_str else None
)

search_type = tool_parameters.get("search_type", "")
rerank_model_id = tool_parameters.get("rerank_model_id", "")

line = 4
retrieved_docs = self._bedrock_retrieve(
query_input=query,
knowledge_base_id=self.knowledge_base_id,
num_results=self.topk,
search_type=search_type,
rerank_model_id=rerank_model_id,
metadata_filter=metadata_filter,
)

line = 5
# Sort results by score in descending order
sorted_docs = sorted(
retrieved_docs, key=operator.itemgetter("score"), reverse=True
)

line = 6
result_type = tool_parameters.get("result_type")
if result_type == "json":
for res in sorted_docs:
yield self.create_json_message(res)
else:
text = ""
for i, res in enumerate(sorted_docs):
text += f"{i + 1}: {res['content']}\n"
yield self.create_text_message(text)

except Exception as e:
yield self.create_text_message(f"Exception {str(e)}, line : {line}")

def validate_parameters(self, parameters: dict[str, Any]) -> None:
"""
Validate the parameters
"""
if not parameters.get("knowledge_base_id"):
raise ValueError("knowledge_base_id is required")

if not parameters.get("query"):
raise ValueError("query is required")

metadata_filter_str = parameters.get("metadata_filter")
if metadata_filter_str and not isinstance(
json.loads(metadata_filter_str), dict
):
raise ValueError("metadata_filter must be a valid JSON object")
Loading

0 comments on commit 7923aae

Please sign in to comment.