xref: /XiangShan/src/main/scala/device/AXI4RAM.scala (revision 4b3d9f67355a9945cd5eca46929b89c130c43c26)
1package device
2
3import chipsalliance.rocketchip.config.Parameters
4import chisel3._
5import chisel3.util._
6import freechips.rocketchip.diplomacy.{AddressSet, LazyModule, LazyModuleImp, RegionType}
7import xiangshan.HasXSParameter
8import utils.{MaskExpand}
9
10class RAMHelper(memByte: BigInt) extends BlackBox with HasXSParameter {
11  val io = IO(new Bundle {
12    val clk   = Input(Clock())
13    val en    = Input(Bool())
14    val rIdx  = Input(UInt(DataBits.W))
15    val rdata = Output(UInt(DataBits.W))
16    val wIdx  = Input(UInt(DataBits.W))
17    val wdata = Input(UInt(DataBits.W))
18    val wmask = Input(UInt(DataBits.W))
19    val wen   = Input(Bool())
20  })
21}
22
23class AXI4RAM
24(
25  address: Seq[AddressSet],
26  memByte: Long,
27  useBlackBox: Boolean = false,
28  executable: Boolean = true,
29  beatBytes: Int = 8,
30  burstLen: Int = 16
31)(implicit p: Parameters)
32  extends AXI4SlaveModule(address, executable, beatBytes, burstLen)
33{
34
35  override lazy val module = new AXI4SlaveModuleImp(this){
36
37    val split = beatBytes / 8
38    val bankByte = memByte / split
39    val offsetBits = log2Up(memByte)
40    val offsetMask = (1 << offsetBits) - 1
41
42
43    def index(addr: UInt) = ((addr & offsetMask.U) >> log2Ceil(beatBytes)).asUInt()
44
45    def inRange(idx: UInt) = idx < (memByte / beatBytes).U
46
47    val wIdx = index(waddr) + writeBeatCnt
48    val rIdx = index(raddr) + readBeatCnt
49    val wen = in.w.fire() && inRange(wIdx)
50    require(beatBytes >= 8)
51
52    val rdata = if (useBlackBox) {
53      val mems = (0 until split).map {_ => Module(new RAMHelper(bankByte))}
54      mems.zipWithIndex map { case (mem, i) =>
55        mem.io.clk   := clock
56        mem.io.en    := !reset.asBool()
57        mem.io.rIdx  := (rIdx << log2Up(split)) + i.U
58        mem.io.wIdx  := (wIdx << log2Up(split)) + i.U
59        mem.io.wdata := in.w.bits.data((i + 1) * 64 - 1, i * 64)
60        mem.io.wmask := MaskExpand(in.w.bits.strb((i + 1) * 8 - 1, i * 8))
61        mem.io.wen   := wen
62      }
63      val rdata = mems.map {mem => mem.io.rdata}
64      Cat(rdata.reverse)
65    } else {
66      val mem = Mem(memByte / beatBytes, Vec(beatBytes, UInt(8.W)))
67
68      val wdata = VecInit.tabulate(beatBytes) { i => in.w.bits.data(8 * (i + 1) - 1, 8 * i) }
69      when(wen) {
70        mem.write(wIdx, wdata, in.w.bits.strb.asBools())
71      }
72
73      Cat(mem.read(rIdx).reverse)
74    }
75    in.r.bits.data := rdata
76  }
77}
78