处理一些简单的语音不在话下,关键可以cpu 模式运行,也挺快的 ,实测占用内存大概2g ,意味着再一台 4c4g 的服务器就可以跑了,初始化需要下载
# pip install kokoro>=0.8.1 "misaki[zh]>=0.8.1"
from flask import Flask, request, jsonify, send_file
from kokoro import KModel, KPipeline
from pathlib import Path
import numpy as np
import soundfile as sf
import torch
import io
import base64
app = Flask(__name__)
# 初始化模型(在应用启动时只执行一次)
REPO_ID = 'hexgrad/Kokoro-82M-v1.1-zh'
SAMPLE_RATE = 24000
VOICE = 'zf_005'
def init_model():
global model, zh_pipeline
# 设备配置
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 初始化模型和管道
model = KModel(repo_id=REPO_ID).to(device).eval()
# 定义英文处理函数
en_pipeline = KPipeline(lang_code='a', repo_id=REPO_ID, model=False)
def en_callable(text):
if text == 'Kokoro':
return 'kˈOkəɹO'
elif text == 'Sol':
return 'sˈOl'
try:
return next(en_pipeline(text)).phonemes
except:
return text # 如果处理失败,返回原始文本
zh_pipeline = KPipeline(lang_code='z', repo_id=REPO_ID, model=model, en_callable=en_callable)
# 初始化模型
init_model()
# HACK: Mitigate rushing caused by lack of training data beyond ~100 tokens
# Simple piecewise linear fn that decreases speed as len_ps increases
def speed_callable(len_ps):
speed = 0.8
if len_ps <= 83:
speed = 1
elif len_ps < 183:
speed = 1 - (len_ps - 83) / 500
return speed * 1.5
def generate_tts(text):
"""生成TTS音频的通用函数"""
try:
# 生成语音
generator = zh_pipeline(text, voice=VOICE, speed=speed_callable)
result = next(generator)
wav = result.audio
# 将 numpy 数组转换为 WAV 格式的字节流
buffer = io.BytesIO()
sf.write(buffer, wav, SAMPLE_RATE, format='WAV')
buffer.seek(0)
return buffer
except Exception as e:
raise Exception(f"TTS生成失败: {str(e)}")
@app.route('/tts', methods=['POST'])
def text_to_speech():
try:
# 获取请求数据
data = request.get_json()
if not data or 'text' not in data:
return jsonify({'error': 'Missing text parameter'}), 400
text = data['text']
if not text:
return jsonify({'error': 'Text cannot be empty'}), 400
buffer = generate_tts(text)
# 返回音频文件
return send_file(
buffer,
mimetype='audio/wav',
as_attachment=True,
download_name='output.wav'
)
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/tts_get')
def text_to_speech_get():
try:
# 从查询参数获取文本
text = request.args.get('text', '')
if not text:
return jsonify({'error': 'Missing text parameter'}), 400
buffer = generate_tts(text)
# 返回音频文件
return send_file(
buffer,
mimetype='audio/wav',
as_attachment=True,
download_name='output.wav'
)
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/tts_base64', methods=['POST'])
def text_to_speech_base64():
try:
# 获取请求数据
data = request.get_json()
if not data or 'text' not in data:
return jsonify({'error': 'Missing text parameter'}), 400
text = data['text']
if not text:
return jsonify({'error': 'Text cannot be empty'}), 400
buffer = generate_tts(text)
# 转换为 base64 编码
wav_base64 = base64.b64encode(buffer.read()).decode('utf-8')
# 返回 base64 编码的音频
return jsonify({
'status': 'success',
'audio': wav_base64,
'format': 'wav'
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/tts_base64_get')
def text_to_speech_base64_get():
try:
# 从查询参数获取文本
text = request.args.get('text', '')
if not text:
return jsonify({'error': 'Missing text parameter'}), 400
buffer = generate_tts(text)
# 转换为 base64 编码
wav_base64 = base64.b64encode(buffer.read()).decode('utf-8')
# 返回 base64 编码的音频
return jsonify({
'status': 'success',
'audio': wav_base64,
'format': 'wav'
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({'status': 'healthy'})
if __name__ == '__main__':
# 禁用调试模式以避免重新加载问题
app.run(host='0.0.0.0', port=5000, debug=False)