xref: /XiangShan/src/main/scala/xiangshan/cache/dcache/mainpipe/AMOALU.scala (revision 708ceed4afe43fb0ea3a52407e46b2794c573634)
1/***************************************************************************************
2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3* Copyright (c) 2020-2021 Peng Cheng Laboratory
4*
5* XiangShan is licensed under Mulan PSL v2.
6* You can use this software according to the terms and conditions of the Mulan PSL v2.
7* You may obtain a copy of Mulan PSL v2 at:
8*          http://license.coscl.org.cn/MulanPSL2
9*
10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13*
14* See the Mulan PSL v2 for more details.
15***************************************************************************************/
16
17package xiangshan.cache
18
19import chisel3._
20import chisel3.util._
21
22class StoreGen(typ: UInt, addr: UInt, dat: UInt, maxSize: Int) {
23  val size = typ(log2Up(log2Up(maxSize)+1)-1,0)
24  def misaligned =
25    (addr & ((1.U << size) - 1.U)(log2Up(maxSize)-1,0)).orR
26
27  def mask = {
28    var res = 1.U
29    for (i <- 0 until log2Up(maxSize)) {
30      val upper = Mux(addr(i), res, 0.U) | Mux(size >= (i+1).U, ((BigInt(1) << (1 << i))-1).U, 0.U)
31      val lower = Mux(addr(i), 0.U, res)
32      res = Cat(upper, lower)
33    }
34    res
35  }
36
37  protected def genData(i: Int): UInt =
38    if (i >= log2Up(maxSize)) dat
39    else Mux(size === i.U, Fill(1 << (log2Up(maxSize)-i), dat((8 << i)-1,0)), genData(i+1))
40
41  def data = genData(0)
42  def wordData = genData(2)
43}
44
45class LoadGen(typ: UInt, signed: Bool, addr: UInt, dat: UInt, zero: Bool, maxSize: Int) {
46  private val size = new StoreGen(typ, addr, dat, maxSize).size
47
48  private def genData(logMinSize: Int): UInt = {
49    var res = dat
50    for (i <- log2Up(maxSize)-1 to logMinSize by -1) {
51      val pos = 8 << i
52      val shifted = Mux(addr(i), res(2*pos-1,pos), res(pos-1,0))
53      val doZero = (i == 0).B && zero
54      val zeroed = Mux(doZero, 0.U, shifted)
55      res = Cat(Mux(size === i.U || doZero, Fill(8*maxSize-pos, signed && zeroed(pos-1)), res(8*maxSize-1,pos)), zeroed)
56    }
57    res
58  }
59
60  def wordData = genData(2)
61  def data = genData(0)
62}
63
64class AMOALU(operandBits: Int) extends Module
65  with MemoryOpConstants {
66  val minXLen = 32
67  val widths = (0 to log2Ceil(operandBits / minXLen)).map(minXLen << _)
68
69  val io = IO(new Bundle {
70    val mask = Input(UInt((operandBits/8).W))
71    val cmd = Input(Bits(M_SZ.W))
72    val lhs = Input(Bits(operandBits.W))
73    val rhs = Input(Bits(operandBits.W))
74    val out = Output(Bits(operandBits.W))
75    val out_unmasked = Output(Bits(operandBits.W))
76  })
77
78  val max = io.cmd === M_XA_MAX || io.cmd === M_XA_MAXU
79  val min = io.cmd === M_XA_MIN || io.cmd === M_XA_MINU
80  val add = io.cmd === M_XA_ADD
81  val logic_and = io.cmd === M_XA_OR || io.cmd === M_XA_AND
82  val logic_xor = io.cmd === M_XA_XOR || io.cmd === M_XA_OR
83
84  val adder_out = {
85    // partition the carry chain to support sub-xLen addition
86    val mask = ~(0.U(operandBits.W) +: widths.init.map(w => !io.mask(w/8-1) << (w-1))).reduce(_|_)
87    (io.lhs & mask) + (io.rhs & mask)
88  }
89
90  val less = {
91    // break up the comparator so the lower parts will be CSE'd
92    def isLessUnsigned(x: UInt, y: UInt, n: Int): Bool = {
93      if (n == minXLen) x(n-1, 0) < y(n-1, 0)
94      else x(n-1, n/2) < y(n-1, n/2) || x(n-1, n/2) === y(n-1, n/2) && isLessUnsigned(x, y, n/2)
95    }
96
97    def isLess(x: UInt, y: UInt, n: Int): Bool = {
98      val signed = {
99        val mask = M_XA_MIN ^ M_XA_MINU
100        (io.cmd & mask) === (M_XA_MIN & mask)
101      }
102      Mux(x(n-1) === y(n-1), isLessUnsigned(x, y, n), Mux(signed, x(n-1), y(n-1)))
103    }
104
105    PriorityMux(widths.reverse.map(w => (io.mask(w/8/2), isLess(io.lhs, io.rhs, w))))
106  }
107
108  val minmax = Mux(Mux(less, min, max), io.lhs, io.rhs)
109  val logic =
110    Mux(logic_and, io.lhs & io.rhs, 0.U) |
111    Mux(logic_xor, io.lhs ^ io.rhs, 0.U)
112  val out =
113    Mux(add,                    adder_out,
114    Mux(logic_and || logic_xor, logic,
115                                minmax))
116
117  val wmask = FillInterleaved(8, io.mask)
118  io.out := wmask & out | ~wmask & io.lhs
119  io.out_unmasked := out
120}
121