

PE(process element)是阵列最小单元,实现一维乘加计算。然后PEs规律地按照行列排列成二维的Tile,需要注意Tile里每一个PE是纯组合地连接,PE之间并无寄存器存储中间结果。Tiles则流水地排列起来组成整个脉动阵列,Tile之间插入了寄存器。这一点从后面的代码分析可以了解得更清楚。需要计算的数据则被提前按照一定规律或是存入脉动阵列或是排列在脉动阵列周围的存储器(bank)里。


object  Constant {def inputType = UInt(32.W)def outputType = UInt(32.W) def accType = UInt(32.W)def OS = false.Bdef WS = true.Bdef df = OSdef latency = 0def tileRows=16def tileColumns=16def meshRows=16def meshColumns=16


脉动阵列最核心代码为Mesh模块,包括Tile、PE、PEcontrol等子模块,最基础单元为PE模型。由于Gemmini中脉动阵列(systolic matrix)模块支持两种数据流,OS(output stationary)和WS(weight stationary)模式,因此PE结构较为复杂,输入控制端信号较多。甚至自定义了IO端口模块PEcontrol,源码如下:

class PEControl extends Bundle{//define two control signalval propagate = UInt(1.W)val dataflow = UInt(1.W)


  • 端口定义:
  val io = IO(new Bundle {val in_a = Input(inputType)val in_b = Input(outputType)val in_d = Input(outputType)val out_a = Output(inputType)val out_b = Output(outputType)val out_c = Output(outputType)val in_control = Input(new PEControl(accType))val out_control = Output(new PEControl(accType))val in_valid = Input(Bool())val out_valid = Output(Bool())})

矩阵乘法:C = A ∗ B + D,D是偏置矩阵。输入有in_a\in_b\in_d,以及控制信号in_control,和valid信号;输出信号有,直接传递的out_a,以及out_b\out_c,传递给下一个PE的控制信号out_control,和valid信号.

  • 源码分析


  val cType = if (df == Dataflow.WS) inputType else accTypeval a  = ShiftRegister(io.in_a, latency)val b  = ShiftRegister(io.in_b, latency)val d  = ShiftRegister(io.in_d, latency)
//val c1 = Reg(cType)val c2 = Reg(cType)val dataflow = ShiftRegister(io.in_control.dataflow, latency)val prop  = ShiftRegister(io.in_control.propagate, latency)val shift = ShiftRegister(io.in_control.shift, latency)val valid = ShiftRegister(io.in_valid, latency) // TODO should we clockgate the rest of the ShiftRegisters based on the values in this ShiftRegisters
//*after latency cycle ,input signal passthrough to output.
//*why output a ?
//In both dataflow,a will always propagate. io.out_a := aio.out_control.dataflow := dataflowio.out_control.propagate := propio.out_control.shift := shiftio.out_valid := valid



  • 参数
/*** A Tile is a purely combinational 2D array of passThrough PEs.* a, b, s, and in_propag are broadcast across the entire arrayand are passed through to the Tile's outputs* @param width The data width of each PE in bits* @param rows Number of PEs on each row* @param columns Number of PEs on each column*///*Arithmetic ??what type?//*pe_latency: Int what mean?class Tile[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, df: Dataflow.Value, pe_latency: Int, val rows: Int, val columns: Int) extends Module {//... code skipped

T <: Data : Arithmetic,T为自定义数据类型,属于Data类的子类和自定义Arithmetic类。Arithmetic类详见Arithmetic.scala。

  • 模块端口
  val io = IO(new Bundle {val in_a        = Input(Vec(rows, inputType))val in_b        = Input(Vec(columns, outputType)) // This is the output of the tile next to itval in_d        = Input(Vec(columns, outputType))val in_control  = Input(Vec(columns, new PEControl(accType)))val out_a       = Output(Vec(rows, inputType))val out_b       = Output(Vec(columns, outputType))val out_c       = Output(Vec(columns, outputType))val out_control = Output(Vec(columns, new PEControl(accType)))val in_valid = Input(Vec(columns, Bool()))val out_valid = Output(Vec(columns, Bool()))})


  • 源码分析


  val tile = Seq.fill(rows, columns)(Module(new PE(inputType, outputType, accType, df, pe_latency)))val tileT = tile.transpose


  // Broadcast 'a' horizontally across the Tilefor (r <- 0 until rows) {tile(r).foldLeft(io.in_a(r)) {case (in_a, pe) =>pe.io.in_a := in_ape.io.out_a}}

对于每一行PE(tile( r )),首先in_a( r )连接到最右边的PE,然后再前一个输出连接到后一个输入,类似于一个链表。准确理解foldLeft函数含义即可。计算过程:首先tile(r ).foldLeft(io.in_a( r)){…}以io.in_a(r )为初始值,匹配到case里的in_a,tile(r )第一个函数元素匹配到pe,代码{pe.io.in_a := in_a,pe.io.out_a}实现in_a与pe的连接,并且运算结果为pe.io.out_a,作为下一次case里匹配到in_a的值,因此即可实现一行里每一个PE的级联。

  // Broadcast 'b' vertically across the Tilefor (c <- 0 until columns) {tileT(c).foldLeft(io.in_b(c)) {case (in_b, pe) =>pe.io.in_b := in_bpe.io.out_b}


  // Drive the Tile's bottom IOfor (c <- 0 until columns) {io.out_c(c) := tile(rows-1)(c).io.out_cio.out_b(c) := tile(rows-1)(c).io.out_bio.out_control(c) := tile(rows-1)(c).io.out_controlio.out_valid(c) := tile(rows-1)(c).io.out_valid}// Drive the Tile's right IOfor (r <- 0 until rows) {io.out_a(r) := tile(r)(columns-1).io.out_a}




  // Chain tile_a_out -> tile_a_in (pipeline a across each row)// TODO clock-gate A signals with in_garbagefor (r <- 0 until meshRows) {mesh(r).foldLeft(io.in_a(r)) {case (in_a, tile) =>tile.io.in_a := RegNext(in_a)tile.io.out_a}}




