xref: /XiangShan/src/main/scala/xiangshan/backend/fu/PMP.scala (revision a273862e37f1d43bee748f2a6353320a2f52f6f4)
1/***************************************************************************************
2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3* Copyright (c) 2020-2021 Peng Cheng Laboratory
4*
5* XiangShan is licensed under Mulan PSL v2.
6* You can use this software according to the terms and conditions of the Mulan PSL v2.
7* You may obtain a copy of Mulan PSL v2 at:
8*          http://license.coscl.org.cn/MulanPSL2
9*
10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13*
14* See the Mulan PSL v2 for more details.
15***************************************************************************************/
16
17package xiangshan.backend.fu
18
19import chipsalliance.rocketchip.config.Parameters
20import chisel3._
21import chisel3.internal.naming.chiselName
22import chisel3.util._
23import utils.MaskedRegMap.WritableMask
24import xiangshan._
25import xiangshan.backend.fu.util.HasCSRConst
26import utils._
27import xiangshan.cache.mmu.{TlbCmd, TlbExceptionBundle}
28
29trait PMPConst {
30  val PMPOffBits = 2 // minimal 4bytes
31}
32
33abstract class PMPBundle(implicit p: Parameters) extends XSBundle with PMPConst {
34  val CoarserGrain: Boolean = PlatformGrain > PMPOffBits
35}
36
37abstract class PMPModule(implicit p: Parameters) extends XSModule with PMPConst with HasCSRConst
38
39@chiselName
40class PMPConfig(implicit p: Parameters) extends PMPBundle {
41  val l = Bool()
42  val c = Bool() // res(1), unuse in pmp
43  val atomic = Bool() // res(0), unuse in pmp
44  val a = UInt(2.W)
45  val x = Bool()
46  val w = Bool()
47  val r = Bool()
48
49  def res: UInt = Cat(c, atomic) // in pmp, unused
50  def off = a === 0.U
51  def tor = a === 1.U
52  def na4 = { if (CoarserGrain) false.B else a === 2.U }
53  def napot = { if (CoarserGrain) a(1).asBool else a === 3.U }
54  def off_tor = !a(1)
55  def na4_napot = a(1)
56
57  def locked = l
58  def addr_locked: Bool = locked
59  def addr_locked(next: PMPConfig): Bool = locked || (next.locked && next.tor)
60}
61
62trait PMPReadWriteMethod extends PMPConst { this: PMPBase =>
63  def write_cfg_vec(cfgs: UInt): UInt = {
64    val cfgVec = Wire(Vec(cfgs.getWidth/8, new PMPConfig))
65    for (i <- cfgVec.indices) {
66      val cfg_w_tmp = cfgs((i+1)*8-1, i*8).asUInt.asTypeOf(new PMPConfig)
67      cfgVec(i) := cfg_w_tmp
68      cfgVec(i).w := cfg_w_tmp.w && cfg_w_tmp.r
69      if (CoarserGrain) { cfgVec(i).a := Cat(cfg_w_tmp.a(1), cfg_w_tmp.a.orR) }
70    }
71    cfgVec.asUInt
72  }
73
74  def write_cfg_vec(mask: Vec[UInt], addr: Vec[UInt], index: Int)(cfgs: UInt): UInt = {
75    val cfgVec = Wire(Vec(cfgs.getWidth/8, new PMPConfig))
76    for (i <- cfgVec.indices) {
77      val cfg_w_m_tmp = cfgs((i+1)*8-1, i*8).asUInt.asTypeOf(new PMPConfig)
78      cfgVec(i) := cfg_w_m_tmp
79      cfgVec(i).w := cfg_w_m_tmp.w && cfg_w_m_tmp.r
80      if (CoarserGrain) { cfgVec(i).a := Cat(cfg_w_m_tmp.a(1), cfg_w_m_tmp.a.orR) }
81      when (cfgVec(i).na4_napot) {
82        mask(index + i) := new PMPEntry().match_mask(cfgVec(i), addr(index + i))
83      }
84    }
85    cfgVec.asUInt
86  }
87
88  /** In general, the PMP grain is 2**{G+2} bytes. when G >= 1, na4 is not selectable.
89   * When G >= 2 and cfg.a(1) is set(then the mode is napot), the bits addr(G-2, 0) read as zeros.
90   * When G >= 1 and cfg.a(1) is clear(the mode is off or tor), the addr(G-1, 0) read as zeros.
91   * The low OffBits is dropped
92   */
93  def read_addr(): UInt = {
94    read_addr(cfg)(addr)
95  }
96
97  def read_addr(cfg: PMPConfig)(addr: UInt): UInt = {
98    val G = PlatformGrain - PMPOffBits
99    require(G >= 0)
100    if (G == 0) {
101      addr
102    } else if (G >= 2) {
103      Mux(cfg.na4_napot, set_low_bits(addr, G-1), clear_low_bits(addr, G))
104    } else { // G is 1
105      Mux(cfg.off_tor, clear_low_bits(addr, G), addr)
106    }
107  }
108  /** addr for inside addr, drop OffBits with.
109   * compare_addr for inside addr for comparing.
110   * paddr for outside addr.
111   */
112  def write_addr(next: PMPBase)(paddr: UInt) = {
113    Mux(!cfg.addr_locked(next.cfg), paddr, addr)
114  }
115  def write_addr(paddr: UInt) = {
116    Mux(!cfg.addr_locked, paddr, addr)
117  }
118
119  def set_low_bits(data: UInt, num: Int): UInt = {
120    require(num >= 0)
121    data | ((1 << num)-1).U
122  }
123
124  /** mask the data's low num bits (lsb) */
125  def clear_low_bits(data: UInt, num: Int): UInt = {
126    require(num >= 0)
127    // use Cat instead of & with mask to avoid "Signal Width" problem
128    if (num == 0) { data }
129    else { Cat(data(data.getWidth-1, num), 0.U(num.W)) }
130  }
131}
132
133/** PMPBase for CSR unit
134  * with only read and write logic
135  */
136@chiselName
137class PMPBase(implicit p: Parameters) extends PMPBundle with PMPReadWriteMethod {
138  val cfg = new PMPConfig
139  val addr = UInt((PAddrBits - PMPOffBits).W)
140
141  def gen(cfg: PMPConfig, addr: UInt) = {
142    require(addr.getWidth == this.addr.getWidth)
143    this.cfg := cfg
144    this.addr := addr
145  }
146}
147
148trait PMPMatchMethod extends PMPConst { this: PMPEntry =>
149  /** compare_addr is used to compare with input addr */
150  def compare_addr: UInt = ((addr << PMPOffBits) & ~(((1 << PlatformGrain) - 1).U(PAddrBits.W))).asUInt
151
152  /** size and maxSize are all log2 Size
153   * for dtlb, the maxSize is bXLEN which is 8
154   * for itlb and ptw, the maxSize is log2(512) ?
155   * but we may only need the 64 bytes? how to prevent the bugs?
156   * TODO: handle the special case that itlb & ptw & dcache access wider size than XLEN
157   */
158  def is_match(paddr: UInt, lgSize: UInt, lgMaxSize: Int, last_pmp: PMPEntry): Bool = {
159    Mux(cfg.na4_napot, napotMatch(paddr, lgSize, lgMaxSize),
160      Mux(cfg.tor, torMatch(paddr, lgSize, lgMaxSize, last_pmp), false.B))
161  }
162
163  /** generate match mask to help match in napot mode */
164  def match_mask(paddr: UInt) = {
165    val match_mask_addr: UInt = Cat(paddr, cfg.a(0)).asUInt() | (((1 << PlatformGrain) - 1) >> PMPOffBits).U((paddr.getWidth + 1).W)
166    Cat(match_mask_addr & ~(match_mask_addr + 1.U), ((1 << PMPOffBits) - 1).U(PMPOffBits.W))
167  }
168
169  def match_mask(cfg: PMPConfig, paddr: UInt) = {
170    val match_mask_c_addr = Cat(paddr, cfg.a(0)) | (((1 << PlatformGrain) - 1) >> PMPOffBits).U((paddr.getWidth + 1).W)
171    Cat(match_mask_c_addr & ~(match_mask_c_addr + 1.U), ((1 << PMPOffBits) - 1).U(PMPOffBits.W))
172  }
173
174  def boundMatch(paddr: UInt, lgSize: UInt, lgMaxSize: Int): Bool = {
175    if (lgMaxSize <= PlatformGrain) {
176      (paddr < compare_addr)
177    } else {
178      val highLess = (paddr >> lgMaxSize) < (compare_addr >> lgMaxSize)
179      val highEqual = (paddr >> lgMaxSize) === (compare_addr >> lgMaxSize)
180      val lowLess = (paddr(lgMaxSize-1, 0) | OneHot.UIntToOH1(lgSize, lgMaxSize))  < compare_addr(lgMaxSize-1, 0)
181      highLess || (highEqual && lowLess)
182    }
183  }
184
185  def lowerBoundMatch(paddr: UInt, lgSize: UInt, lgMaxSize: Int): Bool = {
186    !boundMatch(paddr, lgSize, lgMaxSize)
187  }
188
189  def higherBoundMatch(paddr: UInt, lgMaxSize: Int) = {
190    boundMatch(paddr, 0.U, lgMaxSize)
191  }
192
193  def torMatch(paddr: UInt, lgSize: UInt, lgMaxSize: Int, last_pmp: PMPEntry): Bool = {
194    last_pmp.lowerBoundMatch(paddr, lgSize, lgMaxSize) && higherBoundMatch(paddr, lgMaxSize)
195  }
196
197  def unmaskEqual(a: UInt, b: UInt, m: UInt) = {
198    (a & ~m) === (b & ~m)
199  }
200
201  def napotMatch(paddr: UInt, lgSize: UInt, lgMaxSize: Int) = {
202    if (lgMaxSize <= PlatformGrain) {
203      unmaskEqual(paddr, compare_addr, mask)
204    } else {
205      val lowMask = mask | OneHot.UIntToOH1(lgSize, lgMaxSize)
206      val highMatch = unmaskEqual(paddr >> lgMaxSize, compare_addr >> lgMaxSize, mask >> lgMaxSize)
207      val lowMatch = unmaskEqual(paddr(lgMaxSize-1, 0), compare_addr(lgMaxSize-1, 0), lowMask(lgMaxSize-1, 0))
208      highMatch && lowMatch
209    }
210  }
211
212  def aligned(paddr: UInt, lgSize: UInt, lgMaxSize: Int, last: PMPEntry) = {
213    if (lgMaxSize <= PlatformGrain) {
214      true.B
215    } else {
216      val lowBitsMask = OneHot.UIntToOH1(lgSize, lgMaxSize)
217      val lowerBound = ((paddr >> lgMaxSize) === (last.compare_addr >> lgMaxSize)) &&
218        ((~paddr(lgMaxSize-1, 0) & last.compare_addr(lgMaxSize-1, 0)) =/= 0.U)
219      val upperBound = ((paddr >> lgMaxSize) === (compare_addr >> lgMaxSize)) &&
220        ((compare_addr(lgMaxSize-1, 0) & (paddr(lgMaxSize-1, 0) | lowBitsMask)) =/= 0.U)
221      val torAligned = !(lowerBound || upperBound)
222      val napotAligned = (lowBitsMask & ~mask(lgMaxSize-1, 0)) === 0.U
223      Mux(cfg.na4_napot, napotAligned, torAligned)
224    }
225  }
226}
227
228/** PMPEntry for outside pmp copies
229  * with one more elements mask to help napot match
230  * TODO: make mask an element, not an method, for timing opt
231  */
232@chiselName
233class PMPEntry(implicit p: Parameters) extends PMPBase with PMPMatchMethod {
234  val mask = UInt(PAddrBits.W) // help to match in napot
235
236  def write_addr(next: PMPBase, mask: UInt)(paddr: UInt) = {
237    mask := Mux(!cfg.addr_locked(next.cfg), match_mask(paddr), mask)
238    Mux(!cfg.addr_locked(next.cfg), paddr, addr)
239  }
240
241  def write_addr(mask: UInt)(paddr: UInt) = {
242    mask := Mux(!cfg.addr_locked, match_mask(paddr), mask)
243    Mux(!cfg.addr_locked, paddr, addr)
244  }
245
246  def gen(cfg: PMPConfig, addr: UInt, mask: UInt) = {
247    require(addr.getWidth == this.addr.getWidth)
248    this.cfg := cfg
249    this.addr := addr
250    this.mask := mask
251  }
252}
253
254trait PMPMethod extends HasXSParameter with PMPConst { this: XSModule =>
255  def pmp_init() : (Vec[UInt], Vec[UInt], Vec[UInt])= {
256    val cfg = WireInit(0.U.asTypeOf(Vec(NumPMP/8, UInt(XLEN.W))))
257    val addr = Wire(Vec(NumPMP, UInt((PAddrBits-PMPOffBits).W)))
258    val mask = Wire(Vec(NumPMP, UInt(PAddrBits.W)))
259    addr := DontCare
260    mask := DontCare
261    (cfg, addr, mask)
262  }
263
264  def pmp_gen_mapping
265  (
266    init: () => (Vec[UInt], Vec[UInt], Vec[UInt]),
267    num: Int = 16,
268    cfgBase: Int,
269    addrBase: Int,
270    entries: Vec[PMPEntry]
271  ) = {
272    val pmpCfgPerCSR = XLEN / new PMPConfig().getWidth
273    def pmpCfgIndex(i: Int) = (XLEN / 32) * (i / pmpCfgPerCSR)
274    val init_value = init()
275    /** to fit MaskedRegMap's write, declare cfgs as Merged CSRs and split them into each pmp */
276    val cfgMerged = RegInit(init_value._1) //(Vec(num / pmpCfgPerCSR, UInt(XLEN.W))) // RegInit(VecInit(Seq.fill(num / pmpCfgPerCSR)(0.U(XLEN.W))))
277    val cfgs = WireInit(cfgMerged).asTypeOf(Vec(num, new PMPConfig()))
278    val addr = RegInit(init_value._2) // (Vec(num, UInt((PAddrBits-PMPOffBits).W)))
279    val mask = RegInit(init_value._3) // (Vec(num, UInt(PAddrBits.W)))
280
281    for (i <- entries.indices) {
282      entries(i).gen(cfgs(i), addr(i), mask(i))
283    }
284
285
286
287    val cfg_mapping = (0 until num by pmpCfgPerCSR).map(i => {Map(
288      MaskedRegMap(
289        addr = cfgBase + pmpCfgIndex(i),
290        reg = cfgMerged(i/pmpCfgPerCSR),
291        wmask = WritableMask,
292        wfn = new PMPBase().write_cfg_vec(mask, addr, i)
293      ))
294    }).fold(Map())((a, b) => a ++ b) // ugly code, hit me if u have better codes
295
296    val addr_mapping = (0 until num).map(i => {Map(
297      MaskedRegMap(
298        addr = addrBase + i,
299        reg = addr(i),
300        wmask = WritableMask,
301        wfn = { if (i != num-1) entries(i).write_addr(entries(i+1), mask(i)) else entries(i).write_addr(mask(i)) },
302        rmask = WritableMask,
303        rfn = new PMPBase().read_addr(entries(i).cfg)
304      ))
305    }).fold(Map())((a, b) => a ++ b) // ugly code, hit me if u have better codes.
306
307
308
309    cfg_mapping ++ addr_mapping
310  }
311}
312
313@chiselName
314class PMP(implicit p: Parameters) extends PMPModule with PMPMethod with PMAMethod {
315  val io = IO(new Bundle {
316    val distribute_csr = Flipped(new DistributedCSRIO())
317    val pmp = Output(Vec(NumPMP, new PMPEntry()))
318    val pma = Output(Vec(NumPMA, new PMPEntry()))
319  })
320
321  val w = io.distribute_csr.w
322
323  val pmp = Wire(Vec(NumPMP, new PMPEntry()))
324  val pma = Wire(Vec(NumPMA, new PMPEntry()))
325
326  val pmpMapping = pmp_gen_mapping(pmp_init, NumPMP, PmpcfgBase, PmpaddrBase, pmp)
327  val pmaMapping = pmp_gen_mapping(pma_init, NumPMA, PmacfgBase, PmaaddrBase, pma)
328  val mapping = pmpMapping ++ pmaMapping
329
330  val rdata = Wire(UInt(XLEN.W))
331  MaskedRegMap.generate(mapping, w.bits.addr, rdata, w.valid, w.bits.data)
332
333  io.pmp := pmp
334  io.pma := pma
335}
336
337class PMPReqBundle(lgMaxSize: Int = 3)(implicit p: Parameters) extends PMPBundle {
338  val addr = Output(UInt(PAddrBits.W))
339  val size = Output(UInt(log2Ceil(lgMaxSize+1).W))
340  val cmd = Output(TlbCmd())
341
342  override def cloneType = (new PMPReqBundle(lgMaxSize)).asInstanceOf[this.type]
343}
344
345class PMPRespBundle(implicit p: Parameters) extends TlbExceptionBundle {
346  val mmio = Output(Bool())
347
348  def |(resp: PMPRespBundle): PMPRespBundle = {
349    val res = Wire(new PMPRespBundle())
350    res.ld := this.ld || resp.ld
351    res.st := this.st || resp.st
352    res.instr := this.instr || resp.instr
353    res.mmio := this.mmio || resp.mmio
354    res
355  }
356}
357
358trait PMPCheckMethod extends HasXSParameter with HasCSRConst { this: PMPChecker =>
359  def pmp_check(cmd: UInt, cfg: PMPConfig)(implicit p: Parameters) = {
360    val resp = Wire(new PMPRespBundle)
361    resp.ld := TlbCmd.isRead(cmd) && !TlbCmd.isAtom(cmd) && !cfg.r
362    resp.st := (TlbCmd.isWrite(cmd) || TlbCmd.isAtom(cmd)) && !cfg.w
363    resp.instr := TlbCmd.isExec(cmd) && !cfg.x
364    resp.mmio := false.B
365    resp
366  }
367
368  def pmp_match_res(addr: UInt, size: UInt, pmpEntries: Vec[PMPEntry], mode: UInt, lgMaxSize: Int) = {
369    val num = pmpEntries.size
370    require(num == NumPMP)
371
372    val passThrough = if (pmpEntries.isEmpty) true.B else (mode > ModeS)
373    val pmpMinuxOne = WireInit(0.U.asTypeOf(new PMPEntry()))
374    pmpMinuxOne.cfg.r := passThrough
375    pmpMinuxOne.cfg.w := passThrough
376    pmpMinuxOne.cfg.x := passThrough
377
378    val res = pmpEntries.zip(pmpMinuxOne +: pmpEntries.take(num-1)).zipWithIndex
379      .reverse.foldLeft(pmpMinuxOne) { case (prev, ((pmp, last_pmp), i)) =>
380      val is_match = pmp.is_match(addr, size, lgMaxSize, last_pmp)
381      val ignore = passThrough && !pmp.cfg.l
382      val aligned = pmp.aligned(addr, size, lgMaxSize, last_pmp)
383
384      val cur = WireInit(pmp)
385      cur.cfg.r := aligned && (pmp.cfg.r || ignore)
386      cur.cfg.w := aligned && (pmp.cfg.w || ignore)
387      cur.cfg.x := aligned && (pmp.cfg.x || ignore)
388
389      Mux(is_match, cur, prev)
390    }
391    res
392  }
393}
394
395@chiselName
396class PMPChecker
397(
398  lgMaxSize: Int = 3,
399  sameCycle: Boolean = false
400)(implicit p: Parameters)
401  extends PMPModule
402  with PMPCheckMethod
403  with PMACheckMethod
404{
405  val io = IO(new Bundle{
406    val env = Input(new Bundle {
407      val mode = Input(UInt(2.W))
408      val pmp = Input(Vec(NumPMP, new PMPEntry()))
409      val pma = Input(Vec(NumPMA, new PMPEntry()))
410    })
411    val req = Flipped(Valid(new PMPReqBundle(lgMaxSize))) // usage: assign the valid to fire signal
412    val resp = new PMPRespBundle()
413  })
414
415  val req = io.req.bits
416
417  val res_pmp = pmp_match_res(req.addr, req.size, io.env.pmp, io.env.mode, lgMaxSize)
418  val res_pma = pma_match_res(req.addr, req.size, io.env.pma, io.env.mode, lgMaxSize)
419
420  val resp_pmp = pmp_check(req.cmd, res_pmp.cfg)
421  val resp_pma = pma_check(req.cmd, res_pma.cfg)
422  val resp = resp_pmp | resp_pma
423
424  if (sameCycle) {
425    io.resp := resp
426  } else {
427    io.resp := RegEnable(resp, io.req.valid)
428  }
429}