背景
spring-ai官方提供了多种方式来存储对话:
- InMemoryChatMemory:内存存储
- CassandraChatMemory:在 Cassandra 中带有过期时间的持久化存储
- Neo4jChatMemory:在 Neo4j 中没有过期时间限制的持久化存储
- JdbcChatMemory:在 JDBC 中没有过期时间限制的持久化存储
基于jdbc的持久化目前支持:
- PostgreSQL
- MySQL / MariaDB
- SQL Server
- HSQLDB
但是官方文档(https://docs.spring.io/spring-ai/reference/api/chat-memory.html)对于jdbc持久化方式的介绍草草带过😩,直接放到内存中,项目重启数据就会丢失,并且不断的往内存中存数据,后面可能会导致oom。
于是打算自己通过实现ChatMemory
来实现基于MySQL
的持久化机制。😏
配置
依赖配置
jdk用的是
21版本
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>3.5.0</version>
<relativePath/>
</parent>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>com.alibaba.cloud.ai</groupId>
<artifactId>spring-ai-alibaba-starter</artifactId>
<version>1.0.0-M5.1</version>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>8.0.32</version>
</dependency>
<dependency>
<groupId>com.baomidou</groupId>
<artifactId>mybatis-plus-spring-boot3-starter</artifactId>
<version>3.5.12</version>
</dependency>
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
<version>5.8.16</version>
</dependency>
❗️注意:
mybatis-plus
版本可能会和springboot
冲突,出现:Invalid value type for attribute 'factoryBeanObjectType'
,切换mybatis-plus
版本即可。
yml配置
api-key从阿里云百炼平台获取:
https://bailian.console.aliyun.com/?tab=model#/api-key
spring:
# ai相关配置
ai:
dashscope:
api-key: your-api-key
chat:
client:
enabled: false #禁止ChatClient.Builder的自动装配
# mysql连接配置
datasource:
url: jdbc:mysql://localhost:3306/super_ai_agent?characterEncoding=utf-8&serverTimezone=Asia/Shanghai
username: root
password: root
driver-class-name: com.mysql.cj.jdbc.Driver
# mybatis-plus配置
mybatis-plus:
configuration:
# 下划线转驼峰
map-underscore-to-camel-case: true
# 全局配置
global-config:
db-config:
# 数据库id配置
id-type: auto
logic-delete-field: is_del # 全局逻辑删除字段名
logic-delete-value: 1 # 逻辑已删除值。可选,默认值为 1
logic-not-delete-value: 0 # 逻辑未删除值。可选,默认值为 0
mapper-locations: classpath:/mapper/**.xml
库表设计
✨Tip:
spring-ai
默认生成的会话Id为defalut
,不是UUID。
CREATE TABLE `ai_chat_memory` (
`id` BIGINT NOT NULL AUTO_INCREMENT PRIMARY KEY COMMENT 'id',
`chat_id` varchar(32) NOT NULL COMMENT '会话id',
`type` varchar(10) NOT NULL DEFAULT 'user' COMMENT '消息类型',
`content` text NOT NULL COMMENT '消息内容',
`create_time` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
`update_time` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
`is_del` tinyint(1) NOT NULL DEFAULT '0' COMMENT '删除标记,0-未删除;1-已删除'
)
索引
👉我这里选择将单条消息作为数据库的一行数据,而不是单次会话,因此
chat_id
不是唯一的。
CREATE index idx_chat_id ON ai_chat_memory (chat_id);
InMySqlChatMemory
AiChatMemory
为实体类,AiChatMemoryService
为服务层,直接用mybatis-plus
插件生成就行。
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import jakarta.annotation.Resource;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.*;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.List;
/**
* 基于MySQL的ChatMemory实现
*/
@Component
public class InMySqlChatMemory implements ChatMemory {
@Resource
private AiChatMemoryService aiChatMemoryService;
@Override
public void add(String conversationId, List<Message> messages) {
List<AiChatMemory> aiChatMemorieList = new ArrayList<>();
messages.forEach(message -> {
AiChatMemory aiChatMemory = new AiChatMemory();
aiChatMemory.setChatId(conversationId);
aiChatMemory.setType(message.getMessageType().getValue());
aiChatMemory.setContent(message.getText());
aiChatMemorieList.add(aiChatMemory);
});
aiChatMemoryService.saveBatch(aiChatMemorieList);
}
@Override
public List<Message> get(String conversationId, int lastN) {
if (lastN >0){
List<AiChatMemory> aiChatMemoryList = aiChatMemoryService.list(new QueryWrapper<AiChatMemory>()
.eq("chat_id", conversationId)
.orderByDesc("create_time")
.last("limit " + lastN));
if (CollectionUtils.isEmpty(aiChatMemoryList)){
return List.of();
}
return aiChatMemoryList.stream()
.map(aiChatMemory -> {
String type = aiChatMemory.getType();
String content = aiChatMemory.getContent();
Message message;
return switch (type) {
case "system" -> message = new SystemMessage(content);
case "user" -> message = new UserMessage(content);
case "assistant" -> message = new AssistantMessage(content);
default -> throw new IllegalArgumentException("Unknown message type: " + type);
};
})
.toList();
}
return List.of();
}
@Override
public void clear(String conversationId) {
aiChatMemoryService.remove(new QueryWrapper<AiChatMemory>()
.eq(conversationId!=null,"chat_id",conversationId));
}
}
使用
ChatClient配置
👉
LOVE_PROMPT
为系统预设prompt
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
import jakarta.annotation.Resource;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
public class ChatClientConfig {
@Value("${spring.ai.dashscope.api-key}")
public String apiKey;
@Resource
ChatMemory InMySqlChatMemory;
@Bean
public ChatClient qwenPlusInMemoryChatClient(){
if (apiKey == null)
throw new RuntimeException("apiKey is null");
return ChatClient.builder(new DashScopeChatModel(new DashScopeApi(apiKey),
DashScopeChatOptions.builder().withModel("qwen-plus").build()
))
.defaultSystem(LOVE_PROMPT)
.defaultAdvisors(
//自定义持久化记忆advisor
new MessageChatMemoryAdvisor(InMySqlChatMemory)
)
.build();
}
}
测试
import jakarta.annotation.Nonnull;
import jakarta.annotation.Resource;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux;
@RestController
@RequestMapping("/love")
public class LoveDemoController {
@Resource
private ChatClient qwenPlusInMemoryChatClient;
@GetMapping("chat")
public String simpleChat(String message) {
return qwenPlusInMemoryChatClient.prompt()
.user(message)
.call().content();
}
}
运行结果
数据库:
可以看到,我们没有传
chatId
默认就是default👌