我最近在从事一项很有意思的项目,我想在PFGA上部署CNN并实现手写图片的识别。而本篇文章,是我迈出的第二步。具体代码已发布在github上
模块介绍
卷积神经网络(CNN)可以分为卷积层、池化层、激活层、全链接层结构。本篇实现的,就是CNN的卷积层中的卷积运算模块。
卷积运算的过程如下图所示:
在权重参数已经确定的情况下,我们可以将这过程看成数据滑窗和卷积运算的这两个步骤的重复运算。在前文中,我们已经实现了window模块,而此处我们实现卷积运算模块。
运算过程如下:
[ 1 2 3 4 5 6 7 8 9 ] ∗ [ 1 0 1 0 1 0 1 1 2 ] = 1 ⋅ 1 + 2 ⋅ 0 + 3 ⋅ 1 + 4 ⋅ 0 + 5 ⋅ 1 + 6 ⋅ 0 + 7 ⋅ 1 + 8 ⋅ 1 + 9 ⋅ 2 = 42 \begin{bmatrix}1&2&3\\4&5&6\\7&8&9\end{bmatrix} \ast \begin{bmatrix}1&0&1\\ 0&1&0\\1&1&2 \end{bmatrix}=1\cdot1 +2 \cdot0 +3\cdot1+4\cdot0+5\cdot 1+6\cdot0+7\cdot1+8\cdot 1+9\cdot 2 \\ =42
147258369
∗
101011102
=1⋅1+2⋅0+3⋅1+4⋅0+5⋅1+6⋅0+7⋅1+8⋅1+9⋅2=42
代码
- 模块可配置参数、输入和输出定义
为了支持多通道并行处理,输入为所有输入通道展平后的数据,如一维的窗口数据和权重参数
DATA_WIDTH和WEIGHT_WIDTH分开定义,因为后续工作中会对权重定点数量化
module mult_acc_comb #(
parameter DATA_WIDTH = 8,
parameter KERNEL_SIZE = 3,
parameter IN_CHANNEL = 3,
parameter WEIGHT_WIDTH = 8,
parameter OUTPUT_WIDTH = 20, // 可配置的输出位宽
parameter ACC_WIDTH = 2*DATA_WIDTH + 4 + $clog2(KERNEL_SIZE*KERNEL_SIZE*IN_CHANNEL) // Ensure ACC_WIDTH is sufficient
)(
// 输入数据接口
input window_valid,
input [IN_CHANNEL*KERNEL_SIZE*KERNEL_SIZE*DATA_WIDTH-1:0] multi_channel_window_in,
input weight_valid,
input [IN_CHANNEL*KERNEL_SIZE*KERNEL_SIZE*WEIGHT_WIDTH-1:0] multi_channel_weight_in,
// 输出数据接口
output [OUTPUT_WIDTH-1:0] conv_out, // 使用可配置的输出位宽
output conv_valid
);
- 定义内部相关信号
// 计算权重相关参数
localparam WEIGHTS_PER_FILTER = IN_CHANNEL * KERNEL_SIZE * KERNEL_SIZE;
// 解包后的多通道窗口数据和权重数据,无符号
wire [DATA_WIDTH-1:0] channel_window_data [0:IN_CHANNEL-1][0:KERNEL_SIZE*KERNEL_SIZE-1];
wire [WEIGHT_WIDTH-1:0] channel_weight_data [0:IN_CHANNEL-1][0:KERNEL_SIZE*KERNEL_SIZE-1];
// 每个通道每个位置的乘法结果,无符号
wire [DATA_WIDTH+WEIGHT_WIDTH-1:0] mult_results [0:IN_CHANNEL-1][0:KERNEL_SIZE*KERNEL_SIZE-1];
// 每个通道的累加结果
wire [ACC_WIDTH-1:0] channel_sums [0:IN_CHANNEL-1];
// 最终跨通道累加结果
wire [ACC_WIDTH-1:0] total_sum;
// 循环变量
genvar ch, i_idx, k_idx, c_idx;
- 输入数据解包
generate
for (ch = 0; ch < IN_CHANNEL; ch = ch + 1) begin : unpack_gen
for (i_idx = 0; i_idx < KERNEL_SIZE*KERNEL_SIZE; i_idx = i_idx + 1) begin : element_gen
// 解包窗口数据
assign channel_window_data[ch][i_idx] = multi_channel_window_in[
(ch*KERNEL_SIZE*KERNEL_SIZE + i_idx)*DATA_WIDTH +: DATA_WIDTH ];
// 解包权重数据
assign channel_weight_data[ch][i_idx] = multi_channel_weight_in[
(WEIGHTS_PER_FILTER - 1 - (ch*KERNEL_SIZE*KERNEL_SIZE + i_idx))*WEIGHT_WIDTH +: WEIGHT_WIDTH
];
end
end
endgenerate
a[ b +: c ]的含义是,从a的b位,向上提取c位,也就是a[b+c:b+1];
输入的window和weight的数据结构变化如下
- 并行卷积运算
所有通道同时进行卷积运算
// 并行乘法 - 所有通道所有位置同时计算
generate
for (ch = 0; ch < IN_CHANNEL; ch = ch + 1) begin : mult_ch_gen
for (i_idx = 0; i_idx < KERNEL_SIZE*KERNEL_SIZE; i_idx = i_idx + 1) begin : mult_elem_gen
assign mult_results[ch][i_idx] = channel_window_data[ch][i_idx] * channel_weight_data[ch][i_idx];
end
end
endgenerate
// 每个通道内累加 - 使用组合逻辑加法树
generate
for (ch = 0; ch < IN_CHANNEL; ch = ch + 1) begin : sum_ch_gen
if (KERNEL_SIZE == 3) begin : kernel3_sum
assign channel_sums[ch] =
mult_results[ch][0] + mult_results[ch][1] + mult_results[ch][2] +
mult_results[ch][3] + mult_results[ch][4] + mult_results[ch][5] +
mult_results[ch][6] + mult_results[ch][7] + mult_results[ch][8];
end else begin : general_sum
wire [ACC_WIDTH-1:0] partial_sums [0:KERNEL_SIZE*KERNEL_SIZE-1];
assign partial_sums[0] = mult_results[ch][0];
for (k_idx = 1; k_idx < KERNEL_SIZE*KERNEL_SIZE; k_idx = k_idx + 1) begin : acc_gen
assign partial_sums[k_idx] = partial_sums[k_idx-1] + mult_results[ch][k_idx];
end
assign channel_sums[ch] = partial_sums[KERNEL_SIZE*KERNEL_SIZE-1];
end
end
endgenerate
- 跨通道累加并输出
对所有通道结果进行相加,进行饱和处理,然后输出
// 跨通道累加 - 组合逻辑
generate
if (IN_CHANNEL == 3) begin : channel3_sum
assign total_sum = channel_sums[0] + channel_sums[1] + channel_sums[2];
end else begin : general_channel_sum
wire [ACC_WIDTH-1:0] channel_partial_sums [0:IN_CHANNEL-1];
assign channel_partial_sums[0] = channel_sums[0];
for (c_idx = 1; c_idx < IN_CHANNEL; c_idx = c_idx + 1) begin : ch_acc_gen
assign channel_partial_sums[c_idx] = channel_partial_sums[c_idx-1] + channel_sums[c_idx];
end
assign total_sum = channel_partial_sums[IN_CHANNEL-1];
end
endgenerate
// 输出逻辑 - 组合逻辑
assign conv_valid = window_valid && weight_valid;
assign conv_out = conv_valid ? saturate(total_sum) : {OUTPUT_WIDTH{1'b0}};
// 饱和处理函数(组合逻辑)- UNSIGNED
function [OUTPUT_WIDTH-1:0] saturate;
input [ACC_WIDTH-1:0] value; // UNSIGNED
localparam [ACC_WIDTH-1:0] MAX_UNSIGNED_VAL_SAT = (1 << OUTPUT_WIDTH) - 1;
// MIN_UNSIGNED_VAL is 0
begin
if (value > MAX_UNSIGNED_VAL_SAT)
saturate = MAX_UNSIGNED_VAL_SAT[OUTPUT_WIDTH-1:0]; // 使用OUTPUT_WIDTH进行截取
else
saturate = value[OUTPUT_WIDTH-1:0]; // 使用OUTPUT_WIDTH进行截取
end
endfunction
测试
mult_acc_comb_tb.v
为验证其功能性,使用多个case经行测试,并对比结果
`timescale 1ns / 1ps
module mult_acc_comb_tb;
parameter DATA_WIDTH = 8;
parameter KERNEL_SIZE = 3;
parameter IN_CHANNEL = 3;
parameter WEIGHT_WIDTH = 8;
parameter OUTPUT_WIDTH = 20;
parameter ACC_WIDTH = 2*DATA_WIDTH + 4 + $clog2(KERNEL_SIZE*KERNEL_SIZE*IN_CHANNEL);
reg window_valid;
reg [IN_CHANNEL*KERNEL_SIZE*KERNEL_SIZE*DATA_WIDTH-1:0] multi_channel_window_in;
reg weight_valid;
reg [IN_CHANNEL*KERNEL_SIZE*KERNEL_SIZE*WEIGHT_WIDTH-1:0] multi_channel_weight_in;
wire [OUTPUT_WIDTH-1:0] conv_out;
wire conv_valid;
localparam MAX_UNSIGNED_OUT_VAL = (1 << OUTPUT_WIDTH) - 1;
// Example: Test 2 raw sum for unsigned context
localparam EXPECTED_SUM_TEST2_UNSIGNED_RAW = 3 * 9 * 2 * 3; // 162
localparam EXPECTED_CONV_OUT_TEST2_UNSIGNED_SAT = (EXPECTED_SUM_TEST2_UNSIGNED_RAW > MAX_UNSIGNED_OUT_VAL) ? MAX_UNSIGNED_OUT_VAL : EXPECTED_SUM_TEST2_UNSIGNED_RAW;
localparam MAX_ELEMENT_VAL_TB = (1 << DATA_WIDTH) -1;
localparam MAX_WEIGHT_ELEMENT_VAL_TB = (1 << WEIGHT_WIDTH) -1;
mult_acc_comb #(
.DATA_WIDTH(DATA_WIDTH),
.KERNEL_SIZE(KERNEL_SIZE),
.IN_CHANNEL(IN_CHANNEL),
.WEIGHT_WIDTH(WEIGHT_WIDTH),
.OUTPUT_WIDTH(OUTPUT_WIDTH),
.ACC_WIDTH(ACC_WIDTH)
) dut (
.window_valid(window_valid),
.multi_channel_window_in(multi_channel_window_in),
.weight_valid(weight_valid),
.multi_channel_weight_in(multi_channel_weight_in),
.conv_out(conv_out),
.conv_valid(conv_valid)
);
reg all_tests_passed_flag;
integer test_id_counter;
integer num_errors;
// Task to check results and display Expected/Actual for all
task check_and_report;
input [OUTPUT_WIDTH-1:0] expected_out_val;
input expected_valid_val;
// Test description is displayed before calling this task
begin
test_id_counter = test_id_counter + 1;
// Always display Expected and Actual
$display(" Expected: conv_valid=%b, conv_out=%d", expected_valid_val, expected_out_val);
$display(" Actual: conv_valid=%b, conv_out=%d", conv_valid, conv_out);
if (conv_valid === expected_valid_val &&
( (expected_valid_val === 1'b0) ? (conv_out === {OUTPUT_WIDTH{1'b0}}) : (conv_out === expected_out_val) ) ) begin
$display(" Test ID %0d: Status: PASSED", test_id_counter);
end else begin
$display(" Test ID %0d: Status: FAILED", test_id_counter);
all_tests_passed_flag = 1'b0;
num_errors = num_errors + 1;
end
$display("--------------------------------------------------");
end
endtask
initial begin
$display("=== Comprehensive UNSIGNED Combinational MultAcc Test (OUTPUT_WIDTH=%0d) ===", OUTPUT_WIDTH);
all_tests_passed_flag = 1'b1;
test_id_counter = 0;
num_errors = 0;
// Initialize
window_valid = 0;
weight_valid = 0;
multi_channel_window_in = 0;
multi_channel_weight_in = 0;
#10;
// Test 1
$display("Test Description: Simple Positive Values (1*1, sum 27)");
multi_channel_window_in = {27{8'd1}};
multi_channel_weight_in = {27{8'd1}};
window_valid = 1;
weight_valid = 1;
#1;
check_and_report(27, 1'b1);
#10;
// Test 2
$display("Test Description: Positive Values with Saturation (2*3, raw %0d, sat %0d)", EXPECTED_SUM_TEST2_UNSIGNED_RAW, EXPECTED_CONV_OUT_TEST2_UNSIGNED_SAT);
multi_channel_window_in = {27{8'd2}};
multi_channel_weight_in = {27{8'd3}};
#1;
check_and_report(EXPECTED_CONV_OUT_TEST2_UNSIGNED_SAT, 1'b1);
#10;
// Test 3
$display("Test Description: Invalid Inputs (both valid_n low)");
window_valid = 0;
weight_valid = 0;
#1;
check_and_report(0, 1'b0);
#10;
// Test 4
$display("Test Description: Zero Window Data, Non-zero Weights");
window_valid = 1;
weight_valid = 1;
multi_channel_window_in = {27{8'd0}};
multi_channel_weight_in = {27{8'd5}};
#1;
check_and_report(0, 1'b1);
#10;
// Test 5
$display("Test Description: Non-zero Window, Zero Weight Data");
multi_channel_window_in = {27{8'd5}};
multi_channel_weight_in = {27{8'd0}};
#1;
check_and_report(0, 1'b1);
#10;
// Test 6
$display("Test Description: All Zero Inputs");
multi_channel_window_in = {27{8'd0}};
multi_channel_weight_in = {27{8'd0}};
#1;
check_and_report(0, 1'b1);
#10;
// Test 7
$display("Test Description: Large values (no saturation with 20-bit output)");
multi_channel_window_in = {27{8'd5}};
multi_channel_weight_in = {27{8'd5}};
#1;
check_and_report(27*5*5, 1'b1); // 27*25 = 675, well within 20-bit range
#10;
// Test 8
$display("Test Description: Max Val Inputs (Win=%d, Wgt=%d), should saturate to %d", MAX_ELEMENT_VAL_TB, MAX_WEIGHT_ELEMENT_VAL_TB, MAX_UNSIGNED_OUT_VAL);
multi_channel_window_in = {27{{DATA_WIDTH{1'b1}}}};
multi_channel_weight_in = {27{{WEIGHT_WIDTH{1'b1}}}};
#1;
// 27 * 255 * 255 = 1,759,725, which exceeds 20-bit max (1,048,575), so should saturate
check_and_report(MAX_UNSIGNED_OUT_VAL, 1'b1);
#10;
// Test 8.5: Test 20-bit range capability
$display("Test Description: Medium values to test 20-bit range (100*100, sum 270000)");
multi_channel_window_in = {27{8'd100}};
multi_channel_weight_in = {27{8'd100}};
#1;
check_and_report(27*100*100, 1'b1); // 27*10000 = 270000, well within 20-bit range
#10;
// Test 9: Window valid toggles
$display("--- Test Sequence 9: Window Valid Toggles (base inputs 1*1, sum 27) ---");
multi_channel_window_in = {27{8'd1}};
multi_channel_weight_in = {27{8'd1}};
weight_valid = 1;
$display(" Sub-Test Description: WinValid=1 (Start)");
window_valid = 1; #1; check_and_report(27, 1'b1);
$display(" Sub-Test Description: WinValid=0");
window_valid = 0; #1; check_and_report(0, 1'b0);
$display(" Sub-Test Description: WinValid=1 (End)");
window_valid = 1; #1; check_and_report(27, 1'b1);
#10;
// Test 10: Weight valid toggles
$display("--- Test Sequence 10: Weight Valid Toggles (base inputs 1*1, sum 27) ---");
window_valid = 1;
// inputs are still 1s
$display(" Sub-Test Description: WeightValid=1 (Start)");
weight_valid = 1; #1; check_and_report(27, 1'b1);
$display(" Sub-Test Description: WeightValid=0");
weight_valid = 0; #1; check_and_report(0, 1'b0);
$display(" Sub-Test Description: WeightValid=1 (End)");
weight_valid = 1; #1; check_and_report(27, 1'b1);
#10;
// Final Summary
$display("==================================================");
if (all_tests_passed_flag) begin
$display("FINAL STATUS: SUCCESS! All %0d UNSIGNED Combinational MultAcc tests passed!", test_id_counter);
end else begin
$display("FINAL STATUS: FAILED. %0d out of %0d UNSIGNED Combinational MultAcc tests did not pass.", num_errors, test_id_counter);
end
$display("==================================================");
$finish;
end
endmodule
结果
window模块每个周期传递数据,因而采用组合逻辑实现卷积运算。当输入数据同时有效,也就是window_valid和weight_valid同时为高时,mult_acc_com进行运算,conv_valid拉高,如下图所示
输出打印结果:
=Comprehensive UNSIGNED Combinational MultAcc Test (OUTPUT_WIDTH=20) =
Test Description: Simple Positive Values (1*1, sum 27)
Expected: conv_valid=1, conv_out= 27
Actual: conv_valid=1, conv_out= 27Test ID 1: Status: PASSED
Test Description: Positive Values with Saturation (2*3, raw 162, sat 162)
Expected: conv_valid=1, conv_out= 162
Actual: conv_valid=1, conv_out= 162Test ID 2: Status: PASSED
Test Description: Invalid Inputs (both valid_n low)
Expected: conv_valid=0, conv_out= 0
Actual: conv_valid=0, conv_out= 0Test ID 3: Status: PASSED
Test Description: Zero Window Data, Non-zero Weights
Expected: conv_valid=1, conv_out= 0
Actual: conv_valid=1, conv_out= 0Test ID 4: Status: PASSED
Test Description: Non-zero Window, Zero Weight Data
Expected: conv_valid=1, conv_out= 0
Actual: conv_valid=1, conv_out= 0Test ID 5: Status: PASSED
Test Description: All Zero Inputs
Expected: conv_valid=1, conv_out= 0
Actual: conv_valid=1, conv_out= 0Test ID 6: Status: PASSED
Test Description: Large values (no saturation with 20-bit output)
Expected: conv_valid=1, conv_out= 675
Actual: conv_valid=1, conv_out= 675Test ID 7: Status: PASSED
Test Description: Max Val Inputs (Win= 255, Wgt= 255), should saturate to 1048575
Expected: conv_valid=1, conv_out=1048575
Actual: conv_valid=1, conv_out=1048575Test ID 8: Status: PASSED
Test Description: Medium values to test 20-bit range (100*100, sum 270000)
Expected: conv_valid=1, conv_out= 270000
Actual: conv_valid=1, conv_out= 270000Test ID 9: Status: PASSED
— Test Sequence 9: Window Valid Toggles (base inputs 1*1, sum 27) —
Sub-Test Description: WinValid=1 (Start)
Expected: conv_valid=1, conv_out= 27
Actual: conv_valid=1, conv_out= 27Test ID 10: Status: PASSED
Sub-Test Description: WinValid=0
Expected: conv_valid=0, conv_out= 0
Actual: conv_valid=0, conv_out= 0Test ID 11: Status: PASSED
Sub-Test Description: WinValid=1 (End)
Expected: conv_valid=1, conv_out= 27
Actual: conv_valid=1, conv_out= 27Test ID 12: Status: PASSED
— Test Sequence 10: Weight Valid Toggles (base inputs 1*1, sum 27) —
Sub-Test Description: WeightValid=1 (Start)
Expected: conv_valid=1, conv_out= 27
Actual: conv_valid=1, conv_out= 27Test ID 13: Status: PASSED
Sub-Test Description: WeightValid=0
Expected: conv_valid=0, conv_out= 0
Actual: conv_valid=0, conv_out= 0Test ID 14: Status: PASSED
Sub-Test Description: WeightValid=1 (End)
Expected: conv_valid=1, conv_out= 27
Actual: conv_valid=1, conv_out= 27Test ID 15: Status: PASSED