工具集成与外部API调用
目录
工具集成概述
什么是工具集成
工具集成允许Claude与外部系统、API和服务进行交互,扩展其基础能力。通过工具集成,Claude可以执行计算、查询数据库、调用第三方服务等操作。
核心优势
能力扩展
- 实时数据访问:获取最新的外部数据
- 计算能力增强:执行复杂的数学和统计计算
- 系统集成:与企业系统和数据库集成
- 服务编排:协调多个外部服务
灵活性提升
- 动态功能添加:根据需要添加新的工具
- 自定义业务逻辑:实现特定的业务需求
- 工作流自动化:自动化复杂的工作流程
- 响应式交互:根据上下文智能选择工具
实用价值
- 决策支持:基于实时数据做出决策
- 效率提升:自动化重复性任务
- 准确性保证:通过外部验证确保准确性
- 用户体验:提供更丰富的交互体验
工具定义与配置
基本工具定义
简单计算工具
import anthropic
import json
def define_calculator_tool():
"""定义计算器工具"""
return {
"name": "calculator",
"description": "执行数学计算,支持基本运算、三角函数、对数等",
"input_schema": {
"type": "object",
"properties": {
"expression": {
"type": "string",
"description": "要计算的数学表达式,如 '2 + 3 * 4' 或 'sqrt(16)'"
}
},
"required": ["expression"]
}
}
def execute_calculator(expression):
"""执行计算器工具"""
import math
import re
# 安全的数学函数
safe_functions = {
'sqrt': math.sqrt,
'sin': math.sin,
'cos': math.cos,
'tan': math.tan,
'log': math.log,
'exp': math.exp,
'abs': abs,
'round': round,
'pi': math.pi,
'e': math.e
}
try:
# 清理表达式,只允许安全的字符
safe_expression = re.sub(r'[^0-9+\-*/().\s]', '', expression)
# 替换函数名
for func_name, func in safe_functions.items():
safe_expression = safe_expression.replace(func_name, str(func))
result = eval(safe_expression)
return {
"result": result,
"expression": expression,
"success": True
}
except Exception as e:
return {
"error": str(e),
"expression": expression,
"success": False
}
数据库查询工具
def define_database_tool():
"""定义数据库查询工具"""
return {
"name": "database_query",
"description": "执行SQL查询获取数据库中的信息",
"input_schema": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SQL查询语句,仅支持SELECT操作"
},
"database": {
"type": "string",
"description": "数据库名称",
"enum": ["users", "products", "orders", "analytics"]
}
},
"required": ["query", "database"]
}
}
def execute_database_query(query, database):
"""执行数据库查询"""
import sqlite3
# 验证查询安全性
if not is_safe_query(query):
return {
"error": "不安全的查询,仅允许SELECT操作",
"success": False
}
try:
# 连接到对应的数据库
db_path = get_database_path(database)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute(query)
results = cursor.fetchall()
# 获取列名
column_names = [description[0] for description in cursor.description]
conn.close()
return {
"results": results,
"columns": column_names,
"row_count": len(results),
"success": True
}
except Exception as e:
return {
"error": str(e),
"success": False
}
def is_safe_query(query):
"""检查查询是否安全"""
query_lower = query.lower().strip()
# 只允许SELECT查询
if not query_lower.startswith('select'):
return False
# 禁止的关键词
forbidden_keywords = [
'insert', 'update', 'delete', 'drop', 'create',
'alter', 'truncate', 'exec', 'execute'
]
for keyword in forbidden_keywords:
if keyword in query_lower:
return False
return True
网络API工具
def define_web_api_tool():
"""定义网络API工具"""
return {
"name": "web_api_call",
"description": "调用外部Web API获取数据",
"input_schema": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "API端点URL"
},
"method": {
"type": "string",
"description": "HTTP方法",
"enum": ["GET", "POST"]
},
"headers": {
"type": "object",
"description": "HTTP请求头"
},
"data": {
"type": "object",
"description": "请求数据(POST方法时使用)"
}
},
"required": ["url", "method"]
}
}
def execute_web_api_call(url, method, headers=None, data=None):
"""执行Web API调用"""
import requests
import time
# URL白名单验证
if not is_allowed_url(url):
return {
"error": "URL不在允许的白名单中",
"success": False
}
try:
# 设置默认头部
default_headers = {
"User-Agent": "Claude-Assistant/1.0",
"Accept": "application/json"
}
if headers:
default_headers.update(headers)
# 执行请求
if method.upper() == "GET":
response = requests.get(
url,
headers=default_headers,
timeout=30
)
elif method.upper() == "POST":
response = requests.post(
url,
headers=default_headers,
json=data,
timeout=30
)
# 解析响应
try:
json_data = response.json()
except:
json_data = None
return {
"status_code": response.status_code,
"headers": dict(response.headers),
"data": json_data,
"text": response.text if not json_data else None,
"success": response.status_code < 400
}
except Exception as e:
return {
"error": str(e),
"success": False
}
def is_allowed_url(url):
"""检查URL是否在白名单中"""
allowed_domains = [
"api.openweathermap.org",
"api.github.com",
"jsonplaceholder.typicode.com",
"httpbin.org"
]
from urllib.parse import urlparse
parsed_url = urlparse(url)
domain = parsed_url.netloc
return domain in allowed_domains
复合工具系统
文件处理工具集
def define_file_tools():
"""定义文件处理工具集"""
return [
{
"name": "read_file",
"description": "读取文件内容",
"input_schema": {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "文件路径"
},
"encoding": {
"type": "string",
"description": "文件编码",
"default": "utf-8"
}
},
"required": ["file_path"]
}
},
{
"name": "write_file",
"description": "写入文件内容",
"input_schema": {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "文件路径"
},
"content": {
"type": "string",
"description": "要写入的内容"
},
"mode": {
"type": "string",
"description": "写入模式",
"enum": ["write", "append"],
"default": "write"
}
},
"required": ["file_path", "content"]
}
},
{
"name": "list_directory",
"description": "列出目录内容",
"input_schema": {
"type": "object",
"properties": {
"directory_path": {
"type": "string",
"description": "目录路径"
},
"include_hidden": {
"type": "boolean",
"description": "是否包含隐藏文件",
"default": False
}
},
"required": ["directory_path"]
}
}
]
class FileToolExecutor:
def __init__(self, base_path="/safe/workspace"):
self.base_path = base_path
def execute_tool(self, tool_name, **kwargs):
"""执行文件工具"""
if tool_name == "read_file":
return self.read_file(**kwargs)
elif tool_name == "write_file":
return self.write_file(**kwargs)
elif tool_name == "list_directory":
return self.list_directory(**kwargs)
else:
return {"error": f"未知工具: {tool_name}", "success": False}
def read_file(self, file_path, encoding="utf-8"):
"""读取文件"""
safe_path = self.get_safe_path(file_path)
if not safe_path:
return {"error": "不安全的文件路径", "success": False}
try:
with open(safe_path, 'r', encoding=encoding) as f:
content = f.read()
return {
"content": content,
"file_path": file_path,
"size": len(content),
"success": True
}
except Exception as e:
return {"error": str(e), "success": False}
def write_file(self, file_path, content, mode="write"):
"""写入文件"""
safe_path = self.get_safe_path(file_path)
if not safe_path:
return {"error": "不安全的文件路径", "success": False}
try:
file_mode = 'w' if mode == 'write' else 'a'
with open(safe_path, file_mode, encoding='utf-8') as f:
f.write(content)
return {
"file_path": file_path,
"bytes_written": len(content.encode('utf-8')),
"mode": mode,
"success": True
}
except Exception as e:
return {"error": str(e), "success": False}
def get_safe_path(self, file_path):
"""获取安全的文件路径"""
import os
# 防止路径遍历攻击
if '..' in file_path or file_path.startswith('/'):
return None
safe_path = os.path.join(self.base_path, file_path)
# 确保路径在安全目录内
if not safe_path.startswith(self.base_path):
return None
return safe_path
API集成模式
RESTful API集成
天气API集成
def define_weather_tool():
"""定义天气查询工具"""
return {
"name": "get_weather",
"description": "获取指定城市的当前天气信息",
"input_schema": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "城市名称,如'北京'或'Beijing'"
},
"units": {
"type": "string",
"description": "温度单位",
"enum": ["metric", "imperial"],
"default": "metric"
}
},
"required": ["city"]
}
}
class WeatherAPI:
def __init__(self, api_key):
self.api_key = api_key
self.base_url = "https://api.openweathermap.org/data/2.5"
def get_weather(self, city, units="metric"):
"""获取天气信息"""
import requests
try:
url = f"{self.base_url}/weather"
params = {
"q": city,
"appid": self.api_key,
"units": units,
"lang": "zh_cn"
}
response = requests.get(url, params=params, timeout=10)
if response.status_code == 200:
data = response.json()
return {
"city": data["name"],
"country": data["sys"]["country"],
"temperature": data["main"]["temp"],
"feels_like": data["main"]["feels_like"],
"humidity": data["main"]["humidity"],
"pressure": data["main"]["pressure"],
"description": data["weather"][0]["description"],
"wind_speed": data["wind"]["speed"],
"success": True
}
else:
return {
"error": f"API请求失败: {response.status_code}",
"success": False
}
except Exception as e:
return {
"error": str(e),
"success": False
}
股票API集成
def define_stock_tool():
"""定义股票查询工具"""
return {
"name": "get_stock_price",
"description": "获取股票价格和基本信息",
"input_schema": {
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "股票代码,如'AAPL'、'MSFT'"
},
"interval": {
"type": "string",
"description": "数据间隔",
"enum": ["1d", "1wk", "1mo"],
"default": "1d"
}
},
"required": ["symbol"]
}
}
class StockAPI:
def __init__(self):
self.base_url = "https://query1.finance.yahoo.com/v8/finance/chart"
def get_stock_price(self, symbol, interval="1d"):
"""获取股票价格"""
import requests
try:
url = f"{self.base_url}/{symbol}"
params = {
"interval": interval,
"range": "1d"
}
response = requests.get(url, params=params, timeout=10)
if response.status_code == 200:
data = response.json()
if data["chart"]["error"]:
return {
"error": "股票代码不存在或数据获取失败",
"success": False
}
result = data["chart"]["result"][0]
meta = result["meta"]
return {
"symbol": symbol,
"company_name": meta.get("longName", symbol),
"current_price": meta["regularMarketPrice"],
"previous_close": meta["previousClose"],
"day_high": meta["regularMarketDayHigh"],
"day_low": meta["regularMarketDayLow"],
"volume": meta["regularMarketVolume"],
"currency": meta["currency"],
"exchange": meta["exchangeName"],
"success": True
}
else:
return {
"error": f"API请求失败: {response.status_code}",
"success": False
}
except Exception as e:
return {
"error": str(e),
"success": False
}
GraphQL API集成
GitHub API集成
def define_github_tool():
"""定义GitHub查询工具"""
return {
"name": "github_query",
"description": "查询GitHub仓库信息",
"input_schema": {
"type": "object",
"properties": {
"owner": {
"type": "string",
"description": "仓库所有者"
},
"repo": {
"type": "string",
"description": "仓库名称"
},
"query_type": {
"type": "string",
"description": "查询类型",
"enum": ["repository", "issues", "commits"],
"default": "repository"
}
},
"required": ["owner", "repo"]
}
}
class GitHubAPI:
def __init__(self, token=None):
self.token = token
self.base_url = "https://api.github.com"
self.graphql_url = "https://api.github.com/graphql"
def github_query(self, owner, repo, query_type="repository"):
"""查询GitHub信息"""
if query_type == "repository":
return self.get_repository_info(owner, repo)
elif query_type == "issues":
return self.get_issues(owner, repo)
elif query_type == "commits":
return self.get_commits(owner, repo)
def get_repository_info(self, owner, repo):
"""获取仓库信息"""
import requests
try:
url = f"{self.base_url}/repos/{owner}/{repo}"
headers = {}
if self.token:
headers["Authorization"] = f"token {self.token}"
response = requests.get(url, headers=headers, timeout=10)
if response.status_code == 200:
data = response.json()
return {
"name": data["name"],
"full_name": data["full_name"],
"description": data["description"],
"language": data["language"],
"stars": data["stargazers_count"],
"forks": data["forks_count"],
"issues": data["open_issues_count"],
"created_at": data["created_at"],
"updated_at": data["updated_at"],
"license": data["license"]["name"] if data["license"] else None,
"success": True
}
else:
return {
"error": f"仓库不存在或无权访问: {response.status_code}",
"success": False
}
except Exception as e:
return {
"error": str(e),
"success": False
}
工具执行流程
工具调用管理器
核心调用管理器
class ToolCallManager:
def __init__(self):
self.tools = {}
self.executors = {}
self.call_history = []
def register_tool(self, tool_definition, executor):
"""注册工具"""
tool_name = tool_definition["name"]
self.tools[tool_name] = tool_definition
self.executors[tool_name] = executor
print(f"工具 '{tool_name}' 已注册")
def execute_tool_call(self, tool_name, **kwargs):
"""执行工具调用"""
if tool_name not in self.tools:
return {
"error": f"未知工具: {tool_name}",
"success": False
}
# 验证参数
validation_result = self.validate_parameters(tool_name, kwargs)
if not validation_result["valid"]:
return {
"error": f"参数验证失败: {validation_result['error']}",
"success": False
}
# 记录调用
call_record = {
"tool_name": tool_name,
"parameters": kwargs,
"timestamp": time.time()
}
try:
# 执行工具
executor = self.executors[tool_name]
result = executor(**kwargs)
# 记录结果
call_record["result"] = result
call_record["success"] = result.get("success", True)
self.call_history.append(call_record)
return result
except Exception as e:
call_record["error"] = str(e)
call_record["success"] = False
self.call_history.append(call_record)
return {
"error": str(e),
"success": False
}
def validate_parameters(self, tool_name, parameters):
"""验证工具参数"""
tool_def = self.tools[tool_name]
schema = tool_def["input_schema"]
# 检查必需参数
required_params = schema.get("required", [])
for param in required_params:
if param not in parameters:
return {
"valid": False,
"error": f"缺少必需参数: {param}"
}
# 检查参数类型
properties = schema.get("properties", {})
for param_name, param_value in parameters.items():
if param_name in properties:
expected_type = properties[param_name].get("type")
if not self.check_parameter_type(param_value, expected_type):
return {
"valid": False,
"error": f"参数 {param_name} 类型错误,期望 {expected_type}"
}
return {"valid": True}
def check_parameter_type(self, value, expected_type):
"""检查参数类型"""
type_mapping = {
"string": str,
"number": (int, float),
"integer": int,
"boolean": bool,
"object": dict,
"array": list
}
expected_python_type = type_mapping.get(expected_type)
if expected_python_type:
return isinstance(value, expected_python_type)
return True
与Claude集成
完整的工具调用流程
def create_tool_enabled_conversation():
"""创建支持工具的对话"""
# 初始化工具管理器
tool_manager = ToolCallManager()
# 注册工具
calculator_tool = define_calculator_tool()
tool_manager.register_tool(calculator_tool, execute_calculator)
weather_api = WeatherAPI("your_api_key")
weather_tool = define_weather_tool()
tool_manager.register_tool(weather_tool, weather_api.get_weather)
# 定义工具列表给Claude
tools = list(tool_manager.tools.values())
return tools, tool_manager
def handle_tool_conversation(user_message):
"""处理包含工具调用的对话"""
tools, tool_manager = create_tool_enabled_conversation()
client = anthropic.Anthropic(api_key="your-key")
response = client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=2048,
tools=tools,
messages=[
{
"role": "user",
"content": user_message
}
]
)
# 处理响应中的工具调用
final_response = process_tool_calls(response, tool_manager, client, tools)
return final_response
def process_tool_calls(response, tool_manager, client, tools):
"""处理工具调用"""
messages = [
{
"role": "user",
"content": "请帮我处理这个请求"
}
]
# 添加助手的响应
messages.append({
"role": "assistant",
"content": response.content
})
# 检查是否有工具调用
tool_calls = []
for content_block in response.content:
if content_block.type == "tool_use":
tool_calls.append(content_block)
if not tool_calls:
return response.content[0].text
# 执行工具调用
for tool_call in tool_calls:
tool_name = tool_call.name
tool_input = tool_call.input
# 执行工具
tool_result = tool_manager.execute_tool_call(tool_name, **tool_input)
# 添加工具结果到消息
messages.append({
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": tool_call.id,
"content": json.dumps(tool_result, ensure_ascii=False)
}
]
})
# 获取最终响应
final_response = client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=2048,
tools=tools,
messages=messages
)
return final_response.content[0].text
错误处理机制
错误分类和处理
错误类型定义
class ToolError:
"""工具错误基类"""
def __init__(self, message, error_type="general", tool_name=None):
self.message = message
self.error_type = error_type
self.tool_name = tool_name
self.timestamp = time.time()
class ValidationError(ToolError):
"""参数验证错误"""
def __init__(self, message, parameter=None, tool_name=None):
super().__init__(message, "validation", tool_name)
self.parameter = parameter
class ExecutionError(ToolError):
"""工具执行错误"""
def __init__(self, message, exception=None, tool_name=None):
super().__init__(message, "execution", tool_name)
self.exception = exception
class NetworkError(ToolError):
"""网络请求错误"""
def __init__(self, message, status_code=None, tool_name=None):
super().__init__(message, "network", tool_name)
self.status_code = status_code
class SecurityError(ToolError):
"""安全相关错误"""
def __init__(self, message, security_issue=None, tool_name=None):
super().__init__(message, "security", tool_name)
self.security_issue = security_issue
错误处理策略
class ErrorHandler:
def __init__(self):
self.error_strategies = {
"validation": self.handle_validation_error,
"execution": self.handle_execution_error,
"network": self.handle_network_error,
"security": self.handle_security_error
}
self.retry_settings = {
"network": {"max_retries": 3, "backoff": 2},
"execution": {"max_retries": 1, "backoff": 1}
}
def handle_error(self, error):
"""处理工具错误"""
strategy = self.error_strategies.get(
error.error_type,
self.handle_general_error
)
return strategy(error)
def handle_validation_error(self, error):
"""处理验证错误"""
return {
"error": error.message,
"error_type": "validation",
"suggestion": "请检查参数格式和必需字段",
"recoverable": True,
"success": False
}
def handle_execution_error(self, error):
"""处理执行错误"""
# 判断是否可以重试
if error.tool_name in self.retry_settings.get("execution", {}):
return {
"error": error.message,
"error_type": "execution",
"suggestion": "工具执行失败,可以尝试重新执行",
"recoverable": True,
"retry_recommended": True,
"success": False
}
return {
"error": error.message,
"error_type": "execution",
"suggestion": "工具执行失败,请检查输入参数",
"recoverable": False,
"success": False
}
def handle_network_error(self, error):
"""处理网络错误"""
if error.status_code in [500, 502, 503, 504]:
# 服务器错误,可以重试
return {
"error": error.message,
"error_type": "network",
"suggestion": "网络服务暂时不可用,请稍后重试",
"recoverable": True,
"retry_recommended": True,
"retry_delay": 5,
"success": False
}
elif error.status_code in [401, 403]:
# 认证错误,不可重试
return {
"error": error.message,
"error_type": "network",
"suggestion": "API认证失败,请检查API密钥",
"recoverable": False,
"success": False
}
else:
return {
"error": error.message,
"error_type": "network",
"suggestion": "网络请求失败,请检查网络连接",
"recoverable": True,
"success": False
}
def handle_security_error(self, error):
"""处理安全错误"""
return {
"error": "安全检查失败",
"error_type": "security",
"suggestion": "请求被安全策略阻止,请检查输入内容",
"recoverable": False,
"success": False
}
重试机制
智能重试策略
class RetryManager:
def __init__(self):
self.retry_policies = {
"default": {
"max_retries": 3,
"base_delay": 1,
"max_delay": 60,
"exponential_backoff": True,
"jitter": True
},
"network": {
"max_retries": 5,
"base_delay": 2,
"max_delay": 120,
"exponential_backoff": True,
"jitter": True
},
"rate_limit": {
"max_retries": 10,
"base_delay": 60,
"max_delay": 600,
"exponential_backoff": False,
"jitter": False
}
}
def execute_with_retry(self, func, *args, policy="default", **kwargs):
"""带重试的函数执行"""
retry_policy = self.retry_policies.get(policy, self.retry_policies["default"])
last_error = None
for attempt in range(retry_policy["max_retries"] + 1):
try:
result = func(*args, **kwargs)
# 如果结果表明成功,直接返回
if isinstance(result, dict) and result.get("success", True):
return result
# 如果是网络相关错误且可重试,继续重试
if (isinstance(result, dict) and
result.get("error_type") == "network" and
result.get("recoverable", False)):
last_error = result
if attempt < retry_policy["max_retries"]:
delay = self.calculate_delay(attempt, retry_policy)
time.sleep(delay)
continue
return result
except Exception as e:
last_error = {
"error": str(e),
"error_type": "execution",
"success": False
}
if attempt < retry_policy["max_retries"]:
delay = self.calculate_delay(attempt, retry_policy)
time.sleep(delay)
else:
break
# 所有重试都失败了
return {
"error": f"重试{retry_policy['max_retries']}次后仍然失败",
"last_error": last_error,
"success": False
}
def calculate_delay(self, attempt, policy):
"""计算重试延迟"""
if policy["exponential_backoff"]:
delay = policy["base_delay"] * (2 ** attempt)
else:
delay = policy["base_delay"]
# 限制最大延迟
delay = min(delay, policy["max_delay"])
# 添加随机抖动
if policy["jitter"]:
import random
delay = delay * (0.5 + random.random() * 0.5)
return delay
安全性考虑
输入验证和清理
安全验证器
class SecurityValidator:
def __init__(self):
self.dangerous_patterns = [
r'(?i)(drop|delete|truncate)\s+table',
r'(?i)exec(ute)?\s*\(',
r'(?i)script\s*>',
r'(?i)<\s*script',
r'(?i)javascript:',
r'(?i)on\w+\s*=',
r'\.\./|\.\\\.',
r'(?i)file:///',
r'(?i)http://localhost',
r'(?i)127\.0\.0\.1'
]
self.sql_injection_patterns = [
r"('|(\\+\+);|(--(\\+\+);)",
r"((\%27)|(\'))((\%6F)|o|(\%4F))((\%72)|r|(\%52))",
r"((\%27)|(\'))union",
r"exec(\s|\+)+(s|x)p\w+",
r"union\s+select",
r"insert\s+into",
r"delete\s+from"
]
def validate_input(self, input_data, input_type="general"):
"""验证输入数据"""
if isinstance(input_data, str):
return self.validate_string_input(input_data, input_type)
elif isinstance(input_data, dict):
return self.validate_dict_input(input_data)
elif isinstance(input_data, list):
return self.validate_list_input(input_data)
return {"valid": True}
def validate_string_input(self, text, input_type):
"""验证字符串输入"""
# 检查危险模式
for pattern in self.dangerous_patterns:
if re.search(pattern, text):
return {
"valid": False,
"error": "输入包含潜在危险内容",
"security_issue": "dangerous_pattern"
}
# 对SQL查询进行特殊检查
if input_type == "sql":
for pattern in self.sql_injection_patterns:
if re.search(pattern, text, re.IGNORECASE):
return {
"valid": False,
"error": "检测到潜在的SQL注入",
"security_issue": "sql_injection"
}
return {"valid": True}
def validate_dict_input(self, data):
"""验证字典输入"""
for key, value in data.items():
if isinstance(value, str):
result = self.validate_string_input(value)
if not result["valid"]:
return result
elif isinstance(value, (dict, list)):
result = self.validate_input(value)
if not result["valid"]:
return result
return {"valid": True}
def sanitize_input(self, input_data):
"""清理输入数据"""
if isinstance(input_data, str):
return self.sanitize_string(input_data)
elif isinstance(input_data, dict):
return {k: self.sanitize_input(v) for k, v in input_data.items()}
elif isinstance(input_data, list):
return [self.sanitize_input(item) for item in input_data]
return input_data
def sanitize_string(self, text):
"""清理字符串"""
# 移除潜在危险字符
import html
# HTML编码
sanitized = html.escape(text)
# 移除控制字符
sanitized = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', sanitized)
return sanitized
权限控制
工具权限管理
class ToolPermissionManager:
def __init__(self):
self.permissions = {}
self.user_roles = {}
self.tool_requirements = {}
def define_tool_permissions(self, tool_name, required_permissions):
"""定义工具所需权限"""
self.tool_requirements[tool_name] = required_permissions
def assign_user_role(self, user_id, role):
"""分配用户角色"""
self.user_roles[user_id] = role
def define_role_permissions(self, role, permissions):
"""定义角色权限"""
self.permissions[role] = permissions
def check_permission(self, user_id, tool_name):
"""检查用户是否有权限使用工具"""
# 获取用户角色
user_role = self.user_roles.get(user_id, "guest")
# 获取角色权限
user_permissions = self.permissions.get(user_role, [])
# 获取工具要求的权限
required_permissions = self.tool_requirements.get(tool_name, [])
# 检查权限
for required_permission in required_permissions:
if required_permission not in user_permissions:
return {
"allowed": False,
"missing_permission": required_permission,
"user_role": user_role
}
return {"allowed": True}
def get_allowed_tools(self, user_id):
"""获取用户可以使用的工具列表"""
user_role = self.user_roles.get(user_id, "guest")
user_permissions = self.permissions.get(user_role, [])
allowed_tools = []
for tool_name, required_permissions in self.tool_requirements.items():
if all(perm in user_permissions for perm in required_permissions):
allowed_tools.append(tool_name)
return allowed_tools
# 使用示例
def setup_permissions():
"""设置权限示例"""
perm_manager = ToolPermissionManager()
# 定义角色权限
perm_manager.define_role_permissions("admin", [
"read_files", "write_files", "execute_code",
"network_access", "database_access"
])
perm_manager.define_role_permissions("user", [
"read_files", "network_access"
])
perm_manager.define_role_permissions("guest", [])
# 定义工具权限要求
perm_manager.define_tool_permissions("read_file", ["read_files"])
perm_manager.define_tool_permissions("write_file", ["write_files"])
perm_manager.define_tool_permissions("database_query", ["database_access"])
perm_manager.define_tool_permissions("web_api_call", ["network_access"])
return perm_manager
最佳实践
工具设计原则
1. 单一职责原则
# 好的例子:单一功能工具
def define_temperature_converter():
return {
"name": "convert_temperature",
"description": "转换温度单位",
"input_schema": {
"type": "object",
"properties": {
"value": {"type": "number"},
"from_unit": {"type": "string", "enum": ["C", "F", "K"]},
"to_unit": {"type": "string", "enum": ["C", "F", "K"]}
},
"required": ["value", "from_unit", "to_unit"]
}
}
# 避免的例子:功能过于复杂的工具
def define_complex_tool():
return {
"name": "do_everything", # 不好的设计
"description": "执行各种操作:计算、网络请求、文件操作等",
# 过于复杂的schema...
}
2. 错误信息清晰化
def execute_with_clear_errors(func, **kwargs):
"""执行函数并提供清晰的错误信息"""
try:
result = func(**kwargs)
return result
except ValueError as e:
return {
"error": f"参数值错误: {str(e)}",
"error_type": "parameter_error",
"suggestion": "请检查输入参数的格式和范围",
"success": False
}
except KeyError as e:
return {
"error": f"缺少必需参数: {str(e)}",
"error_type": "missing_parameter",
"suggestion": f"请提供参数 {str(e)}",
"success": False
}
except Exception as e:
return {
"error": f"执行失败: {str(e)}",
"error_type": "execution_error",
"suggestion": "请检查输入或稍后重试",
"success": False
}
3. 性能优化
class OptimizedToolExecutor:
def __init__(self):
self.cache = {}
self.rate_limiter = {}
def execute_with_optimization(self, tool_name, **kwargs):
"""优化的工具执行"""
# 检查缓存
cache_key = self.generate_cache_key(tool_name, kwargs)
if cache_key in self.cache:
cache_result = self.cache[cache_key]
if not self.is_cache_expired(cache_result):
return cache_result["result"]
# 检查速率限制
if not self.check_rate_limit(tool_name):
return {
"error": "请求频率过高,请稍后重试",
"error_type": "rate_limit",
"success": False
}
# 执行工具
result = self.execute_tool(tool_name, **kwargs)
# 缓存结果
if result.get("success", False):
self.cache[cache_key] = {
"result": result,
"timestamp": time.time(),
"ttl": 300 # 5分钟缓存
}
return result
def generate_cache_key(self, tool_name, kwargs):
"""生成缓存键"""
import hashlib
import json
cache_data = {
"tool": tool_name,
"params": kwargs
}
cache_str = json.dumps(cache_data, sort_keys=True)
return hashlib.md5(cache_str.encode()).hexdigest()
def is_cache_expired(self, cache_entry):
"""检查缓存是否过期"""
return time.time() - cache_entry["timestamp"] > cache_entry["ttl"]
4. 监控和日志
class ToolMonitor:
def __init__(self):
self.metrics = {
"total_calls": 0,
"successful_calls": 0,
"failed_calls": 0,
"average_duration": 0,
"error_rates": {}
}
self.call_logs = []
def log_tool_call(self, tool_name, parameters, result, duration):
"""记录工具调用"""
# 更新指标
self.metrics["total_calls"] += 1
if result.get("success", False):
self.metrics["successful_calls"] += 1
else:
self.metrics["failed_calls"] += 1
error_type = result.get("error_type", "unknown")
self.metrics["error_rates"][error_type] = (
self.metrics["error_rates"].get(error_type, 0) + 1
)
# 更新平均持续时间
total_duration = (
self.metrics["average_duration"] * (self.metrics["total_calls"] - 1) +
duration
)
self.metrics["average_duration"] = total_duration / self.metrics["total_calls"]
# 记录详细日志
log_entry = {
"timestamp": time.time(),
"tool_name": tool_name,
"parameters": parameters,
"result": result,
"duration": duration,
"success": result.get("success", False)
}
self.call_logs.append(log_entry)
# 保持日志大小
if len(self.call_logs) > 1000:
self.call_logs = self.call_logs[-1000:]
def get_performance_report(self):
"""获取性能报告"""
if self.metrics["total_calls"] == 0:
return {"message": "暂无调用记录"}
success_rate = (
self.metrics["successful_calls"] / self.metrics["total_calls"] * 100
)
return {
"total_calls": self.metrics["total_calls"],
"success_rate": f"{success_rate:.2f}%",
"average_duration": f"{self.metrics['average_duration']:.3f}s",
"error_breakdown": self.metrics["error_rates"],
"most_common_errors": self.get_most_common_errors()
}
def get_most_common_errors(self):
"""获取最常见的错误"""
error_rates = self.metrics["error_rates"]
if not error_rates:
return []
sorted_errors = sorted(
error_rates.items(),
key=lambda x: x[1],
reverse=True
)
return sorted_errors[:5]
通过合理的工具集成和外部API调用,可以显著扩展Claude的能力边界,创建更加强大和实用的AI应用系统。