xref: /XiangShan/src/main/scala/xiangshan/frontend/RAS.scala (revision 10b9babd805e02359c6fee797c66e99cd4a296ec)
1package xiangshan.frontend
2
3import chisel3._
4import chisel3.util._
5import xiangshan._
6import xiangshan.backend.ALUOpType
7import utils._
8
9class RAS extends BasePredictor
10{
11    class RASResp extends Resp
12    {
13        val target =UInt(VAddrBits.W)
14        val specEmpty = Bool()
15    }
16
17    class RASBranchInfo extends Meta
18    {
19        val rasSp = UInt(log2Up(RasSize).W)
20        val rasTopCtr = UInt(8.W)
21        val rasToqAddr = UInt(VAddrBits.W)
22    }
23
24    class RASIO extends DefaultBasePredictorIO
25    {
26        val is_ret = Input(Bool())
27        val callIdx = Flipped(ValidIO(UInt(log2Ceil(PredictWidth).W)))
28        val isRVC = Input(Bool())
29        val isLastHalfRVI = Input(Bool())
30        val recover =  Flipped(ValidIO(new BranchUpdateInfo))
31        val out = ValidIO(new RASResp)
32        val branchInfo = Output(new RASBranchInfo)
33    }
34
35    def rasEntry() = new Bundle {
36        val retAddr = UInt(VAddrBits.W)
37        val ctr = UInt(8.W) // layer of nested call functions
38    }
39    override val io = IO(new RASIO)
40
41    // val ras_0 = Reg(Vec(RasSize, rasEntry()))  //RegInit(0.U)asTypeOf(Vec(RasSize,rasEntry)) cause comb loop
42    // val ras_1 = Reg(Vec(RasSize, rasEntry()))
43    // val sp_0 = RegInit(0.U(log2Up(RasSize).W))
44    // val sp_1 = RegInit(0.U(log2Up(RasSize).W))
45    // val choose_bit = RegInit(false.B)   //start with 0
46    // val spec_ras = Mux(choose_bit, ras_1, ras_0)
47    // val spec_sp = Mux(choose_bit,sp_1,sp_0)
48    // val commit_ras = Mux(choose_bit, ras_0, ras_1)
49    // val commit_sp = Mux(choose_bit,sp_0,sp_1)
50
51    val spec_ras = Reg(Vec(RasSize, rasEntry()))
52    val spec_sp = RegInit(0.U(log2Up(RasSize).W))
53    val commit_ras = Reg(Vec(RasSize, rasEntry()))
54    val commit_sp = RegInit(0.U(log2Up(RasSize).W))
55
56
57    val spec_is_empty = spec_sp === 0.U
58    val spec_is_full = spec_sp === (RasSize - 1).U
59
60    val spec_ras_top_entry = spec_ras(spec_sp-1.U)
61    val spec_ras_top_addr = spec_ras_top_entry.retAddr
62    val spec_ras_top_ctr = spec_ras_top_entry.ctr
63    //no need to pass the ras branchInfo
64    io.branchInfo.rasSp := DontCare
65    io.branchInfo.rasTopCtr := DontCare
66    io.branchInfo.rasToqAddr := DontCare
67
68    io.out.valid := !spec_is_empty && io.is_ret
69    io.out.bits.specEmpty := spec_is_empty
70
71    // update spec RAS
72    // speculative update RAS
73    val spec_push = !spec_is_full && io.callIdx.valid && io.pc.valid
74    val spec_pop = !spec_is_empty && io.is_ret && io.pc.valid
75    val spec_new_addr = io.pc.bits + (io.callIdx.bits << 1.U) + Mux(io.isRVC,2.U,Mux(io.isLastHalfRVI, 2.U, 4.U))
76    val spec_ras_write = WireInit(0.U.asTypeOf(rasEntry()))
77    val sepc_alloc_new = spec_new_addr =/= spec_ras_top_addr
78    when (spec_push) {
79        //push
80        spec_ras_write.ctr := 1.U
81        spec_ras_write.retAddr := spec_new_addr
82        when(sepc_alloc_new){
83            spec_sp := spec_sp + 1.U
84            spec_ras(spec_sp) := spec_ras_write
85        }.otherwise{
86            spec_ras_top_ctr := spec_ras_top_ctr + 1.U
87        }
88    }
89
90    when (spec_pop) {
91        //pop
92        when (spec_ras_top_ctr === 1.U) {
93            spec_sp := Mux(spec_sp === 0.U, 0.U, spec_sp - 1.U)
94        }.otherwise {
95           spec_ras_top_ctr := spec_ras_top_ctr - 1.U
96        }
97    }
98    io.out.bits.target := spec_ras_top_addr
99    // TODO: back-up stack for ras
100    // use checkpoint to recover RAS
101
102    val commit_is_empty = commit_sp === 0.U
103    val commit_is_full = commit_sp === (RasSize - 1).U
104    val commit_ras_top_entry = commit_ras(commit_sp-1.U)
105    val commit_ras_top_addr = commit_ras_top_entry.retAddr
106    val commit_ras_top_ctr = commit_ras_top_entry.ctr
107    //update commit ras
108    val commit_push = !commit_is_full && io.recover.valid && io.recover.bits.pd.isCall
109    val commit_pop = !commit_is_empty && io.recover.valid && io.recover.bits.pd.isRet
110    val commit_new_addr = Mux(io.recover.bits.pd.isRVC,io.recover.bits.pc + 2.U,io.recover.bits.pc + 4.U)
111    val commit_ras_write = WireInit(0.U.asTypeOf(rasEntry()))
112    val commit_alloc_new = commit_new_addr =/= commit_ras_top_addr
113    when (commit_push) {
114        //push
115        commit_ras_write.ctr := 1.U
116        commit_ras_write.retAddr := commit_new_addr
117        when(commit_alloc_new){
118            commit_sp := commit_sp + 1.U
119            commit_ras(commit_sp) := commit_ras_write
120        }.otherwise{
121            commit_ras_top_ctr := commit_ras_top_ctr + 1.U
122        }
123    }
124
125    when (commit_pop) {
126        //pop
127        when (commit_ras_top_ctr === 1.U) {
128            commit_sp := Mux(commit_sp === 0.U, 0.U, commit_sp - 1.U)
129        }.otherwise {
130           commit_ras_top_ctr := commit_ras_top_ctr - 1.U
131        }
132    }
133
134    val copy_valid = io.recover.valid && io.recover.bits.isMisPred
135    val copy_next = RegNext(copy_valid)
136    when(copy_next)
137    {
138        for(i <- 0 until RasSize)
139        {
140            spec_ras(i) := commit_ras(i)
141            spec_sp := commit_sp
142        }
143    }
144
145    if (BPUDebug && debug) {
146        XSDebug("----------------RAS(spec)----------------\n")
147        XSDebug("  index       addr           ctr \n")
148        for(i <- 0 until RasSize){
149            XSDebug("  (%d)   0x%x      %d",i.U,spec_ras(i).retAddr,spec_ras(i).ctr)
150            when(i.U === spec_sp){XSDebug(false,true.B,"   <----sp")}
151            XSDebug(false,true.B,"\n")
152        }
153        XSDebug("----------------RAS(commit)----------------\n")
154        XSDebug("  index       addr           ctr \n")
155        for(i <- 0 until RasSize){
156            XSDebug("  (%d)   0x%x      %d",i.U,commit_ras(i).retAddr,commit_ras(i).ctr)
157            when(i.U === commit_sp){XSDebug(false,true.B,"   <----sp")}
158            XSDebug(false,true.B,"\n")
159        }
160
161        XSDebug(spec_push, "(spec_ras)push  inAddr: 0x%x  inCtr: %d |  allocNewEntry:%d |   sp:%d \n",spec_ras_write.retAddr,spec_ras_write.ctr,sepc_alloc_new,spec_sp.asUInt)
162        XSDebug(spec_pop, "(spec_ras)pop outValid:%d  outAddr: 0x%x \n",io.out.valid,io.out.bits.target)
163        XSDebug(commit_push, "(commit_ras)push  inAddr: 0x%x  inCtr: %d |  allocNewEntry:%d |   sp:%d \n",commit_ras_write.retAddr,commit_ras_write.ctr,sepc_alloc_new,commit_sp.asUInt)
164        XSDebug(commit_pop, "(commit_ras)pop outValid:%d  outAddr: 0x%x \n",io.out.valid,io.out.bits.target)
165        XSDebug("copyValid:%d copyNext:%d \n",copy_valid,copy_next)
166    }
167
168
169    // val recoverSp = io.recover.bits.brInfo.rasSp
170    // val recoverCtr = io.recover.bits.brInfo.rasTopCtr
171    // val recoverAddr = io.recover.bits.brInfo.rasToqAddr
172    // val recover_top = ras(recoverSp - 1.U)
173    // when (recover_valid) {
174    //     sp := recoverSp
175    //     recover_top.ctr := recoverCtr
176    //     recover_top.retAddr := recoverAddr
177    //     XSDebug("RAS update: SP:%d , Ctr:%d \n",recoverSp,recoverCtr)
178    // }
179    // val recover_and_push = recover_valid && push
180    // val recover_and_pop = recover_valid && pop
181    // val recover_alloc_new = new_addr =/= recoverAddr
182    // when(recover_and_push)
183    // {
184    //     when(recover_alloc_new){
185    //         sp := recoverSp + 1.U
186    //         ras(recoverSp).retAddr := new_addr
187    //         ras(recoverSp).ctr := 1.U
188    //         recover_top.retAddr := recoverAddr
189    //         recover_top.ctr := recoverCtr
190    //     } .otherwise{
191    //         sp := recoverSp
192    //         recover_top.ctr := recoverCtr + 1.U
193    //         recover_top.retAddr := recoverAddr
194    //     }
195    // } .elsewhen(recover_and_pop)
196    // {
197    //     io.out.bits.target := recoverAddr
198    //     when ( recover_top.ctr === 1.U) {
199    //         sp := recoverSp - 1.U
200    //     }.otherwise {
201    //         sp := recoverSp
202    //        recover_top.ctr := recoverCtr - 1.U
203    //     }
204    // }
205
206}
207