xref: /XiangShan/src/main/scala/xiangshan/backend/fu/PMP.scala (revision 066ac8a465b27b54ba22458ff1a67bcd28215d73)
1/***************************************************************************************
2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3* Copyright (c) 2020-2021 Peng Cheng Laboratory
4*
5* XiangShan is licensed under Mulan PSL v2.
6* You can use this software according to the terms and conditions of the Mulan PSL v2.
7* You may obtain a copy of Mulan PSL v2 at:
8*          http://license.coscl.org.cn/MulanPSL2
9*
10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13*
14* See the Mulan PSL v2 for more details.
15***************************************************************************************/
16
17// See LICENSE.SiFive for license details.
18
19package xiangshan.backend.fu
20
21import chipsalliance.rocketchip.config.Parameters
22import chisel3._
23import chisel3.internal.naming.chiselName
24import chisel3.util._
25import utils.MaskedRegMap.WritableMask
26import xiangshan._
27import xiangshan.backend.fu.util.HasCSRConst
28import utils._
29import xiangshan.cache.mmu.{TlbCmd, TlbExceptionBundle}
30
31trait PMPConst {
32  val PMPOffBits = 2 // minimal 4bytes
33}
34
35abstract class PMPBundle(implicit p: Parameters) extends XSBundle with PMPConst {
36  val CoarserGrain: Boolean = PlatformGrain > PMPOffBits
37}
38
39abstract class PMPModule(implicit p: Parameters) extends XSModule with PMPConst with HasCSRConst
40
41@chiselName
42class PMPConfig(implicit p: Parameters) extends PMPBundle {
43  val l = Bool()
44  val c = Bool() // res(1), unuse in pmp
45  val atomic = Bool() // res(0), unuse in pmp
46  val a = UInt(2.W)
47  val x = Bool()
48  val w = Bool()
49  val r = Bool()
50
51  def res: UInt = Cat(c, atomic) // in pmp, unused
52  def off = a === 0.U
53  def tor = a === 1.U
54  def na4 = { if (CoarserGrain) false.B else a === 2.U }
55  def napot = { if (CoarserGrain) a(1).asBool else a === 3.U }
56  def off_tor = !a(1)
57  def na4_napot = a(1)
58
59  def locked = l
60  def addr_locked: Bool = locked
61  def addr_locked(next: PMPConfig): Bool = locked || (next.locked && next.tor)
62}
63
64trait PMPReadWriteMethod extends PMPConst { this: PMPBase =>
65  def write_cfg_vec(cfgs: UInt): UInt = {
66    val cfgVec = Wire(Vec(cfgs.getWidth/8, new PMPConfig))
67    for (i <- cfgVec.indices) {
68      val cfg_w_tmp = cfgs((i+1)*8-1, i*8).asUInt.asTypeOf(new PMPConfig)
69      cfgVec(i) := cfg_w_tmp
70      cfgVec(i).w := cfg_w_tmp.w && cfg_w_tmp.r
71      if (CoarserGrain) { cfgVec(i).a := Cat(cfg_w_tmp.a(1), cfg_w_tmp.a.orR) }
72    }
73    cfgVec.asUInt
74  }
75
76  def write_cfg_vec(mask: Vec[UInt], addr: Vec[UInt], index: Int)(cfgs: UInt): UInt = {
77    val cfgVec = Wire(Vec(cfgs.getWidth/8, new PMPConfig))
78    for (i <- cfgVec.indices) {
79      val cfg_w_m_tmp = cfgs((i+1)*8-1, i*8).asUInt.asTypeOf(new PMPConfig)
80      cfgVec(i) := cfg_w_m_tmp
81      cfgVec(i).w := cfg_w_m_tmp.w && cfg_w_m_tmp.r
82      if (CoarserGrain) { cfgVec(i).a := Cat(cfg_w_m_tmp.a(1), cfg_w_m_tmp.a.orR) }
83      when (cfgVec(i).na4_napot) {
84        mask(index + i) := new PMPEntry().match_mask(cfgVec(i), addr(index + i))
85      }
86    }
87    cfgVec.asUInt
88  }
89
90  /** In general, the PMP grain is 2**{G+2} bytes. when G >= 1, na4 is not selectable.
91   * When G >= 2 and cfg.a(1) is set(then the mode is napot), the bits addr(G-2, 0) read as zeros.
92   * When G >= 1 and cfg.a(1) is clear(the mode is off or tor), the addr(G-1, 0) read as zeros.
93   * The low OffBits is dropped
94   */
95  def read_addr(): UInt = {
96    read_addr(cfg)(addr)
97  }
98
99  def read_addr(cfg: PMPConfig)(addr: UInt): UInt = {
100    val G = PlatformGrain - PMPOffBits
101    require(G >= 0)
102    if (G == 0) {
103      addr
104    } else if (G >= 2) {
105      Mux(cfg.na4_napot, set_low_bits(addr, G-1), clear_low_bits(addr, G))
106    } else { // G is 1
107      Mux(cfg.off_tor, clear_low_bits(addr, G), addr)
108    }
109  }
110  /** addr for inside addr, drop OffBits with.
111   * compare_addr for inside addr for comparing.
112   * paddr for outside addr.
113   */
114  def write_addr(next: PMPBase)(paddr: UInt) = {
115    Mux(!cfg.addr_locked(next.cfg), paddr, addr)
116  }
117  def write_addr(paddr: UInt) = {
118    Mux(!cfg.addr_locked, paddr, addr)
119  }
120
121  def set_low_bits(data: UInt, num: Int): UInt = {
122    require(num >= 0)
123    data | ((1 << num)-1).U
124  }
125
126  /** mask the data's low num bits (lsb) */
127  def clear_low_bits(data: UInt, num: Int): UInt = {
128    require(num >= 0)
129    // use Cat instead of & with mask to avoid "Signal Width" problem
130    if (num == 0) { data }
131    else { Cat(data(data.getWidth-1, num), 0.U(num.W)) }
132  }
133}
134
135/** PMPBase for CSR unit
136  * with only read and write logic
137  */
138@chiselName
139class PMPBase(implicit p: Parameters) extends PMPBundle with PMPReadWriteMethod {
140  val cfg = new PMPConfig
141  val addr = UInt((PAddrBits - PMPOffBits).W)
142
143  def gen(cfg: PMPConfig, addr: UInt) = {
144    require(addr.getWidth == this.addr.getWidth)
145    this.cfg := cfg
146    this.addr := addr
147  }
148}
149
150trait PMPMatchMethod extends PMPConst { this: PMPEntry =>
151  /** compare_addr is used to compare with input addr */
152  def compare_addr: UInt = ((addr << PMPOffBits) & ~(((1 << PlatformGrain) - 1).U(PAddrBits.W))).asUInt
153
154  /** size and maxSize are all log2 Size
155   * for dtlb, the maxSize is bXLEN which is 8
156   * for itlb and ptw, the maxSize is log2(512) ?
157   * but we may only need the 64 bytes? how to prevent the bugs?
158   * TODO: handle the special case that itlb & ptw & dcache access wider size than XLEN
159   */
160  def is_match(paddr: UInt, lgSize: UInt, lgMaxSize: Int, last_pmp: PMPEntry): Bool = {
161    Mux(cfg.na4_napot, napotMatch(paddr, lgSize, lgMaxSize),
162      Mux(cfg.tor, torMatch(paddr, lgSize, lgMaxSize, last_pmp), false.B))
163  }
164
165  /** generate match mask to help match in napot mode */
166  def match_mask(paddr: UInt) = {
167    val match_mask_addr: UInt = Cat(paddr, cfg.a(0)).asUInt() | (((1 << PlatformGrain) - 1) >> PMPOffBits).U((paddr.getWidth + 1).W)
168    Cat(match_mask_addr & ~(match_mask_addr + 1.U), ((1 << PMPOffBits) - 1).U(PMPOffBits.W))
169  }
170
171  def match_mask(cfg: PMPConfig, paddr: UInt) = {
172    val match_mask_c_addr = Cat(paddr, cfg.a(0)) | (((1 << PlatformGrain) - 1) >> PMPOffBits).U((paddr.getWidth + 1).W)
173    Cat(match_mask_c_addr & ~(match_mask_c_addr + 1.U), ((1 << PMPOffBits) - 1).U(PMPOffBits.W))
174  }
175
176  def boundMatch(paddr: UInt, lgSize: UInt, lgMaxSize: Int): Bool = {
177    if (lgMaxSize <= PlatformGrain) {
178      (paddr < compare_addr)
179    } else {
180      val highLess = (paddr >> lgMaxSize) < (compare_addr >> lgMaxSize)
181      val highEqual = (paddr >> lgMaxSize) === (compare_addr >> lgMaxSize)
182      val lowLess = (paddr(lgMaxSize-1, 0) | OneHot.UIntToOH1(lgSize, lgMaxSize))  < compare_addr(lgMaxSize-1, 0)
183      highLess || (highEqual && lowLess)
184    }
185  }
186
187  def lowerBoundMatch(paddr: UInt, lgSize: UInt, lgMaxSize: Int): Bool = {
188    !boundMatch(paddr, lgSize, lgMaxSize)
189  }
190
191  def higherBoundMatch(paddr: UInt, lgMaxSize: Int) = {
192    boundMatch(paddr, 0.U, lgMaxSize)
193  }
194
195  def torMatch(paddr: UInt, lgSize: UInt, lgMaxSize: Int, last_pmp: PMPEntry): Bool = {
196    last_pmp.lowerBoundMatch(paddr, lgSize, lgMaxSize) && higherBoundMatch(paddr, lgMaxSize)
197  }
198
199  def unmaskEqual(a: UInt, b: UInt, m: UInt) = {
200    (a & ~m) === (b & ~m)
201  }
202
203  def napotMatch(paddr: UInt, lgSize: UInt, lgMaxSize: Int) = {
204    if (lgMaxSize <= PlatformGrain) {
205      unmaskEqual(paddr, compare_addr, mask)
206    } else {
207      val lowMask = mask | OneHot.UIntToOH1(lgSize, lgMaxSize)
208      val highMatch = unmaskEqual(paddr >> lgMaxSize, compare_addr >> lgMaxSize, mask >> lgMaxSize)
209      val lowMatch = unmaskEqual(paddr(lgMaxSize-1, 0), compare_addr(lgMaxSize-1, 0), lowMask(lgMaxSize-1, 0))
210      highMatch && lowMatch
211    }
212  }
213
214  def aligned(paddr: UInt, lgSize: UInt, lgMaxSize: Int, last: PMPEntry) = {
215    if (lgMaxSize <= PlatformGrain) {
216      true.B
217    } else {
218      val lowBitsMask = OneHot.UIntToOH1(lgSize, lgMaxSize)
219      val lowerBound = ((paddr >> lgMaxSize) === (last.compare_addr >> lgMaxSize)) &&
220        ((~paddr(lgMaxSize-1, 0) & last.compare_addr(lgMaxSize-1, 0)) =/= 0.U)
221      val upperBound = ((paddr >> lgMaxSize) === (compare_addr >> lgMaxSize)) &&
222        ((compare_addr(lgMaxSize-1, 0) & (paddr(lgMaxSize-1, 0) | lowBitsMask)) =/= 0.U)
223      val torAligned = !(lowerBound || upperBound)
224      val napotAligned = (lowBitsMask & ~mask(lgMaxSize-1, 0)) === 0.U
225      Mux(cfg.na4_napot, napotAligned, torAligned)
226    }
227  }
228}
229
230/** PMPEntry for outside pmp copies
231  * with one more elements mask to help napot match
232  * TODO: make mask an element, not an method, for timing opt
233  */
234@chiselName
235class PMPEntry(implicit p: Parameters) extends PMPBase with PMPMatchMethod {
236  val mask = UInt(PAddrBits.W) // help to match in napot
237
238  def write_addr(next: PMPBase, mask: UInt)(paddr: UInt) = {
239    mask := Mux(!cfg.addr_locked(next.cfg), match_mask(paddr), mask)
240    Mux(!cfg.addr_locked(next.cfg), paddr, addr)
241  }
242
243  def write_addr(mask: UInt)(paddr: UInt) = {
244    mask := Mux(!cfg.addr_locked, match_mask(paddr), mask)
245    Mux(!cfg.addr_locked, paddr, addr)
246  }
247
248  def gen(cfg: PMPConfig, addr: UInt, mask: UInt) = {
249    require(addr.getWidth == this.addr.getWidth)
250    this.cfg := cfg
251    this.addr := addr
252    this.mask := mask
253  }
254}
255
256trait PMPMethod extends HasXSParameter with PMPConst { this: XSModule =>
257  def pmp_init() : (Vec[UInt], Vec[UInt], Vec[UInt])= {
258    val cfg = WireInit(0.U.asTypeOf(Vec(NumPMP/8, UInt(XLEN.W))))
259    val addr = Wire(Vec(NumPMP, UInt((PAddrBits-PMPOffBits).W)))
260    val mask = Wire(Vec(NumPMP, UInt(PAddrBits.W)))
261    addr := DontCare
262    mask := DontCare
263    (cfg, addr, mask)
264  }
265
266  def pmp_gen_mapping
267  (
268    init: () => (Vec[UInt], Vec[UInt], Vec[UInt]),
269    num: Int = 16,
270    cfgBase: Int,
271    addrBase: Int,
272    entries: Vec[PMPEntry]
273  ) = {
274    val pmpCfgPerCSR = XLEN / new PMPConfig().getWidth
275    def pmpCfgIndex(i: Int) = (XLEN / 32) * (i / pmpCfgPerCSR)
276    val init_value = init()
277    /** to fit MaskedRegMap's write, declare cfgs as Merged CSRs and split them into each pmp */
278    val cfgMerged = RegInit(init_value._1) //(Vec(num / pmpCfgPerCSR, UInt(XLEN.W))) // RegInit(VecInit(Seq.fill(num / pmpCfgPerCSR)(0.U(XLEN.W))))
279    val cfgs = WireInit(cfgMerged).asTypeOf(Vec(num, new PMPConfig()))
280    val addr = RegInit(init_value._2) // (Vec(num, UInt((PAddrBits-PMPOffBits).W)))
281    val mask = RegInit(init_value._3) // (Vec(num, UInt(PAddrBits.W)))
282
283    for (i <- entries.indices) {
284      entries(i).gen(cfgs(i), addr(i), mask(i))
285    }
286
287
288
289    val cfg_mapping = (0 until num by pmpCfgPerCSR).map(i => {Map(
290      MaskedRegMap(
291        addr = cfgBase + pmpCfgIndex(i),
292        reg = cfgMerged(i/pmpCfgPerCSR),
293        wmask = WritableMask,
294        wfn = new PMPBase().write_cfg_vec(mask, addr, i)
295      ))
296    }).fold(Map())((a, b) => a ++ b) // ugly code, hit me if u have better codes
297
298    val addr_mapping = (0 until num).map(i => {Map(
299      MaskedRegMap(
300        addr = addrBase + i,
301        reg = addr(i),
302        wmask = WritableMask,
303        wfn = { if (i != num-1) entries(i).write_addr(entries(i+1), mask(i)) else entries(i).write_addr(mask(i)) },
304        rmask = WritableMask,
305        rfn = new PMPBase().read_addr(entries(i).cfg)
306      ))
307    }).fold(Map())((a, b) => a ++ b) // ugly code, hit me if u have better codes.
308
309
310
311    cfg_mapping ++ addr_mapping
312  }
313}
314
315@chiselName
316class PMP(implicit p: Parameters) extends PMPModule with PMPMethod with PMAMethod {
317  val io = IO(new Bundle {
318    val distribute_csr = Flipped(new DistributedCSRIO())
319    val pmp = Output(Vec(NumPMP, new PMPEntry()))
320    val pma = Output(Vec(NumPMA, new PMPEntry()))
321  })
322
323  val w = io.distribute_csr.w
324
325  val pmp = Wire(Vec(NumPMP, new PMPEntry()))
326  val pma = Wire(Vec(NumPMA, new PMPEntry()))
327
328  val pmpMapping = pmp_gen_mapping(pmp_init, NumPMP, PmpcfgBase, PmpaddrBase, pmp)
329  val pmaMapping = pmp_gen_mapping(pma_init, NumPMA, PmacfgBase, PmaaddrBase, pma)
330  val mapping = pmpMapping ++ pmaMapping
331
332  val rdata = Wire(UInt(XLEN.W))
333  MaskedRegMap.generate(mapping, w.bits.addr, rdata, w.valid, w.bits.data)
334
335  io.pmp := pmp
336  io.pma := pma
337}
338
339class PMPReqBundle(lgMaxSize: Int = 3)(implicit p: Parameters) extends PMPBundle {
340  val addr = Output(UInt(PAddrBits.W))
341  val size = Output(UInt(log2Ceil(lgMaxSize+1).W))
342  val cmd = Output(TlbCmd())
343
344  override def cloneType = (new PMPReqBundle(lgMaxSize)).asInstanceOf[this.type]
345}
346
347class PMPRespBundle(implicit p: Parameters) extends TlbExceptionBundle {
348  val mmio = Output(Bool())
349
350  def |(resp: PMPRespBundle): PMPRespBundle = {
351    val res = Wire(new PMPRespBundle())
352    res.ld := this.ld || resp.ld
353    res.st := this.st || resp.st
354    res.instr := this.instr || resp.instr
355    res.mmio := this.mmio || resp.mmio
356    res
357  }
358}
359
360trait PMPCheckMethod extends HasXSParameter with HasCSRConst { this: PMPChecker =>
361  def pmp_check(cmd: UInt, cfg: PMPConfig)(implicit p: Parameters) = {
362    val resp = Wire(new PMPRespBundle)
363    resp.ld := TlbCmd.isRead(cmd) && !TlbCmd.isAtom(cmd) && !cfg.r
364    resp.st := (TlbCmd.isWrite(cmd) || TlbCmd.isAtom(cmd)) && !cfg.w
365    resp.instr := TlbCmd.isExec(cmd) && !cfg.x
366    resp.mmio := false.B
367    resp
368  }
369
370  def pmp_match_res(leaveHitMux: Boolean = false, valid: Bool = true.B)(
371    addr: UInt,
372    size: UInt,
373    pmpEntries: Vec[PMPEntry],
374    mode: UInt,
375    lgMaxSize: Int
376  ) = {
377    val num = pmpEntries.size
378    require(num == NumPMP)
379
380    val passThrough = if (pmpEntries.isEmpty) true.B else (mode > ModeS)
381    val pmpDefault = WireInit(0.U.asTypeOf(new PMPEntry()))
382    pmpDefault.cfg.r := passThrough
383    pmpDefault.cfg.w := passThrough
384    pmpDefault.cfg.x := passThrough
385
386    val match_vec = Wire(Vec(num+1, Bool()))
387    val cfg_vec = Wire(Vec(num+1, new PMPEntry()))
388
389    pmpEntries.zip(pmpDefault +: pmpEntries.take(num-1)).zipWithIndex.foreach{ case ((pmp, last_pmp), i) =>
390      val is_match = pmp.is_match(addr, size, lgMaxSize, last_pmp)
391      val ignore = passThrough && !pmp.cfg.l
392      val aligned = pmp.aligned(addr, size, lgMaxSize, last_pmp)
393
394      val cur = WireInit(pmp)
395      cur.cfg.r := aligned && (pmp.cfg.r || ignore)
396      cur.cfg.w := aligned && (pmp.cfg.w || ignore)
397      cur.cfg.x := aligned && (pmp.cfg.x || ignore)
398
399//      Mux(is_match, cur, prev)
400      match_vec(i) := is_match
401      cfg_vec(i) := cur
402    }
403
404    // default value
405    match_vec(num) := true.B
406    cfg_vec(num) := pmpDefault
407
408    if (leaveHitMux) {
409      ParallelPriorityMux(match_vec.map(RegEnable(_, init = false.B, valid)), RegEnable(cfg_vec, valid))
410    } else {
411      ParallelPriorityMux(match_vec, cfg_vec)
412    }
413  }
414}
415
416@chiselName
417class PMPChecker
418(
419  lgMaxSize: Int = 3,
420  sameCycle: Boolean = false,
421  leaveHitMux: Boolean = false
422)(implicit p: Parameters)
423  extends PMPModule
424  with PMPCheckMethod
425  with PMACheckMethod
426{
427  val io = IO(new Bundle{
428    val env = Input(new Bundle {
429      val mode = Input(UInt(2.W))
430      val pmp = Input(Vec(NumPMP, new PMPEntry()))
431      val pma = Input(Vec(NumPMA, new PMPEntry()))
432    })
433    val req = Flipped(Valid(new PMPReqBundle(lgMaxSize))) // usage: assign the valid to fire signal
434    val resp = new PMPRespBundle()
435  })
436  require(!(leaveHitMux && sameCycle))
437
438  val req = io.req.bits
439
440  val res_pmp = pmp_match_res(leaveHitMux, io.req.valid)(req.addr, req.size, io.env.pmp, io.env.mode, lgMaxSize)
441  val res_pma = pma_match_res(leaveHitMux, io.req.valid)(req.addr, req.size, io.env.pma, io.env.mode, lgMaxSize)
442
443  val resp_pmp = pmp_check(req.cmd, res_pmp.cfg)
444  val resp_pma = pma_check(req.cmd, res_pma.cfg)
445  val resp = resp_pmp | resp_pma
446
447  if (sameCycle || leaveHitMux) {
448    io.resp := resp
449  } else {
450    io.resp := RegEnable(resp, io.req.valid)
451  }
452}