Spring AI Java程序员的AI之Spring AI(二)

发布于:2024-10-17 ⋅ 阅读:(39) ⋅ 点赞:(0)

历史Spring AI文章

Spring AI Java程序员的AI之Spring AI(一)

一丶Spring AI 函数调用

定义工具函数Function

在Spring AI中,如果一个Bean实现了Function接口,那么它就是一个工具函数,并且通过@Description注解可以描述该工具的作用是什么,如果工具有需要接收参数,也可以通过@Schema注解来对参数进行定义,比如以下工具是用来获取指定地点的当前时间的,并且address参数用来接收具体的地点:

package com.qjc.demo.service;

import io.swagger.v3.oas.annotations.media.Schema;
import org.springframework.context.annotation.Description;
import org.springframework.stereotype.Component;

import java.time.LocalDateTime;
import java.util.function.Function;

/***
 * @projectName spring-ai-demo
 * @packageName com.qjc.demo.service
 * @author qjc
 * @description TODO
 * @Email qjc1024@aliyun.com
 * @date 2024-10-16 09:50
 **/
@Component
@Description("获取指定地点的当前时间")
public class DateService implements Function<DateService.Request, DateService.Response> {

    public record Request(@Schema(description = "地点") String address) { }

    public record Response(String date) { }

    @Override
    public Response apply(Request request) {
        System.out.println(request.address);
        return new Response(String.format("%s的当前时间是%s", request.address, LocalDateTime.now()));
    }
}

工具函数调用

当向大模型提问时,需要指定所要调用的工具函数,利用OpenAiChatOptions指定对应的beanName就可以了,比如:

@GetMapping("/function")
public String function(@RequestParam String message) {
    Prompt prompt = new Prompt(message, OpenAiChatOptions.builder().withFunction("dateService").build());
    Generation generation = chatClient.call(prompt).getResult();
    return (generation != null) ? generation.getOutput().getContent() : "";
}

FunctionCallback工具函数

还可以直接在提问时直接定义并调用工具,比如:

@GetMapping("/functionCallback")
public String functionCallback(@RequestParam String message) {
    Prompt prompt = new Prompt(message, OpenAiChatOptions.builder().withFunctionCallbacks(
        List.of(FunctionCallbackWrapper.builder(new DateService())
                .withName("dateService")
                .withDescription("获取指定地点的当前时间").build())
    ).build());
    Generation generation = chatClient.call(prompt).getResult();
    return (generation != null) ? generation.getOutput().getContent() : "";
}

通过这种方式,就不需要将DateService定义为Bean了,当然这样定义的工具只能functionCallback接口单独使用了,而定义Bean则可以让多个接口共享使用。

不过有时候,大模型给你的答案或工具参数可能是英文的

那么可以使用SystemMessage来设置系统提示词,比如:

@GetMapping("/functionCallback")
public String functionCallback(@RequestParam String message) {

    SystemMessage systemMessage = new SystemMessage("请用中文回答我");
    UserMessage userMessage = new UserMessage(message);

    Prompt prompt = new Prompt(List.of(systemMessage, userMessage), OpenAiChatOptions.builder().withFunctionCallbacks(
        List.of(FunctionCallbackWrapper.builder(new DateService())
                .withName("dateService")
                .withDescription("获取指定地点的当前时间").build())
    ).build());
    Generation generation = chatClient.call(prompt).getResult();
    return (generation != null) ? generation.getOutput().getContent() : "";
}

这样就能控制答案了。

二丶 Spring AI 函数调用源码解析

在OpenAiChatClient的call()方法中,会进行:

  1. 请求的处理
  2. 工具的调用
  3. 响应的处理
  4. 重试机制

比如call()方法的大体代码为:

@Override
public ChatResponse call(Prompt prompt) {
    // 请求处理
    ChatCompletionRequest request = createRequest(prompt, false);

    // 重试机制
    return this.retryTemplate.execute(ctx -> {
    
        // 请求调用
        ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request);
    
        // 返回响应
        return new ChatResponse(...);
    });
}

请求处理

请求处理核心是把Prompt对象转换成ChatCompletionRequest对象,包括Prompt中设置的SystemMessage、UserMessage和工具函数。

如果采用Bean的方式来使用工具函数,其底层其实对应的仍然是FunctionCallback,在OpenAiAutoConfiguration自动配置中,定义了一个FunctionCallbackContext的Bean,该Bean提供了一个getFunctionCallback()方法,用来生成beanName对应的FunctionCallback对象,源码为:

