项目地址
服务端
import json
import uuid
import time
import torch
from src.model import RWKV_RNN
from src.sampler import sample_logits
from src.rwkv_tokenizer import RWKV_TOKENIZER
from flask import Flask, request, jsonify, Response
app = Flask(__name__)
# 初始化模型和分词器
def init_model():
# 模型参数配置
args = {
'MODEL_NAME': 'E:/RWKV_Pytorch/weight/RWKV-x060-World-1B6-v2-20240208-ctx4096',
'vocab_size': 65536,
'device': "cpu",
'onnx_opset': '18',
}
device = args['device']
assert device in ['cpu', 'cuda', 'musa', 'npu']
if device == "musa":
import torch_musa
elif device == "npu":
import torch_npu
model = RWKV_RNN(args).to(device)
tokenizer = RWKV_TOKENIZER("asset/rwkv_vocab_v20230424.txt")
return model, tokenizer, device
def format_messages_to_prompt(messages):
formatted_prompt = ""
# 定义角色映射到期望的名称
role_names = {
"system": "System",
"assistant": "Assistant",
"user": "User"
}
# 遍历消息并格式化
for message in messages:
role = role_names.get(message['role'], 'Unknown') # 获取角色名称,默认为'Unknown'
content = message['content']
formatted_prompt += f"{role}: {content}\n\n" # 添加角色和内容到提示,并添加换行符
formatted_prompt += "Assistant: "
return formatted_prompt
def generate_text_stream(prompt: str, temperature=1.5, top_p=0.1, max_tokens=2048, stop=['\n\nUser']):
encoded_input = tokenizer.encode([prompt])
token = torch.tensor(encoded_input).long().to(device)
state = torch.zeros(1, model.state_size[0], model.state_size[1]).to(device)
with torch.no_grad():
token_out, state_out = model.forward_parallel(token, state)
del token
out = token_out[:, -1]
generated_tokens = ''
completion_tokens = 0
if_max_token = True
for step in range(max_tokens):
token_sampled = sample_logits(out, temperature, top_p)
with torch.no_grad():
out, state = model.forward(token_sampled, state)
last_token = tokenizer.decode(token_sampled.unsqueeze(1).tolist())[0]
generated_tokens += last_token
completion_tokens += 1
if generated_tokens.endswith(tuple(stop)):
if_max_token = False
response = {
"object": "chat.completion.chunk",
"model": "rwkv",
"choices": [{
"delta": "",
"index": 0,
"finish_reason": "stop"
}]
}
yield f"data: {json.dumps(response)}\n\n"
else:
response = {
"object": "chat.completion.chunk",
"model": "rwkv",
"choices": [{
"delta": {"content": last_token},
"index": 0,
"finish_reason": None
}]
}
yield f"data: {json.dumps(response)}\n\n"
if if_max_token:
response = {
"object": "chat.completion.chunk",
"model": "rwkv",
"choices": [{
"delta": "",
"index": 0,
"finish_reason": "length"
}]
}
yield f"data: {json.dumps(response)}\n\n"
yield f"data:[DONE]\n\n"
def generate_text(prompt, temperature=1.5, top_p=0.1, max_tokens=2048, stop=['\n\nUser']):
encoded_input = tokenizer.encode([prompt])
token = torch.tensor(encoded_input).long().to(device)
state = torch.zeros(1, model.state_size[0], model.state_size[1]).to(device)
prompt_tokens = len(encoded_input[0])
with torch.no_grad():
token_out, state_out = model.forward_parallel(token, state)
del token
out = token_out[:, -1]
completion_tokens = 0
if_max_token = True
generated_tokens = ''
for step in range(max_tokens):
token_sampled = sample_logits(out, temperature, top_p)
with torch.no_grad():
out, state = model.forward(token_sampled, state)
# 判断是否达到停止条件
last_token = tokenizer.decode(token_sampled.unsqueeze(1).tolist())[0]
completion_tokens += 1
print(last_token, end='')
generated_tokens += last_token
for stop_token in stop:
if generated_tokens.endswith(stop_token):
generated_tokens = generated_tokens.replace(stop_token, "") # 替换掉终止token
if_max_token = False
break
# 如果末尾含有 stop 列表中的字符串,则停止生成
if not if_max_token:
break
total_tokens = prompt_tokens + completion_tokens
usage = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens}
return generated_tokens, if_max_token, usage
@app.route('/events', methods=['POST'])
def sse_request():
try:
# 从查询字符串中获取参数
data = request.json
messages = data.get('messages', [])
stream = data.get('stream', True) == True
temperature = float(data.get('temperature', 0.5))
top_p = float(data.get('top_p', 0.9))
max_tokens = int(data.get('max_tokens', 100))
stop = data.get('stop', ['\n\nUser'])
prompt = format_messages_to_prompt(messages)
if stream:
return Response(generate_text_stream(prompt=prompt, temperature=temperature, top_p=top_p,
max_tokens=max_tokens, stop=stop),
content_type='text/event-stream')
else:
completion, if_max_token, usage = generate_text(prompt, temperature=temperature, top_p=top_p,
max_tokens=max_tokens, stop=stop)
finish_reason = "stop" if if_max_token else "length"
unique_id = str(uuid.uuid4())
current_timestamp = int(time.time())
response = {
"id": unique_id,
"object": "chat.completion",
"created": current_timestamp,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": completion,
},
"finish_reason": finish_reason
}],
"usage": usage
}
return json.dumps(response)
except Exception as e:
return json.dumps({"error": str(e)}), 500
if __name__ == '__main__':
model, tokenizer, device = init_model()
app.run(debug=False)
解释
- 首先引入了需要的库,包括
json
用于处理JSON数据,uuid
用于生成唯一标识符,time
用于获取当前时间戳,torch
用于构建和运行模型,Flask
用于构建API。 - 定义了一个名为
app
的Flask应用。 init_model
函数用于初始化模型和分词器。其中,模型参数通过字典args
指定。format_messages_to_prompt
函数用于将消息格式化为提示字符串,以便于模型生成回复。遍历消息列表,获取每个消息的角色和内容,并添加到提示字符串中。generate_text_stream
函数用于以流的形式生成文本。首先将输入的提示字符串编码为张量,然后利用模型生成回复,并利用yield
关键字将回复以SSE(服务器发送事件)的形式返回。generate_text
函数用于一次性生成完整的文本回复。与generate_text_stream
函数类似,不同的是返回的是完整的回复字符串。sse_request
函数是Flask应用的主要逻辑,用于处理POST请求。从请求的JSON数据中获取参数,并根据参数的设置调用相应的生成函数。如果参数中设置了stream=True
,则返回流式生成的回复;否则返回一次性生成的回复。- 在
__main__
函数中初始化模型和分词器,然后运行Flask应用。
客户端
import json
import requests
from requests import RequestException
# 配置服务器URL
url = 'http://localhost:5000/events' # 假设您的Flask应用运行在本地端口5000上
# POST请求示例
def post_request_stream():
# 构造请求数据
data = {
'messages': [
{'role': 'system', 'content': '你好!'},
{'role': 'user', 'content': '你能告诉我今天的天气吗?'}
],
'temperature': 0.5,
'top_p': 0.9,
'max_tokens': 100,
'stop': ['\n\nUser'],
'stream':True
}
# 使用 requests 库来连接服务器,并传递参数
try:
with requests.post(url, json=data, stream=True) as r:
for line in r.iter_lines():
if line:
# 当服务器发送消息时,解码并打印出来
decoded_line = line.decode('utf-8')
print(json.loads(decoded_line[5:])["choices"][0]["delta"], end="")
except RequestException as e:
print(f'An error occurred: {e}')
def post_request():
# 构造请求数据
data = {
'messages': [
{'role': 'system', 'content': '你好!'},
{'role': 'user', 'content': '你能告诉我今天的天气吗?'}
],
'temperature': 0.5,
'top_p': 0.9,
'max_tokens': 100,
'stop': ['\n\nUser'],
'stream':False
}
# 使用 requests 库来连接服务器,并传递参数
try:
with requests.post(url, json=data, stream=True) as r:
for line in r.iter_lines():
if line:
# 当服务器发送消息时,解码并打印出来
decoded_line = line.decode('utf-8')
res=json.loads(decoded_line)
print(res)
except RequestException as e:
print(f'An error occurred: {e}')
if __name__ == '__main__':
# post_request()
post_request_stream()
解释
这段代码是一个用于向服务器发送POST请求的示例代码。
首先,我们需要导入一些必要的库。json
库用于处理JSON数据,requests
库用于发送HTTP请求,RequestException
用于处理请求异常。
接下来,我们需要配置服务器的URL。在这个示例中,假设服务器运行在本地端口5000上。
代码中定义了两个函数post_request_stream
和post_request
,分别用于发送带有流式响应和非流式响应的POST请求。
post_request_stream
函数构造了一个包含各种参数的数据字典,并使用requests.post
方法发送POST请求。在请求的参数中,stream
参数被设置为True
,表示我们希望获得一个流式的响应。接着,我们使用r.iter_lines()
方法来迭代获取服务器发送的消息。每收到一行消息,我们将其解码并打印出来。
post_request
函数的代码结构与post_request_stream
函数相似,不同之处在于stream
参数被设置为False
,表示我们希望获得一个非流式的响应。
最后,在程序的主体部分,我们调用post_request_stream
函数来发送流式的POST请求,并注释掉了post_request
函数的调用。