原课程代码是用Anthropic写的,下面代码是用OpenAI改写的,模型则用阿里巴巴的模型做测试
.env 文件为:
OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
OPENAI_API_BASE=https://dashscope.aliyuncs.com/compatible-mode/v1
完整代码
import arxiv
import json
import os
from typing import List
from dotenv import load_dotenv
import openai
PAPER_DIR = "papers"
def search_papers(topic: str, max_results: int = 5) -> List[str]:
"""
Search for papers on arXiv based on a topic and store their information.
Args:
topic: The topic to search for
max_results: Maximum number of results to retrieve (default: 5)
Returns:
List of paper IDs found in the search
"""
# Use arxiv to find the papers
client = arxiv.Client()
# Search for the most relevant articles matching the queried topic
search = arxiv.Search(
query = topic,
max_results = max_results,
sort_by = arxiv.SortCriterion.Relevance
)
papers = client.results(search)
# Create directory for this topic
path = os.path.join(PAPER_DIR, topic.lower().replace(" ", "_"))
os.makedirs(path, exist_ok=True)
file_path = os.path.join(path, "papers_info.json")
# Try to load existing papers info
try:
with open(file_path, "r") as json_file:
papers_info = json.load(json_file)
except (FileNotFoundError, json.JSONDecodeError):
papers_info = {}
# Process each paper and add to papers_info
paper_ids = []
for paper in papers:
paper_ids.append(paper.get_short_id())
paper_info = {
'title': paper.title,
'authors': [author.name for author in paper.authors],
'summary': paper.summary,
'pdf_url': paper.pdf_url,
'published': str(paper.published.date())
}
papers_info[paper.get_short_id()] = paper_info
# Save updated papers_info to json file
with open(file_path, "w") as json_file:
json.dump(papers_info, json_file, indent=2)
print(f"Results are saved in: {file_path}")
return paper_ids
def extract_info(paper_id: str) -> str:
"""
Search for information about a specific paper across all topic directories.
Args:
paper_id: The ID of the paper to look for
Returns:
JSON string with paper information if found, error message if not found
"""
for item in os.listdir(PAPER_DIR):
item_path = os.path.join(PAPER_DIR, item)
if os.path.isdir(item_path):
file_path = os.path.join(item_path, "papers_info.json")
if os.path.isfile(file_path):
try:
with open(file_path, "r") as json_file:
papers_info = json.load(json_file)
if paper_id in papers_info:
return json.dumps(papers_info[paper_id], indent=2)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Error reading {file_path}: {str(e)}")
continue
return f"There's no saved information related to paper {paper_id}."
tools = [
{
"type": "function",
"function": {
"name": "search_papers",
"description": "Search for papers on arXiv based on a topic and store their information",
"parameters": {
"type": "object",
"properties": {
"topic": {
"type": "string",
"description": "The topic to search for"
},
"max_results": {
"type": "integer",
"description": "Maximum number of results to retrieve",
"default": 5
}
},
"required": ["topic"]
}
}
},
{
"type": "function",
"function": {
"name": "extract_info",
"description": "Search for information about a specific paper across all topic directories",
"parameters": {
"type": "object",
"properties": {
"paper_id": {
"type": "string",
"description": "The ID of the paper to look for"
}
},
"required": ["paper_id"]
}
}
}
]
mapping_tool_function = {
"search_papers": search_papers,
"extract_info": extract_info
}
def execute_tool(tool_name, tool_args):
result = mapping_tool_function[tool_name](**tool_args)
if result is None:
result = "The operation completed but didn't return any results."
elif isinstance(result, list):
result = ', '.join(result)
elif isinstance(result, dict):
# Convert dictionaries to formatted JSON strings
result = json.dumps(result, indent=2)
else:
# For any other type, convert using str()
result = str(result)
return result
load_dotenv()
client = openai.OpenAI(
api_key = os.getenv("OPENAI_API_KEY"),
base_url= os.getenv("OPENAI_API_BASE")
)
def process_query(query):
messages = [{"role": "user", "content": query}]
response = client.chat.completions.create(
model="qwen-turbo", # 或其他OpenAI模型
max_tokens=2024,
tools=tools,
messages=messages
)
process_query = True
while process_query:
# 获取助手的回复
message = response.choices[0].message
# 检查是否有普通文本内容
if message.content:
print(message.content)
process_query = False
# 检查是否有工具调用
elif message.tool_calls:
# 添加助手消息到历史
messages.append({
"role": "assistant",
"content": None,
"tool_calls": message.tool_calls
})
# 处理每个工具调用
for tool_call in message.tool_calls:
tool_id = tool_call.id
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
print(f"Calling tool {tool_name} with args {tool_args}")
# 执行工具调用
result = execute_tool(tool_name, tool_args)
# 添加工具结果到消息历史
messages.append({
"role": "tool",
"tool_call_id": tool_id,
"content": result
})
# 获取下一个回复
response = client.chat.completions.create(
model="qwen-turbo", # 或其他OpenAI模型
max_tokens=2024,
tools=tools,
messages=messages
)
# 如果只有文本回复,则结束处理
if response.choices[0].message.content and not response.choices[0].message.tool_calls:
print(response.choices[0].message.content)
process_query = False
def chat_loop():
print("Type your queries or 'quit' to exit.")
while True:
try:
query = input("\nQuery: ").strip()
if query.lower() == 'quit':
break
process_query(query)
print("\n")
except Exception as e:
print(f"\nError: {str(e)}")
if __name__ == "__main__":
chat_loop()
代码解释
导入模块
import arxiv # 用于访问arXiv API搜索论文
import json # 处理JSON数据
import os # 操作系统功能,如文件路径处理
from typing import List # 类型提示
from dotenv import load_dotenv # 加载环境变量
import openai # OpenAI API客户端
核心功能函数
1. search_papers 函数
这个函数用于在arXiv上搜索特定主题的论文并保存信息:
def search_papers(topic: str, max_results: int = 5) -> List[str]:
- 参数:
topic
: 要搜索的主题max_results
: 最大结果数量(默认5个)
- 返回值:找到的论文ID列表
功能流程:
- 创建arXiv客户端
- 按相关性搜索主题相关论文
- 为该主题创建目录(如
papers/machine_learning
) - 尝试加载已有的论文信息(如果存在)
- 处理每篇论文,提取标题、作者、摘要等信息
- 将论文信息保存到JSON文件中
- 返回论文ID列表
2. extract_info 函数
这个函数用于在所有主题目录中搜索特定论文的信息:
def extract_info(paper_id: str) -> str:
- 参数:
paper_id
- 要查找的论文ID - 返回值:包含论文信息的JSON字符串(如果找到),否则返回错误信息
功能流程:
- 遍历
papers
目录下的所有子目录 - 在每个子目录中查找
papers_info.json
文件 - 如果找到文件,检查是否包含指定的论文ID
- 如果找到论文信息,返回格式化的JSON字符串
- 如果未找到,返回未找到的提示信息
工具定义
tools = [...]
定义了两个函数工具,用于OpenAI API的工具调用:
search_papers
- 搜索论文extract_info
- 提取论文信息
每个工具都定义了名称、描述和参数规范。
工具执行函数
def execute_tool(tool_name, tool_args):
这个函数负责执行指定的工具函数,并处理返回结果:
- 将None结果转换为提示信息
- 将列表结果转换为逗号分隔的字符串
- 将字典结果转换为格式化的JSON字符串
- 其他类型转换为字符串
OpenAI客户端初始化
load_dotenv()
client = openai.OpenAI(
api_key = os.getenv("OPENAI_API_KEY"),
base_url= os.getenv("OPENAI_API_BASE")
)
从环境变量加载API密钥和基础URL,初始化OpenAI客户端。
查询处理函数
def process_query(query):
这个函数处理用户的查询:
- 创建包含用户查询的消息列表
- 调用OpenAI API创建聊天完成
- 处理助手的回复:
- 如果有普通文本内容,直接打印
- 如果有工具调用,执行工具并将结果添加到消息历史
- 如果执行了工具调用,获取下一个回复
- 如果最终回复只有文本,打印并结束处理
聊天循环函数
def chat_loop():
这个函数实现了一个简单的聊天循环:
- 提示用户输入查询或输入’quit’退出
- 处理用户的查询
- 捕获并显示任何错误
主程序
if __name__ == "__main__":
chat_loop()
当脚本直接运行时,启动聊天循环。
总结
这个脚本实现了一个基于OpenAI API的聊天机器人,它可以:
- 搜索arXiv上的论文并保存信息
- 提取已保存的论文信息
- 通过OpenAI API处理用户查询
- 支持工具调用功能,实现与arXiv的交互
运行示例
目录结构
运行结果