public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable String defaultDescription) {

    // 获取Bean类型
    Type beanType = FunctionContextUtils.findType(this.applicationContext.getBeanFactory(), beanName);

    if (beanType == null) {
        throw new IllegalArgumentException(
            "Functional bean with name: " + beanName + " does not exist in the context.");
    }

    // Bean类型必须是Function类型
    if (!Function.class.isAssignableFrom(FunctionTypeUtils.getRawType(beanType))) {
        throw new IllegalArgumentException(
            "Function call Bean must be of type Function. Found: " + beanType.getTypeName());
    }

    // 获取Function的第一个泛型的类型,比如Request
    Type functionInputType = TypeResolverHelper.getFunctionArgumentType(beanType, 0);
    Class<?> functionInputClass = FunctionTypeUtils.getRawType(functionInputType);
   
    String functionName = beanName;
    String functionDescription = defaultDescription;

    if (!StringUtils.hasText(functionDescription)) {

        // 获取@Description设置的描述信息
        // Look for a Description annotation on the bean
        Description descriptionAnnotation = applicationContext.findAnnotationOnBean(beanName, Description.class);

        if (descriptionAnnotation != null) {
            functionDescription = descriptionAnnotation.value();
        }

        // 获取Request参数前的@JsonClassDescription设置的描述信息
        if (!StringUtils.hasText(functionDescription)) {
            // Look for a JsonClassDescription annotation on the input class
            JsonClassDescription jsonClassDescriptionAnnotation = functionInputClass
            .getAnnotation(JsonClassDescription.class);
            if (jsonClassDescriptionAnnotation != null) {
                functionDescription = jsonClassDescriptionAnnotation.value();
            }
        }

        if (!StringUtils.hasText(functionDescription)) {
            throw new IllegalStateException("Could not determine function description."
                                            + "Please provide a description either as a default parameter, via @Description annotation on the bean "
                                            + "or @JsonClassDescription annotation on the input class.");
        }
    }

    // 获取Bean对象
    Object bean = this.applicationContext.getBean(beanName);

    // 构建为FunctionCallback对象
    if (bean instanceof Function<?, ?> function) {
        return FunctionCallbackWrapper.builder(function)
        .withName(functionName)
        .withSchemaType(this.schemaType)
        .withDescription(functionDescription)
        .withInputType(functionInputClass)
        .build();
    }
    else {
        throw new IllegalArgumentException("Bean must be of type Function");
    }
}

以上代码的核心逻辑为:

  1. 获取Bean类型
  2. 获取Function的第一个泛型的类型,比如Request
  3. 获取@Description设置的描述信息
  4. 构造FunctionCallback对象

在OpenAiChatClient就会注入FunctionCallbackContext这个Bean对象,从而使得OpenAiChatClient可以通过Prompt中指定的beanName获取到对应的FunctionCallback对象。

所以,在createRequest()方法中,就可以得到从FunctionCallbackContext找到的或者直接在Prompt对象中设置的FunctionCallback对象,然后将FunctionCallback对象转成OpenAiApi.FunctionTool对象,最终将FunctionTool设置到ChatCompletionRequest中。

请求调用

请求调用源码如下:

protected Resp callWithFunctionSupport(Req request) {
    Resp response = this.doChatCompletion(request);
    return this.handleFunctionCallOrReturn(request, response);
}
  1. 先发送请求得到响应
  2. 解析响应是否需要调用工具还是直接返回

doChatCompletion()方法比较简单,就是直接把请求发送给OpenAi,重要的是handleFunctionCallOrReturn()方法。

handleFunctionCallOrReturn()方法需要解析响应,比如判断OpenAi返回的响应中是否需要调用工具,比如:

if (!this.isToolFunctionCall(response)) {
    return response;
}

OpenAi中,如果一个响应的finishReason为TOOL_CALLS则表示,当前响应其实是OpenAi的一个工具调用请求。

然后就去执行工具:

protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
                                                            ChatCompletionMessage responseMessage, List<ChatCompletionMessage> conversationHistory) {

    // 遍历每个要调用的工具
    // Every tool-call item requires a separate function call and a response (TOOL)
    // message.
    for (ToolCall toolCall : responseMessage.toolCalls()) {

        // 工具名和参数
        var functionName = toolCall.function().name();
        String functionArguments = toolCall.function().arguments();

        if (!this.functionCallbackRegister.containsKey(functionName)) {
            throw new IllegalStateException("No function callback found for function name: " + functionName);
        }

        // 找到FunctionCallback并进行调用,得到工具执行结果
        String functionResponse = this.functionCallbackRegister.get(functionName)
            .call(functionArguments);

        // 将工具执行结果添加到对话历史
        // Add the function response to the conversation.
        conversationHistory
        .add(new ChatCompletionMessage(functionResponse, Role.TOOL, functionName, toolCall.id(), null));
    }

    // 构造新的请求,将工具执行结果传递给OpenAi
    // Recursively call chatCompletionWithTools until the model doesn't call a
    // functions anymore.
    ChatCompletionRequest newRequest = new ChatCompletionRequest(conversationHistory, false);
    newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ChatCompletionRequest.class);

    return newRequest;
}

以上源码的核心流程为:

  1. 遍历每个要调用的工具
  2. 根据工具名找到FunctionCallback并进行调用,得到工具执行结果
  3. 将工具执行结果添加到对话历史
  4. 构造新的请求,将工具执行结果传递给OpenAi

