xref: /XiangShan/src/main/scala/xiangshan/backend/fu/Bku.scala (revision a273862e37f1d43bee748f2a6353320a2f52f6f4)
1/***************************************************************************************
2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3* Copyright (c) 2020-2021 Peng Cheng Laboratory
4*
5* XiangShan is licensed under Mulan PSL v2.
6* You can use this software according to the terms and conditions of the Mulan PSL v2.
7* You may obtain a copy of Mulan PSL v2 at:
8*          http://license.coscl.org.cn/MulanPSL2
9*
10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13*
14* See the Mulan PSL v2 for more details.
15***************************************************************************************/
16
17package xiangshan.backend.fu
18
19import chipsalliance.rocketchip.config.Parameters
20import chisel3._
21import chisel3.util._
22import utils.{LookupTreeDefault, ParallelMux, ParallelXOR, SignExt, XSDebug, ZeroExt}
23import xiangshan._
24import xiangshan.backend.fu.util._
25
26
27
28
29class CountModule(implicit p: Parameters) extends XSModule {
30  val io = IO(new Bundle() {
31    val src = Input(UInt(XLEN.W))
32    val func = Input(UInt())
33    val out = Output(UInt(XLEN.W))
34  })
35
36  val funcReg = RegNext(io.func)
37
38  def encode(bits: UInt): UInt = {
39    LookupTreeDefault(bits, 0.U, List(0.U -> 2.U(2.W), 1.U -> 1.U(2.W)))
40  }
41  def clzi(msb: Int, left: UInt, right: UInt): UInt = {
42    Mux(left(msb),
43      Cat(left(msb) && right(msb), !right(msb), if(msb==1)right(0) else right(msb-1, 0)),
44      left)
45  }
46
47  val c0 = Wire(Vec(32, UInt(2.W)))
48  val c1 = Wire(Vec(16, UInt(3.W)))
49  val c2 = Reg(Vec(8, UInt(4.W)))
50  val c3 = Wire(Vec(4, UInt(5.W)))
51  val c4 = Wire(Vec(2, UInt(6.W)))
52
53  val countSrc = Mux(io.func(1), Reverse(io.src), io.src)
54
55  for(i <- 0 until 32){ c0(i) := encode(countSrc(2*i+1, 2*i)) }
56  for(i <- 0 until 16){ c1(i) := clzi(1, c0(i*2+1), c0(i*2)) }
57  for(i <- 0 until 8){ c2(i) := clzi(2, c1(i*2+1), c1(i*2)) }
58  for(i <- 0 until 4){ c3(i) := clzi(3, c2(i*2+1), c2(i*2)) }
59  for(i <- 0 until 2){ c4(i) := clzi(4, c3(i*2+1), c3(i*2)) }
60  val zeroRes = clzi(5, c4(1), c4(0))
61  val zeroWRes = Mux(funcReg(1), c4(1), c4(0))
62
63  val cpopTmp = Reg(Vec(4, UInt(5.W)))
64
65  for(i <- 0 until 4){
66    cpopTmp(i) := PopCount(io.src(i*16+15, i*16))
67  }
68
69  val cpopLo32 = cpopTmp(0) +& cpopTmp(1)
70  val cpopHi32 = cpopTmp(2) +& cpopTmp(3)
71
72  val cpopRes = cpopLo32 +& cpopHi32
73  val cpopWRes = cpopLo32
74
75  io.out := Mux(funcReg(2), Mux(funcReg(0), cpopWRes, cpopRes), Mux(funcReg(0), zeroWRes, zeroRes))
76}
77
78class ClmulModule(implicit p: Parameters) extends XSModule {
79  val io = IO(new Bundle() {
80    val src = Vec(2, Input(UInt(XLEN.W)))
81    val func = Input(UInt())
82    val out = Output(UInt(XLEN.W))
83  })
84
85  val funcReg = RegNext(io.func)
86
87  val (src1, src2) = (io.src(0), io.src(1))
88
89  val mul0 = Wire(Vec(64, UInt(128.W)))
90  val mul1 = Wire(Vec(32, UInt(128.W)))
91  val mul2 = Wire(Vec(16, UInt(128.W)))
92  val mul3 = Reg(Vec(8, UInt(128.W)))
93
94  (0 until XLEN) map { i =>
95    mul0(i) := Mux(src1(i), if(i==0) src2 else Cat(src2, 0.U(i.W)), 0.U)
96  }
97
98  (0 until 32) map { i => mul1(i) := mul0(i*2) ^ mul0(i*2+1)}
99  (0 until 16) map { i => mul2(i) := mul1(i*2) ^ mul1(i*2+1)}
100  (0 until 8) map { i => mul3(i) := mul2(i*2) ^ mul2(i*2+1)}
101
102  val res = ParallelXOR(mul3)
103
104  val clmul  = res(63,0)
105  val clmulh = res(127,64)
106  val clmulr = res(126,63)
107
108  io.out := LookupTreeDefault(funcReg, clmul, List(
109    BKUOpType.clmul  -> clmul,
110    BKUOpType.clmulh -> clmulh,
111    BKUOpType.clmulr -> clmulr
112  ))
113}
114
115class MiscModule(implicit p: Parameters) extends XSModule {
116  val io = IO(new Bundle() {
117    val src = Vec(2, Input(UInt(XLEN.W)))
118    val func = Input(UInt())
119    val out = Output(UInt(XLEN.W))
120  })
121
122  val (src1, src2) = (io.src(0), io.src(1))
123
124  def xpermLUT(table: UInt, idx: UInt, width: Int) : UInt = {
125    // ParallelMux((0 until XLEN/width).map( i => i.U -> table(i)).map( x => (x._1 === idx, x._2)))
126    LookupTreeDefault(idx, 0.U(width.W), (0 until XLEN/width).map( i => i.U -> table(i*width+width-1, i*width)))
127  }
128
129  val xpermnVec = Wire(Vec(16, UInt(4.W)))
130  (0 until 16).map( i => xpermnVec(i) := xpermLUT(src1, src2(i*4+3, i*4), 4))
131  val xpermn = Cat(xpermnVec.reverse)
132
133  val xpermbVec = Wire(Vec(8, UInt(8.W)))
134  (0 until 8).map( i => xpermbVec(i) := Mux(src2(i*8+7, i*8+3).orR, 0.U, xpermLUT(src1, src2(i*8+2, i*8), 8)))
135  val xpermb = Cat(xpermbVec.reverse)
136
137  io.out := RegNext(Mux(io.func(0), xpermb, xpermn))
138}
139
140class HashModule(implicit p: Parameters) extends XSModule {
141  val io = IO(new Bundle() {
142    val src = Input(UInt(XLEN.W))
143    val func = Input(UInt())
144    val out = Output(UInt(XLEN.W))
145  })
146
147  val src1 = io.src
148
149  val sha256sum0 = ROR32(src1, 2)  ^ ROR32(src1, 13) ^ ROR32(src1, 22)
150  val sha256sum1 = ROR32(src1, 6)  ^ ROR32(src1, 11) ^ ROR32(src1, 25)
151  val sha256sig0 = ROR32(src1, 7)  ^ ROR32(src1, 18) ^ SHR32(src1, 3)
152  val sha256sig1 = ROR32(src1, 17) ^ ROR32(src1, 19) ^ SHR32(src1, 10)
153  val sha512sum0 = ROR64(src1, 28) ^ ROR64(src1, 34) ^ ROR64(src1, 39)
154  val sha512sum1 = ROR64(src1, 14) ^ ROR64(src1, 18) ^ ROR64(src1, 41)
155  val sha512sig0 = ROR64(src1, 1)  ^ ROR64(src1, 8)  ^ SHR64(src1, 7)
156  val sha512sig1 = ROR64(src1, 19) ^ ROR64(src1, 61) ^ SHR64(src1, 6)
157  val sm3p0      = ROR32(src1, 23) ^ ROR32(src1, 15) ^ src1
158  val sm3p1      = ROR32(src1, 9)  ^ ROR32(src1, 17) ^ src1
159
160  val shaSource = VecInit(Seq(
161    SignExt(sha256sum0(31,0), XLEN),
162    SignExt(sha256sum1(31,0), XLEN),
163    SignExt(sha256sig0(31,0), XLEN),
164    SignExt(sha256sig1(31,0), XLEN),
165    sha512sum0,
166    sha512sum1,
167    sha512sig0,
168    sha512sig1
169  ))
170  val sha = shaSource(io.func(2,0))
171  val sm3 = Mux(io.func(0), SignExt(sm3p1(31,0), XLEN), SignExt(sm3p0(31,0), XLEN))
172
173  io.out := RegNext(Mux(io.func(3), sm3, sha))
174}
175
176class BlockCipherModule(implicit p: Parameters) extends XSModule {
177  val io = IO(new Bundle() {
178    val src = Vec(2, Input(UInt(XLEN.W)))
179    val func = Input(UInt())
180    val out = Output(UInt(XLEN.W))
181  })
182
183  val (src1, src2, func, funcReg) = (io.src(0), io.src(1), io.func, RegNext(io.func))
184
185  val src1Bytes = VecInit((0 until 8).map(i => src1(i*8+7, i*8)))
186  val src2Bytes = VecInit((0 until 8).map(i => src2(i*8+7, i*8)))
187
188  // AES
189  val aesSboxIn  = ForwardShiftRows(src1Bytes, src2Bytes)
190  val aesSboxMid  = Reg(Vec(8, Vec(18, Bool())))
191  val aesSboxOut  = Wire(Vec(8, UInt(8.W)))
192
193  val iaesSboxIn = InverseShiftRows(src1Bytes, src2Bytes)
194  val iaesSboxMid  = Reg(Vec(8, Vec(18, Bool())))
195  val iaesSboxOut = Wire(Vec(8, UInt(8.W)))
196
197  aesSboxOut.zip(aesSboxMid).zip(aesSboxIn)foreach { case ((out, mid), in) =>
198    mid := SboxInv(SboxAesTop(in))
199    out := SboxAesOut(mid)
200  }
201
202  iaesSboxOut.zip(iaesSboxMid).zip(iaesSboxIn)foreach { case ((out, mid), in) =>
203    mid := SboxInv(SboxIaesTop(in))
204    out := SboxIaesOut(mid)
205  }
206
207  val aes64es = aesSboxOut.asUInt
208  val aes64ds = iaesSboxOut.asUInt
209
210  val imMinIn  = RegNext(src1Bytes)
211
212  val aes64esm = Cat(MixFwd(Seq(aesSboxOut(4), aesSboxOut(5), aesSboxOut(6), aesSboxOut(7))),
213                     MixFwd(Seq(aesSboxOut(0), aesSboxOut(1), aesSboxOut(2), aesSboxOut(3))))
214  val aes64dsm = Cat(MixInv(Seq(iaesSboxOut(4), iaesSboxOut(5), iaesSboxOut(6), iaesSboxOut(7))),
215                     MixInv(Seq(iaesSboxOut(0), iaesSboxOut(1), iaesSboxOut(2), iaesSboxOut(3))))
216  val aes64im  = Cat(MixInv(Seq(imMinIn(4), imMinIn(5), imMinIn(6), imMinIn(7))),
217                     MixInv(Seq(imMinIn(0), imMinIn(1), imMinIn(2), imMinIn(3))))
218
219
220  val rcon = WireInit(VecInit(Seq("h01".U, "h02".U, "h04".U, "h08".U,
221                                  "h10".U, "h20".U, "h40".U, "h80".U,
222                                  "h1b".U, "h36".U, "h00".U)))
223
224  val ksSboxIn  = Wire(Vec(4, UInt(8.W)))
225  val ksSboxTop = Reg(Vec(4, Vec(21, Bool())))
226  val ksSboxOut = Wire(Vec(4, UInt(8.W)))
227  ksSboxIn(0) := Mux(src2(3,0) === "ha".U, src1Bytes(4), src1Bytes(5))
228  ksSboxIn(1) := Mux(src2(3,0) === "ha".U, src1Bytes(5), src1Bytes(6))
229  ksSboxIn(2) := Mux(src2(3,0) === "ha".U, src1Bytes(6), src1Bytes(7))
230  ksSboxIn(3) := Mux(src2(3,0) === "ha".U, src1Bytes(7), src1Bytes(4))
231  ksSboxOut.zip(ksSboxTop).zip(ksSboxIn).foreach{ case ((out, top), in) =>
232    top := SboxAesTop(in)
233    out := SboxAesOut(SboxInv(top))
234    }
235
236  val ks1Idx = RegNext(src2(3,0))
237  val aes64ks1i = Cat(ksSboxOut.asUInt ^ rcon(ks1Idx), ksSboxOut.asUInt ^ rcon(ks1Idx))
238
239  val aes64ks2Temp = src1(63,32) ^ src2(31,0)
240  val aes64ks2 = RegNext(Cat(aes64ks2Temp ^ src2(63,32), aes64ks2Temp))
241
242  val aesResult = LookupTreeDefault(funcReg, aes64es, List(
243    BKUOpType.aes64es   -> aes64es,
244    BKUOpType.aes64esm  -> aes64esm,
245    BKUOpType.aes64ds   -> aes64ds,
246    BKUOpType.aes64dsm  -> aes64dsm,
247    BKUOpType.aes64im   -> aes64im,
248    BKUOpType.aes64ks1i -> aes64ks1i,
249    BKUOpType.aes64ks2  -> aes64ks2
250  ))
251
252  // SM4
253  val sm4SboxIn  = src2Bytes(func(1,0))
254  val sm4SboxTop = Reg(Vec(21, Bool()))
255  sm4SboxTop := SboxSm4Top(sm4SboxIn)
256  val sm4SboxOut = SboxSm4Out(SboxInv(sm4SboxTop))
257
258  val sm4ed = sm4SboxOut ^ (sm4SboxOut<<8) ^ (sm4SboxOut<<2) ^ (sm4SboxOut<<18) ^ (sm4SboxOut&"h3f".U<<26) ^ (sm4SboxOut&"hc0".U<<10)
259  val sm4ks = sm4SboxOut ^ (sm4SboxOut&"h07".U<<29) ^ (sm4SboxOut&"hfe".U<<7) ^ (sm4SboxOut&"h01".U<<23) ^ (sm4SboxOut&"hf8".U<<13)
260  val sm4Source = VecInit(Seq(
261    sm4ed(31,0),
262    Cat(sm4ed(23,0), sm4ed(31,24)),
263    Cat(sm4ed(15,0), sm4ed(31,16)),
264    Cat(sm4ed( 7,0), sm4ed(31,8)),
265    sm4ks(31,0),
266    Cat(sm4ks(23,0), sm4ks(31,24)),
267    Cat(sm4ks(15,0), sm4ks(31,16)),
268    Cat(sm4ks( 7,0), sm4ks(31,8))
269  ))
270  val sm4Result = SignExt((sm4Source(funcReg(2,0)) ^ RegNext(src1(31,0)))(31,0), XLEN)
271
272  io.out := Mux(funcReg(3), sm4Result, aesResult)
273}
274
275class CryptoModule(implicit p: Parameters) extends XSModule {
276  val io = IO(new Bundle() {
277    val src = Vec(2, Input(UInt(XLEN.W)))
278    val func = Input(UInt())
279    val out = Output(UInt(XLEN.W))
280  })
281
282  val (src1, src2, func) = (io.src(0), io.src(1), io.func)
283  val funcReg = RegNext(func)
284
285  val hashModule = Module(new HashModule)
286  hashModule.io.src := src1
287  hashModule.io.func := func
288
289  val blockCipherModule = Module(new BlockCipherModule)
290  blockCipherModule.io.src(0) := src1
291  blockCipherModule.io.src(1) := src2
292  blockCipherModule.io.func := func
293
294  io.out := Mux(funcReg(4), hashModule.io.out, blockCipherModule.io.out)
295}
296
297class Bku(implicit p: Parameters) extends FunctionUnit with HasPipelineReg {
298
299  override def latency = 1
300
301  val (src1, src2, func, funcReg) = (
302    io.in.bits.src(0),
303    io.in.bits.src(1),
304    io.in.bits.uop.ctrl.fuOpType,
305    uopVec(latency).ctrl.fuOpType
306  )
307
308  val countModule = Module(new CountModule)
309  countModule.io.src := src1
310  countModule.io.func := func
311
312  val clmulModule = Module(new ClmulModule)
313  clmulModule.io.src(0) := src1
314  clmulModule.io.src(1) := src2
315  clmulModule.io.func := func
316
317  val miscModule = Module(new MiscModule)
318  miscModule.io.src(0) := src1
319  miscModule.io.src(1) := src2
320  miscModule.io.func := func
321
322  val cryptoModule = Module(new CryptoModule)
323  cryptoModule.io.src(0) := src1
324  cryptoModule.io.src(1) := src2
325  cryptoModule.io.func := func
326
327
328  val result = Mux(funcReg(5), cryptoModule.io.out,
329                  Mux(funcReg(3), countModule.io.out,
330                      Mux(funcReg(2),miscModule.io.out, clmulModule.io.out)))
331
332  io.out.bits.data := result
333}
334