Chisel 实践二维脉动阵列计算复数矩阵乘法

发布于:2022-12-18 ⋅ 阅读:(449) ⋅ 点赞:(0)

脉动阵列,其核心概念是让数据在运算单元的阵列中进行流动,减少访存的次数,并且使得结构更加规整,布线更加统一,提高频率。

基于二维脉动阵列的矩阵乘法计算一种实现思想如下引用的文章已经讲清楚,本文重在用Chisel去模拟实践二维脉动阵列的复数矩阵乘法,阅读本文前请阅读如下的前置文章:

基于二维脉动阵列的矩阵乘法

数据类型定义

首先定义我们的基本数据单位为定点数类型,定义数据的位宽如下:

trait HasDataConfig {
  /* 定点数的位宽定义 */
  val DataWidth = 64
  val BinaryPoint = 20
}

然后定义我们的复数计算类型输入输出

class Complex extends Bundle with HasDataConfig{
  /* 复数的 实部,虚部*/
  val re: FixedPoint = FixedPoint(DataWidth.W, BinaryPoint.BP)
  val im: FixedPoint = FixedPoint(DataWidth.W, BinaryPoint.BP)
}

此外增加复数的运算操作数

class ComplexOperationIO extends Bundle with HasDataConfig{
  val op1: Complex = Input(new Complex())
  val op2: Complex = Input(new Complex())
  val res: Complex = Output(new Complex())
}

复数运算定义

为了方便后续对复数进行运算,这里直接定义了复数的计算模块,包括加减乘这几个方便的操作

class ComplexAdd extends Module {
  val io = IO(new ComplexOperationIO)
  io.res.re := io.op1.re + io.op2.re
  io.res.im := io.op1.im + io.op2.im
}

class ComplexSub extends Module {
  val io = IO(new ComplexOperationIO)
  io.res.re := io.op1.re - io.op2.re
  io.res.im := io.op1.im - io.op2.im
}

class ComplexMul extends Module {
  val io  = IO(new ComplexOperationIO)
  val k1 = io.op2.re * (io.op1.re + io.op1.im)
  val k2 = io.op1.re * (io.op2.im - io.op2.re)
  val k3 = io.op1.im * (io.op2.re + io.op2.im)
  io.res.re := k1 - k3
  io.res.im := k1 + k2
}

为了方便Chisel中其他模块直接调用这个组合逻辑模块,我们为其都创建伴生对象,并定义一个工厂方法简化模块的例化和连线。

object ComplexAdd{
  def apply(op1: Complex, op2: Complex): Complex = {
    val add = Module(new ComplexAdd)
    add.io.op1 := op1
    add.io.op2 := op2
    add.io.res
  }
}

object ComplexSub{
  def apply(op1: Complex, op2: Complex): Complex = {
    val sub = Module(new ComplexSub)
    sub.io.op1 := op1
    sub.io.op2 := op2
    sub.io.res
  }
}

object ComplexMul{
  def apply(op1: Complex, op2: Complex): Complex = {
    val mul = Module(new ComplexMul)
    mul.io.op1 := op1
    mul.io.op2 := op2
    mul.io.res
  }
}

PE计算单元定义

如前置文章讲解矩阵乘法运算中的PE单元的设计所示
在这里插入图片描述

我们将定义一个乘累加单元,并向下向右继续传输同方向输入的数据:

  • 定义一个复数类型寄存器,初始化为定点数的0,用于存储乘累加后的U值
  • 定义两个复数类型的寄存器,用于存储输入的X,Y值,并接受时钟激励向外继续传递X,Y值
  • 通过复位的信号reset来将所有的寄存器初始化为定点数0的复数
