一个基于SQLite的Memory实现。我们可以创建一个新的SQLiteMemory
类,继承自NormalMemory
的接口。以下是实现代码:
import sqlite3
import json
import logging
from GeneralAgent.utils import encode_image
class SQLiteMemory:
def __init__(self, db_path='./memory.db', messages=[]):
"""
@db_path: str, SQLite数据库路径,默认为'./memory.db'
"""
self.db_path = db_path
self.conn = sqlite3.connect(db_path)
self._create_table()
if len(messages) > 0:
self._validate_messages(messages)
for msg in messages:
self.add_message(msg['role'], msg['content'])
def _create_table(self):
"""创建messages表"""
cursor = self.conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
role TEXT NOT NULL,
content TEXT NOT NULL,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
''')
self.conn.commit()
def add_message(self, role, content):
"""添加新消息"""
assert role in ['user', 'system', 'assistant']
if isinstance(content, list):
# 处理多模态内容
r = []
for c in content:
if isinstance(c, dict):
if 'image' in c:
r.append({'type': 'image_url', 'image_url': {'url': encode_image(c['image'])}})
elif 'text' in c:
r.append({'type': 'text', 'text': c['text']})
else:
raise Exception('message type wrong')
else:
r.append({'type': 'text', 'text': c})
content = json.dumps(r)
cursor = self.conn.cursor()
cursor.execute('INSERT INTO messages (role, content) VALUES (?, ?)', (role, content))
self.conn.commit()
return cursor.lastrowid
def append_message(self, role, content, message_id=None):
"""追加消息内容"""
cursor = self.conn.cursor()
if message_id is not None:
# 更新指定消息
cursor.execute('SELECT role, content FROM messages WHERE id = ?', (message_id,))
result = cursor.fetchone()
if not result or result[0] != role:
raise ValueError("Invalid message_id or role mismatch")
new_content = result[1] + '\n' + content
cursor.execute('UPDATE messages SET content = ? WHERE id = ?', (new_content, message_id))
# 删除该消息之后的所有消息
cursor.execute('DELETE FROM messages WHERE id > ?', (message_id,))
else:
# 追加到最后一条相同role的消息
cursor.execute('SELECT id, content FROM messages WHERE role = ? ORDER BY id DESC LIMIT 1', (role,))
last_message = cursor.fetchone()
if last_message:
new_content = last_message[1] + '\n' + content
cursor.execute('UPDATE messages SET content = ? WHERE id = ?', (new_content, last_message[0]))
else:
cursor.execute('INSERT INTO messages (role, content) VALUES (?, ?)', (role, content))
self.conn.commit()
return cursor.lastrowid
def get_messages(self):
"""获取所有消息"""
cursor = self.conn.cursor()
cursor.execute('SELECT role, content FROM messages ORDER BY id')
messages = []
for row in cursor.fetchall():
content = row[1]
try:
# 尝试解析JSON格式的多模态内容
content = json.loads(content)
except json.JSONDecodeError:
pass
messages.append({'role': row[0], 'content': content})
return messages
def recover(self, index):
"""恢复到指定索引的状态"""
cursor = self.conn.cursor()
cursor.execute('DELETE FROM messages WHERE id > ?', (index,))
self.conn.commit()
def show_messages(self):
"""显示所有消息"""
logging.info('-' * 50 + '<Memory>' + '-' * 50)
for message in self.get_messages():
content = message['content']
if isinstance(content, (dict, list)):
content = json.dumps(content)
logging.info('[[' + message['role'] + ']]: ' + str(content)[:100])
logging.info('-' * 50 + '</Memory>' + '-' * 50)
def _validate_messages(self, messages):
"""验证消息格式"""
for message in messages:
assert isinstance(message, dict), 'message format wrong'
assert 'role' in message, 'message format wrong'
assert 'content' in message, 'message format wrong'
assert message['role'] in ['user', 'assistant'], 'message format wrong'
def __del__(self):
"""析构函数,确保关闭数据库连接"""
if hasattr(self, 'conn'):
self.conn.close()
主要特点和改进:
- 使用SQLite数据库存储消息,而不是JSON文件
- 每条消息都有自动递增的ID和时间戳
- 支持与原始
NormalMemory
相同的接口
- 支持多模态内容(图片和文本)
- 自动管理数据库连接
- 提供了消息的持久化存储
使用示例:
# 创建SQLite内存实例
memory = SQLiteMemory(db_path='./chat_memory.db')
# 添加消息
memory.add_message('user', 'Hello')
memory.add_message('assistant', 'Hi there!')
# 追加消息
memory.append_message('user', 'How are you?')
# 获取所有消息
messages = memory.get_messages()
# 显示消息
memory.show_messages()
这个实现相比原始的NormalMemory
有以下优势:
- 更好的数据持久性和可靠性
- 支持更大的数据量
- 更好的并发访问控制
- 支持按时间戳查询(可以扩展功能)
- 数据库级别的事务支持
你可以根据需要扩展更多功能,比如添加消息搜索、按时间范围获取消息等。