基于 TensorFlow Lite Micro(TFLM)的关键词识别(KWS)示例,运行在 STM32F407VGT6 上,实现 yes/no 识别。
工程基于官方 micro_speech 示例思路改造,采用官方“两模型管线”:独立 Audio Preprocessor 模型 + Micro Speech 分类模型,并结合嵌入式实际对内存与接口做了适配和优化。
1、下载官方源码
github仓库:https://github.com/tensorflow/tflite-micro/
2、下载flatbuffers放到tflite-micro\third_party目录下面
github仓库:https://github.com/google/flatbuffers/archive/v23.5.26.zip
3、下载pigweed放到tflite-micro\third_party目录下面
github仓库:git clone https://pigweed.googlesource.com/pigweed/pigweed
git checkout 47268dff45019863e20438ca3746c6c62df6ef09
4、下载kissfft放到tflite-micro\third_party目录下面
github仓库:https://github.com/mborgerding/kissfft/archive/refs/tags/v130.zip
5、下载gemmlowp放到tflite-micro\third_party目录下面
github仓库:https://github.com/google/gemmlowp/archive/719139ce755a0f31cbf1c37f7f98adcc7fc9f425.zip
6、下载ruy放到tflite-micro\third_party目录下面
github仓库:https://github.com/google/ruy/archive/d37128311b445e758136b8602d1bbd2a755e115d.zip
7、编译以下路径的文件
C_SOURCES := \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/window.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/window_util.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/filterbank.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/filterbank_util.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/noise_reduction.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/noise_reduction_util.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control_util.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/log_scale.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/log_scale_util.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/log_lut.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/frontend.c \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/frontend_util.c \
CXX_SOURCES_CC := \
tflite-micro/tensorflow/lite/micro/micro_interpreter.cc \
tflite-micro/tensorflow/lite/micro/micro_utils.cc \
tflite-micro/tensorflow/lite/micro/cortex_m_generic/debug_log.cc \
tflite-micro/tensorflow/lite/micro/micro_allocator.cc \
tflite-micro/tensorflow/lite/micro/micro_allocation_info.cc \
tflite-micro/tensorflow/lite/micro/micro_interpreter_context.cc \
tflite-micro/tensorflow/lite/micro/micro_interpreter_graph.cc \
tflite-micro/tensorflow/lite/micro/memory_helpers.cc \
tflite-micro/tensorflow/lite/micro/flatbuffer_utils.cc \
tflite-micro/tensorflow/lite/micro/micro_log.cc \
tflite-micro/tensorflow/lite/micro/micro_time.cc \
tflite-micro/tensorflow/lite/micro/micro_profiler.cc \
tflite-micro/tensorflow/lite/micro/system_setup.cc \
tflite-micro/tensorflow/lite/micro/micro_resource_variable.cc \
tflite-micro/tensorflow/lite/micro/micro_op_resolver.cc \
tflite-micro/tensorflow/lite/micro/tflite_bridge/flatbuffer_conversions_bridge.cc \
tflite-micro/tensorflow/lite/micro/tflite_bridge/micro_error_reporter.cc \
tflite-micro/tensorflow/compiler/mlir/lite/core/api/error_reporter.cc \
tflite-micro/tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.cc \
tflite-micro/tensorflow/lite/micro/arena_allocator/persistent_arena_buffer_allocator.cc \
tflite-micro/tensorflow/lite/micro/arena_allocator/non_persistent_arena_buffer_allocator.cc \
tflite-micro/tensorflow/lite/micro/memory_planner/greedy_memory_planner.cc \
tflite-micro/tensorflow/lite/micro/memory_planner/linear_memory_planner.cc \
tflite-micro/tensorflow/lite/micro/kernels/kernel_util.cc \
tflite-micro/tensorflow/lite/micro/kernel_util_compat.cc \
tflite-micro/tensorflow/lite/kernels/internal/common.cc \
tflite-micro/tensorflow/lite/kernels/internal/quantization_util.cc \
tflite-micro/tensorflow/lite/micro/kernels/fully_connected_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/softmax_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/elementwise.cc \
tflite-micro/tensorflow/lite/micro/kernels/micro_tensor_utils.cc \
tflite-micro/tensorflow/lite/micro/kernels/activations_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/conv_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/depthwise_conv_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/reshape.cc \
tflite-micro/tensorflow/lite/micro/kernels/reshape_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/fully_connected.cc \
tflite-micro/tensorflow/lite/micro/kernels/depthwise_conv.cc \
tflite-micro/tensorflow/lite/micro/kernels/softmax.cc \
tflite-micro/tensorflow/lite/micro/kernels/cast.cc \
tflite-micro/tensorflow/lite/micro/kernels/add.cc \
tflite-micro/tensorflow/lite/micro/kernels/add_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/div.cc \
tflite-micro/tensorflow/lite/micro/kernels/strided_slice.cc \
tflite-micro/tensorflow/lite/micro/kernels/strided_slice_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/concatenation.cc \
tflite-micro/tensorflow/lite/micro/kernels/mul.cc \
tflite-micro/tensorflow/lite/micro/kernels/mul_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/pad.cc \
tflite-micro/tensorflow/lite/micro/kernels/pad_common.cc \
tflite-micro/tensorflow/lite/core/c/common.cc \
tflite-micro/tensorflow/lite/core/api/flatbuffer_conversions.cc \
tflite-micro/tensorflow/lite/kernels/internal/tensor_ctypes.cc \
tflite-micro/tensorflow/lite/kernels/internal/portable_tensor_utils.cc \
tflite-micro/tensorflow/lite/micro/micro_context.cc \
tflite-micro/tensorflow/compiler/mlir/lite/schema/schema_utils.cc \
tflite-micro/tensorflow/lite/micro/kernels/activations.cc \
tflite-micro/tensorflow/lite/micro/kernels/conv.cc \
tflite-micro/tensorflow/lite/micro/kernels/pooling.cc \
tflite-micro/tensorflow/lite/micro/kernels/pooling_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/dequantize.cc \
tflite-micro/tensorflow/lite/micro/kernels/dequantize_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/quantize.cc \
tflite-micro/tensorflow/lite/micro/kernels/quantize_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/maximum_minimum.cc \
tflite-micro/tensorflow/lite/micro/kernels/logistic.cc \
tflite-micro/tensorflow/lite/micro/kernels/logistic_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/reduce.cc \
tflite-micro/tensorflow/lite/micro/kernels/reduce_common.cc \
tflite-micro/tensorflow/lite/micro/kernels/sub.cc \
tflite-micro/tensorflow/lite/micro/kernels/sub_common.cc \
tflite-micro/signal/micro/kernels/window.cc \
tflite-micro/signal/micro/kernels/fft_auto_scale_kernel.cc \
tflite-micro/signal/micro/kernels/rfft.cc \
tflite-micro/signal/micro/kernels/energy.cc \
tflite-micro/signal/micro/kernels/filter_bank.cc \
tflite-micro/signal/micro/kernels/filter_bank_square_root.cc \
tflite-micro/signal/micro/kernels/filter_bank_square_root_common.cc \
tflite-micro/signal/micro/kernels/filter_bank_spectral_subtraction.cc \
tflite-micro/signal/micro/kernels/filter_bank_log.cc \
tflite-micro/signal/micro/kernels/pcan.cc \
tflite-micro/signal/src/window.cc \
tflite-micro/signal/src/fft_auto_scale.cc \
tflite-micro/signal/src/irfft_int16.cc \
tflite-micro/signal/src/irfft_int32.cc \
tflite-micro/signal/src/irfft_float.cc \
tflite-micro/signal/src/rfft_int16.cc \
tflite-micro/signal/src/rfft_int32.cc \
tflite-micro/signal/src/rfft_float.cc \
tflite-micro/signal/src/energy.cc \
tflite-micro/signal/src/filter_bank.cc \
tflite-micro/signal/src/filter_bank_square_root.cc \
tflite-micro/signal/src/filter_bank_spectral_subtraction.cc \
tflite-micro/signal/src/filter_bank_log.cc \
tflite-micro/signal/src/log.cc \
tflite-micro/signal/src/kiss_fft_wrappers/kiss_fft_int16.cc \
tflite-micro/signal/src/msb_32.cc \
tflite-micro/signal/src/max_abs.cc \
tflite-micro/signal/src/square_root_32.cc \
tflite-micro/signal/src/square_root_64.cc \
tflite-micro/signal/src/pcan_argc_fixed.cc \
tflite-micro/signal/micro/kernels/fft_auto_scale_common.cc \
tflite-micro/signal/src/msb_64.cc \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/fft.cc \
tflite-micro/tensorflow/lite/experimental/microfrontend/lib/fft_util.cc \
8、添加添加以下头文件路径
INCLUDE_DIRS := \
. \
tflite-micro/third_party/kissfft \
tflite-micro/third_party/gemmlowp \
tflite-micro/third_party/flatbuffers/include \
tflite-micro/third_party/ruy \
tflite-micro/third_party \
tflite-micro \
tflite-micro/signal \
9、转换预处理模型tflite为C语言数组
在tflite-micro\tensorflow\lite\micro\examples\micro_speech\models\路径下面有一个audio_preprocessor_int8.tflite的预处理模型数据,
运行: xxd -i audio_preprocessor_int8.tflite > audio_preprocessor_int8_model_data.c
把audio_preprocessor_int8_model_data.c添加编译
10、转换识别分类模型tflite为C语言数组
在tflite-micro\tensorflow\lite\micro\examples\micro_speech\models\路径下面有一个micro_speech_quantized.tflite的识别模型数据,
运行: xxd -i micro_speech_quantized.tflite > micro_speech_quantized_model_data.c
把micro_speech_quantized_model_data.c添加编译
11、编写模型识别代码
// Public KWS interfaces and minimal pipeline implementation
#include "micro_speech_quantized_model_data.h"
// #include "tensorflow/lite/micro/all_ops_resolver.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
// #include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
// (C frontend removed; using Audio Preprocessor model)
// Error reporter
#include "tensorflow/lite/micro/tflite_bridge/micro_error_reporter.h"
// Schema version macro
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/micro/kernels/micro_ops.h"
#include "signal/micro/kernels/rfft.h"
// TFLM DebugLog callback registration
#include "tensorflow/lite/micro/cortex_m_generic/debug_log_callback.h"
// #include "audio_frontend.h" // TFLM 音频前端
#include "log_debug.h"
#include <stdint.h>
#include <string.h>
#include <new>
#define KWS_LOG_INFO log_debug
// (uvprojx has included required kernel source files)
// ======== TensorFlow Lite Micro persistent objects ========
// STM32F407 has 192KB RAM, but Keil config only uses 128KB
// Reduce tensor arena to 64KB, leaving space for other components
constexpr int kTensorArenaSize = 10 * 1024;
alignas(16) static uint8_t tensor_arena[kTensorArenaSize];
static tflite::ErrorReporter* g_error_reporter = nullptr;
static const tflite::Model* g_model = nullptr;
// Increase op resolver size to accommodate all required operators
static tflite::MicroMutableOpResolver<48> g_resolver;
static tflite::MicroInterpreter* g_interpreter = nullptr;
static TfLiteTensor* g_input = nullptr;
static TfLiteTensor* g_output = nullptr;
static bool g_inited = false;
// 为主模型解释器提供全局存储,初始化时使用 placement new 构造
alignas(alignof(tflite::MicroInterpreter)) static uint8_t g_interpreter_buffer[sizeof(tflite::MicroInterpreter)];
// Model / feature parameters (micro_speech-like)
static constexpr int kSampleRate = 16000; // 16 kHz
static constexpr int kFrameLenMs = 30; // 30 ms window
static constexpr int kFrameStrideMs = 20; // 20 ms hop
static constexpr int kNumMelBins = 40; // 40 mel bands
// Frames per inference is determined from model input shape at runtime
static int g_frames_per_inference = 49; // fallback typical value
static inline int feature_size_bytes() { return kNumMelBins * g_frames_per_inference; }
// ======== Two-model pipeline (Audio Preprocessor + MicroSpeech) ========
// 始终使用 Audio Preprocessor int8 模型生成 49x40 特征
#include "audio_preprocessor_int8_model_data.h"
// 预处理模型的解释器与算子解析器(参照 micro_speech_test.cc)
static constexpr size_t kPreprocArenaSize = 12 * 1024;
alignas(16) static uint8_t s_preproc_arena[kPreprocArenaSize];
using PreprocOpResolver = tflite::MicroMutableOpResolver<18>;
static tflite::MicroInterpreter* g_preproc_interpreter = nullptr;
static PreprocOpResolver* g_preproc_resolver = nullptr;
static const tflite::Model* g_preproc_model = nullptr;
// 预处理解析器对象与解释器全局存储
static PreprocOpResolver g_preproc_resolver_inst;
alignas(alignof(tflite::MicroInterpreter)) static uint8_t g_preproc_interpreter_buffer[sizeof(tflite::MicroInterpreter)];
static bool init_preprocessor()
{
if (g_preproc_interpreter) {
KWS_LOG_INFO("preproc interpreter already inited\n");
return true;
}
g_preproc_model = tflite::GetModel(__audio_preprocessor_int8_tflite);
if (!g_preproc_model) {
KWS_LOG_INFO("error: preproc model null\n");
return false;
}
g_preproc_resolver = &g_preproc_resolver_inst;
LINE_INFO
// 注册与 README/test 一致的信号算子
g_preproc_resolver->AddReshape();
g_preproc_resolver->AddCast();
g_preproc_resolver->AddStridedSlice();
g_preproc_resolver->AddConcatenation();
g_preproc_resolver->AddMul();
g_preproc_resolver->AddAdd();
g_preproc_resolver->AddDiv();
g_preproc_resolver->AddMinimum();
g_preproc_resolver->AddMaximum();
g_preproc_resolver->AddCustom("SignalWindow", tflite::tflm_signal::Register_WINDOW());
g_preproc_resolver->AddCustom("SignalFftAutoScale", tflite::tflm_signal::Register_FFT_AUTO_SCALE());
g_preproc_resolver->AddCustom("SignalRfft", tflite::tflm_signal::Register_RFFT());
g_preproc_resolver->AddCustom("SignalEnergy", tflite::tflm_signal::Register_ENERGY());
g_preproc_resolver->AddCustom("SignalFilterBank", tflite::tflm_signal::Register_FILTER_BANK());
g_preproc_resolver->AddCustom("SignalFilterBankSquareRoot", tflite::tflm_signal::Register_FILTER_BANK_SQUARE_ROOT());
g_preproc_resolver->AddCustom("SignalFilterBankSpectralSubtraction", tflite::tflm_signal::Register_FILTER_BANK_SPECTRAL_SUBTRACTION());
g_preproc_resolver->AddCustom("SignalPCAN", tflite::tflm_signal::Register_PCAN());
g_preproc_resolver->AddCustom("SignalFilterBankLog", tflite::tflm_signal::Register_FILTER_BANK_LOG());
g_preproc_interpreter = new (g_preproc_interpreter_buffer)
tflite::MicroInterpreter(g_preproc_model, *g_preproc_resolver,
s_preproc_arena, kPreprocArenaSize);
LINE_INFO
TfLiteStatus status;
KWS_LOG_INFO("preproc AllocateTensors start\n");
status = g_preproc_interpreter->AllocateTensors();
if (status != kTfLiteOk) {
KWS_LOG_INFO("preproc AllocateTensors failed, status %d\n", status);
size_t used_bytes = g_preproc_interpreter->arena_used_bytes();
KWS_LOG_INFO("preproc Arena used bytes: %d\n", used_bytes);
g_preproc_interpreter = nullptr;
return false;
}
size_t used_bytes = g_preproc_interpreter->arena_used_bytes();
KWS_LOG_INFO("preproc Arena used bytes: %d / %d (%.1f%%)\n",
used_bytes, kPreprocArenaSize,
(100.0f * used_bytes) / kPreprocArenaSize);
KWS_LOG_INFO("preproc AllocateTensors succeeded!\n");
return true;
}
// 以 30ms(480) 窗/20ms(320) 步生成 49 帧,每帧 40 维,输出按行优先写入 out_features(长度需≥49*40)
static bool generate_features_with_preproc(const int16_t* pcm, int num_samples, int8_t* out_features)
{
// 预处理必须已在 init_tflm() 中完成初始化,运行期不再懒初始化
if (!g_preproc_interpreter) {
KWS_LOG_INFO("error: preproc interpreter null\n");
return false;
}
TfLiteTensor* pin = g_preproc_interpreter->input(0);
TfLiteTensor* pout = g_preproc_interpreter->output(0);
if (!pin || !pout) {
KWS_LOG_INFO("error: preproc input or output null\n");
return false;
}
const int frame_samples = (kSampleRate * kFrameLenMs) / 1000; // 480
const int stride_samples = (kSampleRate * kFrameStrideMs) / 1000; // 320
const int frames = g_frames_per_inference; // 49
int produced = 0;
const int16_t* cursor = pcm;
int remaining = num_samples;
while (remaining >= frame_samples && produced < frames) {
// 拷贝单帧 PCM
memcpy(tflite::GetTensorData<int16_t>(pin), cursor, frame_samples * sizeof(int16_t));
TfLiteStatus status = g_preproc_interpreter->Invoke();
if (status != kTfLiteOk) {
KWS_LOG_INFO("preproc Invoke failed, status %d\n", status);
return false;
}
// 拷贝 40 维 int8 特征
memcpy(out_features + produced * kNumMelBins,
tflite::GetTensorData<int8_t>(pout),
kNumMelBins * sizeof(int8_t));
produced++;
cursor += stride_samples;
remaining -= stride_samples;
}
// 若不足 49 帧,后续填 0
for (int f = produced; f < frames; ++f) {
memset(out_features + f * kNumMelBins, 0, kNumMelBins);
}
return true;
}
// ======== Feature computation via Audio Preprocessor model ========
static void compute_mfcc(const int16_t* audio, int8_t* mfcc_output, int length)
{
// 用预处理模型直接产出 int8(40) 特征,累计 g_frames_per_inference 帧
if (!g_preproc_interpreter || !generate_features_with_preproc(audio, length, mfcc_output)) {
KWS_LOG_INFO("preproc features failed, fill zeros\n");
memset(mfcc_output, 0, kNumMelBins * g_frames_per_inference);
}
}
// ======== Public interface: one-time init ========
extern "C" void init_tflm()
{
KWS_LOG_INFO("init_tflm\n");
if (g_inited) {
KWS_LOG_INFO("error: init_tflm already inited\n");
return;
}
KWS_LOG_INFO("init_tflm2\n");
LINE_INFO
// Redirect TFLM DebugLog()/MicroPrintf() to RTT
RegisterDebugLogCallback([](const char* s){ printf("%s", s); });
LINE_INFO
// 提前初始化音频预处理模型(不在跑数据时初始化)
if (!init_preprocessor()) {
KWS_LOG_INFO("init preprocessor failed\n");
return;
}
LINE_INFO
TfLiteStatus status;
// Model and interpreter
g_model = tflite::GetModel(micro_speech_quantized_tflite);
if (g_model == nullptr) {
KWS_LOG_INFO("Failed to load model\n");
return;
}
KWS_LOG_INFO("Model loaded, size: %d bytes\n", micro_speech_quantized_tflite_len);
// Check schema version
const int model_version = g_model->version();
if (model_version != TFLITE_SCHEMA_VERSION) {
KWS_LOG_INFO("Model schema version %d != supported %d\n", model_version, TFLITE_SCHEMA_VERSION);
return;
}
// Dump operators required by model (builtin codes)
if (g_model->operator_codes()) {
int num_ops = g_model->operator_codes()->size();
KWS_LOG_INFO("Model operator_codes: %d\n", num_ops);
for (int i = 0; i < num_ops; ++i) {
const auto* oc = g_model->operator_codes()->Get(i);
int bcode = static_cast<int>(oc->builtin_code());
int8_t ver = oc->version();
const auto* cname = oc->custom_code();
if (cname) {
KWS_LOG_INFO(" opcode[%d]: builtin=%d ver=%d custom=%s\n", i, bcode, ver, cname->c_str());
} else {
KWS_LOG_INFO(" opcode[%d]: builtin=%d ver=%d\n", i, bcode, ver);
}
}
}
LINE_INFO
auto add_ok = [&](TfLiteStatus s, const char* name){
if (s != kTfLiteOk) { KWS_LOG_INFO("Add %s failed\n", name); return false; }
log_info("%s added\n", name); return true;
};
if (!add_ok(g_resolver.AddConv2D(), "op Conv2D")) return;
if (!add_ok(g_resolver.AddDepthwiseConv2D(), "op DepthwiseConv2D")) return;
if (!add_ok(g_resolver.AddMaxPool2D(), "op MaxPool2D")) return;
if (!add_ok(g_resolver.AddAveragePool2D(), "op AvgPool2D")) return;
if (!add_ok(g_resolver.AddFullyConnected(), "op FullyConnected")) return;
if (!add_ok(g_resolver.AddReshape(), "op Reshape")) return;
if (!add_ok(g_resolver.AddSoftmax(), "op Softmax")) return;
if (!add_ok(g_resolver.AddPad(), "op Pad")) return;
if (!add_ok(g_resolver.AddMean(), "op Mean")) return;
if (!add_ok(g_resolver.AddLogistic(), "op Logistic")) return;
if (!add_ok(g_resolver.AddCast(), "op Cast")) return;
if (!add_ok(g_resolver.AddStridedSlice(), "op StridedSlice")) return;
if (!add_ok(g_resolver.AddConcatenation(), "op Concat")) return;
if (!add_ok(g_resolver.AddAdd(), "op Add")) return;
if (!add_ok(g_resolver.AddDiv(), "op Div")) return;
if (!add_ok(g_resolver.AddMul(), "op Mul")) return;
if (!add_ok(g_resolver.AddSub(), "op Sub")) return;
if (!add_ok(g_resolver.AddMinimum(), "op Minimum")) return;
if (!add_ok(g_resolver.AddMaximum(), "op Maximum")) return;
if (!add_ok(g_resolver.AddQuantize(), "op Quantize")) return;
if (!add_ok(g_resolver.AddDequantize(), "op Dequantize")) return;
// Register custom audio frontend operators (consistent with custom names in model)
if (!add_ok(g_resolver.AddCustom("SignalWindow", tflite::tflm_signal::Register_WINDOW()), "custom SignalWindow")) return;
if (!add_ok(g_resolver.AddCustom("SignalFftAutoScale", tflite::tflm_signal::Register_FFT_AUTO_SCALE()), "custom SignalFftAutoScale")) return;
if (!add_ok(g_resolver.AddCustom("SignalRfft", tflite::tflm_signal::Register_RFFT()), "custom SignalRfft")) return;
if (!add_ok(g_resolver.AddCustom("SignalEnergy", tflite::tflm_signal::Register_ENERGY()), "custom SignalEnergy")) return;
if (!add_ok(g_resolver.AddCustom("SignalFilterBank", tflite::tflm_signal::Register_FILTER_BANK()), "custom SignalFilterBank")) return;
if (!add_ok(g_resolver.AddCustom("SignalFilterBankSquareRoot", tflite::tflm_signal::Register_FILTER_BANK_SQUARE_ROOT()), "custom SignalFilterBankSquareRoot")) return;
if (!add_ok(g_resolver.AddCustom("SignalFilterBankSpectralSubtraction", tflite::tflm_signal::Register_FILTER_BANK_SPECTRAL_SUBTRACTION()), "custom SignalFilterBankSpectralSubtraction")) return;
if (!add_ok(g_resolver.AddCustom("SignalPCAN", tflite::tflm_signal::Register_PCAN()), "custom SignalPCAN")) return;
if (!add_ok(g_resolver.AddCustom("SignalFilterBankLog", tflite::tflm_signal::Register_FILTER_BANK_LOG()), "custom SignalFilterBankLog")) return;
LINE_INFO
// Error reporter must be obtained before constructing interpreter
g_error_reporter = tflite::GetMicroErrorReporter();
LINE_INFO
// Print memory info before creating interpreter
KWS_LOG_INFO("Creating interpreter with tensor arena size: %d KB\n", kTensorArenaSize / 1024);
g_interpreter = new (g_interpreter_buffer)
tflite::MicroInterpreter(g_model, g_resolver,
tensor_arena, kTensorArenaSize);
LINE_INFO
// Print more debug info
KWS_LOG_INFO("Allocating tensors...\n");
status = g_interpreter->AllocateTensors();
if (status != kTfLiteOk) {
KWS_LOG_INFO("AllocateTensors failed, status %d (kTfLiteError)\n", status);
KWS_LOG_INFO("This usually means insufficient memory or unsupported operations\n");
// Try to get more information
size_t used_bytes = g_interpreter->arena_used_bytes();
KWS_LOG_INFO("Arena used bytes: %d\n", used_bytes);
return;
}
// After successful tensor allocation, print memory usage
KWS_LOG_INFO("AllocateTensors succeeded!\n");
size_t used_bytes = g_interpreter->arena_used_bytes();
KWS_LOG_INFO("Arena used bytes: %d / %d (%.1f%%)\n",
used_bytes, kTensorArenaSize,
(100.0f * used_bytes) / kTensorArenaSize);
LINE_INFO
g_input = g_interpreter->input(0);
LINE_INFO
g_output = g_interpreter->output(0);
// Print input/output tensor info
if (g_input) {
KWS_LOG_INFO("Input tensor info:\n");
KWS_LOG_INFO(" Type: %d\n", g_input->type);
KWS_LOG_INFO(" Bytes: %d\n", g_input->bytes);
if (g_input->dims) {
KWS_LOG_INFO(" Dims: [");
for (int i = 0; i < g_input->dims->size; ++i) {
KWS_LOG_INFO("%d", g_input->dims->data[i]);
if (i < g_input->dims->size - 1) r_printf(", ");
}
KWS_LOG_INFO("]\n");
}
}
if (g_output) {
KWS_LOG_INFO("Output tensor info:\n");
KWS_LOG_INFO(" Type: %d\n", g_output->type);
KWS_LOG_INFO(" Bytes: %d\n", g_output->bytes);
if (g_output->dims) {
KWS_LOG_INFO(" Dims: [");
for (int i = 0; i < g_output->dims->size; ++i) {
KWS_LOG_INFO("%d", g_output->dims->data[i]);
if (i < g_output->dims->size - 1) KWS_LOG_INFO(", ");
}
KWS_LOG_INFO("]\n");
}
}
// Derive frames per inference from input tensor shape
if (g_input && g_input->dims) {
int elems = 1;
for (int i = 0; i < g_input->dims->size; ++i) {
elems *= g_input->dims->data[i];
}
if (elems > 0 && (elems % kNumMelBins) == 0) {
g_frames_per_inference = elems / kNumMelBins;
}
KWS_LOG_INFO("input elems=%d, mel=%d, frames=%d\n", elems, kNumMelBins, g_frames_per_inference);
}
LINE_INFO
g_inited = true;
}
// ======== Extended C-callable KWS interfaces ========
// Return required audio front-end parameters
extern "C" int kws_get_required_sample_rate()
{
return kSampleRate;
}
extern "C" int kws_get_frame_length_ms()
{
return kFrameLenMs;
}
extern "C" int kws_get_frame_stride_ms()
{
return kFrameStrideMs;
}
extern "C" int kws_get_required_samples_per_inference()
{
const int frame_samples = (kSampleRate * kFrameLenMs) / 1000;
const int stride_samples = (kSampleRate * kFrameStrideMs) / 1000;
return frame_samples + (g_frames_per_inference - 1) * stride_samples;
}
extern "C" int kws_get_num_classes()
{
if (!g_inited) {
init_tflm();
}
if (!g_output || !g_output->dims) {
return 0;
}
int elems = 1;
for (int i = 0; i < g_output->dims->size; ++i) {
elems *= g_output->dims->data[i];
}
return elems;
}
extern "C" int kws_get_output_shape(int* out_rank, int* out_dims, int max_dims)
{
if (!g_inited) {
init_tflm();
}
if (out_rank) {
*out_rank = 0;
}
if (!g_output || !g_output->dims) {
return 0;
}
const int rank = g_output->dims->size;
if (out_rank) {
*out_rank = rank;
}
if (out_dims && max_dims > 0) {
const int n = (rank < max_dims) ? rank : max_dims;
for (int i = 0; i < n; ++i) {
out_dims[i] = g_output->dims->data[i];
}
}
return rank;
}
// One-shot inference interface: feed a block of PCM and get top-1 result
// Returns 0 on failure, 1 on success
extern "C" int kws_run_inference(const int16_t* pcm, int num_samples, int* out_index, int8_t* out_score)
{
if (!g_inited) {
KWS_LOG_INFO("error: not inited\n");
return 0;
}
if (!pcm || num_samples <= 0 || !out_index || !out_score) {
return 0;
}
// Compute features directly into input tensor
compute_mfcc(pcm, g_input->data.int8, num_samples);
TfLiteStatus status = g_interpreter->Invoke();
if (status != kTfLiteOk) {
KWS_LOG_INFO("Invoke failed, status %d \n", status);
return 0;
}
// Argmax over output tensor (int8 logits or probabilities)
const int8_t* out_data = g_output->data.int8;
int out_len = g_output->bytes;
int best_i = 0;
int8_t best_v = out_data[0];
for (int i = 1; i < out_len; ++i) {
if (out_data[i] > best_v) {
best_v = out_data[i];
best_i = i;
}
}
*out_index = best_i;
*out_score = best_v;
if (out_len >= 4) {
KWS_LOG_INFO("scores: s=%d u=%d y=%d n=%d | best=%d(%d)\n", out_data[0], out_data[1], out_data[2], out_data[3], best_i, best_v);
} else {
KWS_LOG_INFO("best_i %d, best_v %d\n", best_i, best_v);
}
return 1;
}
static int16_t window_buf[16000];
static int filled = 0;
extern "C" void on_audio_chunk_samples(const int16_t* audio_data, int num_samples);
extern "C" void on_audio_chunk_10ms(const int16_t* in160)
{
// 10ms@16kHz = 160 samples, use common entry to avoid inconsistency with fixed stride
on_audio_chunk_samples(in160, 160);
}
extern "C" void on_audio_chunk_samples(const int16_t* audio_data, int num_samples)
{
if (!g_inited) {
KWS_LOG_INFO("error: not inited\n");
return;
}
const int required = kws_get_required_samples_per_inference();
const int stride = (kws_get_required_sample_rate() * kws_get_frame_stride_ms()) / 1000;
int capacity = 0;
// 1) Append new data to FIFO
if (num_samples > 0) {
// Simple protection: if append would exceed buffer, truncate to window tail available capacity
capacity = (int)(sizeof(window_buf) / sizeof(window_buf[0]));
if (filled + num_samples > capacity) {
// First try to consume existing data by stride to free space
while (filled >= required && (filled + num_samples > capacity)) {
int index; int8_t score;
(void)kws_run_inference(window_buf, required, &index, &score);
memmove(window_buf, window_buf + stride, (filled - stride) * sizeof(int16_t));
filled -= stride;
}
// If still exceeding limit, only keep required-1 data from window tail
if (filled + num_samples > capacity) {
if (required < capacity) {
// Move tail required-1 data to buffer start
if (filled > required) {
memmove(window_buf, window_buf + (filled - required), required * sizeof(int16_t));
filled = required;
}
} else {
// Extreme case: capacity insufficient for one window, truncate to capacity
if (filled > capacity) {
memmove(window_buf, window_buf + (filled - capacity), capacity * sizeof(int16_t));
filled = capacity;
}
}
}
}
int copy = num_samples;
memcpy(&window_buf[filled], audio_data, copy * sizeof(int16_t));
filled += copy;
}
// 2) Fixed stride advance with simple smoothing vote
static const int kVoteWindow = 8;
static int s_pred_buf[kVoteWindow];
static int s_vote_counts[4] = {0,0,0,0};
static int s_write = 0; static int s_filled_win = 0;
while (filled >= required) {
int index; int8_t score;
if (kws_run_inference(window_buf, required, &index, &score)) {
if (s_filled_win < kVoteWindow) {
s_pred_buf[s_write] = index;
if (index >= 0 && index < 4) s_vote_counts[index]++;
s_write = (s_write + 1) % kVoteWindow;
s_filled_win++;
} else {
int old = s_pred_buf[s_write];
if (old >= 0 && old < 4) s_vote_counts[old]--;
s_pred_buf[s_write] = index;
if (index >= 0 && index < 4) s_vote_counts[index]++;
s_write = (s_write + 1) % kVoteWindow;
}
int smooth_i = 0; int smooth_v = s_vote_counts[0];
for (int c = 1; c < 4; ++c) {
if (s_vote_counts[c] > smooth_v) { smooth_v = s_vote_counts[c]; smooth_i = c; }
}
r_printf("smooth_best=%d votes=[%d,%d,%d,%d]\n", smooth_i, s_vote_counts[0], s_vote_counts[1], s_vote_counts[2], s_vote_counts[3]);
}
memmove(window_buf, window_buf + stride, (filled - stride) * sizeof(int16_t));
filled -= stride;
}
}
// ======== Public interface: release resources ========
extern "C" void kws_close()
{
// 在 STM32 上:对象与张量内存均为静态存储,不进行释放/析构
// 仅复位运行状态并停止流式处理
g_inited = false;
filled = 0;
g_preproc_interpreter = nullptr;
}
12、传入yes/no的数据进行测试
init_tflm();
int frame_length_ms = kws_get_frame_length_ms();
int frame_stride_ms = kws_get_frame_stride_ms();
int required_samples_per_inference = kws_get_required_samples_per_inference();
log_info("SR=%d, frame_len=%dms, frame_stride=%dms, required_samples=%d, wav_samples=%d\n",
required_sr, frame_length_ms, frame_stride_ms, required_samples_per_inference, num_samples);
int16_t in160pf[320];
//uint16_t window_buf[1840];
int index = 0;
uint8_t score = 0;
int get_no_int16_data(uint16_t *s_cnt, int16_t *data, uint16_t points, uint8_t ch);
int get_yes_int16_data(uint16_t *s_cnt, int16_t *data, uint16_t points, uint8_t ch);
uint16_t yes_s_cnt = 0;
for (int i = 0; i < 300; i++) {
// get_yes_int16_data(&yes_s_cnt, in160pf, 160, 1);
get_no_int16_data(&yes_s_cnt, in160pf, 160, 1);
//on_audio_chunk_samples(in160pf, 160);
on_audio_chunk_10ms(in160pf);
}