xref: /XiangShan/src/main/scala/xiangshan/cache/mmu/TLBStorage.scala (revision 67ba96b4871c459c09df20e3052738174021a830)
1/***************************************************************************************
2  * Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3  * Copyright (c) 2020-2021 Peng Cheng Laboratory
4  *
5  * XiangShan is licensed under Mulan PSL v2.
6  * You can use this software according to the terms and conditions of the Mulan PSL v2.
7  * You may obtain a copy of Mulan PSL v2 at:
8  *          http://license.coscl.org.cn/MulanPSL2
9  *
10  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11  * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12  * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13  *
14  * See the Mulan PSL v2 for more details.
15  ***************************************************************************************/
16
17package xiangshan.cache.mmu
18
19import chipsalliance.rocketchip.config.Parameters
20import chisel3._
21import chisel3.experimental.chiselName
22import chisel3.util._
23import utils._
24import utility._
25import freechips.rocketchip.formal.PropertyClass
26
27import scala.math.min
28
29class BankedAsyncDataModuleTemplateWithDup[T <: Data](
30  gen: T,
31  numEntries: Int,
32  numRead: Int,
33  numDup: Int,
34  numBanks: Int
35) extends Module {
36  val io = IO(new Bundle {
37    val raddr = Vec(numRead, Input(UInt(log2Ceil(numEntries).W)))
38    val rdata = Vec(numRead, Vec(numDup, Output(gen)))
39    val wen   = Input(Bool())
40    val waddr = Input(UInt(log2Ceil(numEntries).W))
41    val wdata = Input(gen)
42  })
43  require(numBanks > 1)
44  require(numEntries > numBanks)
45
46  val numBankEntries = numEntries / numBanks
47  def bankOffset(address: UInt): UInt = {
48    address(log2Ceil(numBankEntries) - 1, 0)
49  }
50
51  def bankIndex(address: UInt): UInt = {
52    address(log2Ceil(numEntries) - 1, log2Ceil(numBankEntries))
53  }
54
55  val dataBanks = Seq.tabulate(numBanks)(i => {
56    val bankEntries = if (i < numBanks - 1) numBankEntries else (numEntries - (i * numBankEntries))
57    Mem(bankEntries, gen)
58  })
59
60  // async read, but regnext
61  for (i <- 0 until numRead) {
62    val data_read = Reg(Vec(numDup, Vec(numBanks, gen)))
63    val bank_index = Reg(Vec(numDup, UInt(numBanks.W)))
64    for (j <- 0 until numDup) {
65      bank_index(j) := UIntToOH(bankIndex(io.raddr(i)))
66      for (k <- 0 until numBanks) {
67        data_read(j)(k) := Mux(io.wen && (io.waddr === io.raddr(i)),
68          io.wdata, dataBanks(k)(bankOffset(io.raddr(i))))
69      }
70    }
71    // next cycle
72    for (j <- 0 until numDup) {
73      io.rdata(i)(j) := Mux1H(bank_index(j), data_read(j))
74    }
75  }
76
77  // write
78  for (i <- 0 until numBanks) {
79    when (io.wen && (bankIndex(io.waddr) === i.U)) {
80      dataBanks(i)(bankOffset(io.waddr)) := io.wdata
81    }
82  }
83}
84
85@chiselName
86class TLBFA(
87  parentName: String,
88  ports: Int,
89  nSets: Int,
90  nWays: Int,
91  saveLevel: Boolean = false,
92  normalPage: Boolean,
93  superPage: Boolean
94)(implicit p: Parameters) extends TlbModule with HasPerfEvents {
95
96  val io = IO(new TlbStorageIO(nSets, nWays, ports))
97  io.r.req.map(_.ready := true.B)
98
99  val v = RegInit(VecInit(Seq.fill(nWays)(false.B)))
100  val entries = Reg(Vec(nWays, new TlbEntry(normalPage, superPage)))
101  val g = entries.map(_.perm.g)
102
103  for (i <- 0 until ports) {
104    val req = io.r.req(i)
105    val resp = io.r.resp(i)
106    val access = io.access(i)
107
108    val vpn = req.bits.vpn
109    val vpn_reg = RegEnable(vpn, req.fire())
110    val vpn_gen_ppn = if(saveLevel) vpn else vpn_reg
111
112    val refill_mask = Mux(io.w.valid, UIntToOH(io.w.bits.wayIdx), 0.U(nWays.W))
113    val hitVec = VecInit((entries.zipWithIndex).zip(v zip refill_mask.asBools).map{case (e, m) => e._1.hit(vpn, io.csr.satp.asid) && m._1 && !m._2 })
114
115    hitVec.suggestName("hitVec")
116
117    val hitVecReg = RegEnable(hitVec, req.fire())
118    assert(!resp.valid || (PopCount(hitVecReg) === 0.U || PopCount(hitVecReg) === 1.U), s"${parentName} fa port${i} multi-hit")
119
120    resp.valid := RegNext(req.valid)
121    resp.bits.hit := Cat(hitVecReg).orR
122    if (nWays == 1) {
123      resp.bits.ppn(0) := entries(0).genPPN(saveLevel, req.valid)(vpn_gen_ppn)
124      resp.bits.perm(0) := entries(0).perm
125    } else {
126      resp.bits.ppn(0) := ParallelMux(hitVecReg zip entries.map(_.genPPN(saveLevel, req.valid)(vpn_gen_ppn)))
127      resp.bits.perm(0) := ParallelMux(hitVecReg zip entries.map(_.perm))
128    }
129
130    access.sets := get_set_idx(vpn_reg, nSets) // no use
131    access.touch_ways.valid := resp.valid && Cat(hitVecReg).orR
132    access.touch_ways.bits := OHToUInt(hitVecReg)
133
134    resp.bits.hit.suggestName("hit")
135    resp.bits.ppn.suggestName("ppn")
136    resp.bits.perm.suggestName("perm")
137  }
138
139  when (io.w.valid) {
140    v(io.w.bits.wayIdx) := true.B
141    entries(io.w.bits.wayIdx).apply(io.w.bits.data, io.csr.satp.asid, io.w.bits.data_replenish)
142  }
143  // write assert, shoulg not duplicate with the existing entries
144  val w_hit_vec = VecInit(entries.zip(v).map{case (e, vi) => e.hit(io.w.bits.data.entry.tag, io.csr.satp.asid) && vi })
145  XSError(io.w.valid && Cat(w_hit_vec).orR, s"${parentName} refill, duplicate with existing entries")
146
147  val refill_vpn_reg = RegNext(io.w.bits.data.entry.tag)
148  val refill_wayIdx_reg = RegNext(io.w.bits.wayIdx)
149  when (RegNext(io.w.valid)) {
150    io.access.map { access =>
151      access.sets := get_set_idx(refill_vpn_reg, nSets)
152      access.touch_ways.valid := true.B
153      access.touch_ways.bits := refill_wayIdx_reg
154    }
155  }
156
157  val sfence = io.sfence
158  val sfence_vpn = sfence.bits.addr.asTypeOf(new VaBundle().cloneType).vpn
159  val sfenceHit = entries.map(_.hit(sfence_vpn, sfence.bits.asid))
160  val sfenceHit_noasid = entries.map(_.hit(sfence_vpn, sfence.bits.asid, ignoreAsid = true))
161  when (io.sfence.valid) {
162    when (sfence.bits.rs1) { // virtual address *.rs1 <- (rs1===0.U)
163      when (sfence.bits.rs2) { // asid, but i do not want to support asid, *.rs2 <- (rs2===0.U)
164        // all addr and all asid
165        v.map(_ := false.B)
166      }.otherwise {
167        // all addr but specific asid
168        v.zipWithIndex.map{ case (a,i) => a := a & (g(i) | !(entries(i).asid === sfence.bits.asid)) }
169      }
170    }.otherwise {
171      when (sfence.bits.rs2) {
172        // specific addr but all asid
173        v.zipWithIndex.map{ case (a,i) => a := a & !sfenceHit_noasid(i) }
174      }.otherwise {
175        // specific addr and specific asid
176        v.zipWithIndex.map{ case (a,i) => a := a & !(sfenceHit(i) && !g(i)) }
177      }
178    }
179  }
180
181  val victim_idx = io.w.bits.wayIdx
182  io.victim.out.valid := v(victim_idx) && io.w.valid && entries(victim_idx).is_normalentry()
183  io.victim.out.bits.entry := ns_to_n(entries(victim_idx))
184
185  def ns_to_n(ns: TlbEntry): TlbEntry = {
186    val n = Wire(new TlbEntry(pageNormal = true, pageSuper = false))
187    n.perm := ns.perm
188    n.ppn := ns.ppn
189    n.tag := ns.tag
190    n.asid := ns.asid
191    n
192  }
193
194  XSPerfAccumulate(s"access", io.r.resp.map(_.valid.asUInt()).fold(0.U)(_ + _))
195  XSPerfAccumulate(s"hit", io.r.resp.map(a => a.valid && a.bits.hit).fold(0.U)(_.asUInt() + _.asUInt()))
196
197  for (i <- 0 until nWays) {
198    XSPerfAccumulate(s"access${i}", io.r.resp.zip(io.access.map(acc => UIntToOH(acc.touch_ways.bits))).map{ case (a, b) =>
199      a.valid && a.bits.hit && b(i)}.fold(0.U)(_.asUInt() + _.asUInt()))
200  }
201  for (i <- 0 until nWays) {
202    XSPerfAccumulate(s"refill${i}", io.w.valid && io.w.bits.wayIdx === i.U)
203  }
204
205  val perfEvents = Seq(
206    ("tlbstore_access", io.r.resp.map(_.valid.asUInt()).fold(0.U)(_ + _)                            ),
207    ("tlbstore_hit   ", io.r.resp.map(a => a.valid && a.bits.hit).fold(0.U)(_.asUInt() + _.asUInt())),
208  )
209  generatePerfEvent()
210
211  println(s"${parentName} tlb_fa: nSets${nSets} nWays:${nWays}")
212}
213
214@chiselName
215class TLBSA(
216  parentName: String,
217  ports: Int,
218  nDups: Int,
219  nSets: Int,
220  nWays: Int,
221  normalPage: Boolean,
222  superPage: Boolean
223)(implicit p: Parameters) extends TlbModule {
224  require(!superPage, "super page should use reg/fa")
225  require(nWays == 1, "nWays larger than 1 causes bad timing")
226
227  // timing optimization to divide v select into two cycles.
228  val VPRE_SELECT = min(8, nSets)
229  val VPOST_SELECT = nSets / VPRE_SELECT
230  val nBanks = 8
231
232  val io = IO(new TlbStorageIO(nSets, nWays, ports, nDups))
233
234  io.r.req.map(_.ready :=  true.B)
235  val v = RegInit(VecInit(Seq.fill(nSets)(VecInit(Seq.fill(nWays)(false.B)))))
236  val entries = Module(new BankedAsyncDataModuleTemplateWithDup(new TlbEntry(normalPage, superPage), nSets, ports, nDups, nBanks))
237
238  for (i <- 0 until ports) { // duplicate sram
239    val req = io.r.req(i)
240    val resp = io.r.resp(i)
241    val access = io.access(i)
242
243    val vpn = req.bits.vpn
244    val vpn_reg = RegEnable(vpn, req.fire())
245
246    val ridx = get_set_idx(vpn, nSets)
247    val v_resize = v.asTypeOf(Vec(VPRE_SELECT, Vec(VPOST_SELECT, UInt(nWays.W))))
248    val vidx_resize = RegNext(v_resize(get_set_idx(drop_set_idx(vpn, VPOST_SELECT), VPRE_SELECT)))
249    val vidx = vidx_resize(get_set_idx(vpn_reg, VPOST_SELECT)).asBools.map(_ && RegNext(req.fire()))
250    val vidx_bypass = RegNext((entries.io.waddr === ridx) && entries.io.wen)
251    entries.io.raddr(i) := ridx
252
253    val data = entries.io.rdata(i)
254    val hit = data(0).hit(vpn_reg, io.csr.satp.asid, nSets) && (vidx(0) || vidx_bypass)
255    resp.bits.hit := hit
256    for (d <- 0 until nDups) {
257      resp.bits.ppn(d) := data(d).genPPN()(vpn_reg)
258      resp.bits.perm(d) := data(d).perm
259    }
260
261    resp.valid := { RegNext(req.valid) }
262    resp.bits.hit.suggestName("hit")
263    resp.bits.ppn.suggestName("ppn")
264    resp.bits.perm.suggestName("perm")
265
266    access.sets := get_set_idx(vpn_reg, nSets) // no use
267    access.touch_ways.valid := resp.valid && hit
268    access.touch_ways.bits := 1.U // TODO: set-assoc need no replacer when nset is 1
269  }
270
271  // W ports should be 1, or, check at above will be wrong.
272  entries.io.wen := io.w.valid || io.victim.in.valid
273  entries.io.waddr := Mux(io.w.valid,
274    get_set_idx(io.w.bits.data.entry.tag, nSets),
275    get_set_idx(io.victim.in.bits.entry.tag, nSets))
276  entries.io.wdata := Mux(io.w.valid,
277    (Wire(new TlbEntry(normalPage, superPage)).apply(io.w.bits.data, io.csr.satp.asid, io.w.bits.data_replenish)),
278    io.victim.in.bits.entry)
279
280  when (io.victim.in.valid) {
281    v(get_set_idx(io.victim.in.bits.entry.tag, nSets))(io.w.bits.wayIdx) := true.B
282  }
283  // w has higher priority than victim
284  when (io.w.valid) {
285    v(get_set_idx(io.w.bits.data.entry.tag, nSets))(io.w.bits.wayIdx) := true.B
286  }
287
288  val refill_vpn_reg = RegNext(Mux(io.victim.in.valid, io.victim.in.bits.entry.tag, io.w.bits.data.entry.tag))
289  val refill_wayIdx_reg = RegNext(io.w.bits.wayIdx)
290  when (RegNext(io.w.valid || io.victim.in.valid)) {
291    io.access.map { access =>
292      access.sets := get_set_idx(refill_vpn_reg, nSets)
293      access.touch_ways.valid := true.B
294      access.touch_ways.bits := refill_wayIdx_reg
295    }
296  }
297
298  val sfence = io.sfence
299  val sfence_vpn = sfence.bits.addr.asTypeOf(new VaBundle().cloneType).vpn
300  when (io.sfence.valid) {
301    when (sfence.bits.rs1) { // virtual address *.rs1 <- (rs1===0.U)
302        v.map(a => a.map(b => b := false.B))
303    }.otherwise {
304        // specific addr but all asid
305        v(get_set_idx(sfence_vpn, nSets)).map(_ := false.B)
306    }
307  }
308
309  io.victim.out := DontCare
310  io.victim.out.valid := false.B
311
312  XSPerfAccumulate(s"access", io.r.req.map(_.valid.asUInt()).fold(0.U)(_ + _))
313  XSPerfAccumulate(s"hit", io.r.resp.map(a => a.valid && a.bits.hit).fold(0.U)(_.asUInt() + _.asUInt()))
314
315  for (i <- 0 until nSets) {
316    XSPerfAccumulate(s"refill${i}", (io.w.valid || io.victim.in.valid) &&
317        (Mux(io.w.valid, get_set_idx(io.w.bits.data.entry.tag, nSets), get_set_idx(io.victim.in.bits.entry.tag, nSets)) === i.U)
318      )
319  }
320
321  for (i <- 0 until nSets) {
322    XSPerfAccumulate(s"hit${i}", io.r.resp.map(a => a.valid & a.bits.hit)
323      .zip(io.r.req.map(a => RegNext(get_set_idx(a.bits.vpn, nSets)) === i.U))
324      .map{a => (a._1 && a._2).asUInt()}
325      .fold(0.U)(_ + _)
326    )
327  }
328
329  for (i <- 0 until nSets) {
330    XSPerfAccumulate(s"access${i}", io.r.resp.map(_.valid)
331      .zip(io.r.req.map(a => RegNext(get_set_idx(a.bits.vpn, nSets)) === i.U))
332      .map{a => (a._1 && a._2).asUInt()}
333      .fold(0.U)(_ + _)
334    )
335  }
336
337  println(s"${parentName} tlb_sa: nSets:${nSets} nWays:${nWays}")
338}
339
340object TlbStorage {
341  def apply
342  (
343    parentName: String,
344    associative: String,
345    ports: Int,
346    nDups: Int = 1,
347    nSets: Int,
348    nWays: Int,
349    saveLevel: Boolean = false,
350    normalPage: Boolean,
351    superPage: Boolean
352  )(implicit p: Parameters) = {
353    if (associative == "fa") {
354       val storage = Module(new TLBFA(parentName, ports, nSets, nWays, saveLevel, normalPage, superPage))
355       storage.suggestName(s"${parentName}_fa")
356       storage.io
357    } else {
358       val storage = Module(new TLBSA(parentName, ports, nDups, nSets, nWays, normalPage, superPage))
359       storage.suggestName(s"${parentName}_sa")
360       storage.io
361    }
362  }
363}
364
365class TlbStorageWrapper(ports: Int, q: TLBParameters, nDups: Int = 1)(implicit p: Parameters) extends TlbModule {
366  val io = IO(new TlbStorageWrapperIO(ports, q, nDups))
367
368// TODO: wrap Normal page and super page together, wrap the declare & refill dirty codes
369  val normalPage = TlbStorage(
370    parentName = q.name + "_np_storage",
371    associative = q.normalAssociative,
372    ports = ports,
373    nDups = nDups,
374    nSets = q.normalNSets,
375    nWays = q.normalNWays,
376    saveLevel = q.saveLevel,
377    normalPage = true,
378    superPage = false
379  )
380  val superPage = TlbStorage(
381    parentName = q.name + "_sp_storage",
382    associative = q.superAssociative,
383    ports = ports,
384    nSets = q.superNSets,
385    nWays = q.superNWays,
386    normalPage = q.normalAsVictim,
387    superPage = true,
388  )
389
390  for (i <- 0 until ports) {
391    normalPage.r_req_apply(
392      valid = io.r.req(i).valid,
393      vpn = io.r.req(i).bits.vpn,
394      i = i
395    )
396    superPage.r_req_apply(
397      valid = io.r.req(i).valid,
398      vpn = io.r.req(i).bits.vpn,
399      i = i
400    )
401  }
402
403  for (i <- 0 until ports) {
404    val nq = normalPage.r.req(i)
405    val np = normalPage.r.resp(i)
406    val sq = superPage.r.req(i)
407    val sp = superPage.r.resp(i)
408    val rq = io.r.req(i)
409    val rp = io.r.resp(i)
410    rq.ready := nq.ready && sq.ready // actually, not used
411    rp.valid := np.valid && sp.valid // actually, not used
412    rp.bits.hit := np.bits.hit || sp.bits.hit
413    for (d <- 0 until nDups) {
414      rp.bits.ppn(d) := Mux(sp.bits.hit, sp.bits.ppn(0), np.bits.ppn(d))
415      rp.bits.perm(d) := Mux(sp.bits.hit, sp.bits.perm(0), np.bits.perm(d))
416    }
417    rp.bits.super_hit := sp.bits.hit
418    rp.bits.super_ppn := sp.bits.ppn(0)
419    rp.bits.spm := np.bits.perm(0).pm
420    assert(!np.bits.hit || !sp.bits.hit || !rp.valid, s"${q.name} storage ports${i} normal and super multi-hit")
421  }
422
423  normalPage.victim.in <> superPage.victim.out
424  normalPage.victim.out <> superPage.victim.in
425  normalPage.sfence <> io.sfence
426  superPage.sfence <> io.sfence
427  normalPage.csr <> io.csr
428  superPage.csr <> io.csr
429
430  val normal_refill_idx = if (q.outReplace) {
431    io.replace.normalPage.access <> normalPage.access
432    io.replace.normalPage.chosen_set := get_set_idx(io.w.bits.data.entry.tag, q.normalNSets)
433    io.replace.normalPage.refillIdx
434  } else if (q.normalAssociative == "fa") {
435    val re = ReplacementPolicy.fromString(q.normalReplacer, q.normalNWays)
436    re.access(normalPage.access.map(_.touch_ways)) // normalhitVecVec.zipWithIndex.map{ case (hv, i) => get_access(hv, validRegVec(i))})
437    re.way
438  } else { // set-acco && plru
439    val re = ReplacementPolicy.fromString(q.normalReplacer, q.normalNSets, q.normalNWays)
440    re.access(normalPage.access.map(_.sets), normalPage.access.map(_.touch_ways))
441    re.way(get_set_idx(io.w.bits.data.entry.tag, q.normalNSets))
442  }
443
444  val super_refill_idx = if (q.outReplace) {
445    io.replace.superPage.access <> superPage.access
446    io.replace.superPage.chosen_set := DontCare
447    io.replace.superPage.refillIdx
448  } else {
449    val re = ReplacementPolicy.fromString(q.superReplacer, q.superNWays)
450    re.access(superPage.access.map(_.touch_ways))
451    re.way
452  }
453
454  normalPage.w_apply(
455    valid = { if (q.normalAsVictim) false.B
456    else io.w.valid && io.w.bits.data.entry.level.get === 2.U },
457    wayIdx = normal_refill_idx,
458    data = io.w.bits.data,
459    data_replenish = io.w.bits.data_replenish
460  )
461  superPage.w_apply(
462    valid = { if (q.normalAsVictim) io.w.valid
463    else io.w.valid && io.w.bits.data.entry.level.get =/= 2.U },
464    wayIdx = super_refill_idx,
465    data = io.w.bits.data,
466    data_replenish = io.w.bits.data_replenish
467  )
468
469    // replacement
470  def get_access(one_hot: UInt, valid: Bool): Valid[UInt] = {
471    val res = Wire(Valid(UInt(log2Up(one_hot.getWidth).W)))
472    res.valid := Cat(one_hot).orR && valid
473    res.bits := OHToUInt(one_hot)
474    res
475  }
476}
477