1package xiangshan.backend.fu.wrapper 2 3import org.chipsalliance.cde.config.Parameters 4import chisel3.{VecInit, _} 5import chisel3.util._ 6import chisel3.util.experimental.decode.{QMCMinimizer, TruthTable, decoder} 7import utility.DelayN 8import utils.XSError 9import xiangshan.XSCoreParamsKey 10import xiangshan.backend.fu.vector.Bundles.{VConfig, VSew, ma} 11import xiangshan.backend.fu.vector.{Mgu, Mgtu, VecPipedFuncUnit} 12import xiangshan.backend.fu.vector.Utils.VecDataToMaskDataVec 13import xiangshan.backend.fu.vector.utils.VecDataSplitModule 14import xiangshan.backend.fu.{FuConfig, FuType} 15import xiangshan.ExceptionNO 16import yunsuan.{OpType, VialuFixType} 17import yunsuan.vector.alu.{VIntFixpAlu64b, VIntFixpDecode, VIntFixpTable} 18import yunsuan.encoding.{VdType, Vs1IntType, Vs2IntType} 19import yunsuan.encoding.Opcode.VialuOpcode 20import yunsuan.vector.SewOH 21 22class VIAluSrcTypeIO extends Bundle { 23 val in = Input(new Bundle { 24 val fuOpType: UInt = OpType() 25 val vsew: UInt = VSew() 26 val isReverse: Bool = Bool() // vrsub, vrdiv 27 val isExt: Bool = Bool() 28 val isDstMask: Bool = Bool() // vvm, vvvm, mmm 29 val isMove: Bool = Bool() // vmv.s.x, vmv.v.v, vmv.v.x, vmv.v.i 30 }) 31 val out = Output(new Bundle { 32 val vs1Type: UInt = Vs1IntType() 33 val vs2Type: UInt = Vs2IntType() 34 val vdType: UInt = VdType() 35 val illegal: Bool = Bool() 36 val isVextF2: Bool = Bool() 37 val isVextF4: Bool = Bool() 38 val isVextF8: Bool = Bool() 39 }) 40} 41 42class VIAluSrcTypeModule extends Module { 43 val io: VIAluSrcTypeIO = IO(new VIAluSrcTypeIO) 44 45 private val vsew = io.in.vsew 46 private val isExt = io.in.isExt 47 private val isDstMask = io.in.isDstMask 48 49 private val opcode = VialuFixType.getOpcode(io.in.fuOpType) 50 private val isSign = VialuFixType.isSigned(io.in.fuOpType) 51 private val format = VialuFixType.getFormat(io.in.fuOpType) 52 53 private val vsewX2 = vsew + 1.U 54 private val vsewF2 = vsew - 1.U 55 private val vsewF4 = vsew - 2.U 56 private val vsewF8 = vsew - 3.U 57 58 private val isAddSub = opcode === VialuOpcode.vadd || opcode === VialuOpcode.vsub 59 private val isShiftRight = Seq(VialuOpcode.vsrl, VialuOpcode.vsra, VialuOpcode.vssrl, VialuOpcode.vssra).map(fmt => fmt === format).reduce(_ || _) 60 private val isVext = opcode === VialuOpcode.vext 61 62 private val isWiden = isAddSub && Seq(VialuFixType.FMT.VVW, VialuFixType.FMT.WVW).map(fmt => fmt === format).reduce(_ || _) 63 private val isNarrow = isShiftRight && format === VialuFixType.FMT.WVV 64 private val isVextF2 = isVext && format === VialuFixType.FMT.VF2 65 private val isVextF4 = isVext && format === VialuFixType.FMT.VF4 66 private val isVextF8 = isVext && format === VialuFixType.FMT.VF8 67 68 // check illegal 69 private val widenIllegal = isWiden && vsewX2 === VSew.e8 70 private val narrowIllegal = isNarrow && vsewF2 === VSew.e64 71 private val vextIllegal = (isVextF2 && (vsewF2 === VSew.e64)) || 72 (isVextF4 && (vsewF4 === VSew.e64)) || 73 (isVextF8 && (vsewF8 === VSew.e64)) 74 // Todo: use it 75 private val illegal = widenIllegal || narrowIllegal || vextIllegal 76 77 private val intType = Cat(0.U(1.W), isSign) 78 79 private class Vs2Vs1VdSew extends Bundle { 80 val vs2 = VSew() 81 val vs1 = VSew() 82 val vd = VSew() 83 } 84 85 private class Vs2Vs1VdType extends Bundle { 86 val vs2 = Vs2IntType() 87 val vs1 = Vs1IntType() 88 val vd = VdType() 89 } 90 91 private val addSubSews = Mux1H(Seq( 92 (format === VialuFixType.FMT.VVV) -> Cat(vsew, vsew, vsew), 93 (format === VialuFixType.FMT.VVW) -> Cat(vsew, vsew, vsewX2), 94 (format === VialuFixType.FMT.WVW) -> Cat(vsewX2, vsew, vsewX2), 95 (format === VialuFixType.FMT.WVV) -> Cat(vsewX2, vsew, vsew), 96 )).asTypeOf(new Vs2Vs1VdSew) 97 98 private val vextSews = Mux1H(Seq( 99 (format === VialuFixType.FMT.VF2) -> Cat(vsewF2, vsewF2, vsew), 100 (format === VialuFixType.FMT.VF4) -> Cat(vsewF4, vsewF4, vsew), 101 (format === VialuFixType.FMT.VF8) -> Cat(vsewF8, vsewF8, vsew), 102 )).asTypeOf(new Vs2Vs1VdSew) 103 104 private val maskTypes = Mux1H(Seq( 105 (format === VialuFixType.FMT.VVM) -> Cat(Cat(intType, vsew), Cat(intType, vsew), VdType.mask), 106 (format === VialuFixType.FMT.VVMM) -> Cat(Cat(intType, vsew), Cat(intType, vsew), VdType.mask), 107 (format === VialuFixType.FMT.MMM) -> Cat(Vs2IntType.mask, Vs1IntType.mask, VdType.mask), 108 )).asTypeOf(new Vs2Vs1VdType) 109 110 private val vs2Type = Mux1H(Seq( 111 isDstMask -> maskTypes.vs2, 112 isExt -> Cat(intType, vextSews.vs2), 113 (!isExt && !isDstMask) -> Cat(intType, addSubSews.vs2), 114 )) 115 private val vs1Type = Mux1H(Seq( 116 isDstMask -> maskTypes.vs1, 117 isExt -> Cat(intType, vextSews.vs1), 118 (!isExt && !isDstMask) -> Cat(intType, addSubSews.vs1), 119 )) 120 private val vdType = Mux1H(Seq( 121 isDstMask -> maskTypes.vd, 122 isExt -> Cat(intType, vextSews.vd), 123 (!isExt && !isDstMask) -> Cat(intType, addSubSews.vd), 124 )) 125 126 io.out.vs2Type := vs2Type 127 io.out.vs1Type := vs1Type 128 io.out.vdType := vdType 129 io.out.illegal := illegal 130 io.out.isVextF2 := isVextF2 131 io.out.isVextF4 := isVextF4 132 io.out.isVextF8 := isVextF8 133} 134 135class VIAluFix(cfg: FuConfig)(implicit p: Parameters) extends VecPipedFuncUnit(cfg) { 136 XSError(io.in.valid && io.in.bits.ctrl.fuOpType === VialuFixType.dummy, "VialuF OpType not supported") 137 138 // config params 139 private val dataWidth = cfg.dataBits 140 private val dataWidthOfDataModule = 64 141 private val numVecModule = dataWidth / dataWidthOfDataModule 142 143 // modules 144 private val typeMod = Module(new VIAluSrcTypeModule) 145 private val vs2Split = Module(new VecDataSplitModule(dataWidth, dataWidthOfDataModule)) 146 private val vs1Split = Module(new VecDataSplitModule(dataWidth, dataWidthOfDataModule)) 147 private val oldVdSplit = Module(new VecDataSplitModule(dataWidth, dataWidthOfDataModule)) 148 private val vIntFixpAlus = Seq.fill(numVecModule)(Module(new VIntFixpAlu64b)) 149 private val mgu = Module(new Mgu(dataWidth)) 150 private val mgtu = Module(new Mgtu(dataWidth)) 151 152 /** 153 * [[typeMod]]'s in connection 154 */ 155 typeMod.io.in.fuOpType := fuOpType 156 typeMod.io.in.vsew := vsew 157 typeMod.io.in.isReverse := isReverse 158 typeMod.io.in.isExt := isExt 159 typeMod.io.in.isDstMask := vecCtrl.isDstMask 160 typeMod.io.in.isMove := isMove 161 162 private val vs2GroupedVec32b: Vec[UInt] = VecInit(vs2Split.io.outVec32b.zipWithIndex.groupBy(_._2 % 2).map(x => x._1 -> x._2.map(_._1)).values.map(x => Cat(x.reverse)).toSeq) 163 private val vs2GroupedVec16b: Vec[UInt] = VecInit(vs2Split.io.outVec16b.zipWithIndex.groupBy(_._2 % 2).map(x => x._1 -> x._2.map(_._1)).values.map(x => Cat(x.reverse)).toSeq) 164 private val vs2GroupedVec8b: Vec[UInt] = VecInit(vs2Split.io.outVec8b.zipWithIndex.groupBy(_._2 % 2).map(x => x._1 -> x._2.map(_._1)).values.map(x => Cat(x.reverse)).toSeq) 165 private val vs1GroupedVec: Vec[UInt] = VecInit(vs1Split.io.outVec32b.zipWithIndex.groupBy(_._2 % 2).map(x => x._1 -> x._2.map(_._1)).values.map(x => Cat(x.reverse)).toSeq) 166 167 /** 168 * In connection of [[vs2Split]], [[vs1Split]] and [[oldVdSplit]] 169 */ 170 vs2Split.io.inVecData := vs2 171 vs1Split.io.inVecData := vs1 172 oldVdSplit.io.inVecData := oldVd 173 174 /** 175 * [[vIntFixpAlus]]'s in connection 176 */ 177 private val opcode = VialuFixType.getOpcode(inCtrl.fuOpType).asTypeOf(vIntFixpAlus.head.io.opcode) 178 private val vs1Type = typeMod.io.out.vs1Type 179 private val vs2Type = typeMod.io.out.vs2Type 180 private val vdType = typeMod.io.out.vdType 181 private val isVextF2 = typeMod.io.out.isVextF2 182 private val isVextF4 = typeMod.io.out.isVextF4 183 private val isVextF8 = typeMod.io.out.isVextF8 184 185 private val truthTable = TruthTable(VIntFixpTable.table, VIntFixpTable.default) 186 private val decoderOut = decoder(QMCMinimizer, Cat(opcode.op), truthTable) 187 private val vIntFixpDecode = decoderOut.asTypeOf(new VIntFixpDecode) 188 private val isFixp = Mux(vIntFixpDecode.misc, opcode.isScalingShift, opcode.isSatAdd || opcode.isAvgAdd) 189 private val widen = opcode.isAddSub && vs1Type(1, 0) =/= vdType(1, 0) 190 private val widen_vs2 = widen && vs2Type(1, 0) =/= vdType(1, 0) 191 private val eewVs1 = SewOH(vs1Type(1, 0)) 192 private val eewVd = SewOH(vdType(1, 0)) 193 194 // Extension instructions 195 private val vf2 = isVextF2 196 private val vf4 = isVextF4 197 private val vf8 = isVextF8 198 199 private val vs1VecUsed: Vec[UInt] = Mux(widen || isNarrow, vs1GroupedVec, vs1Split.io.outVec64b) 200 private val vs2VecUsed = Wire(Vec(numVecModule, UInt(64.W))) 201 when(vf2) { 202 vs2VecUsed := vs2GroupedVec32b 203 }.elsewhen(vf4) { 204 vs2VecUsed := vs2GroupedVec16b 205 }.elsewhen(vf8) { 206 vs2VecUsed := vs2GroupedVec8b 207 }.otherwise { 208 vs2VecUsed := vs2Split.io.outVec64b 209 } 210 211 private val vs2Adder = Mux(widen_vs2, vs2GroupedVec32b, vs2Split.io.outVec64b) 212 213 // mask 214 private val maskDataVec: Vec[UInt] = VecDataToMaskDataVec(srcMask, vsew) 215 private val maskIdx = Mux(isNarrow, (vuopIdx >> 1.U).asUInt, vuopIdx) 216 private val eewVd_is_1b = vdType === VdType.mask 217 private val maskUsed = splitMask(maskDataVec(maskIdx), Mux(eewVd_is_1b, eewVs1, eewVd)) 218 219 private val oldVdUsed = splitMask(VecDataToMaskDataVec(oldVd, vs1Type(1, 0))(vuopIdx), eewVs1) 220 221 vIntFixpAlus.zipWithIndex.foreach { 222 case (mod, i) => 223 mod.io.fire := io.in.valid 224 mod.io.opcode := opcode 225 226 mod.io.info.vm := vm 227 mod.io.info.ma := vma 228 mod.io.info.ta := vta 229 mod.io.info.vlmul := vlmul 230 mod.io.info.vl := vl 231 mod.io.info.vstart := vstart 232 mod.io.info.uopIdx := vuopIdx 233 mod.io.info.vxrm := vxrm 234 235 mod.io.srcType(0) := vs2Type 236 mod.io.srcType(1) := vs1Type 237 mod.io.vdType := vdType 238 mod.io.narrow := isNarrow 239 mod.io.isSub := vIntFixpDecode.sub 240 mod.io.isMisc := vIntFixpDecode.misc 241 mod.io.isFixp := isFixp 242 mod.io.widen := widen 243 mod.io.widen_vs2 := widen_vs2 244 mod.io.vs1 := vs1VecUsed(i) 245 mod.io.vs2_adder := vs2Adder(i) 246 mod.io.vs2_misc := vs2VecUsed(i) 247 mod.io.vmask := maskUsed(i) 248 mod.io.oldVd := oldVdUsed(i) 249 } 250 251 /** 252 * [[mgu]]'s in connection 253 */ 254 //private val outEewVs1 = DelayN(eewVs1, latency) 255 private val outEewVs1 = SNReg(eewVs1, latency) 256 257 private val outVd = Cat(vIntFixpAlus.reverse.map(_.io.vd)) 258 private val outCmp = Mux1H(outEewVs1.oneHot, Seq(8, 4, 2, 1).map( 259 k => Cat(vIntFixpAlus.reverse.map(_.io.cmpOut(k - 1, 0))))) 260 private val outNarrow = Cat(vIntFixpAlus.reverse.map(_.io.narrowVd)) 261 private val outOpcode = VialuFixType.getOpcode(outCtrl.fuOpType).asTypeOf(vIntFixpAlus.head.io.opcode) 262 263 private val numBytes = dataWidth / 8 264 private val maxMaskIdx = numBytes 265 private val maxVdIdx = 8 266 private val elementsInOneUop = Mux1H(outEewVs1.oneHot, Seq(1, 2, 4, 8).map(k => (numBytes / k).U(5.W))) 267 private val vdIdx = outVecCtrl.vuopIdx(2, 0) 268 private val elementsComputed = Mux1H(Seq.tabulate(maxVdIdx)(i => (vdIdx === i.U) -> (elementsInOneUop * i.U))) 269 val outCmpWithTail = Wire(Vec(maxMaskIdx, UInt(1.W))) 270 // set the bits in vd to 1 if the index is larger than vl and vta is true 271 for (i <- 0 until maxMaskIdx) { 272 when(elementsComputed +& i.U >= outVl) { 273 // always operate under a tail-agnostic policy 274 outCmpWithTail(i) := 1.U 275 }.otherwise { 276 outCmpWithTail(i) := outCmp(i) 277 } 278 } 279 280 /* insts whose mask is not used to generate 'agnosticEn' and 'activeEn' in mgu: 281 * vadc, vmadc... 282 * vmerge 283 */ 284 private val needNoMask = VialuFixType.needNoMask(outCtrl.fuOpType) 285 private val maskToMgu = Mux(needNoMask, allMaskTrue, outSrcMask) 286 287 private val outFormat = VialuFixType.getFormat(outCtrl.fuOpType) 288 private val outWiden = (outFormat === VialuFixType.FMT.VVW | outFormat === VialuFixType.FMT.WVW) & !outVecCtrl.isExt & !outVecCtrl.isDstMask 289 private val narrow = outVecCtrl.isNarrow 290 private val dstMask = outVecCtrl.isDstMask 291 private val outVxsat = Mux(narrow, Cat(vIntFixpAlus.reverse.map(_.io.vxsat(3, 0))), Cat(vIntFixpAlus.reverse.map(_.io.vxsat))) 292 293 // the result of narrow inst which needs concat 294 private val narrowNeedCat = outVecCtrl.vuopIdx(0).asBool && narrow 295 private val outNarrowVd = Mux(narrowNeedCat, Cat(outNarrow, outOldVd(dataWidth / 2 - 1, 0)), outNarrow) 296 private val outVxsatReal = Mux(narrowNeedCat, Cat(outVxsat(numBytes / 2 - 1, 0), 0.U((numBytes / 2).W)), outVxsat) 297 298 private val outEew = Mux(outWiden, outVecCtrl.vsew + 1.U, outVecCtrl.vsew) 299 300 /* 301 * vl of vmv.x.s is 1 302 */ 303 private val outIsVmvsx = outOpcode.isVmvsx 304 305 /* 306 * when vstart >= vl, no need to update vd, the old value should be kept 307 */ 308 private val outVstartGeVl = outVstart >= outVl 309 310 mgu.io.in.vd := MuxCase(outVd, Seq( 311 narrow -> outNarrowVd, 312 dstMask -> outCmpWithTail.asUInt, 313 )) 314 mgu.io.in.oldVd := outOldVd 315 mgu.io.in.mask := maskToMgu 316 mgu.io.in.info.ta := outVecCtrl.vta 317 mgu.io.in.info.ma := outVecCtrl.vma 318 mgu.io.in.info.vl := Mux(outIsVmvsx, 1.U, outVl) 319 mgu.io.in.info.vlmul := outVecCtrl.vlmul 320 mgu.io.in.info.valid := validVec.last 321 mgu.io.in.info.vstart := outVecCtrl.vstart 322 mgu.io.in.info.eew := outEew 323 mgu.io.in.info.vsew := outVecCtrl.vsew 324 mgu.io.in.info.vdIdx := outVecCtrl.vuopIdx 325 mgu.io.in.info.narrow := narrow 326 mgu.io.in.info.dstMask := dstMask 327 mgu.io.in.isIndexedVls := false.B 328 329 /** 330 * [[mgtu]]'s in connection, for vmask instructions 331 */ 332 mgtu.io.in.vd := Mux(dstMask && !outVecCtrl.isOpMask, mgu.io.out.vd, outVd) 333 mgtu.io.in.vl := outVl 334 335 io.out.bits.res.data := Mux(outVstartGeVl, outOldVd, Mux(dstMask, mgtu.io.out.vd, mgu.io.out.vd)) 336 io.out.bits.res.vxsat.get := Mux(outVstartGeVl, false.B, (outVxsatReal & mgu.io.out.active).orR) 337 io.out.bits.ctrl.exceptionVec.get(ExceptionNO.illegalInstr) := mgu.io.out.illegal && !outVstartGeVl 338 339 // util function 340 def splitMask(maskIn: UInt, sew: SewOH): Vec[UInt] = { 341 val maskWidth = maskIn.getWidth 342 val result = Wire(Vec(maskWidth / 8, UInt(8.W))) 343 for ((resultData, i) <- result.zipWithIndex) { 344 resultData := Mux1H(Seq( 345 sew.is8 -> maskIn(i * 8 + 7, i * 8), 346 sew.is16 -> Cat(0.U((8 - 4).W), maskIn(i * 4 + 3, i * 4)), 347 sew.is32 -> Cat(0.U((8 - 2).W), maskIn(i * 2 + 1, i * 2)), 348 sew.is64 -> Cat(0.U((8 - 1).W), maskIn(i)), 349 )) 350 } 351 result 352 } 353 354}