1// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package suffixarray
6
7import (
8	"bytes"
9	"fmt"
10	"io/fs"
11	"math/rand"
12	"os"
13	"path/filepath"
14	"regexp"
15	"slices"
16	"sort"
17	"strings"
18	"testing"
19)
20
21type testCase struct {
22	name     string   // name of test case
23	source   string   // source to index
24	patterns []string // patterns to lookup
25}
26
27var testCases = []testCase{
28	{
29		"empty string",
30		"",
31		[]string{
32			"",
33			"foo",
34			"(foo)",
35			".*",
36			"a*",
37		},
38	},
39
40	{
41		"all a's",
42		"aaaaaaaaaa", // 10 a's
43		[]string{
44			"",
45			"a",
46			"aa",
47			"aaa",
48			"aaaa",
49			"aaaaa",
50			"aaaaaa",
51			"aaaaaaa",
52			"aaaaaaaa",
53			"aaaaaaaaa",
54			"aaaaaaaaaa",
55			"aaaaaaaaaaa", // 11 a's
56			".",
57			".*",
58			"a+",
59			"aa+",
60			"aaaa[b]?",
61			"aaa*",
62		},
63	},
64
65	{
66		"abc",
67		"abc",
68		[]string{
69			"a",
70			"b",
71			"c",
72			"ab",
73			"bc",
74			"abc",
75			"a.c",
76			"a(b|c)",
77			"abc?",
78		},
79	},
80
81	{
82		"barbara*3",
83		"barbarabarbarabarbara",
84		[]string{
85			"a",
86			"bar",
87			"rab",
88			"arab",
89			"barbar",
90			"bara?bar",
91		},
92	},
93
94	{
95		"typing drill",
96		"Now is the time for all good men to come to the aid of their country.",
97		[]string{
98			"Now",
99			"the time",
100			"to come the aid",
101			"is the time for all good men to come to the aid of their",
102			"to (come|the)?",
103		},
104	},
105
106	{
107		"godoc simulation",
108		"package main\n\nimport(\n    \"rand\"\n    ",
109		[]string{},
110	},
111}
112
113// find all occurrences of s in source; report at most n occurrences
114func find(src, s string, n int) []int {
115	var res []int
116	if s != "" && n != 0 {
117		// find at most n occurrences of s in src
118		for i := -1; n < 0 || len(res) < n; {
119			j := strings.Index(src[i+1:], s)
120			if j < 0 {
121				break
122			}
123			i += j + 1
124			res = append(res, i)
125		}
126	}
127	return res
128}
129
130func testLookup(t *testing.T, tc *testCase, x *Index, s string, n int) {
131	res := x.Lookup([]byte(s), n)
132	exp := find(tc.source, s, n)
133
134	// check that the lengths match
135	if len(res) != len(exp) {
136		t.Errorf("test %q, lookup %q (n = %d): expected %d results; got %d", tc.name, s, n, len(exp), len(res))
137	}
138
139	// if n >= 0 the number of results is limited --- unless n >= all results,
140	// we may obtain different positions from the Index and from find (because
141	// Index may not find the results in the same order as find) => in general
142	// we cannot simply check that the res and exp lists are equal
143
144	// check that each result is in fact a correct match and there are no duplicates
145	slices.Sort(res)
146	for i, r := range res {
147		if r < 0 || len(tc.source) <= r {
148			t.Errorf("test %q, lookup %q, result %d (n = %d): index %d out of range [0, %d[", tc.name, s, i, n, r, len(tc.source))
149		} else if !strings.HasPrefix(tc.source[r:], s) {
150			t.Errorf("test %q, lookup %q, result %d (n = %d): index %d not a match", tc.name, s, i, n, r)
151		}
152		if i > 0 && res[i-1] == r {
153			t.Errorf("test %q, lookup %q, result %d (n = %d): found duplicate index %d", tc.name, s, i, n, r)
154		}
155	}
156
157	if n < 0 {
158		// all results computed - sorted res and exp must be equal
159		for i, r := range res {
160			e := exp[i]
161			if r != e {
162				t.Errorf("test %q, lookup %q, result %d: expected index %d; got %d", tc.name, s, i, e, r)
163			}
164		}
165	}
166}
167
168func testFindAllIndex(t *testing.T, tc *testCase, x *Index, rx *regexp.Regexp, n int) {
169	res := x.FindAllIndex(rx, n)
170	exp := rx.FindAllStringIndex(tc.source, n)
171
172	// check that the lengths match
173	if len(res) != len(exp) {
174		t.Errorf("test %q, FindAllIndex %q (n = %d): expected %d results; got %d", tc.name, rx, n, len(exp), len(res))
175	}
176
177	// if n >= 0 the number of results is limited --- unless n >= all results,
178	// we may obtain different positions from the Index and from regexp (because
179	// Index may not find the results in the same order as regexp) => in general
180	// we cannot simply check that the res and exp lists are equal
181
182	// check that each result is in fact a correct match and the result is sorted
183	for i, r := range res {
184		if r[0] < 0 || r[0] > r[1] || len(tc.source) < r[1] {
185			t.Errorf("test %q, FindAllIndex %q, result %d (n == %d): illegal match [%d, %d]", tc.name, rx, i, n, r[0], r[1])
186		} else if !rx.MatchString(tc.source[r[0]:r[1]]) {
187			t.Errorf("test %q, FindAllIndex %q, result %d (n = %d): [%d, %d] not a match", tc.name, rx, i, n, r[0], r[1])
188		}
189	}
190
191	if n < 0 {
192		// all results computed - sorted res and exp must be equal
193		for i, r := range res {
194			e := exp[i]
195			if r[0] != e[0] || r[1] != e[1] {
196				t.Errorf("test %q, FindAllIndex %q, result %d: expected match [%d, %d]; got [%d, %d]",
197					tc.name, rx, i, e[0], e[1], r[0], r[1])
198			}
199		}
200	}
201}
202
203func testLookups(t *testing.T, tc *testCase, x *Index, n int) {
204	for _, pat := range tc.patterns {
205		testLookup(t, tc, x, pat, n)
206		if rx, err := regexp.Compile(pat); err == nil {
207			testFindAllIndex(t, tc, x, rx, n)
208		}
209	}
210}
211
212// index is used to hide the sort.Interface
213type index Index
214
215func (x *index) Len() int           { return x.sa.len() }
216func (x *index) Less(i, j int) bool { return bytes.Compare(x.at(i), x.at(j)) < 0 }
217func (x *index) Swap(i, j int) {
218	if x.sa.int32 != nil {
219		x.sa.int32[i], x.sa.int32[j] = x.sa.int32[j], x.sa.int32[i]
220	} else {
221		x.sa.int64[i], x.sa.int64[j] = x.sa.int64[j], x.sa.int64[i]
222	}
223}
224
225func (x *index) at(i int) []byte {
226	return x.data[x.sa.get(i):]
227}
228
229func testConstruction(t *testing.T, tc *testCase, x *Index) {
230	if !sort.IsSorted((*index)(x)) {
231		t.Errorf("failed testConstruction %s", tc.name)
232	}
233}
234
235func equal(x, y *Index) bool {
236	if !bytes.Equal(x.data, y.data) {
237		return false
238	}
239	if x.sa.len() != y.sa.len() {
240		return false
241	}
242	n := x.sa.len()
243	for i := 0; i < n; i++ {
244		if x.sa.get(i) != y.sa.get(i) {
245			return false
246		}
247	}
248	return true
249}
250
251// returns the serialized index size
252func testSaveRestore(t *testing.T, tc *testCase, x *Index) int {
253	var buf bytes.Buffer
254	if err := x.Write(&buf); err != nil {
255		t.Errorf("failed writing index %s (%s)", tc.name, err)
256	}
257	size := buf.Len()
258	var y Index
259	if err := y.Read(bytes.NewReader(buf.Bytes())); err != nil {
260		t.Errorf("failed reading index %s (%s)", tc.name, err)
261	}
262	if !equal(x, &y) {
263		t.Errorf("restored index doesn't match saved index %s", tc.name)
264	}
265
266	old := maxData32
267	defer func() {
268		maxData32 = old
269	}()
270	// Reread as forced 32.
271	y = Index{}
272	maxData32 = realMaxData32
273	if err := y.Read(bytes.NewReader(buf.Bytes())); err != nil {
274		t.Errorf("failed reading index %s (%s)", tc.name, err)
275	}
276	if !equal(x, &y) {
277		t.Errorf("restored index doesn't match saved index %s", tc.name)
278	}
279
280	// Reread as forced 64.
281	y = Index{}
282	maxData32 = -1
283	if err := y.Read(bytes.NewReader(buf.Bytes())); err != nil {
284		t.Errorf("failed reading index %s (%s)", tc.name, err)
285	}
286	if !equal(x, &y) {
287		t.Errorf("restored index doesn't match saved index %s", tc.name)
288	}
289
290	return size
291}
292
293func testIndex(t *testing.T) {
294	for _, tc := range testCases {
295		x := New([]byte(tc.source))
296		testConstruction(t, &tc, x)
297		testSaveRestore(t, &tc, x)
298		testLookups(t, &tc, x, 0)
299		testLookups(t, &tc, x, 1)
300		testLookups(t, &tc, x, 10)
301		testLookups(t, &tc, x, 2e9)
302		testLookups(t, &tc, x, -1)
303	}
304}
305
306func TestIndex32(t *testing.T) {
307	testIndex(t)
308}
309
310func TestIndex64(t *testing.T) {
311	maxData32 = -1
312	defer func() {
313		maxData32 = realMaxData32
314	}()
315	testIndex(t)
316}
317
318func TestNew32(t *testing.T) {
319	test(t, func(x []byte) []int {
320		sa := make([]int32, len(x))
321		text_32(x, sa)
322		out := make([]int, len(sa))
323		for i, v := range sa {
324			out[i] = int(v)
325		}
326		return out
327	})
328}
329
330func TestNew64(t *testing.T) {
331	test(t, func(x []byte) []int {
332		sa := make([]int64, len(x))
333		text_64(x, sa)
334		out := make([]int, len(sa))
335		for i, v := range sa {
336			out[i] = int(v)
337		}
338		return out
339	})
340}
341
342// test tests an arbitrary suffix array construction function.
343// Generates many inputs, builds and checks suffix arrays.
344func test(t *testing.T, build func([]byte) []int) {
345	t.Run("ababab...", func(t *testing.T) {
346		// Very repetitive input has numLMS = len(x)/2-1
347		// at top level, the largest it can be.
348		// But maxID is only two (aba and ab$).
349		size := 100000
350		if testing.Short() {
351			size = 10000
352		}
353		x := make([]byte, size)
354		for i := range x {
355			x[i] = "ab"[i%2]
356		}
357		testSA(t, x, build)
358	})
359
360	t.Run("forcealloc", func(t *testing.T) {
361		// Construct a pathological input that forces
362		// recurse_32 to allocate a new temporary buffer.
363		// The input must have more than N/3 LMS-substrings,
364		// which we arrange by repeating an SLSLSLSLSLSL pattern
365		// like ababab... above, but then we must also arrange
366		// for a large number of distinct LMS-substrings.
367		// We use this pattern:
368		// 1 255 1 254 1 253 1 ... 1 2 1 255 2 254 2 253 2 252 2 ...
369		// This gives approximately 2¹⁵ distinct LMS-substrings.
370		// We need to repeat at least one substring, though,
371		// or else the recursion can be bypassed entirely.
372		x := make([]byte, 100000, 100001)
373		lo := byte(1)
374		hi := byte(255)
375		for i := range x {
376			if i%2 == 0 {
377				x[i] = lo
378			} else {
379				x[i] = hi
380				hi--
381				if hi <= lo {
382					lo++
383					if lo == 0 {
384						lo = 1
385					}
386					hi = 255
387				}
388			}
389		}
390		x[:cap(x)][len(x)] = 0 // for sais.New
391		testSA(t, x, build)
392	})
393
394	t.Run("exhaustive2", func(t *testing.T) {
395		// All inputs over {0,1} up to length 21.
396		// Runs in about 10 seconds on my laptop.
397		x := make([]byte, 30)
398		numFail := 0
399		for n := 0; n <= 21; n++ {
400			if n > 12 && testing.Short() {
401				break
402			}
403			x[n] = 0 // for sais.New
404			testRec(t, x[:n], 0, 2, &numFail, build)
405		}
406	})
407
408	t.Run("exhaustive3", func(t *testing.T) {
409		// All inputs over {0,1,2} up to length 14.
410		// Runs in about 10 seconds on my laptop.
411		x := make([]byte, 30)
412		numFail := 0
413		for n := 0; n <= 14; n++ {
414			if n > 8 && testing.Short() {
415				break
416			}
417			x[n] = 0 // for sais.New
418			testRec(t, x[:n], 0, 3, &numFail, build)
419		}
420	})
421}
422
423// testRec fills x[i:] with all possible combinations of values in [1,max]
424// and then calls testSA(t, x, build) for each one.
425func testRec(t *testing.T, x []byte, i, max int, numFail *int, build func([]byte) []int) {
426	if i < len(x) {
427		for x[i] = 1; x[i] <= byte(max); x[i]++ {
428			testRec(t, x, i+1, max, numFail, build)
429		}
430		return
431	}
432
433	if !testSA(t, x, build) {
434		*numFail++
435		if *numFail >= 10 {
436			t.Errorf("stopping after %d failures", *numFail)
437			t.FailNow()
438		}
439	}
440}
441
442// testSA tests the suffix array build function on the input x.
443// It constructs the suffix array and then checks that it is correct.
444func testSA(t *testing.T, x []byte, build func([]byte) []int) bool {
445	defer func() {
446		if e := recover(); e != nil {
447			t.Logf("build %v", x)
448			panic(e)
449		}
450	}()
451	sa := build(x)
452	if len(sa) != len(x) {
453		t.Errorf("build %v: len(sa) = %d, want %d", x, len(sa), len(x))
454		return false
455	}
456	for i := 0; i+1 < len(sa); i++ {
457		if sa[i] < 0 || sa[i] >= len(x) || sa[i+1] < 0 || sa[i+1] >= len(x) {
458			t.Errorf("build %s: sa out of range: %v\n", x, sa)
459			return false
460		}
461		if bytes.Compare(x[sa[i]:], x[sa[i+1]:]) >= 0 {
462			t.Errorf("build %v -> %v\nsa[%d:] = %d,%d out of order", x, sa, i, sa[i], sa[i+1])
463			return false
464		}
465	}
466
467	return true
468}
469
470var (
471	benchdata = make([]byte, 1e6)
472	benchrand = make([]byte, 1e6)
473)
474
475// Of all possible inputs, the random bytes have the least amount of substring
476// repetition, and the repeated bytes have the most. For most algorithms,
477// the running time of every input will be between these two.
478func benchmarkNew(b *testing.B, random bool) {
479	b.ReportAllocs()
480	b.StopTimer()
481	data := benchdata
482	if random {
483		data = benchrand
484		if data[0] == 0 {
485			for i := range data {
486				data[i] = byte(rand.Intn(256))
487			}
488		}
489	}
490	b.StartTimer()
491	b.SetBytes(int64(len(data)))
492	for i := 0; i < b.N; i++ {
493		New(data)
494	}
495}
496
497func makeText(name string) ([]byte, error) {
498	var data []byte
499	switch name {
500	case "opticks":
501		var err error
502		data, err = os.ReadFile("../../testdata/Isaac.Newton-Opticks.txt")
503		if err != nil {
504			return nil, err
505		}
506	case "go":
507		err := filepath.WalkDir("../..", func(path string, info fs.DirEntry, err error) error {
508			if err == nil && strings.HasSuffix(path, ".go") && !info.IsDir() {
509				file, err := os.ReadFile(path)
510				if err != nil {
511					return err
512				}
513				data = append(data, file...)
514			}
515			return nil
516		})
517		if err != nil {
518			return nil, err
519		}
520	case "zero":
521		data = make([]byte, 50e6)
522	case "rand":
523		data = make([]byte, 50e6)
524		for i := range data {
525			data[i] = byte(rand.Intn(256))
526		}
527	}
528	return data, nil
529}
530
531func setBits(bits int) (cleanup func()) {
532	if bits == 32 {
533		maxData32 = realMaxData32
534	} else {
535		maxData32 = -1 // force use of 64-bit code
536	}
537	return func() {
538		maxData32 = realMaxData32
539	}
540}
541
542func BenchmarkNew(b *testing.B) {
543	for _, text := range []string{"opticks", "go", "zero", "rand"} {
544		b.Run("text="+text, func(b *testing.B) {
545			data, err := makeText(text)
546			if err != nil {
547				b.Fatal(err)
548			}
549			if testing.Short() && len(data) > 5e6 {
550				data = data[:5e6]
551			}
552			for _, size := range []int{100e3, 500e3, 1e6, 5e6, 10e6, 50e6} {
553				if len(data) < size {
554					continue
555				}
556				data := data[:size]
557				name := fmt.Sprintf("%dK", size/1e3)
558				if size >= 1e6 {
559					name = fmt.Sprintf("%dM", size/1e6)
560				}
561				b.Run("size="+name, func(b *testing.B) {
562					for _, bits := range []int{32, 64} {
563						if ^uint(0) == 0xffffffff && bits == 64 {
564							continue
565						}
566						b.Run(fmt.Sprintf("bits=%d", bits), func(b *testing.B) {
567							cleanup := setBits(bits)
568							defer cleanup()
569
570							b.SetBytes(int64(len(data)))
571							b.ReportAllocs()
572							for i := 0; i < b.N; i++ {
573								New(data)
574							}
575						})
576					}
577				})
578			}
579		})
580	}
581}
582
583func BenchmarkSaveRestore(b *testing.B) {
584	r := rand.New(rand.NewSource(0x5a77a1)) // guarantee always same sequence
585	data := make([]byte, 1<<20)             // 1MB of data to index
586	for i := range data {
587		data[i] = byte(r.Intn(256))
588	}
589	for _, bits := range []int{32, 64} {
590		if ^uint(0) == 0xffffffff && bits == 64 {
591			continue
592		}
593		b.Run(fmt.Sprintf("bits=%d", bits), func(b *testing.B) {
594			cleanup := setBits(bits)
595			defer cleanup()
596
597			b.StopTimer()
598			x := New(data)
599			size := testSaveRestore(nil, nil, x)       // verify correctness
600			buf := bytes.NewBuffer(make([]byte, size)) // avoid growing
601			b.SetBytes(int64(size))
602			b.StartTimer()
603			b.ReportAllocs()
604			for i := 0; i < b.N; i++ {
605				buf.Reset()
606				if err := x.Write(buf); err != nil {
607					b.Fatal(err)
608				}
609				var y Index
610				if err := y.Read(buf); err != nil {
611					b.Fatal(err)
612				}
613			}
614		})
615	}
616}
617