1 //==--------------- llvm/CodeGen/SDPatternMatch.h ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// Contains matchers for matching SelectionDAG nodes and values.
10 ///
11 //===----------------------------------------------------------------------===//
12
13 #ifndef LLVM_CODEGEN_SDPATTERNMATCH_H
14 #define LLVM_CODEGEN_SDPATTERNMATCH_H
15
16 #include "llvm/ADT/APInt.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/CodeGen/SelectionDAG.h"
19 #include "llvm/CodeGen/SelectionDAGNodes.h"
20 #include "llvm/CodeGen/TargetLowering.h"
21
22 namespace llvm {
23 namespace SDPatternMatch {
24
25 /// MatchContext can repurpose existing patterns to behave differently under
26 /// a certain context. For instance, `m_Opc(ISD::ADD)` matches plain ADD nodes
27 /// in normal circumstances, but matches VP_ADD nodes under a custom
28 /// VPMatchContext. This design is meant to facilitate code / pattern reusing.
29 class BasicMatchContext {
30 const SelectionDAG *DAG;
31 const TargetLowering *TLI;
32
33 public:
BasicMatchContext(const SelectionDAG * DAG)34 explicit BasicMatchContext(const SelectionDAG *DAG)
35 : DAG(DAG), TLI(DAG ? &DAG->getTargetLoweringInfo() : nullptr) {}
36
BasicMatchContext(const TargetLowering * TLI)37 explicit BasicMatchContext(const TargetLowering *TLI)
38 : DAG(nullptr), TLI(TLI) {}
39
40 // A valid MatchContext has to implement the following functions.
41
getDAG()42 const SelectionDAG *getDAG() const { return DAG; }
43
getTLI()44 const TargetLowering *getTLI() const { return TLI; }
45
46 /// Return true if N effectively has opcode Opcode.
match(SDValue N,unsigned Opcode)47 bool match(SDValue N, unsigned Opcode) const {
48 return N->getOpcode() == Opcode;
49 }
50 };
51
52 template <typename Pattern, typename MatchContext>
sd_context_match(SDValue N,const MatchContext & Ctx,Pattern && P)53 [[nodiscard]] bool sd_context_match(SDValue N, const MatchContext &Ctx,
54 Pattern &&P) {
55 return P.match(Ctx, N);
56 }
57
58 template <typename Pattern, typename MatchContext>
sd_context_match(SDNode * N,const MatchContext & Ctx,Pattern && P)59 [[nodiscard]] bool sd_context_match(SDNode *N, const MatchContext &Ctx,
60 Pattern &&P) {
61 return sd_context_match(SDValue(N, 0), Ctx, P);
62 }
63
64 template <typename Pattern>
sd_match(SDNode * N,const SelectionDAG * DAG,Pattern && P)65 [[nodiscard]] bool sd_match(SDNode *N, const SelectionDAG *DAG, Pattern &&P) {
66 return sd_context_match(N, BasicMatchContext(DAG), P);
67 }
68
69 template <typename Pattern>
sd_match(SDValue N,const SelectionDAG * DAG,Pattern && P)70 [[nodiscard]] bool sd_match(SDValue N, const SelectionDAG *DAG, Pattern &&P) {
71 return sd_context_match(N, BasicMatchContext(DAG), P);
72 }
73
74 template <typename Pattern>
sd_match(SDNode * N,Pattern && P)75 [[nodiscard]] bool sd_match(SDNode *N, Pattern &&P) {
76 return sd_match(N, nullptr, P);
77 }
78
79 template <typename Pattern>
sd_match(SDValue N,Pattern && P)80 [[nodiscard]] bool sd_match(SDValue N, Pattern &&P) {
81 return sd_match(N, nullptr, P);
82 }
83
84 // === Utilities ===
85 struct Value_match {
86 SDValue MatchVal;
87
88 Value_match() = default;
89
Value_matchValue_match90 explicit Value_match(SDValue Match) : MatchVal(Match) {}
91
matchValue_match92 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
93 if (MatchVal)
94 return MatchVal == N;
95 return N.getNode();
96 }
97 };
98
99 /// Match any valid SDValue.
m_Value()100 inline Value_match m_Value() { return Value_match(); }
101
m_Specific(SDValue N)102 inline Value_match m_Specific(SDValue N) {
103 assert(N);
104 return Value_match(N);
105 }
106
107 struct DeferredValue_match {
108 SDValue &MatchVal;
109
DeferredValue_matchDeferredValue_match110 explicit DeferredValue_match(SDValue &Match) : MatchVal(Match) {}
111
matchDeferredValue_match112 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
113 return N == MatchVal;
114 }
115 };
116
117 /// Similar to m_Specific, but the specific value to match is determined by
118 /// another sub-pattern in the same sd_match() expression. For instance,
119 /// We cannot match `(add V, V)` with `m_Add(m_Value(X), m_Specific(X))` since
120 /// `X` is not initialized at the time it got copied into `m_Specific`. Instead,
121 /// we should use `m_Add(m_Value(X), m_Deferred(X))`.
m_Deferred(SDValue & V)122 inline DeferredValue_match m_Deferred(SDValue &V) {
123 return DeferredValue_match(V);
124 }
125
126 struct Opcode_match {
127 unsigned Opcode;
128
Opcode_matchOpcode_match129 explicit Opcode_match(unsigned Opc) : Opcode(Opc) {}
130
131 template <typename MatchContext>
matchOpcode_match132 bool match(const MatchContext &Ctx, SDValue N) {
133 return Ctx.match(N, Opcode);
134 }
135 };
136
m_Opc(unsigned Opcode)137 inline Opcode_match m_Opc(unsigned Opcode) { return Opcode_match(Opcode); }
138
139 template <unsigned NumUses, typename Pattern> struct NUses_match {
140 Pattern P;
141
NUses_matchNUses_match142 explicit NUses_match(const Pattern &P) : P(P) {}
143
144 template <typename MatchContext>
matchNUses_match145 bool match(const MatchContext &Ctx, SDValue N) {
146 // SDNode::hasNUsesOfValue is pretty expensive when the SDNode produces
147 // multiple results, hence we check the subsequent pattern here before
148 // checking the number of value users.
149 return P.match(Ctx, N) && N->hasNUsesOfValue(NumUses, N.getResNo());
150 }
151 };
152
153 template <typename Pattern>
m_OneUse(const Pattern & P)154 inline NUses_match<1, Pattern> m_OneUse(const Pattern &P) {
155 return NUses_match<1, Pattern>(P);
156 }
157 template <unsigned N, typename Pattern>
m_NUses(const Pattern & P)158 inline NUses_match<N, Pattern> m_NUses(const Pattern &P) {
159 return NUses_match<N, Pattern>(P);
160 }
161
m_OneUse()162 inline NUses_match<1, Value_match> m_OneUse() {
163 return NUses_match<1, Value_match>(m_Value());
164 }
m_NUses()165 template <unsigned N> inline NUses_match<N, Value_match> m_NUses() {
166 return NUses_match<N, Value_match>(m_Value());
167 }
168
169 struct Value_bind {
170 SDValue &BindVal;
171
Value_bindValue_bind172 explicit Value_bind(SDValue &N) : BindVal(N) {}
173
matchValue_bind174 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
175 BindVal = N;
176 return true;
177 }
178 };
179
m_Value(SDValue & N)180 inline Value_bind m_Value(SDValue &N) { return Value_bind(N); }
181
182 template <typename Pattern, typename PredFuncT> struct TLI_pred_match {
183 Pattern P;
184 PredFuncT PredFunc;
185
TLI_pred_matchTLI_pred_match186 TLI_pred_match(const PredFuncT &Pred, const Pattern &P)
187 : P(P), PredFunc(Pred) {}
188
189 template <typename MatchContext>
matchTLI_pred_match190 bool match(const MatchContext &Ctx, SDValue N) {
191 assert(Ctx.getTLI() && "TargetLowering is required for this pattern.");
192 return PredFunc(*Ctx.getTLI(), N) && P.match(Ctx, N);
193 }
194 };
195
196 // Explicit deduction guide.
197 template <typename PredFuncT, typename Pattern>
198 TLI_pred_match(const PredFuncT &Pred, const Pattern &P)
199 -> TLI_pred_match<Pattern, PredFuncT>;
200
201 /// Match legal SDNodes based on the information provided by TargetLowering.
m_LegalOp(const Pattern & P)202 template <typename Pattern> inline auto m_LegalOp(const Pattern &P) {
203 return TLI_pred_match{[](const TargetLowering &TLI, SDValue N) {
204 return TLI.isOperationLegal(N->getOpcode(),
205 N.getValueType());
206 },
207 P};
208 }
209
210 /// Switch to a different MatchContext for subsequent patterns.
211 template <typename NewMatchContext, typename Pattern> struct SwitchContext {
212 const NewMatchContext &Ctx;
213 Pattern P;
214
215 template <typename OrigMatchContext>
matchSwitchContext216 bool match(const OrigMatchContext &, SDValue N) {
217 return P.match(Ctx, N);
218 }
219 };
220
221 template <typename MatchContext, typename Pattern>
m_Context(const MatchContext & Ctx,Pattern && P)222 inline SwitchContext<MatchContext, Pattern> m_Context(const MatchContext &Ctx,
223 Pattern &&P) {
224 return SwitchContext<MatchContext, Pattern>{Ctx, std::move(P)};
225 }
226
227 // === Value type ===
228 struct ValueType_bind {
229 EVT &BindVT;
230
ValueType_bindValueType_bind231 explicit ValueType_bind(EVT &Bind) : BindVT(Bind) {}
232
matchValueType_bind233 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
234 BindVT = N.getValueType();
235 return true;
236 }
237 };
238
239 /// Retreive the ValueType of the current SDValue.
m_VT(EVT & VT)240 inline ValueType_bind m_VT(EVT &VT) { return ValueType_bind(VT); }
241
242 template <typename Pattern, typename PredFuncT> struct ValueType_match {
243 PredFuncT PredFunc;
244 Pattern P;
245
ValueType_matchValueType_match246 ValueType_match(const PredFuncT &Pred, const Pattern &P)
247 : PredFunc(Pred), P(P) {}
248
249 template <typename MatchContext>
matchValueType_match250 bool match(const MatchContext &Ctx, SDValue N) {
251 return PredFunc(N.getValueType()) && P.match(Ctx, N);
252 }
253 };
254
255 // Explicit deduction guide.
256 template <typename PredFuncT, typename Pattern>
257 ValueType_match(const PredFuncT &Pred, const Pattern &P)
258 -> ValueType_match<Pattern, PredFuncT>;
259
260 /// Match a specific ValueType.
261 template <typename Pattern>
m_SpecificVT(EVT RefVT,const Pattern & P)262 inline auto m_SpecificVT(EVT RefVT, const Pattern &P) {
263 return ValueType_match{[=](EVT VT) { return VT == RefVT; }, P};
264 }
m_SpecificVT(EVT RefVT)265 inline auto m_SpecificVT(EVT RefVT) {
266 return ValueType_match{[=](EVT VT) { return VT == RefVT; }, m_Value()};
267 }
268
m_Glue()269 inline auto m_Glue() { return m_SpecificVT(MVT::Glue); }
m_OtherVT()270 inline auto m_OtherVT() { return m_SpecificVT(MVT::Other); }
271
272 /// Match any integer ValueTypes.
m_IntegerVT(const Pattern & P)273 template <typename Pattern> inline auto m_IntegerVT(const Pattern &P) {
274 return ValueType_match{[](EVT VT) { return VT.isInteger(); }, P};
275 }
m_IntegerVT()276 inline auto m_IntegerVT() {
277 return ValueType_match{[](EVT VT) { return VT.isInteger(); }, m_Value()};
278 }
279
280 /// Match any floating point ValueTypes.
m_FloatingPointVT(const Pattern & P)281 template <typename Pattern> inline auto m_FloatingPointVT(const Pattern &P) {
282 return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); }, P};
283 }
m_FloatingPointVT()284 inline auto m_FloatingPointVT() {
285 return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); },
286 m_Value()};
287 }
288
289 /// Match any vector ValueTypes.
m_VectorVT(const Pattern & P)290 template <typename Pattern> inline auto m_VectorVT(const Pattern &P) {
291 return ValueType_match{[](EVT VT) { return VT.isVector(); }, P};
292 }
m_VectorVT()293 inline auto m_VectorVT() {
294 return ValueType_match{[](EVT VT) { return VT.isVector(); }, m_Value()};
295 }
296
297 /// Match fixed-length vector ValueTypes.
m_FixedVectorVT(const Pattern & P)298 template <typename Pattern> inline auto m_FixedVectorVT(const Pattern &P) {
299 return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); }, P};
300 }
m_FixedVectorVT()301 inline auto m_FixedVectorVT() {
302 return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); },
303 m_Value()};
304 }
305
306 /// Match scalable vector ValueTypes.
m_ScalableVectorVT(const Pattern & P)307 template <typename Pattern> inline auto m_ScalableVectorVT(const Pattern &P) {
308 return ValueType_match{[](EVT VT) { return VT.isScalableVector(); }, P};
309 }
m_ScalableVectorVT()310 inline auto m_ScalableVectorVT() {
311 return ValueType_match{[](EVT VT) { return VT.isScalableVector(); },
312 m_Value()};
313 }
314
315 /// Match legal ValueTypes based on the information provided by TargetLowering.
m_LegalType(const Pattern & P)316 template <typename Pattern> inline auto m_LegalType(const Pattern &P) {
317 return TLI_pred_match{[](const TargetLowering &TLI, SDValue N) {
318 return TLI.isTypeLegal(N.getValueType());
319 },
320 P};
321 }
322
323 // === Patterns combinators ===
324 template <typename... Preds> struct And {
matchAnd325 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
326 return true;
327 }
328 };
329
330 template <typename Pred, typename... Preds>
331 struct And<Pred, Preds...> : And<Preds...> {
332 Pred P;
333 And(Pred &&p, Preds &&...preds)
334 : And<Preds...>(std::forward<Preds>(preds)...), P(std::forward<Pred>(p)) {
335 }
336
337 template <typename MatchContext>
338 bool match(const MatchContext &Ctx, SDValue N) {
339 return P.match(Ctx, N) && And<Preds...>::match(Ctx, N);
340 }
341 };
342
343 template <typename... Preds> struct Or {
344 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
345 return false;
346 }
347 };
348
349 template <typename Pred, typename... Preds>
350 struct Or<Pred, Preds...> : Or<Preds...> {
351 Pred P;
352 Or(Pred &&p, Preds &&...preds)
353 : Or<Preds...>(std::forward<Preds>(preds)...), P(std::forward<Pred>(p)) {}
354
355 template <typename MatchContext>
356 bool match(const MatchContext &Ctx, SDValue N) {
357 return P.match(Ctx, N) || Or<Preds...>::match(Ctx, N);
358 }
359 };
360
361 template <typename Pred> struct Not {
362 Pred P;
363
364 explicit Not(const Pred &P) : P(P) {}
365
366 template <typename MatchContext>
367 bool match(const MatchContext &Ctx, SDValue N) {
368 return !P.match(Ctx, N);
369 }
370 };
371 // Explicit deduction guide.
372 template <typename Pred> Not(const Pred &P) -> Not<Pred>;
373
374 /// Match if the inner pattern does NOT match.
375 template <typename Pred> inline Not<Pred> m_Unless(const Pred &P) {
376 return Not{P};
377 }
378
379 template <typename... Preds> And<Preds...> m_AllOf(Preds &&...preds) {
380 return And<Preds...>(std::forward<Preds>(preds)...);
381 }
382
383 template <typename... Preds> Or<Preds...> m_AnyOf(Preds &&...preds) {
384 return Or<Preds...>(std::forward<Preds>(preds)...);
385 }
386
387 template <typename... Preds> auto m_NoneOf(Preds &&...preds) {
388 return m_Unless(m_AnyOf(std::forward<Preds>(preds)...));
389 }
390
391 // === Generic node matching ===
392 template <unsigned OpIdx, typename... OpndPreds> struct Operands_match {
393 template <typename MatchContext>
394 bool match(const MatchContext &Ctx, SDValue N) {
395 // Returns false if there are more operands than predicates;
396 return N->getNumOperands() == OpIdx;
397 }
398 };
399
400 template <unsigned OpIdx, typename OpndPred, typename... OpndPreds>
401 struct Operands_match<OpIdx, OpndPred, OpndPreds...>
402 : Operands_match<OpIdx + 1, OpndPreds...> {
403 OpndPred P;
404
405 Operands_match(OpndPred &&p, OpndPreds &&...preds)
406 : Operands_match<OpIdx + 1, OpndPreds...>(
407 std::forward<OpndPreds>(preds)...),
408 P(std::forward<OpndPred>(p)) {}
409
410 template <typename MatchContext>
411 bool match(const MatchContext &Ctx, SDValue N) {
412 if (OpIdx < N->getNumOperands())
413 return P.match(Ctx, N->getOperand(OpIdx)) &&
414 Operands_match<OpIdx + 1, OpndPreds...>::match(Ctx, N);
415
416 // This is the case where there are more predicates than operands.
417 return false;
418 }
419 };
420
421 template <typename... OpndPreds>
422 auto m_Node(unsigned Opcode, OpndPreds &&...preds) {
423 return m_AllOf(m_Opc(Opcode), Operands_match<0, OpndPreds...>(
424 std::forward<OpndPreds>(preds)...));
425 }
426
427 /// Provide number of operands that are not chain or glue, as well as the first
428 /// index of such operand.
429 template <bool ExcludeChain> struct EffectiveOperands {
430 unsigned Size = 0;
431 unsigned FirstIndex = 0;
432
433 explicit EffectiveOperands(SDValue N) {
434 const unsigned TotalNumOps = N->getNumOperands();
435 FirstIndex = TotalNumOps;
436 for (unsigned I = 0; I < TotalNumOps; ++I) {
437 // Count the number of non-chain and non-glue nodes (we ignore chain
438 // and glue by default) and retreive the operand index offset.
439 EVT VT = N->getOperand(I).getValueType();
440 if (VT != MVT::Glue && VT != MVT::Other) {
441 ++Size;
442 if (FirstIndex == TotalNumOps)
443 FirstIndex = I;
444 }
445 }
446 }
447 };
448
449 template <> struct EffectiveOperands<false> {
450 unsigned Size = 0;
451 unsigned FirstIndex = 0;
452
453 explicit EffectiveOperands(SDValue N) : Size(N->getNumOperands()) {}
454 };
455
456 // === Binary operations ===
457 template <typename LHS_P, typename RHS_P, bool Commutable = false,
458 bool ExcludeChain = false>
459 struct BinaryOpc_match {
460 unsigned Opcode;
461 LHS_P LHS;
462 RHS_P RHS;
463
464 BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R)
465 : Opcode(Opc), LHS(L), RHS(R) {}
466
467 template <typename MatchContext>
468 bool match(const MatchContext &Ctx, SDValue N) {
469 if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
470 EffectiveOperands<ExcludeChain> EO(N);
471 assert(EO.Size == 2);
472 return (LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
473 RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
474 (Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
475 RHS.match(Ctx, N->getOperand(EO.FirstIndex)));
476 }
477
478 return false;
479 }
480 };
481
482 template <typename LHS, typename RHS>
483 inline BinaryOpc_match<LHS, RHS, false> m_BinOp(unsigned Opc, const LHS &L,
484 const RHS &R) {
485 return BinaryOpc_match<LHS, RHS, false>(Opc, L, R);
486 }
487 template <typename LHS, typename RHS>
488 inline BinaryOpc_match<LHS, RHS, true> m_c_BinOp(unsigned Opc, const LHS &L,
489 const RHS &R) {
490 return BinaryOpc_match<LHS, RHS, true>(Opc, L, R);
491 }
492
493 template <typename LHS, typename RHS>
494 inline BinaryOpc_match<LHS, RHS, false, true>
495 m_ChainedBinOp(unsigned Opc, const LHS &L, const RHS &R) {
496 return BinaryOpc_match<LHS, RHS, false, true>(Opc, L, R);
497 }
498 template <typename LHS, typename RHS>
499 inline BinaryOpc_match<LHS, RHS, true, true>
500 m_c_ChainedBinOp(unsigned Opc, const LHS &L, const RHS &R) {
501 return BinaryOpc_match<LHS, RHS, true, true>(Opc, L, R);
502 }
503
504 // Common binary operations
505 template <typename LHS, typename RHS>
506 inline BinaryOpc_match<LHS, RHS, true> m_Add(const LHS &L, const RHS &R) {
507 return BinaryOpc_match<LHS, RHS, true>(ISD::ADD, L, R);
508 }
509
510 template <typename LHS, typename RHS>
511 inline BinaryOpc_match<LHS, RHS, false> m_Sub(const LHS &L, const RHS &R) {
512 return BinaryOpc_match<LHS, RHS, false>(ISD::SUB, L, R);
513 }
514
515 template <typename LHS, typename RHS>
516 inline BinaryOpc_match<LHS, RHS, true> m_Mul(const LHS &L, const RHS &R) {
517 return BinaryOpc_match<LHS, RHS, true>(ISD::MUL, L, R);
518 }
519
520 template <typename LHS, typename RHS>
521 inline BinaryOpc_match<LHS, RHS, true> m_And(const LHS &L, const RHS &R) {
522 return BinaryOpc_match<LHS, RHS, true>(ISD::AND, L, R);
523 }
524
525 template <typename LHS, typename RHS>
526 inline BinaryOpc_match<LHS, RHS, true> m_Or(const LHS &L, const RHS &R) {
527 return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R);
528 }
529
530 template <typename LHS, typename RHS>
531 inline BinaryOpc_match<LHS, RHS, true> m_Xor(const LHS &L, const RHS &R) {
532 return BinaryOpc_match<LHS, RHS, true>(ISD::XOR, L, R);
533 }
534
535 template <typename LHS, typename RHS>
536 inline BinaryOpc_match<LHS, RHS, true> m_SMin(const LHS &L, const RHS &R) {
537 return BinaryOpc_match<LHS, RHS, true>(ISD::SMIN, L, R);
538 }
539
540 template <typename LHS, typename RHS>
541 inline BinaryOpc_match<LHS, RHS, true> m_SMax(const LHS &L, const RHS &R) {
542 return BinaryOpc_match<LHS, RHS, true>(ISD::SMAX, L, R);
543 }
544
545 template <typename LHS, typename RHS>
546 inline BinaryOpc_match<LHS, RHS, true> m_UMin(const LHS &L, const RHS &R) {
547 return BinaryOpc_match<LHS, RHS, true>(ISD::UMIN, L, R);
548 }
549
550 template <typename LHS, typename RHS>
551 inline BinaryOpc_match<LHS, RHS, true> m_UMax(const LHS &L, const RHS &R) {
552 return BinaryOpc_match<LHS, RHS, true>(ISD::UMAX, L, R);
553 }
554
555 template <typename LHS, typename RHS>
556 inline BinaryOpc_match<LHS, RHS, false> m_UDiv(const LHS &L, const RHS &R) {
557 return BinaryOpc_match<LHS, RHS, false>(ISD::UDIV, L, R);
558 }
559 template <typename LHS, typename RHS>
560 inline BinaryOpc_match<LHS, RHS, false> m_SDiv(const LHS &L, const RHS &R) {
561 return BinaryOpc_match<LHS, RHS, false>(ISD::SDIV, L, R);
562 }
563
564 template <typename LHS, typename RHS>
565 inline BinaryOpc_match<LHS, RHS, false> m_URem(const LHS &L, const RHS &R) {
566 return BinaryOpc_match<LHS, RHS, false>(ISD::UREM, L, R);
567 }
568 template <typename LHS, typename RHS>
569 inline BinaryOpc_match<LHS, RHS, false> m_SRem(const LHS &L, const RHS &R) {
570 return BinaryOpc_match<LHS, RHS, false>(ISD::SREM, L, R);
571 }
572
573 template <typename LHS, typename RHS>
574 inline BinaryOpc_match<LHS, RHS, false> m_Shl(const LHS &L, const RHS &R) {
575 return BinaryOpc_match<LHS, RHS, false>(ISD::SHL, L, R);
576 }
577
578 template <typename LHS, typename RHS>
579 inline BinaryOpc_match<LHS, RHS, false> m_Sra(const LHS &L, const RHS &R) {
580 return BinaryOpc_match<LHS, RHS, false>(ISD::SRA, L, R);
581 }
582 template <typename LHS, typename RHS>
583 inline BinaryOpc_match<LHS, RHS, false> m_Srl(const LHS &L, const RHS &R) {
584 return BinaryOpc_match<LHS, RHS, false>(ISD::SRL, L, R);
585 }
586
587 template <typename LHS, typename RHS>
588 inline BinaryOpc_match<LHS, RHS, true> m_FAdd(const LHS &L, const RHS &R) {
589 return BinaryOpc_match<LHS, RHS, true>(ISD::FADD, L, R);
590 }
591
592 template <typename LHS, typename RHS>
593 inline BinaryOpc_match<LHS, RHS, false> m_FSub(const LHS &L, const RHS &R) {
594 return BinaryOpc_match<LHS, RHS, false>(ISD::FSUB, L, R);
595 }
596
597 template <typename LHS, typename RHS>
598 inline BinaryOpc_match<LHS, RHS, true> m_FMul(const LHS &L, const RHS &R) {
599 return BinaryOpc_match<LHS, RHS, true>(ISD::FMUL, L, R);
600 }
601
602 template <typename LHS, typename RHS>
603 inline BinaryOpc_match<LHS, RHS, false> m_FDiv(const LHS &L, const RHS &R) {
604 return BinaryOpc_match<LHS, RHS, false>(ISD::FDIV, L, R);
605 }
606
607 template <typename LHS, typename RHS>
608 inline BinaryOpc_match<LHS, RHS, false> m_FRem(const LHS &L, const RHS &R) {
609 return BinaryOpc_match<LHS, RHS, false>(ISD::FREM, L, R);
610 }
611
612 // === Unary operations ===
613 template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
614 unsigned Opcode;
615 Opnd_P Opnd;
616
617 UnaryOpc_match(unsigned Opc, const Opnd_P &Op) : Opcode(Opc), Opnd(Op) {}
618
619 template <typename MatchContext>
620 bool match(const MatchContext &Ctx, SDValue N) {
621 if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
622 EffectiveOperands<ExcludeChain> EO(N);
623 assert(EO.Size == 1);
624 return Opnd.match(Ctx, N->getOperand(EO.FirstIndex));
625 }
626
627 return false;
628 }
629 };
630
631 template <typename Opnd>
632 inline UnaryOpc_match<Opnd> m_UnaryOp(unsigned Opc, const Opnd &Op) {
633 return UnaryOpc_match<Opnd>(Opc, Op);
634 }
635 template <typename Opnd>
636 inline UnaryOpc_match<Opnd, true> m_ChainedUnaryOp(unsigned Opc,
637 const Opnd &Op) {
638 return UnaryOpc_match<Opnd, true>(Opc, Op);
639 }
640
641 template <typename Opnd>
642 inline UnaryOpc_match<Opnd> m_BitReverse(const Opnd &Op) {
643 return UnaryOpc_match<Opnd>(ISD::BITREVERSE, Op);
644 }
645
646 template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) {
647 return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op);
648 }
649
650 template <typename Opnd> inline auto m_SExt(Opnd &&Op) {
651 return m_AnyOf(
652 UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op),
653 m_Node(ISD::SIGN_EXTEND_INREG, std::forward<Opnd>(Op), m_Value()));
654 }
655
656 template <typename Opnd> inline UnaryOpc_match<Opnd> m_AnyExt(const Opnd &Op) {
657 return UnaryOpc_match<Opnd>(ISD::ANY_EXTEND, Op);
658 }
659
660 template <typename Opnd> inline UnaryOpc_match<Opnd> m_Trunc(const Opnd &Op) {
661 return UnaryOpc_match<Opnd>(ISD::TRUNCATE, Op);
662 }
663
664 /// Match a zext or identity
665 /// Allows to peek through optional extensions
666 template <typename Opnd> inline auto m_ZExtOrSelf(Opnd &&Op) {
667 return m_AnyOf(m_ZExt(std::forward<Opnd>(Op)), std::forward<Opnd>(Op));
668 }
669
670 /// Match a sext or identity
671 /// Allows to peek through optional extensions
672 template <typename Opnd> inline auto m_SExtOrSelf(Opnd &&Op) {
673 return m_AnyOf(m_SExt(std::forward<Opnd>(Op)), std::forward<Opnd>(Op));
674 }
675
676 /// Match a aext or identity
677 /// Allows to peek through optional extensions
678 template <typename Opnd>
679 inline Or<UnaryOpc_match<Opnd>, Opnd> m_AExtOrSelf(Opnd &&Op) {
680 return Or<UnaryOpc_match<Opnd>, Opnd>(m_AnyExt(std::forward<Opnd>(Op)),
681 std::forward<Opnd>(Op));
682 }
683
684 /// Match a trunc or identity
685 /// Allows to peek through optional truncations
686 template <typename Opnd>
687 inline Or<UnaryOpc_match<Opnd>, Opnd> m_TruncOrSelf(Opnd &&Op) {
688 return Or<UnaryOpc_match<Opnd>, Opnd>(m_Trunc(std::forward<Opnd>(Op)),
689 std::forward<Opnd>(Op));
690 }
691
692 // === Constants ===
693 struct ConstantInt_match {
694 APInt *BindVal;
695
696 explicit ConstantInt_match(APInt *V) : BindVal(V) {}
697
698 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
699 // The logics here are similar to that in
700 // SelectionDAG::isConstantIntBuildVectorOrConstantInt, but the latter also
701 // treats GlobalAddressSDNode as a constant, which is difficult to turn into
702 // APInt.
703 if (auto *C = dyn_cast_or_null<ConstantSDNode>(N.getNode())) {
704 if (BindVal)
705 *BindVal = C->getAPIntValue();
706 return true;
707 }
708
709 APInt Discard;
710 return ISD::isConstantSplatVector(N.getNode(),
711 BindVal ? *BindVal : Discard);
712 }
713 };
714 /// Match any interger constants or splat of an integer constant.
715 inline ConstantInt_match m_ConstInt() { return ConstantInt_match(nullptr); }
716 /// Match any interger constants or splat of an integer constant; return the
717 /// specific constant or constant splat value.
718 inline ConstantInt_match m_ConstInt(APInt &V) { return ConstantInt_match(&V); }
719
720 struct SpecificInt_match {
721 APInt IntVal;
722
723 explicit SpecificInt_match(APInt APV) : IntVal(std::move(APV)) {}
724
725 template <typename MatchContext>
726 bool match(const MatchContext &Ctx, SDValue N) {
727 APInt ConstInt;
728 if (sd_context_match(N, Ctx, m_ConstInt(ConstInt)))
729 return APInt::isSameValue(IntVal, ConstInt);
730 return false;
731 }
732 };
733
734 /// Match a specific integer constant or constant splat value.
735 inline SpecificInt_match m_SpecificInt(APInt V) {
736 return SpecificInt_match(std::move(V));
737 }
738 inline SpecificInt_match m_SpecificInt(uint64_t V) {
739 return SpecificInt_match(APInt(64, V));
740 }
741
742 inline SpecificInt_match m_Zero() { return m_SpecificInt(0U); }
743 inline SpecificInt_match m_One() { return m_SpecificInt(1U); }
744
745 struct AllOnes_match {
746
747 AllOnes_match() = default;
748
749 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
750 return isAllOnesOrAllOnesSplat(N);
751 }
752 };
753
754 inline AllOnes_match m_AllOnes() { return AllOnes_match(); }
755
756 /// Match true boolean value based on the information provided by
757 /// TargetLowering.
758 inline auto m_True() {
759 return TLI_pred_match{
760 [](const TargetLowering &TLI, SDValue N) {
761 APInt ConstVal;
762 if (sd_match(N, m_ConstInt(ConstVal)))
763 switch (TLI.getBooleanContents(N.getValueType())) {
764 case TargetLowering::ZeroOrOneBooleanContent:
765 return ConstVal.isOne();
766 case TargetLowering::ZeroOrNegativeOneBooleanContent:
767 return ConstVal.isAllOnes();
768 case TargetLowering::UndefinedBooleanContent:
769 return (ConstVal & 0x01) == 1;
770 }
771
772 return false;
773 },
774 m_Value()};
775 }
776 /// Match false boolean value based on the information provided by
777 /// TargetLowering.
778 inline auto m_False() {
779 return TLI_pred_match{
780 [](const TargetLowering &TLI, SDValue N) {
781 APInt ConstVal;
782 if (sd_match(N, m_ConstInt(ConstVal)))
783 switch (TLI.getBooleanContents(N.getValueType())) {
784 case TargetLowering::ZeroOrOneBooleanContent:
785 case TargetLowering::ZeroOrNegativeOneBooleanContent:
786 return ConstVal.isZero();
787 case TargetLowering::UndefinedBooleanContent:
788 return (ConstVal & 0x01) == 0;
789 }
790
791 return false;
792 },
793 m_Value()};
794 }
795
796 struct CondCode_match {
797 std::optional<ISD::CondCode> CCToMatch;
798 ISD::CondCode *BindCC = nullptr;
799
800 explicit CondCode_match(ISD::CondCode CC) : CCToMatch(CC) {}
801
802 explicit CondCode_match(ISD::CondCode *CC) : BindCC(CC) {}
803
804 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
805 if (auto *CC = dyn_cast<CondCodeSDNode>(N.getNode())) {
806 if (CCToMatch && *CCToMatch != CC->get())
807 return false;
808
809 if (BindCC)
810 *BindCC = CC->get();
811 return true;
812 }
813
814 return false;
815 }
816 };
817
818 /// Match any conditional code SDNode.
819 inline CondCode_match m_CondCode() { return CondCode_match(nullptr); }
820 /// Match any conditional code SDNode and return its ISD::CondCode value.
821 inline CondCode_match m_CondCode(ISD::CondCode &CC) {
822 return CondCode_match(&CC);
823 }
824 /// Match a conditional code SDNode with a specific ISD::CondCode.
825 inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) {
826 return CondCode_match(CC);
827 }
828
829 /// Match a negate as a sub(0, v)
830 template <typename ValTy>
831 inline BinaryOpc_match<SpecificInt_match, ValTy> m_Neg(const ValTy &V) {
832 return m_Sub(m_Zero(), V);
833 }
834
835 /// Match a Not as a xor(v, -1) or xor(-1, v)
836 template <typename ValTy>
837 inline BinaryOpc_match<ValTy, AllOnes_match, true> m_Not(const ValTy &V) {
838 return m_Xor(V, m_AllOnes());
839 }
840
841 } // namespace SDPatternMatch
842 } // namespace llvm
843 #endif
844