class PE extends Module with HasDataConfig {
  /*
   * @reset: 复位信号
   * @in_y: 输入PE的纵向数据
   * @in_y: 输入PE的纵向数据
   * @out_pe: 输出PE单元的累加结果
   * @out_x: 输出PE的横向数据
   * @out_y: 输出PE的纵向数据
   * details: PE单元,将横向纵向的数据输入后与当前的寄存数据进行累加保存,
              然后分别输出横向纵向的数据
  **/
  val io = IO(new Bundle {
    val reset: Bool = Input(new Bool)
    val in_x: Complex = Input(new Complex)
    val in_y: Complex = Input(new Complex)
    val out_pe: Complex = Output(new Complex)
    val out_x: Complex = Output(new Complex)
    val out_y: Complex = Output(new Complex)
  })

  /* 内部存储单元 */
  val pe_reg: Complex = Reg(new Complex)
  val x_reg: Complex = Reg(new Complex)
  val y_reg: Complex = Reg(new Complex)

  when(io.reset) {
    /* 复位 */
    pe_reg.re := 0.F(DataWidth.W, BinaryPoint.BP)
    pe_reg.im := 0.F(DataWidth.W, BinaryPoint.BP)
    x_reg.re := 0.F(DataWidth.W, BinaryPoint.BP)
    x_reg.im := 0.F(DataWidth.W, BinaryPoint.BP)
    y_reg.re := 0.F(DataWidth.W, BinaryPoint.BP)
    y_reg.im := 0.F(DataWidth.W, BinaryPoint.BP)
  }.otherwise {
    /* 累加 */
    pe_reg := ComplexAdd(pe_reg, ComplexMul(io.in_x, io.in_y))
    /* 传递数值 */
    x_reg := io.in_x
    y_reg := io.in_y
  }

  /* 输出 */
  io.out_pe := pe_reg
  io.out_x := x_reg
  io.out_y := y_reg
}

复数矩阵乘法

定义输入

  • 通过类参数控制矩阵的大小
  • 定义reset来初始化和复位所有的PE阵列单元
  • 通过握手信号ready控制输入,valid表示计算完毕
