摘要:”大模型的对话记忆”这一概念,根植于人工智能与自然语言处理领域,特别是针对具有深度学习能力的大型语言模型而言,它指的是模型在与用户进行交互式对话过程中,能够追踪、理解并利用先前对话上下文的能力。 此机制使得大模型不仅能够响应即时的输入请求,还能基于之前的交流内
”大模型的对话记忆”这一概念,根植于人工智能与自然语言处理领域,特别是针对具有深度学习能力的大型语言模型而言,它指的是模型在与用户进行交互式对话过程中,能够追踪、理解并利用先前对话上下文的能力。 此机制使得大模型不仅能够响应即时的输入请求,还能基于之前的交流内容能够在对话中记住先前的对话内容,并根据这些信息进行后续的响应。这种记忆机制使得模型能够在对话中持续跟踪和理解用户的意图和上下文,从而实现更自然和连贯的对话。
举个栗子:
如果大模型没有记忆,会发生如下情况:
第一轮对话:
用户:我是张三,记住我了吗?
大模型:好的,你是张三,我记住了
第二轮对话:
用户:我是谁?
大模型:可以补充更多背景信息让我知道你是谁
可以看到,如果大模型没有记忆,就无法知道此前跟用户有过什么交流、用户当前提问的背景是什么,也就没法跟用户正常对话了。
给大模型加上记忆,会发生什么呢?
第一轮对话:
用户:我是张三,记住我了吗?
大模型:好的,你是张三,我记住了
第二轮对话:
用户:我是谁?
大模型:你是张三,我还记得你呦
大模型就可以非常清楚的理解用户提问意图,可以从之前的对话中寻找答案。
其实大模型本身并不会去记录与用户的对话信息,之所以大模型看起来好像有记忆一样,是因为每次去跟大模型交流的时候,程序会把与大模型对话的所有内容都告诉大模型,这样大模型就知道之前与大模型对话的所有数据了。就可以根据当前用户的对话内容,适时的去之前的对话中寻找信息。
从程序角度是按以下步骤实现的:
import java.util.ArrayList;List messages = new ArrayList;//第一轮对话messages.add(new SystemMessage("你是一个旅游规划师"));messages.add(new UserMessage("我想去新疆"));ChatResponse response = chatModel.call(new Prompt(messages));String content = response.getResult.getOutput.getContent;messages.add(new AssistantMessage(content));//第二轮对话messages.add(new UserMessage("能帮我推荐一些旅游景点吗?"));response = chatModel.call(new Prompt(messages));content = response.getResult.getOutput.getContent;messages.add(new AssistantMessage(content));//第三轮对话messages.add(new UserMessage("那里这两天的天气如何?"));response = chatModel.call(new Prompt(messages));content = response.getResult.getOutput.getContent;System.out.printf("content: %s\n", content);每一次跟大模型交互的时候,都会把之前用户请求和大模型响应信息都发给大模型。这样循环往复手动维护上下文信息着实比较麻烦,并且这还没有涉及到对话内容的持久化,不然会更麻烦。
能不能比较方便的实现大模型的会话记忆功能呢?
借助Spring AI的会话记忆(ChatMemory)可以快速实现基于内存的会话记忆功能,实现如下:
@Configurationpublic class ChatClientConfig {@Beanpublic ChatClient chatClient(ChatModel chatModel) {return ChatClient.builder(chatModel)// InMemoryChatMemory 就是框架提供的基于内存的存储方案.defaultAdvisors(new MessageChatMemoryAdvisor(new InMemoryChatMemory)).defaultAdvisors(new SimpleLoggerAdvisor).defaultOptions(DashScopeChatOptions.builder.withTopP(0.7).build).build;}}这样配置之后,框架就会在每次请求大模型的时候,自动将同一个会话的上下文信息一起发送给大模型,也就实现了我们上一步手动维护会话记忆的功能,按下面方式请求大模型即可:
//对话记忆的唯一标识String conversantId = UUID.randomUUID.toString;chatClient.prompt(prompt).advisors(spec -> // 配置当前会话IDspec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, conversantId)// 请求大模型时,携带最后n条对话记录.param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10)).call.chatResponse.getResults.get(0).getOutput.getText;但如果仅仅将历史会话保存在内存中,每次服务器重启都会导致历史会话被清空,在大多数场景可能都不能满足我们的业务需求,我们更多是想要一种持久化的效果。但是Spring AI默认只提供了这种基于内存的方式,如果想基于其他方式来存储历史会话(比如:数据库、文件、NoSQL方式),需要我们来自己开发相关的实现。
下面以基于Redis的实现方式来抛砖引玉。
上来就手动实现,大多数小伙伴可能一下子不知如何去做。既然框架提供了一个默认实现,我们就可以从中汲取灵感。
先看一下InMemoryChatMemory是如何实现历史会话存储的。
public class InMemoryChatMemory implements ChatMemory {Map> conversationHistory = new ConcurrentHashMap;@Overridepublic void add(String conversationId, List messages) {this.conversationHistory.putIfAbsent(conversationId, new ArrayList);this.conversationHistory.get(conversationId).addAll(messages);}@Overridepublic List get(String conversationId, int lastN) {List all = this.conversationHistory.get(conversationId);return all != null ? all.stream.skip(Math.max(0, all.size - lastN)).toList : List.of;}@Overridepublic void clear(String conversationId) {this.conversationHistory.remove(conversationId);}}看下其父类ChatMemory的结构
public interface ChatMemory { // 新增会话default void add(String conversationId, Message message) {this.add(conversationId, List.of(message));}void add(String conversationId, List messages); // 获取最后n条会话List get(String conversationId, int lastN); // 清除会话内容void clear(String conversationId);}可以看到其实会话管理的思路是非常清晰的,无非就是
存储会话 -> 获取会话 -> 清除会话
在InMemoryChatMemory类中也是这样实现的,我们如果基于Reids来处理会话,也是需要实现ChatMemory,在相关方法中实现我们自己的逻辑。
在具体实现之前,我们要想清楚这个会话列表我们应该如何存到Reids中。我们日常使用的时候一般有两种方式:
将会话列表转成JSON格式存储,获取的时候再从JSON转为Message对象将Message对象序列化之后存储到Redis中其实如果使用方式1,会特别麻烦,因为Message是一个接口,其下又有若干种不同的消息实现,在转换JSON和Message对象之间,需要处理很多类型的数据,不推荐使用这种方式。
而对于方式2也有一个问题,我们如果需要将Message序列化,首先Message的子类需要实现Serializable接口,但很遗憾,所有的子类都没有实现该接口。所以我们不能使用平常的序列化方式,这里借助开源的序列化框架:kryo。
首先引入依赖:
com.esotericsoftwarekryo5.6.2基于Redis的历史会话管理实现如下。
既然要序列化,首先要定义Redis的序列化方式: @Beanpublic RedisTemplate> redisChatTemplate(RedisConnectionFactory factory) {RedisTemplate> template = new RedisTemplate;template.setConnectionFactory(factory);// redis的key使用stringtemplate.setKeySerializer(new StringRedisSerializer);template.setHashKeySerializer(new StringRedisSerializer);// redis的value使用我们自定义的序列化方式KryoRedisSerializer serializer = new KryoRedisSerializer;template.setValueSerializer(serializer);template.setHashValueSerializer(serializer);template.afterPropertiesSet;return template;}/*使用Kryo进行数据序列化与反序列化*/public static class KryoRedisSerializer implements RedisSerializer> {private static final Kryo kryo = new Kryo;static {kryo.setRegistrationRequired(false);kryo.setInstantiatorStrategy(new StdInstantiatorStrategy);}@Overridepublic byte serialize(List value) throws SerializationException {ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream;Output output = new Output(byteArrayOutputStream);kryo.writeObject(output, value);output.close;return byteArrayOutputStream.toByteArray;}@Overridepublic List deserialize(byte bytes) throws SerializationException {if (bytes == null) {return new ArrayList;}Input input = new Input(bytes);List list = kryo.readObject(input, ArrayList.class);input.close;return list;}}实现ChatMemory接口,实现我们自己的会话管理工具@Componentpublic class RedisChatMemory implements ChatMemory {@Resource@Lazyprivate RedisTemplate> redisChatTemplate;@Overridepublic void add(String conversationId, List messages) {getAndSetAll(conversationId, messages);}@Overridepublic List get(String conversationId, int lastN) {List all = getAndSetAll(conversationId, null);return all.stream.skip(Math.max(0, all.size - lastN)).toList;}private List getAndSetAll(String conversationId, List newMessages) {String key = getConversationKey(conversationId);List messages = redisChatTemplate.opsForValue.get(key);if (messages == null) {return List.of;}if (newMessages != null) {messages.addAll(newMessages);redisChatTemplate.opsForValue.set(key, messages);redisChatTemplate.expire(key, 1, TimeUnit.DAYS);}return messages;}@Overridepublic void clear(String conversationId) {String key = getConversationKey(conversationId);redisChatTemplate.delete(key);}@NotNullprivate static String getConversationKey(String conversationId) {return RedisUtil.buildKey(RedisKeyConstant.CHAT_MEMORY_KEY, conversationId);}}通过以上处理之后,就可以将所有会话内容存储到Redis中了。
程序源码:walter/spring-ai
来源:记忆旅途