1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ssa
6
7import (
8	"cmd/compile/internal/base"
9	"cmd/compile/internal/types"
10	"fmt"
11)
12
13type indVarFlags uint8
14
15const (
16	indVarMinExc    indVarFlags = 1 << iota // minimum value is exclusive (default: inclusive)
17	indVarMaxInc                            // maximum value is inclusive (default: exclusive)
18	indVarCountDown                         // if set the iteration starts at max and count towards min (default: min towards max)
19)
20
21type indVar struct {
22	ind   *Value // induction variable
23	nxt   *Value // the incremented variable
24	min   *Value // minimum value, inclusive/exclusive depends on flags
25	max   *Value // maximum value, inclusive/exclusive depends on flags
26	entry *Block // entry block in the loop.
27	flags indVarFlags
28	// Invariant: for all blocks strictly dominated by entry:
29	//	min <= ind <  max    [if flags == 0]
30	//	min <  ind <  max    [if flags == indVarMinExc]
31	//	min <= ind <= max    [if flags == indVarMaxInc]
32	//	min <  ind <= max    [if flags == indVarMinExc|indVarMaxInc]
33}
34
35// parseIndVar checks whether the SSA value passed as argument is a valid induction
36// variable, and, if so, extracts:
37//   - the minimum bound
38//   - the increment value
39//   - the "next" value (SSA value that is Phi'd into the induction variable every loop)
40//
41// Currently, we detect induction variables that match (Phi min nxt),
42// with nxt being (Add inc ind).
43// If it can't parse the induction variable correctly, it returns (nil, nil, nil).
44func parseIndVar(ind *Value) (min, inc, nxt *Value) {
45	if ind.Op != OpPhi {
46		return
47	}
48
49	if n := ind.Args[0]; (n.Op == OpAdd64 || n.Op == OpAdd32 || n.Op == OpAdd16 || n.Op == OpAdd8) && (n.Args[0] == ind || n.Args[1] == ind) {
50		min, nxt = ind.Args[1], n
51	} else if n := ind.Args[1]; (n.Op == OpAdd64 || n.Op == OpAdd32 || n.Op == OpAdd16 || n.Op == OpAdd8) && (n.Args[0] == ind || n.Args[1] == ind) {
52		min, nxt = ind.Args[0], n
53	} else {
54		// Not a recognized induction variable.
55		return
56	}
57
58	if nxt.Args[0] == ind { // nxt = ind + inc
59		inc = nxt.Args[1]
60	} else if nxt.Args[1] == ind { // nxt = inc + ind
61		inc = nxt.Args[0]
62	} else {
63		panic("unreachable") // one of the cases must be true from the above.
64	}
65
66	return
67}
68
69// findIndVar finds induction variables in a function.
70//
71// Look for variables and blocks that satisfy the following
72//
73//	 loop:
74//	   ind = (Phi min nxt),
75//	   if ind < max
76//	     then goto enter_loop
77//	     else goto exit_loop
78//
79//	   enter_loop:
80//		do something
81//	      nxt = inc + ind
82//		goto loop
83//
84//	 exit_loop:
85func findIndVar(f *Func) []indVar {
86	var iv []indVar
87	sdom := f.Sdom()
88
89	for _, b := range f.Blocks {
90		if b.Kind != BlockIf || len(b.Preds) != 2 {
91			continue
92		}
93
94		var ind *Value   // induction variable
95		var init *Value  // starting value
96		var limit *Value // ending value
97
98		// Check that the control if it either ind </<= limit or limit </<= ind.
99		// TODO: Handle unsigned comparisons?
100		c := b.Controls[0]
101		inclusive := false
102		switch c.Op {
103		case OpLeq64, OpLeq32, OpLeq16, OpLeq8:
104			inclusive = true
105			fallthrough
106		case OpLess64, OpLess32, OpLess16, OpLess8:
107			ind, limit = c.Args[0], c.Args[1]
108		default:
109			continue
110		}
111
112		// See if this is really an induction variable
113		less := true
114		init, inc, nxt := parseIndVar(ind)
115		if init == nil {
116			// We failed to parse the induction variable. Before punting, we want to check
117			// whether the control op was written with the induction variable on the RHS
118			// instead of the LHS. This happens for the downwards case, like:
119			//     for i := len(n)-1; i >= 0; i--
120			init, inc, nxt = parseIndVar(limit)
121			if init == nil {
122				// No recognized induction variable on either operand
123				continue
124			}
125
126			// Ok, the arguments were reversed. Swap them, and remember that we're
127			// looking at an ind >/>= loop (so the induction must be decrementing).
128			ind, limit = limit, ind
129			less = false
130		}
131
132		if ind.Block != b {
133			// TODO: Could be extended to include disjointed loop headers.
134			// I don't think this is causing missed optimizations in real world code often.
135			// See https://go.dev/issue/63955
136			continue
137		}
138
139		// Expect the increment to be a nonzero constant.
140		if !inc.isGenericIntConst() {
141			continue
142		}
143		step := inc.AuxInt
144		if step == 0 {
145			continue
146		}
147
148		// Increment sign must match comparison direction.
149		// When incrementing, the termination comparison must be ind </<= limit.
150		// When decrementing, the termination comparison must be ind >/>= limit.
151		// See issue 26116.
152		if step > 0 && !less {
153			continue
154		}
155		if step < 0 && less {
156			continue
157		}
158
159		// Up to now we extracted the induction variable (ind),
160		// the increment delta (inc), the temporary sum (nxt),
161		// the initial value (init) and the limiting value (limit).
162		//
163		// We also know that ind has the form (Phi init nxt) where
164		// nxt is (Add inc nxt) which means: 1) inc dominates nxt
165		// and 2) there is a loop starting at inc and containing nxt.
166		//
167		// We need to prove that the induction variable is incremented
168		// only when it's smaller than the limiting value.
169		// Two conditions must happen listed below to accept ind
170		// as an induction variable.
171
172		// First condition: loop entry has a single predecessor, which
173		// is the header block.  This implies that b.Succs[0] is
174		// reached iff ind < limit.
175		if len(b.Succs[0].b.Preds) != 1 {
176			// b.Succs[1] must exit the loop.
177			continue
178		}
179
180		// Second condition: b.Succs[0] dominates nxt so that
181		// nxt is computed when inc < limit.
182		if !sdom.IsAncestorEq(b.Succs[0].b, nxt.Block) {
183			// inc+ind can only be reached through the branch that enters the loop.
184			continue
185		}
186
187		// Check for overflow/underflow. We need to make sure that inc never causes
188		// the induction variable to wrap around.
189		// We use a function wrapper here for easy return true / return false / keep going logic.
190		// This function returns true if the increment will never overflow/underflow.
191		ok := func() bool {
192			if step > 0 {
193				if limit.isGenericIntConst() {
194					// Figure out the actual largest value.
195					v := limit.AuxInt
196					if !inclusive {
197						if v == minSignedValue(limit.Type) {
198							return false // < minint is never satisfiable.
199						}
200						v--
201					}
202					if init.isGenericIntConst() {
203						// Use stride to compute a better lower limit.
204						if init.AuxInt > v {
205							return false
206						}
207						v = addU(init.AuxInt, diff(v, init.AuxInt)/uint64(step)*uint64(step))
208					}
209					if addWillOverflow(v, step) {
210						return false
211					}
212					if inclusive && v != limit.AuxInt || !inclusive && v+1 != limit.AuxInt {
213						// We know a better limit than the programmer did. Use our limit instead.
214						limit = f.constVal(limit.Op, limit.Type, v, true)
215						inclusive = true
216					}
217					return true
218				}
219				if step == 1 && !inclusive {
220					// Can't overflow because maxint is never a possible value.
221					return true
222				}
223				// If the limit is not a constant, check to see if it is a
224				// negative offset from a known non-negative value.
225				knn, k := findKNN(limit)
226				if knn == nil || k < 0 {
227					return false
228				}
229				// limit == (something nonnegative) - k. That subtraction can't underflow, so
230				// we can trust it.
231				if inclusive {
232					// ind <= knn - k cannot overflow if step is at most k
233					return step <= k
234				}
235				// ind < knn - k cannot overflow if step is at most k+1
236				return step <= k+1 && k != maxSignedValue(limit.Type)
237			} else { // step < 0
238				if limit.Op == OpConst64 {
239					// Figure out the actual smallest value.
240					v := limit.AuxInt
241					if !inclusive {
242						if v == maxSignedValue(limit.Type) {
243							return false // > maxint is never satisfiable.
244						}
245						v++
246					}
247					if init.isGenericIntConst() {
248						// Use stride to compute a better lower limit.
249						if init.AuxInt < v {
250							return false
251						}
252						v = subU(init.AuxInt, diff(init.AuxInt, v)/uint64(-step)*uint64(-step))
253					}
254					if subWillUnderflow(v, -step) {
255						return false
256					}
257					if inclusive && v != limit.AuxInt || !inclusive && v-1 != limit.AuxInt {
258						// We know a better limit than the programmer did. Use our limit instead.
259						limit = f.constVal(limit.Op, limit.Type, v, true)
260						inclusive = true
261					}
262					return true
263				}
264				if step == -1 && !inclusive {
265					// Can't underflow because minint is never a possible value.
266					return true
267				}
268			}
269			return false
270
271		}
272
273		if ok() {
274			flags := indVarFlags(0)
275			var min, max *Value
276			if step > 0 {
277				min = init
278				max = limit
279				if inclusive {
280					flags |= indVarMaxInc
281				}
282			} else {
283				min = limit
284				max = init
285				flags |= indVarMaxInc
286				if !inclusive {
287					flags |= indVarMinExc
288				}
289				flags |= indVarCountDown
290				step = -step
291			}
292			if f.pass.debug >= 1 {
293				printIndVar(b, ind, min, max, step, flags)
294			}
295
296			iv = append(iv, indVar{
297				ind:   ind,
298				nxt:   nxt,
299				min:   min,
300				max:   max,
301				entry: b.Succs[0].b,
302				flags: flags,
303			})
304			b.Logf("found induction variable %v (inc = %v, min = %v, max = %v)\n", ind, inc, min, max)
305		}
306
307		// TODO: other unrolling idioms
308		// for i := 0; i < KNN - KNN % k ; i += k
309		// for i := 0; i < KNN&^(k-1) ; i += k // k a power of 2
310		// for i := 0; i < KNN&(-k) ; i += k // k a power of 2
311	}
312
313	return iv
314}
315
316// addWillOverflow reports whether x+y would result in a value more than maxint.
317func addWillOverflow(x, y int64) bool {
318	return x+y < x
319}
320
321// subWillUnderflow reports whether x-y would result in a value less than minint.
322func subWillUnderflow(x, y int64) bool {
323	return x-y > x
324}
325
326// diff returns x-y as a uint64. Requires x>=y.
327func diff(x, y int64) uint64 {
328	if x < y {
329		base.Fatalf("diff %d - %d underflowed", x, y)
330	}
331	return uint64(x - y)
332}
333
334// addU returns x+y. Requires that x+y does not overflow an int64.
335func addU(x int64, y uint64) int64 {
336	if y >= 1<<63 {
337		if x >= 0 {
338			base.Fatalf("addU overflowed %d + %d", x, y)
339		}
340		x += 1<<63 - 1
341		x += 1
342		y -= 1 << 63
343	}
344	if addWillOverflow(x, int64(y)) {
345		base.Fatalf("addU overflowed %d + %d", x, y)
346	}
347	return x + int64(y)
348}
349
350// subU returns x-y. Requires that x-y does not underflow an int64.
351func subU(x int64, y uint64) int64 {
352	if y >= 1<<63 {
353		if x < 0 {
354			base.Fatalf("subU underflowed %d - %d", x, y)
355		}
356		x -= 1<<63 - 1
357		x -= 1
358		y -= 1 << 63
359	}
360	if subWillUnderflow(x, int64(y)) {
361		base.Fatalf("subU underflowed %d - %d", x, y)
362	}
363	return x - int64(y)
364}
365
366// if v is known to be x - c, where x is known to be nonnegative and c is a
367// constant, return x, c. Otherwise return nil, 0.
368func findKNN(v *Value) (*Value, int64) {
369	var x, y *Value
370	x = v
371	switch v.Op {
372	case OpSub64, OpSub32, OpSub16, OpSub8:
373		x = v.Args[0]
374		y = v.Args[1]
375
376	case OpAdd64, OpAdd32, OpAdd16, OpAdd8:
377		x = v.Args[0]
378		y = v.Args[1]
379		if x.isGenericIntConst() {
380			x, y = y, x
381		}
382	}
383	switch x.Op {
384	case OpSliceLen, OpStringLen, OpSliceCap:
385	default:
386		return nil, 0
387	}
388	if y == nil {
389		return x, 0
390	}
391	if !y.isGenericIntConst() {
392		return nil, 0
393	}
394	if v.Op == OpAdd64 || v.Op == OpAdd32 || v.Op == OpAdd16 || v.Op == OpAdd8 {
395		return x, -y.AuxInt
396	}
397	return x, y.AuxInt
398}
399
400func printIndVar(b *Block, i, min, max *Value, inc int64, flags indVarFlags) {
401	mb1, mb2 := "[", "]"
402	if flags&indVarMinExc != 0 {
403		mb1 = "("
404	}
405	if flags&indVarMaxInc == 0 {
406		mb2 = ")"
407	}
408
409	mlim1, mlim2 := fmt.Sprint(min.AuxInt), fmt.Sprint(max.AuxInt)
410	if !min.isGenericIntConst() {
411		if b.Func.pass.debug >= 2 {
412			mlim1 = fmt.Sprint(min)
413		} else {
414			mlim1 = "?"
415		}
416	}
417	if !max.isGenericIntConst() {
418		if b.Func.pass.debug >= 2 {
419			mlim2 = fmt.Sprint(max)
420		} else {
421			mlim2 = "?"
422		}
423	}
424	extra := ""
425	if b.Func.pass.debug >= 2 {
426		extra = fmt.Sprintf(" (%s)", i)
427	}
428	b.Func.Warnl(b.Pos, "Induction variable: limits %v%v,%v%v, increment %d%s", mb1, mlim1, mlim2, mb2, inc, extra)
429}
430
431func minSignedValue(t *types.Type) int64 {
432	return -1 << (t.Size()*8 - 1)
433}
434
435func maxSignedValue(t *types.Type) int64 {
436	return 1<<((t.Size()*8)-1) - 1
437}
438