class matrix_mul_v1(n: Int, m: Int, p: Int) extends Module with HasDataConfig {
  /*
   * @ready: 输入使能,使能上升沿则开始计算
   * @matrixA: 输入矩阵A n行*m列 复数类型的二维矩阵
   * @matrixB: 输入矩阵B m行*p列 复数类型的二维矩阵
   * @matrixC: 输出矩阵C n行*p列 复数类型的二维矩阵
   * @valid: 输出结果有效
   * details: 通过二维脉动阵列(n,p)计算矩阵A(n*m)×B(m*p)=Z(n*p)
              首先将数据全部存储到寄存器堆,再流入二维脉动阵列计算。
              通过使能激发运算,运算完成后输出valid有效信号
  **/
  val io = IO(new Bundle {
    val reset: Bool = Input(Bool())
    val ready: Bool = Input(Bool())
    val matrixA: Vec[Complex] = Input(Vec(n * m, new Complex))
    val matrixB: Vec[Complex] = Input(Vec(m * p, new Complex))
    val matrixC: Vec[Complex] = Output(Vec(n * p, new Complex))
    val valid: Bool = Output(Bool())

定义内部的重要变量

  • 定义两个寄存器堆分别存储输入的矩阵AB,目的是输入一次数据后长期有效
  • 定义一个input_point指向最后一个数据的位置,如matrixA(n-1,m-1)这个数据流动到了PE阵列的最后一个单元U(n-1,p-1)完成计算,即input_point移动范围是[0, ((n-1+m)+p))
    • 注意这里 ((n-1+m)+p) 中 (n-1+m) 是最后一个数据在输入PE阵列单元前移动的次数,p为阵列单元中移动的次数
  • 定义PE阵列,即实例化 N*P 个 PE模块
/* 定义寄存器暂存需要计算的数据 */
val regsA: Vec[Complex] = Reg(Vec(n * m, new Complex))
val regsB: Vec[Complex] = Reg(Vec(m * p, new Complex))
val input_point: UInt = Reg(UInt((log2Up(m + n + p) + 1).W)) // 指向输入数据流动的位置,范围[0,3*m),最终最后一个数据流动到最后一个PE

/* 二维脉动阵列计算,每次流水传递计算长度小的一个矩阵 */
val PEs = Seq.fill(n * p)(Module(new PE).io)

状态机设计

由于我们设计了握手信号ready,valid,所以需要一个简单的状态机

  • **输入状态。**当ready信号出现,将矩阵输入寄存器堆,并将input_point置为0,同时复位PE阵列
  • **计算状态。**当ready信号消失计算完成,即开始计算,每个时钟周期将input_point加1以标识数据移动的位置,同时PE阵列的复位信号为0.
  • **输出状态。**当input_point >= (m + p + n - 2)时,即数据已经完全移动到了PE阵列的最后一个PE,输出结果矩阵。
when(io.reset) {
  /* 复位 */
  for (i <- 0 until n * m) {
    regsA(i).re := 0.F(DataWidth.W, BinaryPoint.BP)
    regsA(i).im := 0.F(DataWidth.W, BinaryPoint.BP)
  }
  for (i <- 0 until m * p) {
    regsB(i).re := 0.F(DataWidth.W, BinaryPoint.BP)
    regsB(i).im := 0.F(DataWidth.W, BinaryPoint.BP)
  }
  input_point := 0.U
  for (i <- 0 until n * p) {
    PEs(i).reset := 1.B
  }
}.elsewhen(io.ready) {
  /* 当输入信号出现,开始输入,存储数据 */
  regsA := io.matrixA
  regsB := io.matrixB
  /* 初始化二维脉动阵列 */
  input_point := 0.U
  for (i <- 0 until n * p) {
    PEs(i).reset := 1.B
  }
}.otherwise {
  for (i <- 0 until n * p) {
    PEs(i).reset := 0.B
  }
  input_point := input_point + 1.U // 每个时钟前进一个数据
}

/* 输出 */
when(input_point >= (m + p + n - 2).U) {
  /* 当最后一个数据的指针指到(从0开始计算的) m + p + n - 2位置时,即PEs的最右下角的计算完成位置 */
  io.valid := 1.B
}.otherwise {
  io.valid := 0.B
}
for (i <- 0 until n * p) {
  io.matrixC(i) := PEs(i).out_pe
}

排布输入PE阵列的数据

由于输入的数据序列比较特殊,如下所示,不断向右向下传递:

在这里插入图片描述

则我们的数据需要提前规整成如上的格式。一种实现方法是所有的外部数据也存储为寄存器,向右向下移动进入PE阵列即可,但这样对寄存器的浪费较多。我们采用其他方法。

由于我们已经存储好了两个矩阵,我们可以把输入的矩阵组织一个连线构成输入格式的新矩阵,用input_point的值来识别是哪一列X哪一行Y输入PE阵列。

  • 创建如下的两个wire_matB, wire_matA的矩阵

在这里插入图片描述

/* 创建连线矩阵寄存器到PEs */
val wire_matA: Vec[Complex] = Wire(Vec(n * (m + n - 1), new Complex)) // n行pe要输入m+n次数据
val wire_matB: Vec[Complex] = Wire(Vec(p * (m + p - 1), new Complex)) // n行pe要输入m+p次数据
for (i <- 0 until n) {
  for (j <- 0 until (m + n - 1)) {
    wire_matA(i * (m + n - 1) + j).re := 0.F(DataWidth.W, BinaryPoint.BP)
    wire_matA(i * (m + n - 1) + j).im := 0.F(DataWidth.W, BinaryPoint.BP)
  }
  for (j <- i until i + m) {
    wire_matA(i * (m + n - 1) + j) := regsA(i * m + (j - i))
  }
}
for (i <- 0 until p) {
  for (j <- 0 until (m + p - 1)) {
    wire_matB(i * (m + p - 1) + j).re := 0.F(DataWidth.W, BinaryPoint.BP)
    wire_matB(i * (m + p - 1) + j).im := 0.F(DataWidth.W, BinaryPoint.BP)
  }
  for (j <- i until i + m) {
    wire_matB(i * (m + p - 1) + j) := regsB((j - i) * m + i)
  }
}

PE阵列的计算

PE阵列中的计算这里分为

  • 输入侧的处理,即二维阵列中的X输入(i,0)单元,Y输入(0,j)单元;
  • 内部PE的流动,列式输入侧(i,0)单元向下流动数据,行输入侧(0,j)单元向右流动数据,其他内部PE单元接受左上的输入,向下向右流动数据
  /* PEs阵列计算 */
  /* 处理PE阵列的输入,将wire连线连接到PEs的输入端 */
  when(input_point < (m + n - 1).U) {
    /* x方向的输入 */
    for (i <- 0 until n) {
      PEs(i * p + 0).in_x := wire_matA(i.U * (m + n - 1).U + input_point)
    }
  }.otherwise {
    for (i <- 0 until n) {
      PEs(i * p + 0).in_x.re := 0.F(DataWidth.W, BinaryPoint.BP)
      PEs(i * p + 0).in_x.im := 0.F(DataWidth.W, BinaryPoint.BP)
    }
  }
  when(input_point < (m + p - 1).U) {
    /* y方向的输入 */
    for (i <- 0 until p) {
      PEs(i).in_y := wire_matB(i.U * (m + p - 1).U + input_point)
    }
  }.otherwise {
    for (i <- 0 until p) {
      PEs(i).in_y.re := 0.F(DataWidth.W, BinaryPoint.BP)
      PEs(i).in_y.im := 0.F(DataWidth.W, BinaryPoint.BP)
    }
  }
  /* PEs内部数据流动与计算 */
  for (i <- 1 until n) {
    PEs(i * p + 0).in_y := PEs((i - 1) * p + 0).out_y
  }
  for (i <- 1 until p) {
    PEs(i).in_x := PEs(i - 1).out_x
  }
  for (i <- 1 until n) {
    for (j <- 1 until p) {
      PEs(i * p + j).in_y := PEs((i - 1) * p + j).out_y
      PEs(i * p + j).in_x := PEs(i * p + j - 1).out_x
    }
  }

Chisel测试

  • 模拟(43)矩阵与(33)矩阵相乘
  • 模拟两个输入矩阵,随机数生成矩阵,并做矩阵乘法最后和模块的结果做对比
  • 第0周期,复位模块进行初始化
  • 第1周期,输入数据ready,输入数据,
  • 第2~12周期,输入数据ready=0,进行计算每次打印结果矩阵可以看到数据传递的情况
import org.scalatest._
import chisel3._
import chiseltest._
import chisel3.experimental._
import scala.math._

class matrix_mul_Tester extends FlatSpec with ChiselScalatestTester with Matchers {
  behavior of "mytest2"
  it should "do something" in{
    val n = 4
    val m = 3
    val p = 3
    test(new matrix_mul_v1(n,m,p)){ c =>
      // 定义待计算的矩阵
      var matirxA_re = new Array[Double](n*m)
      var matirxA_im = new Array[Double](n*m)
      var matirxB_re = new Array[Double](p*m)
      var matirxB_im = new Array[Double](p*m)

      println(s"复位\n")
      c.io.reset.poke(1.B) // 复位
      c.clock.step(1)

      println(s"开始输入数据")
      c.io.reset.poke(0.B) // 复位
      c.io.ready.poke(1.B) // 输入数据
      for(i <- 0 until n){
        for(j <- 0 until m) {
          matirxA_re(i*m+j) = scala.util.Random.nextDouble()*100 // 生成随机数据
          matirxA_im(i*m+j) = scala.util.Random.nextDouble()*100
          c.io.matrixA(i*m+j).re.poke(FixedPoint.fromDouble(matirxA_re(i*m+j),64.W, 20.BP))
          c.io.matrixA(i*m+j).im.poke(FixedPoint.fromDouble(matirxA_im(i*m+j),64.W, 20.BP))
          println(s"matrixA(${i},${j})  ${c.io.matrixA(i*m+j).re.peek}+i${c.io.matrixA(i*m+j).im.peek}")
        }
      }
      println("")
      for(i <- 0 until m){
        for(j <- 0 until p) {
          matirxB_re(i*p+j) = scala.util.Random.nextDouble()*100
          matirxB_im(i*p+j) = scala.util.Random.nextDouble()*100
          c.io.matrixB(i*p+j).re.poke(FixedPoint.fromDouble(matirxB_re(i*p+j),64.W, 20.BP))
          c.io.matrixB(i*p+j).im.poke(FixedPoint.fromDouble(matirxB_im(i*p+j),64.W, 20.BP))
          println(s"matrixB(${i},${j})  ${c.io.matrixB(i*p+j).re.peek}+i${c.io.matrixB(i*p+j).im.peek}")
        }
      }
      c.clock.step(1)

      // 开始计算
      c.io.ready.poke(0.B) // 输入数据
      for(time <- 1 to n+m+p ){
        println(s"\n****************周期${time}****************")
        println(s"valid  ${c.io.valid.peek}")
        for(i <- 0 until n){
          for(j <- 0 until p){
            println(s"matrixC(${i},${j})  ${c.io.matrixC(i*p+j).re.peek}+i${c.io.matrixC(i*p+j).im.peek}")
          }
        }
        c.clock.step(1)
      }
  }
}

测试结果如下:

****************周期10****************
valid  Bool(true)
matrixC(0,0)  FixedPoint<31><<20>>(696.8377819061279)+iFixedPoint<35><<20>>(14813.185137748718)
matrixC(0,1)  FixedPoint<30><<20>>(376.97445487976074)+iFixedPoint<35><<20>>(12928.625484466553)
matrixC(0,2)  FixedPoint<33><<20>>(2071.8445892333984)+iFixedPoint<35><<20>>(11199.166589736938)
matrixC(1,0)  FixedPoint<35><<20>>(9670.568314552307)+iFixedPoint<35><<20>>(10087.168369293213)
matrixC(1,1)  FixedPoint<34><<20>>(5600.734077453613)+iFixedPoint<35><<20>>(10833.343472480774)
matrixC(1,2)  FixedPoint<34><<20>>(7653.225340843201)+iFixedPoint<35><<20>>(9973.20855140686)
matrixC(2,0)  FixedPoint<34><<20>>(5903.603125572205)+iFixedPoint<36><<20>>(21499.827312469482)
matrixC(2,1)  FixedPoint<31><<20>>(-890.9886846542358)+iFixedPoint<35><<20>>(11743.788872718811)
matrixC(2,2)  FixedPoint<32><<20>>(1502.2475490570068)+iFixedPoint<36><<20>>(18703.109066963196)
matrixC(3,0)  FixedPoint<33><<20>>(3016.4313774108887)+iFixedPoint<35><<20>>(10400.436962127686)
matrixC(3,1)  FixedPoint<30><<20>>(283.43966579437256)+iFixedPoint<35><<20>>(11639.99439907074)
matrixC(3,2)  FixedPoint<31><<20>>(802.3278121948242)+iFixedPoint<34><<20>>(6876.339378356934)

标准结果
opposite_matrixC(0,0)  696.8378540282138 + i(14813.185136490176)
opposite_matrixC(0,1)  376.9745304380208 + i(12928.625522200751)
opposite_matrixC(0,2)  2071.8446125733003 + i(11199.16654025088)
opposite_matrixC(1,0)  9670.568324498112 + i(10087.168354827681)
opposite_matrixC(1,1)  5600.73412919027 + i(10833.343517194138)
opposite_matrixC(1,2)  7653.22524402835 + i(9973.20855644613)
opposite_matrixC(2,0)  5903.603131971 + i(21499.82731039136)
opposite_matrixC(2,1)  -890.9886494960549 + i(11743.788918002465)
opposite_matrixC(2,2)  1502.2474641165672 + i(18703.108998519714)
opposite_matrixC(3,0)  3016.431469134278 + i(10400.436974575334)
opposite_matrixC(3,1)  283.43969340100386 + i(11639.99445848738)
opposite_matrixC(3,2)  802.3277719330887 + i(6876.339323880517)
本文含有隐藏内容,请 开通VIP 后查看

网站公告

今日签到

点亮在社区的每一天
去签到