xref: /aosp_15_r20/external/pytorch/test/export/testing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport functools
2*da0073e9SAndroid Build Coastguard Workerimport unittest
3*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workeraten = torch.ops.aten
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker# This list is not meant to be comprehensive
11*da0073e9SAndroid Build Coastguard Worker_COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY = [
12*da0073e9SAndroid Build Coastguard Worker    aten.arctan2.default,
13*da0073e9SAndroid Build Coastguard Worker    aten.divide.Tensor,
14*da0073e9SAndroid Build Coastguard Worker    aten.divide.Scalar,
15*da0073e9SAndroid Build Coastguard Worker    aten.divide.Tensor_mode,
16*da0073e9SAndroid Build Coastguard Worker    aten.divide.Scalar_mode,
17*da0073e9SAndroid Build Coastguard Worker    aten.multiply.Tensor,
18*da0073e9SAndroid Build Coastguard Worker    aten.multiply.Scalar,
19*da0073e9SAndroid Build Coastguard Worker    aten.subtract.Tensor,
20*da0073e9SAndroid Build Coastguard Worker    aten.subtract.Scalar,
21*da0073e9SAndroid Build Coastguard Worker    aten.true_divide.Tensor,
22*da0073e9SAndroid Build Coastguard Worker    aten.true_divide.Scalar,
23*da0073e9SAndroid Build Coastguard Worker    aten.greater.Tensor,
24*da0073e9SAndroid Build Coastguard Worker    aten.greater.Scalar,
25*da0073e9SAndroid Build Coastguard Worker    aten.greater_equal.Tensor,
26*da0073e9SAndroid Build Coastguard Worker    aten.greater_equal.Scalar,
27*da0073e9SAndroid Build Coastguard Worker    aten.less_equal.Tensor,
28*da0073e9SAndroid Build Coastguard Worker    aten.less_equal.Scalar,
29*da0073e9SAndroid Build Coastguard Worker    aten.less.Tensor,
30*da0073e9SAndroid Build Coastguard Worker    aten.less.Scalar,
31*da0073e9SAndroid Build Coastguard Worker    aten.not_equal.Tensor,
32*da0073e9SAndroid Build Coastguard Worker    aten.not_equal.Scalar,
33*da0073e9SAndroid Build Coastguard Worker    aten.cat.names,
34*da0073e9SAndroid Build Coastguard Worker    aten.sum.dim_DimnameList,
35*da0073e9SAndroid Build Coastguard Worker    aten.mean.names_dim,
36*da0073e9SAndroid Build Coastguard Worker    aten.prod.dim_Dimname,
37*da0073e9SAndroid Build Coastguard Worker    aten.all.dimname,
38*da0073e9SAndroid Build Coastguard Worker    aten.norm.names_ScalarOpt_dim,
39*da0073e9SAndroid Build Coastguard Worker    aten.norm.names_ScalarOpt_dim_dtype,
40*da0073e9SAndroid Build Coastguard Worker    aten.var.default,
41*da0073e9SAndroid Build Coastguard Worker    aten.var.dim,
42*da0073e9SAndroid Build Coastguard Worker    aten.var.names_dim,
43*da0073e9SAndroid Build Coastguard Worker    aten.var.correction_names,
44*da0073e9SAndroid Build Coastguard Worker    aten.std.default,
45*da0073e9SAndroid Build Coastguard Worker    aten.std.dim,
46*da0073e9SAndroid Build Coastguard Worker    aten.std.names_dim,
47*da0073e9SAndroid Build Coastguard Worker    aten.std.correction_names,
48*da0073e9SAndroid Build Coastguard Worker    aten.absolute.default,
49*da0073e9SAndroid Build Coastguard Worker    aten.arccos.default,
50*da0073e9SAndroid Build Coastguard Worker    aten.arccosh.default,
51*da0073e9SAndroid Build Coastguard Worker    aten.arcsin.default,
52*da0073e9SAndroid Build Coastguard Worker    aten.arcsinh.default,
53*da0073e9SAndroid Build Coastguard Worker    aten.arctan.default,
54*da0073e9SAndroid Build Coastguard Worker    aten.arctanh.default,
55*da0073e9SAndroid Build Coastguard Worker    aten.clip.default,
56*da0073e9SAndroid Build Coastguard Worker    aten.clip.Tensor,
57*da0073e9SAndroid Build Coastguard Worker    aten.fix.default,
58*da0073e9SAndroid Build Coastguard Worker    aten.negative.default,
59*da0073e9SAndroid Build Coastguard Worker    aten.square.default,
60*da0073e9SAndroid Build Coastguard Worker    aten.size.int,
61*da0073e9SAndroid Build Coastguard Worker    aten.size.Dimname,
62*da0073e9SAndroid Build Coastguard Worker    aten.stride.int,
63*da0073e9SAndroid Build Coastguard Worker    aten.stride.Dimname,
64*da0073e9SAndroid Build Coastguard Worker    aten.repeat_interleave.self_Tensor,
65*da0073e9SAndroid Build Coastguard Worker    aten.repeat_interleave.self_int,
66*da0073e9SAndroid Build Coastguard Worker    aten.sym_size.int,
67*da0073e9SAndroid Build Coastguard Worker    aten.sym_stride.int,
68*da0073e9SAndroid Build Coastguard Worker    aten.atleast_1d.Sequence,
69*da0073e9SAndroid Build Coastguard Worker    aten.atleast_2d.Sequence,
70*da0073e9SAndroid Build Coastguard Worker    aten.atleast_3d.Sequence,
71*da0073e9SAndroid Build Coastguard Worker    aten.linear.default,
72*da0073e9SAndroid Build Coastguard Worker    aten.conv2d.default,
73*da0073e9SAndroid Build Coastguard Worker    aten.conv2d.padding,
74*da0073e9SAndroid Build Coastguard Worker    aten.mish_backward.default,
75*da0073e9SAndroid Build Coastguard Worker    aten.silu_backward.default,
76*da0073e9SAndroid Build Coastguard Worker    aten.index_add.dimname,
77*da0073e9SAndroid Build Coastguard Worker    aten.pad_sequence.default,
78*da0073e9SAndroid Build Coastguard Worker    aten.index_copy.dimname,
79*da0073e9SAndroid Build Coastguard Worker    aten.upsample_nearest1d.vec,
80*da0073e9SAndroid Build Coastguard Worker    aten.upsample_nearest2d.vec,
81*da0073e9SAndroid Build Coastguard Worker    aten.upsample_nearest3d.vec,
82*da0073e9SAndroid Build Coastguard Worker    aten._upsample_nearest_exact1d.vec,
83*da0073e9SAndroid Build Coastguard Worker    aten._upsample_nearest_exact2d.vec,
84*da0073e9SAndroid Build Coastguard Worker    aten._upsample_nearest_exact3d.vec,
85*da0073e9SAndroid Build Coastguard Worker    aten.rnn_tanh.input,
86*da0073e9SAndroid Build Coastguard Worker    aten.rnn_tanh.data,
87*da0073e9SAndroid Build Coastguard Worker    aten.rnn_relu.input,
88*da0073e9SAndroid Build Coastguard Worker    aten.rnn_relu.data,
89*da0073e9SAndroid Build Coastguard Worker    aten.lstm.input,
90*da0073e9SAndroid Build Coastguard Worker    aten.lstm.data,
91*da0073e9SAndroid Build Coastguard Worker    aten.gru.input,
92*da0073e9SAndroid Build Coastguard Worker    aten.gru.data,
93*da0073e9SAndroid Build Coastguard Worker    aten._upsample_bilinear2d_aa.vec,
94*da0073e9SAndroid Build Coastguard Worker    aten._upsample_bicubic2d_aa.vec,
95*da0073e9SAndroid Build Coastguard Worker    aten.upsample_bilinear2d.vec,
96*da0073e9SAndroid Build Coastguard Worker    aten.upsample_trilinear3d.vec,
97*da0073e9SAndroid Build Coastguard Worker    aten.upsample_linear1d.vec,
98*da0073e9SAndroid Build Coastguard Worker    aten.matmul.default,
99*da0073e9SAndroid Build Coastguard Worker    aten.upsample_bicubic2d.vec,
100*da0073e9SAndroid Build Coastguard Worker    aten.__and__.Scalar,
101*da0073e9SAndroid Build Coastguard Worker    aten.__and__.Tensor,
102*da0073e9SAndroid Build Coastguard Worker    aten.__or__.Tensor,
103*da0073e9SAndroid Build Coastguard Worker    aten.__or__.Scalar,
104*da0073e9SAndroid Build Coastguard Worker    aten.__xor__.Tensor,
105*da0073e9SAndroid Build Coastguard Worker    aten.__xor__.Scalar,
106*da0073e9SAndroid Build Coastguard Worker    aten.scatter.dimname_src,
107*da0073e9SAndroid Build Coastguard Worker    aten.scatter.dimname_value,
108*da0073e9SAndroid Build Coastguard Worker    aten.scatter_add.dimname,
109*da0073e9SAndroid Build Coastguard Worker    aten.is_complex.default,
110*da0073e9SAndroid Build Coastguard Worker    aten.logsumexp.names,
111*da0073e9SAndroid Build Coastguard Worker    aten.where.ScalarOther,
112*da0073e9SAndroid Build Coastguard Worker    aten.where.ScalarSelf,
113*da0073e9SAndroid Build Coastguard Worker    aten.where.Scalar,
114*da0073e9SAndroid Build Coastguard Worker    aten.where.default,
115*da0073e9SAndroid Build Coastguard Worker    aten.item.default,
116*da0073e9SAndroid Build Coastguard Worker    aten.any.dimname,
117*da0073e9SAndroid Build Coastguard Worker    aten.std_mean.default,
118*da0073e9SAndroid Build Coastguard Worker    aten.std_mean.dim,
119*da0073e9SAndroid Build Coastguard Worker    aten.std_mean.names_dim,
120*da0073e9SAndroid Build Coastguard Worker    aten.std_mean.correction_names,
121*da0073e9SAndroid Build Coastguard Worker    aten.var_mean.default,
122*da0073e9SAndroid Build Coastguard Worker    aten.var_mean.dim,
123*da0073e9SAndroid Build Coastguard Worker    aten.var_mean.names_dim,
124*da0073e9SAndroid Build Coastguard Worker    aten.var_mean.correction_names,
125*da0073e9SAndroid Build Coastguard Worker    aten.broadcast_tensors.default,
126*da0073e9SAndroid Build Coastguard Worker    aten.stft.default,
127*da0073e9SAndroid Build Coastguard Worker    aten.stft.center,
128*da0073e9SAndroid Build Coastguard Worker    aten.istft.default,
129*da0073e9SAndroid Build Coastguard Worker    aten.index_fill.Dimname_Scalar,
130*da0073e9SAndroid Build Coastguard Worker    aten.index_fill.Dimname_Tensor,
131*da0073e9SAndroid Build Coastguard Worker    aten.index_select.dimname,
132*da0073e9SAndroid Build Coastguard Worker    aten.diag.default,
133*da0073e9SAndroid Build Coastguard Worker    aten.cumsum.dimname,
134*da0073e9SAndroid Build Coastguard Worker    aten.cumprod.dimname,
135*da0073e9SAndroid Build Coastguard Worker    aten.meshgrid.default,
136*da0073e9SAndroid Build Coastguard Worker    aten.meshgrid.indexing,
137*da0073e9SAndroid Build Coastguard Worker    aten.fft_fft.default,
138*da0073e9SAndroid Build Coastguard Worker    aten.fft_ifft.default,
139*da0073e9SAndroid Build Coastguard Worker    aten.fft_rfft.default,
140*da0073e9SAndroid Build Coastguard Worker    aten.fft_irfft.default,
141*da0073e9SAndroid Build Coastguard Worker    aten.fft_hfft.default,
142*da0073e9SAndroid Build Coastguard Worker    aten.fft_ihfft.default,
143*da0073e9SAndroid Build Coastguard Worker    aten.fft_fftn.default,
144*da0073e9SAndroid Build Coastguard Worker    aten.fft_ifftn.default,
145*da0073e9SAndroid Build Coastguard Worker    aten.fft_rfftn.default,
146*da0073e9SAndroid Build Coastguard Worker    aten.fft_ihfftn.default,
147*da0073e9SAndroid Build Coastguard Worker    aten.fft_irfftn.default,
148*da0073e9SAndroid Build Coastguard Worker    aten.fft_hfftn.default,
149*da0073e9SAndroid Build Coastguard Worker    aten.fft_fft2.default,
150*da0073e9SAndroid Build Coastguard Worker    aten.fft_ifft2.default,
151*da0073e9SAndroid Build Coastguard Worker    aten.fft_rfft2.default,
152*da0073e9SAndroid Build Coastguard Worker    aten.fft_irfft2.default,
153*da0073e9SAndroid Build Coastguard Worker    aten.fft_hfft2.default,
154*da0073e9SAndroid Build Coastguard Worker    aten.fft_ihfft2.default,
155*da0073e9SAndroid Build Coastguard Worker    aten.fft_fftshift.default,
156*da0073e9SAndroid Build Coastguard Worker    aten.fft_ifftshift.default,
157*da0073e9SAndroid Build Coastguard Worker    aten.selu.default,
158*da0073e9SAndroid Build Coastguard Worker    aten.margin_ranking_loss.default,
159*da0073e9SAndroid Build Coastguard Worker    aten.hinge_embedding_loss.default,
160*da0073e9SAndroid Build Coastguard Worker    aten.nll_loss.default,
161*da0073e9SAndroid Build Coastguard Worker    aten.prelu.default,
162*da0073e9SAndroid Build Coastguard Worker    aten.relu6.default,
163*da0073e9SAndroid Build Coastguard Worker    aten.pairwise_distance.default,
164*da0073e9SAndroid Build Coastguard Worker    aten.pdist.default,
165*da0073e9SAndroid Build Coastguard Worker    aten.special_ndtr.default,
166*da0073e9SAndroid Build Coastguard Worker    aten.cummax.dimname,
167*da0073e9SAndroid Build Coastguard Worker    aten.cummin.dimname,
168*da0073e9SAndroid Build Coastguard Worker    aten.logcumsumexp.dimname,
169*da0073e9SAndroid Build Coastguard Worker    aten.max.other,
170*da0073e9SAndroid Build Coastguard Worker    aten.max.names_dim,
171*da0073e9SAndroid Build Coastguard Worker    aten.min.other,
172*da0073e9SAndroid Build Coastguard Worker    aten.min.names_dim,
173*da0073e9SAndroid Build Coastguard Worker    aten.linalg_eigvals.default,
174*da0073e9SAndroid Build Coastguard Worker    aten.median.names_dim,
175*da0073e9SAndroid Build Coastguard Worker    aten.nanmedian.names_dim,
176*da0073e9SAndroid Build Coastguard Worker    aten.mode.dimname,
177*da0073e9SAndroid Build Coastguard Worker    aten.gather.dimname,
178*da0073e9SAndroid Build Coastguard Worker    aten.sort.dimname,
179*da0073e9SAndroid Build Coastguard Worker    aten.sort.dimname_stable,
180*da0073e9SAndroid Build Coastguard Worker    aten.argsort.default,
181*da0073e9SAndroid Build Coastguard Worker    aten.argsort.dimname,
182*da0073e9SAndroid Build Coastguard Worker    aten.rrelu.default,
183*da0073e9SAndroid Build Coastguard Worker    aten.conv_transpose1d.default,
184*da0073e9SAndroid Build Coastguard Worker    aten.conv_transpose2d.input,
185*da0073e9SAndroid Build Coastguard Worker    aten.conv_transpose3d.input,
186*da0073e9SAndroid Build Coastguard Worker    aten.conv1d.default,
187*da0073e9SAndroid Build Coastguard Worker    aten.conv1d.padding,
188*da0073e9SAndroid Build Coastguard Worker    aten.conv3d.default,
189*da0073e9SAndroid Build Coastguard Worker    aten.conv3d.padding,
190*da0073e9SAndroid Build Coastguard Worker    aten.float_power.Tensor_Tensor,
191*da0073e9SAndroid Build Coastguard Worker    aten.float_power.Tensor_Scalar,
192*da0073e9SAndroid Build Coastguard Worker    aten.float_power.Scalar,
193*da0073e9SAndroid Build Coastguard Worker    aten.ldexp.Tensor,
194*da0073e9SAndroid Build Coastguard Worker    aten._version.default,
195*da0073e9SAndroid Build Coastguard Worker]
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Workerdef make_test_cls_with_mocked_export(
199*da0073e9SAndroid Build Coastguard Worker    cls, cls_prefix, fn_suffix, mocked_export_fn, xfail_prop=None
200*da0073e9SAndroid Build Coastguard Worker):
201*da0073e9SAndroid Build Coastguard Worker    MockedTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {})
202*da0073e9SAndroid Build Coastguard Worker    MockedTestClass.__qualname__ = MockedTestClass.__name__
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker    for name in dir(cls):
205*da0073e9SAndroid Build Coastguard Worker        if name.startswith("test_"):
206*da0073e9SAndroid Build Coastguard Worker            fn = getattr(cls, name)
207*da0073e9SAndroid Build Coastguard Worker            if not callable(fn):
208*da0073e9SAndroid Build Coastguard Worker                setattr(MockedTestClass, name, getattr(cls, name))
209*da0073e9SAndroid Build Coastguard Worker                continue
210*da0073e9SAndroid Build Coastguard Worker            new_name = f"{name}{fn_suffix}"
211*da0073e9SAndroid Build Coastguard Worker            new_fn = _make_fn_with_mocked_export(fn, mocked_export_fn)
212*da0073e9SAndroid Build Coastguard Worker            new_fn.__name__ = new_name
213*da0073e9SAndroid Build Coastguard Worker            if xfail_prop is not None and hasattr(fn, xfail_prop):
214*da0073e9SAndroid Build Coastguard Worker                new_fn = unittest.expectedFailure(new_fn)
215*da0073e9SAndroid Build Coastguard Worker            setattr(MockedTestClass, new_name, new_fn)
216*da0073e9SAndroid Build Coastguard Worker        # NB: Doesn't handle slots correctly, but whatever
217*da0073e9SAndroid Build Coastguard Worker        elif not hasattr(MockedTestClass, name):
218*da0073e9SAndroid Build Coastguard Worker            setattr(MockedTestClass, name, getattr(cls, name))
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker    return MockedTestClass
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Workerdef _make_fn_with_mocked_export(fn, mocked_export_fn):
224*da0073e9SAndroid Build Coastguard Worker    @functools.wraps(fn)
225*da0073e9SAndroid Build Coastguard Worker    def _fn(*args, **kwargs):
226*da0073e9SAndroid Build Coastguard Worker        try:
227*da0073e9SAndroid Build Coastguard Worker            from . import test_export
228*da0073e9SAndroid Build Coastguard Worker        except ImportError:
229*da0073e9SAndroid Build Coastguard Worker            import test_export
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker        with patch(f"{test_export.__name__}.export", mocked_export_fn):
232*da0073e9SAndroid Build Coastguard Worker            return fn(*args, **kwargs)
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker    return _fn
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker# Controls tests generated in test/export/test_export_training_ir_to_run_decomp.py
238*da0073e9SAndroid Build Coastguard Workerdef expectedFailureTrainingIRToRunDecomp(fn):
239*da0073e9SAndroid Build Coastguard Worker    fn._expected_failure_training_ir_to_run_decomp = True
240*da0073e9SAndroid Build Coastguard Worker    return fn
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker# Controls tests generated in test/export/test_export_training_ir_to_run_decomp.py
244*da0073e9SAndroid Build Coastguard Workerdef expectedFailureTrainingIRToRunDecompNonStrict(fn):
245*da0073e9SAndroid Build Coastguard Worker    fn._expected_failure_training_ir_to_run_decomp_non_strict = True
246*da0073e9SAndroid Build Coastguard Worker    return fn
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker# Controls tests generated in test/export/test_export_nonstrict.py
250*da0073e9SAndroid Build Coastguard Workerdef expectedFailureNonStrict(fn):
251*da0073e9SAndroid Build Coastguard Worker    fn._expected_failure_non_strict = True
252*da0073e9SAndroid Build Coastguard Worker    return fn
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker# Controls tests generated in test/export/test_retraceability.py
256*da0073e9SAndroid Build Coastguard Workerdef expectedFailureRetraceability(fn):
257*da0073e9SAndroid Build Coastguard Worker    fn._expected_failure_retrace = True
258*da0073e9SAndroid Build Coastguard Worker    return fn
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker# Controls tests generated in test/export/test_serdes.py
262*da0073e9SAndroid Build Coastguard Workerdef expectedFailureSerDer(fn):
263*da0073e9SAndroid Build Coastguard Worker    fn._expected_failure_serdes = True
264*da0073e9SAndroid Build Coastguard Worker    return fn
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Workerdef expectedFailureSerDerPreDispatch(fn):
268*da0073e9SAndroid Build Coastguard Worker    fn._expected_failure_serdes_pre_dispatch = True
269*da0073e9SAndroid Build Coastguard Worker    return fn
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Workerdef expectedFailurePreDispatchRunDecomp(fn):
273*da0073e9SAndroid Build Coastguard Worker    fn._expected_failure_pre_dispatch = True
274*da0073e9SAndroid Build Coastguard Worker    return fn
275