xref: /XiangShan/src/main/scala/device/AXI4RAM.scala (revision ab3aa7eedc9c70d560572701ea30e863011452a8)
1package device
2
3import chipsalliance.rocketchip.config.Parameters
4import chisel3._
5import chisel3.util._
6import freechips.rocketchip.diplomacy.{AddressSet, LazyModule, LazyModuleImp, RegionType}
7import xiangshan.HasXSParameter
8import utils.{MaskExpand}
9
10class RAMHelper(memByte: BigInt) extends BlackBox with HasXSParameter {
11  val io = IO(new Bundle {
12    val clk = Input(Clock())
13    val rIdx = Input(UInt(DataBits.W))
14    val rdata = Output(UInt(DataBits.W))
15    val wIdx = Input(UInt(DataBits.W))
16    val wdata = Input(UInt(DataBits.W))
17    val wmask = Input(UInt(DataBits.W))
18    val wen = Input(Bool())
19  })
20}
21
22class AXI4RAM
23(
24  address: Seq[AddressSet],
25  memByte: Long,
26  useBlackBox: Boolean = false,
27  executable: Boolean = true,
28  beatBytes: Int = 8,
29  burstLen: Int = 16
30)(implicit p: Parameters)
31  extends AXI4SlaveModule(address, executable, beatBytes, burstLen)
32{
33
34  override lazy val module = new AXI4SlaveModuleImp(this){
35
36    val split = beatBytes / 8
37    val bankByte = memByte / split
38    val offsetBits = log2Up(memByte)
39    val offsetMask = (1 << offsetBits) - 1
40
41
42    def index(addr: UInt) = ((addr & offsetMask.U) >> log2Ceil(beatBytes)).asUInt()
43
44    def inRange(idx: UInt) = idx < (memByte / beatBytes).U
45
46    val wIdx = index(waddr) + writeBeatCnt
47    val rIdx = index(raddr) + readBeatCnt
48    val wen = in.w.fire() && inRange(wIdx)
49    require(beatBytes >= 8)
50
51    val rdata = if (useBlackBox) {
52      val mems = (0 until split).map {_ => Module(new RAMHelper(bankByte))}
53      mems.zipWithIndex map { case (mem, i) =>
54        mem.io.clk := clock
55        mem.io.rIdx := (rIdx << log2Up(split)) + i.U
56        mem.io.wIdx := (wIdx << log2Up(split)) + i.U
57        mem.io.wdata := in.w.bits.data((i + 1) * 64 - 1, i * 64)
58        mem.io.wmask := MaskExpand(in.w.bits.strb((i + 1) * 8 - 1, i * 8))
59        mem.io.wen := wen
60      }
61      val rdata = mems.map {mem => mem.io.rdata}
62      Cat(rdata.reverse)
63    } else {
64      val mem = Mem(memByte / beatBytes, Vec(beatBytes, UInt(8.W)))
65
66      val wdata = VecInit.tabulate(beatBytes) { i => in.w.bits.data(8 * (i + 1) - 1, 8 * i) }
67      when(wen) {
68        mem.write(wIdx, wdata, in.w.bits.strb.asBools())
69      }
70
71      Cat(mem.read(rIdx).reverse)
72    }
73    in.r.bits.data := rdata
74  }
75}
76