1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package errors_test
6
7import (
8	"errors"
9	"fmt"
10	"io/fs"
11	"os"
12	"reflect"
13	"testing"
14)
15
16func TestIs(t *testing.T) {
17	err1 := errors.New("1")
18	erra := wrapped{"wrap 2", err1}
19	errb := wrapped{"wrap 3", erra}
20
21	err3 := errors.New("3")
22
23	poser := &poser{"either 1 or 3", func(err error) bool {
24		return err == err1 || err == err3
25	}}
26
27	testCases := []struct {
28		err    error
29		target error
30		match  bool
31	}{
32		{nil, nil, true},
33		{nil, err1, false},
34		{err1, nil, false},
35		{err1, err1, true},
36		{erra, err1, true},
37		{errb, err1, true},
38		{err1, err3, false},
39		{erra, err3, false},
40		{errb, err3, false},
41		{poser, err1, true},
42		{poser, err3, true},
43		{poser, erra, false},
44		{poser, errb, false},
45		{errorUncomparable{}, errorUncomparable{}, true},
46		{errorUncomparable{}, &errorUncomparable{}, false},
47		{&errorUncomparable{}, errorUncomparable{}, true},
48		{&errorUncomparable{}, &errorUncomparable{}, false},
49		{errorUncomparable{}, err1, false},
50		{&errorUncomparable{}, err1, false},
51		{multiErr{}, err1, false},
52		{multiErr{err1, err3}, err1, true},
53		{multiErr{err3, err1}, err1, true},
54		{multiErr{err1, err3}, errors.New("x"), false},
55		{multiErr{err3, errb}, errb, true},
56		{multiErr{err3, errb}, erra, true},
57		{multiErr{err3, errb}, err1, true},
58		{multiErr{errb, err3}, err1, true},
59		{multiErr{poser}, err1, true},
60		{multiErr{poser}, err3, true},
61		{multiErr{nil}, nil, false},
62	}
63	for _, tc := range testCases {
64		t.Run("", func(t *testing.T) {
65			if got := errors.Is(tc.err, tc.target); got != tc.match {
66				t.Errorf("Is(%v, %v) = %v, want %v", tc.err, tc.target, got, tc.match)
67			}
68		})
69	}
70}
71
72type poser struct {
73	msg string
74	f   func(error) bool
75}
76
77var poserPathErr = &fs.PathError{Op: "poser"}
78
79func (p *poser) Error() string     { return p.msg }
80func (p *poser) Is(err error) bool { return p.f(err) }
81func (p *poser) As(err any) bool {
82	switch x := err.(type) {
83	case **poser:
84		*x = p
85	case *errorT:
86		*x = errorT{"poser"}
87	case **fs.PathError:
88		*x = poserPathErr
89	default:
90		return false
91	}
92	return true
93}
94
95func TestAs(t *testing.T) {
96	var errT errorT
97	var errP *fs.PathError
98	var timeout interface{ Timeout() bool }
99	var p *poser
100	_, errF := os.Open("non-existing")
101	poserErr := &poser{"oh no", nil}
102
103	testCases := []struct {
104		err    error
105		target any
106		match  bool
107		want   any // value of target on match
108	}{{
109		nil,
110		&errP,
111		false,
112		nil,
113	}, {
114		wrapped{"pitied the fool", errorT{"T"}},
115		&errT,
116		true,
117		errorT{"T"},
118	}, {
119		errF,
120		&errP,
121		true,
122		errF,
123	}, {
124		errorT{},
125		&errP,
126		false,
127		nil,
128	}, {
129		wrapped{"wrapped", nil},
130		&errT,
131		false,
132		nil,
133	}, {
134		&poser{"error", nil},
135		&errT,
136		true,
137		errorT{"poser"},
138	}, {
139		&poser{"path", nil},
140		&errP,
141		true,
142		poserPathErr,
143	}, {
144		poserErr,
145		&p,
146		true,
147		poserErr,
148	}, {
149		errors.New("err"),
150		&timeout,
151		false,
152		nil,
153	}, {
154		errF,
155		&timeout,
156		true,
157		errF,
158	}, {
159		wrapped{"path error", errF},
160		&timeout,
161		true,
162		errF,
163	}, {
164		multiErr{},
165		&errT,
166		false,
167		nil,
168	}, {
169		multiErr{errors.New("a"), errorT{"T"}},
170		&errT,
171		true,
172		errorT{"T"},
173	}, {
174		multiErr{errorT{"T"}, errors.New("a")},
175		&errT,
176		true,
177		errorT{"T"},
178	}, {
179		multiErr{errorT{"a"}, errorT{"b"}},
180		&errT,
181		true,
182		errorT{"a"},
183	}, {
184		multiErr{multiErr{errors.New("a"), errorT{"a"}}, errorT{"b"}},
185		&errT,
186		true,
187		errorT{"a"},
188	}, {
189		multiErr{wrapped{"path error", errF}},
190		&timeout,
191		true,
192		errF,
193	}, {
194		multiErr{nil},
195		&errT,
196		false,
197		nil,
198	}}
199	for i, tc := range testCases {
200		name := fmt.Sprintf("%d:As(Errorf(..., %v), %v)", i, tc.err, tc.target)
201		// Clear the target pointer, in case it was set in a previous test.
202		rtarget := reflect.ValueOf(tc.target)
203		rtarget.Elem().Set(reflect.Zero(reflect.TypeOf(tc.target).Elem()))
204		t.Run(name, func(t *testing.T) {
205			match := errors.As(tc.err, tc.target)
206			if match != tc.match {
207				t.Fatalf("match: got %v; want %v", match, tc.match)
208			}
209			if !match {
210				return
211			}
212			if got := rtarget.Elem().Interface(); got != tc.want {
213				t.Fatalf("got %#v, want %#v", got, tc.want)
214			}
215		})
216	}
217}
218
219func TestAsValidation(t *testing.T) {
220	var s string
221	testCases := []any{
222		nil,
223		(*int)(nil),
224		"error",
225		&s,
226	}
227	err := errors.New("error")
228	for _, tc := range testCases {
229		t.Run(fmt.Sprintf("%T(%v)", tc, tc), func(t *testing.T) {
230			defer func() {
231				recover()
232			}()
233			if errors.As(err, tc) {
234				t.Errorf("As(err, %T(%v)) = true, want false", tc, tc)
235				return
236			}
237			t.Errorf("As(err, %T(%v)) did not panic", tc, tc)
238		})
239	}
240}
241
242func BenchmarkIs(b *testing.B) {
243	err1 := errors.New("1")
244	err2 := multiErr{multiErr{multiErr{err1, errorT{"a"}}, errorT{"b"}}}
245
246	for i := 0; i < b.N; i++ {
247		if !errors.Is(err2, err1) {
248			b.Fatal("Is failed")
249		}
250	}
251}
252
253func BenchmarkAs(b *testing.B) {
254	err := multiErr{multiErr{multiErr{errors.New("a"), errorT{"a"}}, errorT{"b"}}}
255	for i := 0; i < b.N; i++ {
256		var target errorT
257		if !errors.As(err, &target) {
258			b.Fatal("As failed")
259		}
260	}
261}
262
263func TestUnwrap(t *testing.T) {
264	err1 := errors.New("1")
265	erra := wrapped{"wrap 2", err1}
266
267	testCases := []struct {
268		err  error
269		want error
270	}{
271		{nil, nil},
272		{wrapped{"wrapped", nil}, nil},
273		{err1, nil},
274		{erra, err1},
275		{wrapped{"wrap 3", erra}, erra},
276	}
277	for _, tc := range testCases {
278		if got := errors.Unwrap(tc.err); got != tc.want {
279			t.Errorf("Unwrap(%v) = %v, want %v", tc.err, got, tc.want)
280		}
281	}
282}
283
284type errorT struct{ s string }
285
286func (e errorT) Error() string { return fmt.Sprintf("errorT(%s)", e.s) }
287
288type wrapped struct {
289	msg string
290	err error
291}
292
293func (e wrapped) Error() string { return e.msg }
294func (e wrapped) Unwrap() error { return e.err }
295
296type multiErr []error
297
298func (m multiErr) Error() string   { return "multiError" }
299func (m multiErr) Unwrap() []error { return []error(m) }
300
301type errorUncomparable struct {
302	f []string
303}
304
305func (errorUncomparable) Error() string {
306	return "uncomparable error"
307}
308
309func (errorUncomparable) Is(target error) bool {
310	_, ok := target.(errorUncomparable)
311	return ok
312}
313