1package device 2 3import chipsalliance.rocketchip.config.Parameters 4import chisel3._ 5import chisel3.util._ 6import freechips.rocketchip.diplomacy.{AddressSet, LazyModule, LazyModuleImp, RegionType} 7import xiangshan.HasXSParameter 8import utils.{MaskExpand} 9 10class RAMHelper(memByte: BigInt) extends BlackBox with HasXSParameter { 11 val io = IO(new Bundle { 12 val clk = Input(Clock()) 13 val rIdx = Input(UInt(DataBits.W)) 14 val rdata = Output(UInt(DataBits.W)) 15 val wIdx = Input(UInt(DataBits.W)) 16 val wdata = Input(UInt(DataBits.W)) 17 val wmask = Input(UInt(DataBits.W)) 18 val wen = Input(Bool()) 19 }) 20} 21 22class AXI4RAM 23( 24 address: Seq[AddressSet], 25 memByte: Long, 26 useBlackBox: Boolean = false, 27 executable: Boolean = true, 28 beatBytes: Int = 8, 29 burstLen: Int = 16 30)(implicit p: Parameters) 31 extends AXI4SlaveModule(address, executable, beatBytes, burstLen) 32{ 33 34 override lazy val module = new AXI4SlaveModuleImp(this){ 35 36 val split = beatBytes / 8 37 val bankByte = memByte / split 38 val offsetBits = log2Up(memByte) 39 val offsetMask = (1 << offsetBits) - 1 40 41 42 def index(addr: UInt) = ((addr & offsetMask.U) >> log2Ceil(beatBytes)).asUInt() 43 44 def inRange(idx: UInt) = idx < (memByte / beatBytes).U 45 46 val wIdx = index(waddr) + writeBeatCnt 47 val rIdx = index(raddr) + readBeatCnt 48 val wen = in.w.fire() && inRange(wIdx) 49 require(beatBytes >= 8) 50 51 val rdata = if (useBlackBox) { 52 val mems = (0 until split).map {_ => Module(new RAMHelper(bankByte))} 53 mems.zipWithIndex map { case (mem, i) => 54 mem.io.clk := clock 55 mem.io.rIdx := (rIdx << log2Up(split)) + i.U 56 mem.io.wIdx := (wIdx << log2Up(split)) + i.U 57 mem.io.wdata := in.w.bits.data((i + 1) * 64 - 1, i * 64) 58 mem.io.wmask := MaskExpand(in.w.bits.strb((i + 1) * 8 - 1, i * 8)) 59 mem.io.wen := wen 60 } 61 val rdata = mems.map {mem => mem.io.rdata} 62 Cat(rdata.reverse) 63 } else { 64 val mem = Mem(memByte / beatBytes, Vec(beatBytes, UInt(8.W))) 65 66 val wdata = VecInit.tabulate(beatBytes) { i => in.w.bits.data(8 * (i + 1) - 1, 8 * i) } 67 when(wen) { 68 mem.write(wIdx, wdata, in.w.bits.strb.asBools()) 69 } 70 71 Cat(mem.read(rIdx).reverse) 72 } 73 in.r.bits.data := rdata 74 } 75} 76