Spring AI之函数调用实战与原理分析
历史Spring AI文章
Spring AI Java程序员的AI之Spring AI(一)
一丶Spring AI 函数调用
定义工具函数Function
在Spring AI中,如果一个Bean实现了Function接口,那么它就是一个工具函数,并且通过@Description注解可以描述该工具的作用是什么,如果工具有需要接收参数,也可以通过@Schema注解来对参数进行定义,比如以下工具是用来获取指定地点的当前时间的,并且address参数用来接收具体的地点:
package com.qjc.demo.service;
import io.swagger.v3.oas.annotations.media.Schema;
import org.springframework.context.annotation.Description;
import org.springframework.stereotype.Component;
import java.time.LocalDateTime;
import java.util.function.Function;
/***
* @projectName spring-ai-demo
* @packageName com.qjc.demo.service
* @author qjc
* @description TODO
* @Email qjc1024@aliyun.com
* @date 2024-10-16 09:50
**/
@Component
@Description("获取指定地点的当前时间")
public class DateService implements Function<DateService.Request, DateService.Response> {
public record Request(@Schema(description = "地点") String address) { }
public record Response(String date) { }
@Override
public Response apply(Request request) {
System.out.println(request.address);
return new Response(String.format("%s的当前时间是%s", request.address, LocalDateTime.now()));
}
}
工具函数调用
当向大模型提问时,需要指定所要调用的工具函数,利用OpenAiChatOptions指定对应的beanName就可以了,比如:
@GetMapping("/function")
public String function(@RequestParam String message) {
Prompt prompt = new Prompt(message, OpenAiChatOptions.builder().withFunction("dateService").build());
Generation generation = chatClient.call(prompt).getResult();
return (generation != null) ? generation.getOutput().getContent() : "";
}
FunctionCallback工具函数
还可以直接在提问时直接定义并调用工具,比如:
@GetMapping("/functionCallback")
public String functionCallback(@RequestParam String message) {
Prompt prompt = new Prompt(message, OpenAiChatOptions.builder().withFunctionCallbacks(
List.of(FunctionCallbackWrapper.builder(new DateService())
.withName("dateService")
.withDescription("获取指定地点的当前时间").build())
).build());
Generation generation = chatClient.call(prompt).getResult();
return (generation != null) ? generation.getOutput().getContent() : "";
}
通过这种方式,就不需要将DateService定义为Bean了,当然这样定义的工具只能functionCallback接口单独使用了,而定义Bean则可以让多个接口共享使用。
不过有时候,大模型给你的答案或工具参数可能是英文的
那么可以使用SystemMessage来设置系统提示词,比如:
@GetMapping("/functionCallback")
public String functionCallback(@RequestParam String message) {
SystemMessage systemMessage = new SystemMessage("请用中文回答我");
UserMessage userMessage = new UserMessage(message);
Prompt prompt = new Prompt(List.of(systemMessage, userMessage), OpenAiChatOptions.builder().withFunctionCallbacks(
List.of(FunctionCallbackWrapper.builder(new DateService())
.withName("dateService")
.withDescription("获取指定地点的当前时间").build())
).build());
Generation generation = chatClient.call(prompt).getResult();
return (generation != null) ? generation.getOutput().getContent() : "";
}
这样就能控制答案了。
二丶 Spring AI 函数调用源码解析
在OpenAiChatClient的call()方法中,会进行:
- 请求的处理
- 工具的调用
- 响应的处理
- 重试机制
比如call()方法的大体代码为:
@Override
public ChatResponse call(Prompt prompt) {
// 请求处理
ChatCompletionRequest request = createRequest(prompt, false);
// 重试机制
return this.retryTemplate.execute(ctx -> {
// 请求调用
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request);
// 返回响应
return new ChatResponse(...);
});
}
请求处理
请求处理核心是把Prompt对象转换成ChatCompletionRequest对象,包括Prompt中设置的SystemMessage、UserMessage和工具函数。
如果采用Bean的方式来使用工具函数,其底层其实对应的仍然是FunctionCallback,在OpenAiAutoConfiguration自动配置中,定义了一个FunctionCallbackContext的Bean,该Bean提供了一个getFunctionCallback()方法,用来生成beanName对应的FunctionCallback对象,源码为:
public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable String defaultDescription) {
// 获取Bean类型
Type beanType = FunctionContextUtils.findType(this.applicationContext.getBeanFactory(), beanName);
if (beanType == null) {
throw new IllegalArgumentException(
"Functional bean with name: " + beanName + " does not exist in the context.");
}
// Bean类型必须是Function类型
if (!Function.class.isAssignableFrom(FunctionTypeUtils.getRawType(beanType))) {
throw new IllegalArgumentException(
"Function call Bean must be of type Function. Found: " + beanType.getTypeName());
}
// 获取Function的第一个泛型的类型,比如Request
Type functionInputType = TypeResolverHelper.getFunctionArgumentType(beanType, 0);
Class<?> functionInputClass = FunctionTypeUtils.getRawType(functionInputType);
String functionName = beanName;
String functionDescription = defaultDescription;
if (!StringUtils.hasText(functionDescription)) {
// 获取@Description设置的描述信息
// Look for a Description annotation on the bean
Description descriptionAnnotation = applicationContext.findAnnotationOnBean(beanName, Description.class);
if (descriptionAnnotation != null) {
functionDescription = descriptionAnnotation.value();
}
// 获取Request参数前的@JsonClassDescription设置的描述信息
if (!StringUtils.hasText(functionDescription)) {
// Look for a JsonClassDescription annotation on the input class
JsonClassDescription jsonClassDescriptionAnnotation = functionInputClass
.getAnnotation(JsonClassDescription.class);
if (jsonClassDescriptionAnnotation != null) {
functionDescription = jsonClassDescriptionAnnotation.value();
}
}
if (!StringUtils.hasText(functionDescription)) {
throw new IllegalStateException("Could not determine function description."
+ "Please provide a description either as a default parameter, via @Description annotation on the bean "
+ "or @JsonClassDescription annotation on the input class.");
}
}
// 获取Bean对象
Object bean = this.applicationContext.getBean(beanName);
// 构建为FunctionCallback对象
if (bean instanceof Function<?, ?> function) {
return FunctionCallbackWrapper.builder(function)
.withName(functionName)
.withSchemaType(this.schemaType)
.withDescription(functionDescription)
.withInputType(functionInputClass)
.build();
}
else {
throw new IllegalArgumentException("Bean must be of type Function");
}
}
以上代码的核心逻辑为:
- 获取Bean类型
- 获取Function的第一个泛型的类型,比如Request
- 获取@Description设置的描述信息
- 构造FunctionCallback对象
在OpenAiChatClient就会注入FunctionCallbackContext这个Bean对象,从而使得OpenAiChatClient可以通过Prompt中指定的beanName获取到对应的FunctionCallback对象。
所以,在createRequest()方法中,就可以得到从FunctionCallbackContext找到的或者直接在Prompt对象中设置的FunctionCallback对象,然后将FunctionCallback对象转成OpenAiApi.FunctionTool对象,最终将FunctionTool设置到ChatCompletionRequest中。
请求调用
请求调用源码如下:
protected Resp callWithFunctionSupport(Req request) {
Resp response = this.doChatCompletion(request);
return this.handleFunctionCallOrReturn(request, response);
}
- 先发送请求得到响应
- 解析响应是否需要调用工具还是直接返回
doChatCompletion()方法比较简单,就是直接把请求发送给OpenAi,重要的是handleFunctionCallOrReturn()方法。
handleFunctionCallOrReturn()方法需要解析响应,比如判断OpenAi返回的响应中是否需要调用工具,比如:
if (!this.isToolFunctionCall(response)) {
return response;
}
OpenAi中,如果一个响应的finishReason为TOOL_CALLS则表示,当前响应其实是OpenAi的一个工具调用请求。
然后就去执行工具:
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
ChatCompletionMessage responseMessage, List<ChatCompletionMessage> conversationHistory) {
// 遍历每个要调用的工具
// Every tool-call item requires a separate function call and a response (TOOL)
// message.
for (ToolCall toolCall : responseMessage.toolCalls()) {
// 工具名和参数
var functionName = toolCall.function().name();
String functionArguments = toolCall.function().arguments();
if (!this.functionCallbackRegister.containsKey(functionName)) {
throw new IllegalStateException("No function callback found for function name: " + functionName);
}
// 找到FunctionCallback并进行调用,得到工具执行结果
String functionResponse = this.functionCallbackRegister.get(functionName)
.call(functionArguments);
// 将工具执行结果添加到对话历史
// Add the function response to the conversation.
conversationHistory
.add(new ChatCompletionMessage(functionResponse, Role.TOOL, functionName, toolCall.id(), null));
}
// 构造新的请求,将工具执行结果传递给OpenAi
// Recursively call chatCompletionWithTools until the model doesn't call a
// functions anymore.
ChatCompletionRequest newRequest = new ChatCompletionRequest(conversationHistory, false);
newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ChatCompletionRequest.class);
return newRequest;
}
以上源码的核心流程为:
- 遍历每个要调用的工具
- 根据工具名找到FunctionCallback并进行调用,得到工具执行结果
- 将工具执行结果添加到对话历史
- 构造新的请求,将工具执行结果传递给OpenAi
得到新的请求对象后,又会调用callWithFunctionSupport()方法,所以这里出现了递归调用。
函数调用
当调用FunctionCallback的call方法时,就是在执行函数调用:
@Override
public String call(String functionArguments) {
// 将OpenAi给的请求参数转成指定类,比如Request
// Convert the tool calls JSON arguments into a Java function request object.
I request = fromJson(functionArguments, inputType);
// 然后执行apply方法
// extend conversation with function response.
return this.andThen(this.responseConverter).apply(request);
}
从这里可以发现,对于工具执行结果,还可以设置responseConverter来进行处理,比如:
@GetMapping("/functionCallback")
public String functionCallback(@RequestParam String message) {
SystemMessage systemMessage = new SystemMessage("请用中文回答我");
UserMessage userMessage = new UserMessage(message);
Prompt prompt = new Prompt(List.of(systemMessage, userMessage), OpenAiChatOptions.builder().withFunctionCallbacks(
List.of(FunctionCallbackWrapper.builder(new DateService())
.withName("dateService")
.withDescription("获取指定地点的当前时间")
.withResponseConverter(response -> "2024年10月16日09:22")
.build())
).build());
Generation generation = chatClient.call(prompt).getResult();
return (generation != null) ? generation.getOutput().getContent() : "";
}
这样做,最终函数执行结果被我固定成了"2024年10月16日09:22,因此最终OpenAi给我答案也是
交互流程图
为什么OpenAiChatClient不在本地先直接执行工具,然后再请求OpenAiServer呢?
以上场景比较简单,实际上的思想是:把OpenAi当做一个大脑,通过第一次请求告诉OpenAi我的需求任务,以及我们提供了哪些工具,然后由OpenAi:
- 先理解任务
- 然后制定策略,也就是OpenAi要完成任务,需要调用哪些工具,并且调用这些工具的具体参数是什么,调用工具的顺序是什么,这些都由OpenAi来进行分析
- 然后OpenAi就向OpenAiChatClient发送工具调用请求,并得到工具执行结果
- 然后OpenAi再基于任务和工具执行结果进行分析,看是否能完成任务了,还是需要继续调用工具。
- 如果能完成任务了,那就直接把任务的执行结果返回给OpenAiChatClient。
三丶 案例
需求:获取今天注册的新用户信息。
定义获取当前时间工具:
package com.qjc.demo.service;
import io.swagger.v3.oas.annotations.media.Schema;
import org.springframework.context.annotation.Description;
import org.springframework.stereotype.Component;
import java.time.LocalDateTime;
import java.util.function.Function;
/***
* @projectName spring-ai-demo
* @packageName com.qjc.demo.service
* @author qjc
* @description TODO
* @Email qjc1024@aliyun.com
* @date 2024-10-16 10:01
**/
@Component
@Description("获取当前时间")
public class DateService implements Function<DateService.Request, String> {
public record Request(String noUse) { }
@Override
public String apply(Request request) {
System.out.println("执行DateService工具");
return LocalDateTime.now().toString();
}
}
定义获取用户信息服务:
package com.qjc.demo.service;
import org.springframework.context.annotation.Description;
import org.springframework.stereotype.Component;
import java.util.List;
import java.util.function.Function;
/***
* @projectName spring-ai-demo
* @packageName com.qjc.demo.service
* @author qjc
* @description TODO
* @Email qjc1024@aliyun.com
* @date 2024-10-16 10:05
**/
@Component
@Description("获取指定时间的注册用户")
public class UserService implements Function<UserService.Request, List<UserService.User>> {
public record Request(String date) { }
@Override
public List<User> apply(Request request) {
System.out.println("执行OrderService工具, 入参为:" + request.date);
return List.of(new User("小齐", "2024年10月16号"), new User("宇将军", "2024年10月16号"));
}
class User {
private String username;
private String registrationDate;
public User(String username, String registrationDate) {
this.username = username;
this.registrationDate = registrationDate;
}
public String getUsername() {
return username;
}
public void setUsername(String username) {
this.username = username;
}
public String getRegistrationDate() {
return registrationDate;
}
public void setRegistrationDate(String registrationDate) {
this.registrationDate = registrationDate;
}
}
}
定义请求接口:
@GetMapping("/user")
public String user(@RequestParam String message) {
SystemMessage systemMessage = new SystemMessage("将结果按JSON格式返回");
UserMessage userMessage = new UserMessage(message);
Prompt prompt = new Prompt(List.of(systemMessage, userMessage), OpenAiChatOptions.builder()
.withFunctions(Set.of("dateService", "userService"))
.build());
Generation generation = chatClient.call(prompt).getResult();
return (generation != null) ? generation.getOutput().getContent() : "";
}
总结
我感觉非常爽啊,我的大Spring 函数调用,没有冗余代码。