1package xiangshan.frontend 2 3import chisel3._ 4import chisel3.util._ 5import xiangshan._ 6import utils._ 7 8import scala.math.min 9 10class SCReq extends TageReq 11 12class SCResp(val ctrBits: Int = 6) extends TageBundle { 13 val ctr = Vec(2, SInt(ctrBits.W)) 14} 15 16class SCUpdate(val ctrBits: Int = 6) extends TageBundle { 17 val pc = UInt(VAddrBits.W) 18 val fetchIdx = UInt(log2Up(TageBanks).W) 19 val hist = UInt(HistoryLength.W) 20 val mask = Vec(TageBanks, Bool()) 21 val oldCtr = SInt(ctrBits.W) 22 val tagePred = Bool() 23 val taken = Bool() 24} 25 26class SCTableIO extends TageBundle { 27 val req = Input(Valid(new SCReq)) 28 val resp = Output(Vec(TageBanks, new SCResp)) 29 val update = Input(new SCUpdate) 30} 31 32abstract class BaseSCTable(val r: Int = 1024, val cb: Int = 6, val h: Int = 0) extends TageModule { 33 val io = IO(new SCTableIO) 34 def getCenteredValue(ctr: SInt): SInt = (ctr << 1) + 1.S 35} 36 37class FakeSCTable extends BaseSCTable { 38 io.resp := Wire(0.U.asTypeOf(Vec(TageBanks, new SCResp))) 39} 40 41class SCTable(val nRows: Int, val ctrBits: Int, val histLen: Int) extends BaseSCTable(nRows, ctrBits, histLen) { 42 43 val table = List.fill(TageBanks) { 44 List.fill(2) { 45 Module(new SRAMTemplate(SInt(ctrBits.W), set=nRows, shouldReset=false, holdRead=true, singlePort=false)) 46 } 47 } 48 49 def compute_folded_hist(hist: UInt, l: Int) = { 50 if (histLen > 0) { 51 val nChunks = (histLen + l - 1) / l 52 val hist_chunks = (0 until nChunks) map {i => 53 hist(min((i+1)*l, histLen)-1, i*l) 54 } 55 hist_chunks.reduce(_^_) 56 } 57 else 0.U 58 } 59 60 def getIdx(hist: UInt, pc: UInt) = { 61 (compute_folded_hist(hist, log2Ceil(nRows)) ^ (pc >> 1.U))(log2Ceil(nRows)-1,0) 62 } 63 64 def ctrUpdate(ctr: SInt, cond: Bool): SInt = signedSatUpdate(ctr, ctrBits, cond) 65 66 val doing_reset = RegInit(true.B) 67 val reset_idx = RegInit(0.U(log2Ceil(nRows).W)) 68 reset_idx := reset_idx + doing_reset 69 when (reset_idx === (nRows-1).U) { doing_reset := false.B } 70 71 val idx = getIdx(io.req.bits.hist, io.req.bits.pc) 72 val idxLatch = RegEnable(idx, enable=io.req.valid) 73 74 val table_r = WireInit(0.U.asTypeOf(Vec(TageBanks,Vec(2, SInt(ctrBits.W))))) 75 76 val baseBank = io.req.bits.pc(log2Up(TageBanks), 1) 77 val baseBankLatch = RegEnable(baseBank, enable=io.req.valid) 78 79 val bankIdxInOrder = VecInit((0 until TageBanks).map(b => (baseBankLatch +& b.U)(log2Up(TageBanks)-1, 0))) 80 val realMask = circularShiftLeft(io.req.bits.mask, TageBanks, baseBank) 81 val maskLatch = RegEnable(io.req.bits.mask, enable=io.req.valid) 82 83 val update_idx = getIdx(io.update.hist, io.update.pc - (io.update.fetchIdx << 1)) 84 val update_wdata = ctrUpdate(io.update.oldCtr, io.update.taken) 85 86 87 for (b <- 0 until TageBanks) { 88 for (i <- 0 to 1) { 89 table(b)(i).reset := reset.asBool 90 table(b)(i).io.r.req.valid := io.req.valid && realMask(b) 91 table(b)(i).io.r.req.bits.setIdx := idx 92 93 table_r(b)(i) := table(b)(i).io.r.resp.data(0) 94 95 table(b)(i).io.w.req.valid := (io.update.mask(b) && i.U === io.update.tagePred.asUInt) || doing_reset 96 table(b)(i).io.w.req.bits.setIdx := Mux(doing_reset, reset_idx, update_idx) 97 table(b)(i).io.w.req.bits.data := Mux(doing_reset, 0.S, update_wdata) 98 } 99 100 } 101 102 (0 until TageBanks).map(b => { 103 io.resp(b).ctr := table_r(bankIdxInOrder(b)) 104 }) 105 106} 107 108class SCThreshold(val ctrBits: Int = 5, val initVal: Int = 5) extends TageBundle { 109 val thres = UInt(ctrBits.W) 110 def update(cause: Bool): SCThreshold = { 111 val res = Wire(new SCThreshold(this.ctrBits)) 112 res.thres := satUpdate(this.thres, this.ctrBits, cause) 113 res 114 } 115} 116 117object SCThreshold { 118 def apply(bits: Int) = { 119 val t = Wire(new SCThreshold(ctrBits=bits)) 120 t 121 } 122}