whisper fastapi 完整识别一个音频文件实现

发布于:2024-06-27 ⋅ 阅读:(218) ⋅ 点赞:(0)
import whisper
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import StreamingResponse
import io
import torch
import numpy as np
from pydantic import BaseModel
import os
import tempfile

# 加载 Whisper 模型
model = whisper.load_model("medium")

# FastAPI 应用
app = FastAPI()

def pad_or_trim(array, length: int = 16000 * 30, *, axis: int = -1):
    """
    Split the audio array into multiple segments of length N_SAMPLES if it exceeds N_SAMPLES.
    If it is shorter, pad the array to N_SAMPLES.
    """
    if torch.is_tensor(array):
        arrays = []
        for i in range(0, array.shape[axis], length):
            segment = array.index_select(
                dim=axis, index=torch.arange(i, min(i + length, array.shape[axis]), device=array.device)
            )
            if segment.shape[axis] < length:
                pad_widths = [(0, 0)] * segment.ndim
                pad_widths[axis] = (0, length - segment.shape[axis])
                segment = torch.nn.functional.pad(segment, [pad for sizes in pad_widths[::-1] for pad in sizes])
            arrays.append(segment)
        array = torch.stack(arrays)
    else:
        arrays = []
        for i in range(0, array.shape[axis], length):
            segment = array.take(indices=range(i, min(i + length, array.shape[axis])), axis=axis)
            if segment.shape[axis] < length:
                pad_widths = [(0, 0)] * segment.ndim
                pad_widths[axis] = (0, length - segment.shape[axis])
                segment = np.pad(segment, pad_widths)
            arrays.append(segment)
        array = np.stack(arrays)

    return array

@app.post("/transcribe/")
async def transcribe_audio(file: UploadFile = File(...)):
    # 将上传的文件保存到临时文件
    with tempfile.NamedTemporaryFile(delete=False) as tmp:
        tmp.write(file.file.read())
        tmp_path = tmp.name

    audio = whisper.load_audio(tmp_path)

    # Pad or trim the audio to 30 seconds
    audio = pad_or_trim(audio)

    async def generate():
        # Process each audio segment
        # for segment in audio:
            # Compute the log-Mel spectrogram and move to the same device as the model
            mel = whisper.log_mel_spectrogram(audio).to(model.device)

            # Detect the spoken language
            _, probs = model.detect_language(mel)
            for probs in probs:

                detected_language = max(probs, key=probs.get)
                yield f"data: Detected language: {detected_language}\n\n"

            # Decode the audio
            options = whisper.DecodingOptions()
            result = whisper.decode(model, mel, options)

            # Yield the transcribed text
            for r in result:
                yield f"data: {r.text}\n\n"

    # 删除临时文件
    os.remove(tmp_path)

    return StreamingResponse(generate(), media_type="text/event-stream")

# 运行 FastAPI 应用
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=6006)

测试方法

1. 使用 curl
curl -X POST "http://127.0.0.1:8000/transcribe/" -H "accept: text/event-stream" -H "Content-Type: multipart/form-data" -F "file=@path_to_your_audio_file"
2. 使用 requests Python 脚本
import requests

url = "http://localhost:6006/transcribe/"
file_path = "瓦解--史塔克.mp3"  # 替换为你的实际音频文件路径

with open(file_path, "rb") as f:
    files = {"file": f}
    response = requests.post(url, files=files, headers={"accept": "text/event-stream"})
    print(response)
    for line in response.iter_lines():
        if line:
            print(line.decode("utf-8"))

3. 使用 Postman
  • 打开 Postman,创建一个新的 POST 请求。
  • 输入 URL:http://127.0.0.1:8000/transcribe/
  • 设置请求头:
    • accepttext/event-stream
    • Content-Typemultipart/form-data
  • 在 Body 选项卡中选择 form-data,添加一个新的键值对:
    • Key:file
    • Type:File
    • Value:选择你的音频文件
  • 发送请求,并查看服务器返回的流数据。

注意事项

  • 确保你的 Whisper 模型已经正确安装,并且你的 Python 环境中包含所有必要的依赖项(如 torch, numpy, fastapi, uvicorn 等)。
  • 确保 FastAPI 服务正在运行,使用 uvicorn 启动该服务。

网站公告

今日签到

点亮在社区的每一天
去签到