得到新的请求对象后,又会调用callWithFunctionSupport()方法,所以这里出现了递归调用。

函数调用

当调用FunctionCallback的call方法时,就是在执行函数调用:

@Override
public String call(String functionArguments) {

    // 将OpenAi给的请求参数转成指定类,比如Request
    // Convert the tool calls JSON arguments into a Java function request object.
    I request = fromJson(functionArguments, inputType);

    // 然后执行apply方法
    // extend conversation with function response.
    return this.andThen(this.responseConverter).apply(request);
}

从这里可以发现,对于工具执行结果,还可以设置responseConverter来进行处理,比如:

@GetMapping("/functionCallback")
public String functionCallback(@RequestParam String message) {

    SystemMessage systemMessage = new SystemMessage("请用中文回答我");
    UserMessage userMessage = new UserMessage(message);

    Prompt prompt = new Prompt(List.of(systemMessage, userMessage), OpenAiChatOptions.builder().withFunctionCallbacks(
        List.of(FunctionCallbackWrapper.builder(new DateService())
                .withName("dateService")
                .withDescription("获取指定地点的当前时间")
                .withResponseConverter(response -> "2024年10月16日09:22")
                .build())
    ).build());
    Generation generation = chatClient.call(prompt).getResult();
    return (generation != null) ? generation.getOutput().getContent() : "";
}

这样做,最终函数执行结果被我固定成了"2024年10月16日09:22,因此最终OpenAi给我答案也是

交互流程图

在这里插入图片描述
为什么OpenAiChatClient不在本地先直接执行工具,然后再请求OpenAiServer呢?

以上场景比较简单,实际上的思想是:把OpenAi当做一个大脑,通过第一次请求告诉OpenAi我的需求任务,以及我们提供了哪些工具,然后由OpenAi:

  1. 先理解任务
  2. 然后制定策略,也就是OpenAi要完成任务,需要调用哪些工具,并且调用这些工具的具体参数是什么,调用工具的顺序是什么,这些都由OpenAi来进行分析
  3. 然后OpenAi就向OpenAiChatClient发送工具调用请求,并得到工具执行结果
  4. 然后OpenAi再基于任务和工具执行结果进行分析,看是否能完成任务了,还是需要继续调用工具。
  5. 如果能完成任务了,那就直接把任务的执行结果返回给OpenAiChatClient。

三丶 案例

需求:获取今天注册的新用户信息。

定义获取当前时间工具:

package com.qjc.demo.service;

import io.swagger.v3.oas.annotations.media.Schema;
import org.springframework.context.annotation.Description;
import org.springframework.stereotype.Component;

import java.time.LocalDateTime;
import java.util.function.Function;

/***
 * @projectName spring-ai-demo
 * @packageName com.qjc.demo.service
 * @author qjc
 * @description TODO
 * @Email qjc1024@aliyun.com
 * @date 2024-10-16 10:01
 **/
@Component
@Description("获取当前时间")
public class DateService implements Function<DateService.Request, String> {

    public record Request(String noUse) { }

    @Override
    public String apply(Request request) {
        System.out.println("执行DateService工具");
        return LocalDateTime.now().toString();
    }
}

定义获取用户信息服务:

package com.qjc.demo.service;

import org.springframework.context.annotation.Description;
import org.springframework.stereotype.Component;

import java.util.List;
import java.util.function.Function;

/***
 * @projectName spring-ai-demo
 * @packageName com.qjc.demo.service
 * @author qjc
 * @description TODO
 * @Email qjc1024@aliyun.com
 * @date 2024-10-16 10:05
 **/
@Component
@Description("获取指定时间的注册用户")
public class UserService implements Function<UserService.Request, List<UserService.User>> {

    public record Request(String date) { }


    @Override
    public List<User> apply(Request request) {
        System.out.println("执行OrderService工具, 入参为:" + request.date);
        return List.of(new User("小齐", "2024年10月16号"), new User("宇将军", "2024年10月16号"));
    }

    class User {
        private String username;
        private String registrationDate;

        public User(String username, String registrationDate) {
            this.username = username;
            this.registrationDate = registrationDate;
        }

        public String getUsername() {
            return username;
        }

        public void setUsername(String username) {
            this.username = username;
        }

        public String getRegistrationDate() {
            return registrationDate;
        }

        public void setRegistrationDate(String registrationDate) {
            this.registrationDate = registrationDate;
        }
    }


}

定义请求接口:

@GetMapping("/user")
public String user(@RequestParam String message) {
    SystemMessage systemMessage = new SystemMessage("将结果按JSON格式返回");
    UserMessage userMessage = new UserMessage(message);

    Prompt prompt = new Prompt(List.of(systemMessage, userMessage), OpenAiChatOptions.builder()
                               .withFunctions(Set.of("dateService", "userService"))
                               .build());
    Generation generation = chatClient.call(prompt).getResult();
    return (generation != null) ? generation.getOutput().getContent() : "";
}

总结

我感觉非常爽啊,我的大Spring 函数调用,没有冗余代码。