xref: /XiangShan/src/main/scala/xiangshan/backend/fu/Alu.scala (revision a58e33519795596dc4f85fe66907cbc7dde2d66a)
1/***************************************************************************************
2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3* Copyright (c) 2020-2021 Peng Cheng Laboratory
4*
5* XiangShan is licensed under Mulan PSL v2.
6* You can use this software according to the terms and conditions of the Mulan PSL v2.
7* You may obtain a copy of Mulan PSL v2 at:
8*          http://license.coscl.org.cn/MulanPSL2
9*
10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13*
14* See the Mulan PSL v2 for more details.
15***************************************************************************************/
16
17package xiangshan.backend.fu
18
19import chipsalliance.rocketchip.config.Parameters
20import chisel3._
21import chisel3.util._
22import utils.{LookupTree, LookupTreeDefault, ParallelMux, SignExt, ZeroExt}
23import xiangshan._
24
25class AddModule(implicit p: Parameters) extends XSModule {
26  val io = IO(new Bundle() {
27    val src = Vec(2, Input(UInt(XLEN.W)))
28    val srcw = Input(UInt((XLEN/2).W))
29    val add = Output(UInt(XLEN.W))
30    val addw = Output(UInt((XLEN/2).W))
31  })
32  io.add := io.src(0) + io.src(1)
33  // TODO: why this extra adder?
34  io.addw := io.srcw + io.src(1)(31,0)
35}
36
37class SubModule(implicit p: Parameters) extends XSModule {
38  val io = IO(new Bundle() {
39    val src = Vec(2, Input(UInt(XLEN.W)))
40    val sub = Output(UInt((XLEN+1).W))
41  })
42  io.sub := (io.src(0) +& (~io.src(1)).asUInt()) + 1.U
43}
44
45class LeftShiftModule(implicit p: Parameters) extends XSModule {
46  val io = IO(new Bundle() {
47    val shamt = Input(UInt(6.W))
48    val revShamt = Input(UInt(6.W))
49    val sllSrc = Input(UInt(XLEN.W))
50    val sll = Output(UInt(XLEN.W))
51    val revSll = Output(UInt(XLEN.W))
52  })
53  io.sll := io.sllSrc << io.shamt
54  io.revSll := io.sllSrc << io.revShamt
55}
56
57class LeftShiftWordModule(implicit p: Parameters) extends XSModule {
58  val io = IO(new Bundle() {
59    val shamt = Input(UInt(5.W))
60    val revShamt = Input(UInt(5.W))
61    val sllSrc = Input(UInt((XLEN/2).W))
62    val sllw = Output(UInt((XLEN/2).W))
63    val revSllw = Output(UInt((XLEN/2).W))
64  })
65  io.sllw := io.sllSrc << io.shamt
66  io.revSllw := io.sllSrc << io.revShamt
67}
68
69class RightShiftModule(implicit p: Parameters) extends XSModule {
70  val io = IO(new Bundle() {
71    val shamt = Input(UInt(6.W))
72    val revShamt = Input(UInt(6.W))
73    val srlSrc, sraSrc = Input(UInt(XLEN.W))
74    val srl, sra = Output(UInt(XLEN.W))
75    val revSrl = Output(UInt(XLEN.W))
76  })
77  io.srl  := io.srlSrc >> io.shamt
78  io.sra  := (io.sraSrc.asSInt() >> io.shamt).asUInt()
79  io.revSrl  := io.srlSrc >> io.revShamt
80}
81
82class RightShiftWordModule(implicit p: Parameters) extends XSModule {
83  val io = IO(new Bundle() {
84    val shamt = Input(UInt(5.W))
85    val revShamt = Input(UInt(5.W))
86    val srlSrc, sraSrc = Input(UInt((XLEN/2).W))
87    val srlw, sraw = Output(UInt((XLEN/2).W))
88    val revSrlw = Output(UInt((XLEN/2).W))
89  })
90
91  io.srlw := io.srlSrc >> io.shamt
92  io.sraw := (io.sraSrc.asSInt() >> io.shamt).asUInt()
93  io.revSrlw := io.srlSrc >> io.revShamt
94}
95
96
97class MiscResultSelect(implicit p: Parameters) extends XSModule {
98  val io = IO(new Bundle() {
99    val func = Input(UInt(5.W))
100    val andn, orn, xnor, and, or, xor, orh48, sextb, sexth, zexth, rev8, orcb = Input(UInt(XLEN.W))
101    val src = Input(UInt(XLEN.W))
102    val miscRes = Output(UInt(XLEN.W))
103  })
104
105  val logicResSel = ParallelMux(List(
106    ALUOpType.andn  -> io.andn,
107    ALUOpType.and   -> io.and,
108    ALUOpType.orn   -> io.orn,
109    ALUOpType.or    -> io.or,
110    ALUOpType.xnor  -> io.xnor,
111    ALUOpType.xor   -> io.xor,
112    ALUOpType.orh48 -> io.orh48,
113    ALUOpType.orc_b -> io.orcb
114  ).map(x => (x._1(2, 0) === io.func(2, 0), x._2)))
115  val maskedLogicRes = Cat(Fill(63, ~io.func(3)), 1.U(1.W)) & logicResSel
116
117  val miscRes = ParallelMux(List(
118    ALUOpType.sext_b -> io.sextb,
119    ALUOpType.sext_h -> io.sexth,
120    ALUOpType.zext_h -> io.zexth,
121    ALUOpType.rev8   -> io.rev8,
122    ALUOpType.szewl1 -> Cat(0.U(31.W), io.src(31, 0), 0.U(1.W)),
123    ALUOpType.szewl2 -> Cat(0.U(30.W), io.src(31, 0), 0.U(2.W)),
124    ALUOpType.szewl3 -> Cat(0.U(29.W), io.src(31, 0), 0.U(3.W)),
125    ALUOpType.byte2  -> Cat(0.U(56.W), io.src(15, 8))
126  ).map(x => (x._1(2, 0) === io.func(2, 0), x._2)))
127
128  io.miscRes := Mux(io.func(3) && !io.func(4), miscRes, maskedLogicRes)
129}
130
131class ShiftResultSelect(implicit p: Parameters) extends XSModule {
132  val io = IO(new Bundle() {
133    val func = Input(UInt())
134    val sll, srl, sra, rol, ror, bclr, bset, binv, bext = Input(UInt(XLEN.W))
135    val shiftRes = Output(UInt(XLEN.W))
136  })
137
138  val leftBit  = Mux(io.func(1), io.binv, Mux(io.func(0), io.bset, io.bclr))
139  val leftRes  = Mux(io.func(2), leftBit, io.sll)
140  val rightRes = Mux(io.func(2), io.sra, Mux(io.func(1), io.bext, io.srl))
141
142  io.shiftRes := Mux(io.func(4), Mux(io.func(3), io.ror, io.rol), Mux(io.func(3), rightRes, leftRes))
143}
144
145class WordResultSelect(implicit p: Parameters) extends XSModule {
146  val io = IO(new Bundle() {
147    val func = Input(UInt())
148    val sllw, srlw, sraw, rolw, rorw, addw, subw = Input(UInt((XLEN/2).W))
149    val wordRes = Output(UInt(XLEN.W))
150  })
151
152  val addsubRes = Mux(io.func(6), io.subw, io.addw)
153  val shiftRes = Mux(io.func(4),
154                  Mux(io.func(3), io.rorw, io.rolw),
155                  Mux(io.func(3),
156                    Mux(io.func(2), io.sraw, io.srlw),
157                    io.sllw))
158  val wordRes = Mux(io.func(6,5) === 2.U, shiftRes, addsubRes)
159  io.wordRes := SignExt(wordRes, XLEN)
160}
161
162
163class AluResSel(implicit p: Parameters) extends XSModule {
164  val io = IO(new Bundle() {
165    val func = Input(UInt())
166    val addRes, shiftRes, miscRes, compareRes, wordRes = Input(UInt(XLEN.W))
167    val aluRes = Output(UInt(XLEN.W))
168  })
169
170  val res = Mux(io.func(7), io.wordRes, Mux(io.func(6),
171    Mux(io.func(5), io.compareRes, io.shiftRes),
172    Mux(io.func(5), io.addRes, io.miscRes)
173  ))
174  io.aluRes := res
175}
176
177class AluDataModule(implicit p: Parameters) extends XSModule {
178  val io = IO(new Bundle() {
179    val src = Vec(2, Input(UInt(XLEN.W)))
180    val func = Input(FuOpType())
181    val pred_taken, isBranch = Input(Bool())
182    val result = Output(UInt(XLEN.W))
183    val taken, mispredict = Output(Bool())
184  })
185  val (src1, src2, func) = (io.src(0), io.src(1), io.func)
186
187  val addModule = Module(new AddModule)
188  // For 64-bit adder:
189  // BITS(2, 1): shamt (0, 1, 2, 3)
190  // BITS(3   ): different fused cases
191  val wordMaskAddSource = Cat(Fill(32, func(0)), Fill(32, 1.U)) & src1
192  val shaddSource = VecInit(Seq(
193    Cat(wordMaskAddSource(62, 0), 0.U(1.W)),
194    Cat(wordMaskAddSource(61, 0), 0.U(2.W)),
195    Cat(wordMaskAddSource(60, 0), 0.U(3.W)),
196    Cat(wordMaskAddSource(59, 0), 0.U(4.W))
197  ))
198  val sraddSource = VecInit(Seq(
199    ZeroExt(src1(63, 29), XLEN),
200    ZeroExt(src1(63, 30), XLEN),
201    ZeroExt(src1(63, 31), XLEN),
202    ZeroExt(src1(63, 32), XLEN)
203  ))
204  // TODO: use decoder or other libraries to optimize timing
205  // Now we assume shadd has the worst timing.
206  addModule.io.src(0) := Mux(ALUOpType.isShAdd(func), shaddSource(func(2, 1)),
207    Mux(ALUOpType.isSrAdd(func), sraddSource(func(2, 1)),
208    Mux(ALUOpType.isAddOddBit(func), ZeroExt(src1(0), XLEN), wordMaskAddSource))
209  )
210  addModule.io.src(1) := src2
211  val add = addModule.io.add
212  // For 32-bit adder: its source comes from lower 32bits or lowest bit.
213  addModule.io.srcw := Mux(ALUOpType.isAddOddBit(func), ZeroExt(src1(0), XLEN), src1(31,0))
214  val byteMask = Cat(Fill(56, ~func(1)), 0xff.U(8.W))
215  val bitMask = Cat(Fill(63, ~func(2)), 0x1.U(1.W))
216  val addw = addModule.io.addw & byteMask & bitMask
217
218  val subModule = Module(new SubModule)
219  val sub  = subModule.io.sub
220  val subw = subModule.io.sub
221  subModule.io.src(0) := src1
222  subModule.io.src(1) := src2
223
224  val shamt = src2(5, 0)
225  val revShamt = ~src2(5,0) + 1.U
226
227  val leftShiftModule = Module(new LeftShiftModule)
228  val sll = leftShiftModule.io.sll
229  val revSll = leftShiftModule.io.revSll
230  leftShiftModule.io.sllSrc := Cat(Fill(32, func(0)), Fill(32,1.U)) & src1
231  leftShiftModule.io.shamt := shamt
232  leftShiftModule.io.revShamt := revShamt
233
234  val leftShiftWordModule = Module(new LeftShiftWordModule)
235  val sllw = leftShiftWordModule.io.sllw
236  val revSllw = leftShiftWordModule.io.revSllw
237  leftShiftWordModule.io.sllSrc := src1
238  leftShiftWordModule.io.shamt := shamt
239  leftShiftWordModule.io.revShamt := revShamt
240
241  val rightShiftModule = Module(new RightShiftModule)
242  val srl = rightShiftModule.io.srl
243  val revSrl = rightShiftModule.io.revSrl
244  val sra = rightShiftModule.io.sra
245  rightShiftModule.io.shamt := shamt
246  rightShiftModule.io.revShamt := revShamt
247  rightShiftModule.io.srlSrc := src1
248  rightShiftModule.io.sraSrc := src1
249
250  val rightShiftWordModule = Module(new RightShiftWordModule)
251  val srlw = rightShiftWordModule.io.srlw
252  val revSrlw = rightShiftWordModule.io.revSrlw
253  val sraw = rightShiftWordModule.io.sraw
254  rightShiftWordModule.io.shamt := shamt
255  rightShiftWordModule.io.revShamt := revShamt
256  rightShiftWordModule.io.srlSrc := src1
257  rightShiftWordModule.io.sraSrc := src1
258
259  val rol = revSrl | sll
260  val ror = srl | revSll
261  val rolw = revSrlw | sllw
262  val rorw = srlw | revSllw
263
264  val bitShift = 1.U << src2(5, 0)
265  val bset = src1 | bitShift
266  val bclr = src1 & ~bitShift
267  val binv = src1 ^ bitShift
268  val bext = srl(0)
269
270  val andn    = src1 & ~src2
271  val orn     = src1 | ~src2
272  val xnor    = src1 ^ ~src2
273  val and     = src1 & src2
274  val or      = src1 | src2
275  val xor     = src1 ^ src2
276  val orh48   = Cat(src1(63, 8), 0.U(8.W)) | src2
277  val sgtu    = sub(XLEN)
278  val sltu    = !sgtu
279  val slt     = xor(XLEN-1) ^ sltu
280  // val maxMin  = Mux(slt ^ func(0), src2, src1)
281  // val maxMinU = Mux(sltu^ func(0), src2, src1)
282  val maxMin  = Mux(slt ^ func(0), src2, src1)
283  val maxMinU = Mux((sgtu && func(0)) || ~(sgtu && func(0)), src2, src1)
284  val sextb   = SignExt(src1(7, 0), XLEN)
285  val sexth   = SignExt(src1(15, 0), XLEN)
286  val zexth   = ZeroExt(src1(15, 0), XLEN)
287  val rev8    = Cat(src1(7,0), src1(15,8), src1(23,16), src1(31,24),
288                    src1(39,32), src1(47,40), src1(55,48), src1(63,56))
289  val orcb    = Cat(Reverse(src1(63,56)), Reverse(src1(55,48)), Reverse(src1(47,40)), Reverse(src1(39,32)),
290                    Reverse(src1(31,24)), Reverse(src1(23,16)), Reverse(src1(15,8)), Reverse(src1(7,0)))
291
292  val branchOpTable = List(
293    ALUOpType.getBranchType(ALUOpType.beq)  -> !xor.orR,
294    ALUOpType.getBranchType(ALUOpType.blt)  -> slt,
295    ALUOpType.getBranchType(ALUOpType.bltu) -> sltu
296  )
297  val taken = LookupTree(ALUOpType.getBranchType(func), branchOpTable) ^ ALUOpType.isBranchInvert(func)
298
299
300  // Result Select
301
302  val compareRes = Mux(func(2), Mux(func(1), maxMin, maxMinU), Mux(func(1), slt, Mux(func(0), sltu, sub)))
303
304  val shiftResSel = Module(new ShiftResultSelect)
305  shiftResSel.io.func := func(4,0)
306  shiftResSel.io.sll  := sll
307  shiftResSel.io.srl  := srl
308  shiftResSel.io.sra  := sra
309  shiftResSel.io.rol  := rol
310  shiftResSel.io.ror  := ror
311  shiftResSel.io.bclr := bclr
312  shiftResSel.io.binv := binv
313  shiftResSel.io.bset := bset
314  shiftResSel.io.bext := bext
315  val shiftRes = shiftResSel.io.shiftRes
316
317  val miscResSel = Module(new MiscResultSelect)
318  miscResSel.io.func    := func(4, 0)
319  miscResSel.io.andn    := andn
320  miscResSel.io.orn     := orn
321  miscResSel.io.xnor    := xnor
322  miscResSel.io.and     := and
323  miscResSel.io.or      := or
324  miscResSel.io.xor     := xor
325  miscResSel.io.orh48   := orh48
326  miscResSel.io.sextb   := sextb
327  miscResSel.io.sexth   := sexth
328  miscResSel.io.zexth   := zexth
329  miscResSel.io.rev8    := rev8
330  miscResSel.io.orcb    := orcb
331  miscResSel.io.src     := src1
332  val miscRes = miscResSel.io.miscRes
333
334  val wordResSel = Module(new WordResultSelect)
335  wordResSel.io.func := func
336  wordResSel.io.addw := addw
337  wordResSel.io.subw := subw
338  wordResSel.io.sllw := sllw
339  wordResSel.io.srlw := srlw
340  wordResSel.io.sraw := sraw
341  wordResSel.io.rolw := rolw
342  wordResSel.io.rorw := rorw
343  val wordRes = wordResSel.io.wordRes
344
345  val aluResSel = Module(new AluResSel)
346  aluResSel.io.func := func
347  aluResSel.io.addRes := add
348  aluResSel.io.compareRes := compareRes
349  aluResSel.io.shiftRes := shiftRes
350  aluResSel.io.miscRes := miscRes
351  aluResSel.io.wordRes := wordRes
352  val aluRes = aluResSel.io.aluRes
353
354  io.result := aluRes
355  io.taken := taken
356  io.mispredict := (io.pred_taken ^ taken) && io.isBranch
357}
358
359class Alu(implicit p: Parameters) extends FUWithRedirect {
360
361  val (src1, src2, func, pc, uop) = (
362    io.in.bits.src(0),
363    io.in.bits.src(1),
364    io.in.bits.uop.ctrl.fuOpType,
365    SignExt(io.in.bits.uop.cf.pc, AddrBits),
366    io.in.bits.uop
367  )
368
369  val valid = io.in.valid
370  val isBranch = ALUOpType.isBranch(func)
371  val dataModule = Module(new AluDataModule)
372
373  dataModule.io.src(0) := src1
374  dataModule.io.src(1) := src2
375  dataModule.io.func := func
376  dataModule.io.pred_taken := uop.cf.pred_taken
377  dataModule.io.isBranch := isBranch
378
379  redirectOutValid := io.out.valid && isBranch
380  redirectOut := DontCare
381  redirectOut.level := RedirectLevel.flushAfter
382  redirectOut.roqIdx := uop.roqIdx
383  redirectOut.ftqIdx := uop.cf.ftqPtr
384  redirectOut.ftqOffset := uop.cf.ftqOffset
385  redirectOut.cfiUpdate.isMisPred := dataModule.io.mispredict
386  redirectOut.cfiUpdate.taken := dataModule.io.taken
387  redirectOut.cfiUpdate.predTaken := uop.cf.pred_taken
388
389  io.in.ready := io.out.ready
390  io.out.valid := valid
391  io.out.bits.uop <> io.in.bits.uop
392  io.out.bits.data := dataModule.io.result
393}
394