xref: /XiangShan/src/main/scala/device/AXI4RAM.scala (revision 367512b707c976b7ff3fa2e0a4cf1b35a5c1d3c2)
1package device
2
3import chipsalliance.rocketchip.config.Parameters
4import chisel3._
5import chisel3.util._
6import freechips.rocketchip.diplomacy.{AddressSet, LazyModule, LazyModuleImp, RegionType}
7import xiangshan.HasXSParameter
8import utils.{MaskExpand}
9
10class RAMHelper(memByte: BigInt) extends BlackBox with HasXSParameter {
11  val io = IO(new Bundle {
12    val clk   = Input(Clock())
13    val en    = Input(Bool())
14    val rIdx  = Input(UInt(DataBits.W))
15    val rdata = Output(UInt(DataBits.W))
16    val wIdx  = Input(UInt(DataBits.W))
17    val wdata = Input(UInt(DataBits.W))
18    val wmask = Input(UInt(DataBits.W))
19    val wen   = Input(Bool())
20  })
21}
22
23class AXI4RAM
24(
25  address: Seq[AddressSet],
26  memByte: Long,
27  useBlackBox: Boolean = false,
28  executable: Boolean = true,
29  beatBytes: Int = 8,
30  burstLen: Int = 16
31)(implicit p: Parameters)
32  extends AXI4SlaveModule(address, executable, beatBytes, burstLen)
33{
34
35  override lazy val module = new AXI4SlaveModuleImp(this){
36
37    val split = beatBytes / 8
38    val bankByte = memByte / split
39    val offsetBits = log2Up(memByte)
40
41    require(address.length >= 1)
42    val baseAddress = address(0).base
43
44    def index(addr: UInt) = ((addr - baseAddress.U)(offsetBits - 1, 0) >> log2Ceil(beatBytes)).asUInt()
45
46    def inRange(idx: UInt) = idx < (memByte / beatBytes).U
47
48    val wIdx = index(waddr) + writeBeatCnt
49    val rIdx = index(raddr) + readBeatCnt
50    val wen = in.w.fire() && inRange(wIdx)
51    require(beatBytes >= 8)
52
53    val rdata = if (useBlackBox) {
54      val mems = (0 until split).map {_ => Module(new RAMHelper(bankByte))}
55      mems.zipWithIndex map { case (mem, i) =>
56        mem.io.clk   := clock
57        mem.io.en    := !reset.asBool() && ((state === s_rdata) || (state === s_wdata))
58        mem.io.rIdx  := (rIdx << log2Up(split)) + i.U
59        mem.io.wIdx  := (wIdx << log2Up(split)) + i.U
60        mem.io.wdata := in.w.bits.data((i + 1) * 64 - 1, i * 64)
61        mem.io.wmask := MaskExpand(in.w.bits.strb((i + 1) * 8 - 1, i * 8))
62        mem.io.wen   := wen
63      }
64      val rdata = mems.map {mem => mem.io.rdata}
65      Cat(rdata.reverse)
66    } else {
67      val mem = Mem(memByte / beatBytes, Vec(beatBytes, UInt(8.W)))
68
69      val wdata = VecInit.tabulate(beatBytes) { i => in.w.bits.data(8 * (i + 1) - 1, 8 * i) }
70      when(wen) {
71        mem.write(wIdx, wdata, in.w.bits.strb.asBools())
72      }
73
74      Cat(mem.read(rIdx).reverse)
75    }
76    in.r.bits.data := rdata
77  }
78}
79