1*da0073e9SAndroid Build Coastguard Workerimport functools 2*da0073e9SAndroid Build Coastguard Workerimport unittest 3*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workeraten = torch.ops.aten 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker# This list is not meant to be comprehensive 11*da0073e9SAndroid Build Coastguard Worker_COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY = [ 12*da0073e9SAndroid Build Coastguard Worker aten.arctan2.default, 13*da0073e9SAndroid Build Coastguard Worker aten.divide.Tensor, 14*da0073e9SAndroid Build Coastguard Worker aten.divide.Scalar, 15*da0073e9SAndroid Build Coastguard Worker aten.divide.Tensor_mode, 16*da0073e9SAndroid Build Coastguard Worker aten.divide.Scalar_mode, 17*da0073e9SAndroid Build Coastguard Worker aten.multiply.Tensor, 18*da0073e9SAndroid Build Coastguard Worker aten.multiply.Scalar, 19*da0073e9SAndroid Build Coastguard Worker aten.subtract.Tensor, 20*da0073e9SAndroid Build Coastguard Worker aten.subtract.Scalar, 21*da0073e9SAndroid Build Coastguard Worker aten.true_divide.Tensor, 22*da0073e9SAndroid Build Coastguard Worker aten.true_divide.Scalar, 23*da0073e9SAndroid Build Coastguard Worker aten.greater.Tensor, 24*da0073e9SAndroid Build Coastguard Worker aten.greater.Scalar, 25*da0073e9SAndroid Build Coastguard Worker aten.greater_equal.Tensor, 26*da0073e9SAndroid Build Coastguard Worker aten.greater_equal.Scalar, 27*da0073e9SAndroid Build Coastguard Worker aten.less_equal.Tensor, 28*da0073e9SAndroid Build Coastguard Worker aten.less_equal.Scalar, 29*da0073e9SAndroid Build Coastguard Worker aten.less.Tensor, 30*da0073e9SAndroid Build Coastguard Worker aten.less.Scalar, 31*da0073e9SAndroid Build Coastguard Worker aten.not_equal.Tensor, 32*da0073e9SAndroid Build Coastguard Worker aten.not_equal.Scalar, 33*da0073e9SAndroid Build Coastguard Worker aten.cat.names, 34*da0073e9SAndroid Build Coastguard Worker aten.sum.dim_DimnameList, 35*da0073e9SAndroid Build Coastguard Worker aten.mean.names_dim, 36*da0073e9SAndroid Build Coastguard Worker aten.prod.dim_Dimname, 37*da0073e9SAndroid Build Coastguard Worker aten.all.dimname, 38*da0073e9SAndroid Build Coastguard Worker aten.norm.names_ScalarOpt_dim, 39*da0073e9SAndroid Build Coastguard Worker aten.norm.names_ScalarOpt_dim_dtype, 40*da0073e9SAndroid Build Coastguard Worker aten.var.default, 41*da0073e9SAndroid Build Coastguard Worker aten.var.dim, 42*da0073e9SAndroid Build Coastguard Worker aten.var.names_dim, 43*da0073e9SAndroid Build Coastguard Worker aten.var.correction_names, 44*da0073e9SAndroid Build Coastguard Worker aten.std.default, 45*da0073e9SAndroid Build Coastguard Worker aten.std.dim, 46*da0073e9SAndroid Build Coastguard Worker aten.std.names_dim, 47*da0073e9SAndroid Build Coastguard Worker aten.std.correction_names, 48*da0073e9SAndroid Build Coastguard Worker aten.absolute.default, 49*da0073e9SAndroid Build Coastguard Worker aten.arccos.default, 50*da0073e9SAndroid Build Coastguard Worker aten.arccosh.default, 51*da0073e9SAndroid Build Coastguard Worker aten.arcsin.default, 52*da0073e9SAndroid Build Coastguard Worker aten.arcsinh.default, 53*da0073e9SAndroid Build Coastguard Worker aten.arctan.default, 54*da0073e9SAndroid Build Coastguard Worker aten.arctanh.default, 55*da0073e9SAndroid Build Coastguard Worker aten.clip.default, 56*da0073e9SAndroid Build Coastguard Worker aten.clip.Tensor, 57*da0073e9SAndroid Build Coastguard Worker aten.fix.default, 58*da0073e9SAndroid Build Coastguard Worker aten.negative.default, 59*da0073e9SAndroid Build Coastguard Worker aten.square.default, 60*da0073e9SAndroid Build Coastguard Worker aten.size.int, 61*da0073e9SAndroid Build Coastguard Worker aten.size.Dimname, 62*da0073e9SAndroid Build Coastguard Worker aten.stride.int, 63*da0073e9SAndroid Build Coastguard Worker aten.stride.Dimname, 64*da0073e9SAndroid Build Coastguard Worker aten.repeat_interleave.self_Tensor, 65*da0073e9SAndroid Build Coastguard Worker aten.repeat_interleave.self_int, 66*da0073e9SAndroid Build Coastguard Worker aten.sym_size.int, 67*da0073e9SAndroid Build Coastguard Worker aten.sym_stride.int, 68*da0073e9SAndroid Build Coastguard Worker aten.atleast_1d.Sequence, 69*da0073e9SAndroid Build Coastguard Worker aten.atleast_2d.Sequence, 70*da0073e9SAndroid Build Coastguard Worker aten.atleast_3d.Sequence, 71*da0073e9SAndroid Build Coastguard Worker aten.linear.default, 72*da0073e9SAndroid Build Coastguard Worker aten.conv2d.default, 73*da0073e9SAndroid Build Coastguard Worker aten.conv2d.padding, 74*da0073e9SAndroid Build Coastguard Worker aten.mish_backward.default, 75*da0073e9SAndroid Build Coastguard Worker aten.silu_backward.default, 76*da0073e9SAndroid Build Coastguard Worker aten.index_add.dimname, 77*da0073e9SAndroid Build Coastguard Worker aten.pad_sequence.default, 78*da0073e9SAndroid Build Coastguard Worker aten.index_copy.dimname, 79*da0073e9SAndroid Build Coastguard Worker aten.upsample_nearest1d.vec, 80*da0073e9SAndroid Build Coastguard Worker aten.upsample_nearest2d.vec, 81*da0073e9SAndroid Build Coastguard Worker aten.upsample_nearest3d.vec, 82*da0073e9SAndroid Build Coastguard Worker aten._upsample_nearest_exact1d.vec, 83*da0073e9SAndroid Build Coastguard Worker aten._upsample_nearest_exact2d.vec, 84*da0073e9SAndroid Build Coastguard Worker aten._upsample_nearest_exact3d.vec, 85*da0073e9SAndroid Build Coastguard Worker aten.rnn_tanh.input, 86*da0073e9SAndroid Build Coastguard Worker aten.rnn_tanh.data, 87*da0073e9SAndroid Build Coastguard Worker aten.rnn_relu.input, 88*da0073e9SAndroid Build Coastguard Worker aten.rnn_relu.data, 89*da0073e9SAndroid Build Coastguard Worker aten.lstm.input, 90*da0073e9SAndroid Build Coastguard Worker aten.lstm.data, 91*da0073e9SAndroid Build Coastguard Worker aten.gru.input, 92*da0073e9SAndroid Build Coastguard Worker aten.gru.data, 93*da0073e9SAndroid Build Coastguard Worker aten._upsample_bilinear2d_aa.vec, 94*da0073e9SAndroid Build Coastguard Worker aten._upsample_bicubic2d_aa.vec, 95*da0073e9SAndroid Build Coastguard Worker aten.upsample_bilinear2d.vec, 96*da0073e9SAndroid Build Coastguard Worker aten.upsample_trilinear3d.vec, 97*da0073e9SAndroid Build Coastguard Worker aten.upsample_linear1d.vec, 98*da0073e9SAndroid Build Coastguard Worker aten.matmul.default, 99*da0073e9SAndroid Build Coastguard Worker aten.upsample_bicubic2d.vec, 100*da0073e9SAndroid Build Coastguard Worker aten.__and__.Scalar, 101*da0073e9SAndroid Build Coastguard Worker aten.__and__.Tensor, 102*da0073e9SAndroid Build Coastguard Worker aten.__or__.Tensor, 103*da0073e9SAndroid Build Coastguard Worker aten.__or__.Scalar, 104*da0073e9SAndroid Build Coastguard Worker aten.__xor__.Tensor, 105*da0073e9SAndroid Build Coastguard Worker aten.__xor__.Scalar, 106*da0073e9SAndroid Build Coastguard Worker aten.scatter.dimname_src, 107*da0073e9SAndroid Build Coastguard Worker aten.scatter.dimname_value, 108*da0073e9SAndroid Build Coastguard Worker aten.scatter_add.dimname, 109*da0073e9SAndroid Build Coastguard Worker aten.is_complex.default, 110*da0073e9SAndroid Build Coastguard Worker aten.logsumexp.names, 111*da0073e9SAndroid Build Coastguard Worker aten.where.ScalarOther, 112*da0073e9SAndroid Build Coastguard Worker aten.where.ScalarSelf, 113*da0073e9SAndroid Build Coastguard Worker aten.where.Scalar, 114*da0073e9SAndroid Build Coastguard Worker aten.where.default, 115*da0073e9SAndroid Build Coastguard Worker aten.item.default, 116*da0073e9SAndroid Build Coastguard Worker aten.any.dimname, 117*da0073e9SAndroid Build Coastguard Worker aten.std_mean.default, 118*da0073e9SAndroid Build Coastguard Worker aten.std_mean.dim, 119*da0073e9SAndroid Build Coastguard Worker aten.std_mean.names_dim, 120*da0073e9SAndroid Build Coastguard Worker aten.std_mean.correction_names, 121*da0073e9SAndroid Build Coastguard Worker aten.var_mean.default, 122*da0073e9SAndroid Build Coastguard Worker aten.var_mean.dim, 123*da0073e9SAndroid Build Coastguard Worker aten.var_mean.names_dim, 124*da0073e9SAndroid Build Coastguard Worker aten.var_mean.correction_names, 125*da0073e9SAndroid Build Coastguard Worker aten.broadcast_tensors.default, 126*da0073e9SAndroid Build Coastguard Worker aten.stft.default, 127*da0073e9SAndroid Build Coastguard Worker aten.stft.center, 128*da0073e9SAndroid Build Coastguard Worker aten.istft.default, 129*da0073e9SAndroid Build Coastguard Worker aten.index_fill.Dimname_Scalar, 130*da0073e9SAndroid Build Coastguard Worker aten.index_fill.Dimname_Tensor, 131*da0073e9SAndroid Build Coastguard Worker aten.index_select.dimname, 132*da0073e9SAndroid Build Coastguard Worker aten.diag.default, 133*da0073e9SAndroid Build Coastguard Worker aten.cumsum.dimname, 134*da0073e9SAndroid Build Coastguard Worker aten.cumprod.dimname, 135*da0073e9SAndroid Build Coastguard Worker aten.meshgrid.default, 136*da0073e9SAndroid Build Coastguard Worker aten.meshgrid.indexing, 137*da0073e9SAndroid Build Coastguard Worker aten.fft_fft.default, 138*da0073e9SAndroid Build Coastguard Worker aten.fft_ifft.default, 139*da0073e9SAndroid Build Coastguard Worker aten.fft_rfft.default, 140*da0073e9SAndroid Build Coastguard Worker aten.fft_irfft.default, 141*da0073e9SAndroid Build Coastguard Worker aten.fft_hfft.default, 142*da0073e9SAndroid Build Coastguard Worker aten.fft_ihfft.default, 143*da0073e9SAndroid Build Coastguard Worker aten.fft_fftn.default, 144*da0073e9SAndroid Build Coastguard Worker aten.fft_ifftn.default, 145*da0073e9SAndroid Build Coastguard Worker aten.fft_rfftn.default, 146*da0073e9SAndroid Build Coastguard Worker aten.fft_ihfftn.default, 147*da0073e9SAndroid Build Coastguard Worker aten.fft_irfftn.default, 148*da0073e9SAndroid Build Coastguard Worker aten.fft_hfftn.default, 149*da0073e9SAndroid Build Coastguard Worker aten.fft_fft2.default, 150*da0073e9SAndroid Build Coastguard Worker aten.fft_ifft2.default, 151*da0073e9SAndroid Build Coastguard Worker aten.fft_rfft2.default, 152*da0073e9SAndroid Build Coastguard Worker aten.fft_irfft2.default, 153*da0073e9SAndroid Build Coastguard Worker aten.fft_hfft2.default, 154*da0073e9SAndroid Build Coastguard Worker aten.fft_ihfft2.default, 155*da0073e9SAndroid Build Coastguard Worker aten.fft_fftshift.default, 156*da0073e9SAndroid Build Coastguard Worker aten.fft_ifftshift.default, 157*da0073e9SAndroid Build Coastguard Worker aten.selu.default, 158*da0073e9SAndroid Build Coastguard Worker aten.margin_ranking_loss.default, 159*da0073e9SAndroid Build Coastguard Worker aten.hinge_embedding_loss.default, 160*da0073e9SAndroid Build Coastguard Worker aten.nll_loss.default, 161*da0073e9SAndroid Build Coastguard Worker aten.prelu.default, 162*da0073e9SAndroid Build Coastguard Worker aten.relu6.default, 163*da0073e9SAndroid Build Coastguard Worker aten.pairwise_distance.default, 164*da0073e9SAndroid Build Coastguard Worker aten.pdist.default, 165*da0073e9SAndroid Build Coastguard Worker aten.special_ndtr.default, 166*da0073e9SAndroid Build Coastguard Worker aten.cummax.dimname, 167*da0073e9SAndroid Build Coastguard Worker aten.cummin.dimname, 168*da0073e9SAndroid Build Coastguard Worker aten.logcumsumexp.dimname, 169*da0073e9SAndroid Build Coastguard Worker aten.max.other, 170*da0073e9SAndroid Build Coastguard Worker aten.max.names_dim, 171*da0073e9SAndroid Build Coastguard Worker aten.min.other, 172*da0073e9SAndroid Build Coastguard Worker aten.min.names_dim, 173*da0073e9SAndroid Build Coastguard Worker aten.linalg_eigvals.default, 174*da0073e9SAndroid Build Coastguard Worker aten.median.names_dim, 175*da0073e9SAndroid Build Coastguard Worker aten.nanmedian.names_dim, 176*da0073e9SAndroid Build Coastguard Worker aten.mode.dimname, 177*da0073e9SAndroid Build Coastguard Worker aten.gather.dimname, 178*da0073e9SAndroid Build Coastguard Worker aten.sort.dimname, 179*da0073e9SAndroid Build Coastguard Worker aten.sort.dimname_stable, 180*da0073e9SAndroid Build Coastguard Worker aten.argsort.default, 181*da0073e9SAndroid Build Coastguard Worker aten.argsort.dimname, 182*da0073e9SAndroid Build Coastguard Worker aten.rrelu.default, 183*da0073e9SAndroid Build Coastguard Worker aten.conv_transpose1d.default, 184*da0073e9SAndroid Build Coastguard Worker aten.conv_transpose2d.input, 185*da0073e9SAndroid Build Coastguard Worker aten.conv_transpose3d.input, 186*da0073e9SAndroid Build Coastguard Worker aten.conv1d.default, 187*da0073e9SAndroid Build Coastguard Worker aten.conv1d.padding, 188*da0073e9SAndroid Build Coastguard Worker aten.conv3d.default, 189*da0073e9SAndroid Build Coastguard Worker aten.conv3d.padding, 190*da0073e9SAndroid Build Coastguard Worker aten.float_power.Tensor_Tensor, 191*da0073e9SAndroid Build Coastguard Worker aten.float_power.Tensor_Scalar, 192*da0073e9SAndroid Build Coastguard Worker aten.float_power.Scalar, 193*da0073e9SAndroid Build Coastguard Worker aten.ldexp.Tensor, 194*da0073e9SAndroid Build Coastguard Worker aten._version.default, 195*da0073e9SAndroid Build Coastguard Worker] 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Workerdef make_test_cls_with_mocked_export( 199*da0073e9SAndroid Build Coastguard Worker cls, cls_prefix, fn_suffix, mocked_export_fn, xfail_prop=None 200*da0073e9SAndroid Build Coastguard Worker): 201*da0073e9SAndroid Build Coastguard Worker MockedTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {}) 202*da0073e9SAndroid Build Coastguard Worker MockedTestClass.__qualname__ = MockedTestClass.__name__ 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker for name in dir(cls): 205*da0073e9SAndroid Build Coastguard Worker if name.startswith("test_"): 206*da0073e9SAndroid Build Coastguard Worker fn = getattr(cls, name) 207*da0073e9SAndroid Build Coastguard Worker if not callable(fn): 208*da0073e9SAndroid Build Coastguard Worker setattr(MockedTestClass, name, getattr(cls, name)) 209*da0073e9SAndroid Build Coastguard Worker continue 210*da0073e9SAndroid Build Coastguard Worker new_name = f"{name}{fn_suffix}" 211*da0073e9SAndroid Build Coastguard Worker new_fn = _make_fn_with_mocked_export(fn, mocked_export_fn) 212*da0073e9SAndroid Build Coastguard Worker new_fn.__name__ = new_name 213*da0073e9SAndroid Build Coastguard Worker if xfail_prop is not None and hasattr(fn, xfail_prop): 214*da0073e9SAndroid Build Coastguard Worker new_fn = unittest.expectedFailure(new_fn) 215*da0073e9SAndroid Build Coastguard Worker setattr(MockedTestClass, new_name, new_fn) 216*da0073e9SAndroid Build Coastguard Worker # NB: Doesn't handle slots correctly, but whatever 217*da0073e9SAndroid Build Coastguard Worker elif not hasattr(MockedTestClass, name): 218*da0073e9SAndroid Build Coastguard Worker setattr(MockedTestClass, name, getattr(cls, name)) 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker return MockedTestClass 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Workerdef _make_fn_with_mocked_export(fn, mocked_export_fn): 224*da0073e9SAndroid Build Coastguard Worker @functools.wraps(fn) 225*da0073e9SAndroid Build Coastguard Worker def _fn(*args, **kwargs): 226*da0073e9SAndroid Build Coastguard Worker try: 227*da0073e9SAndroid Build Coastguard Worker from . import test_export 228*da0073e9SAndroid Build Coastguard Worker except ImportError: 229*da0073e9SAndroid Build Coastguard Worker import test_export 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker with patch(f"{test_export.__name__}.export", mocked_export_fn): 232*da0073e9SAndroid Build Coastguard Worker return fn(*args, **kwargs) 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker return _fn 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker# Controls tests generated in test/export/test_export_training_ir_to_run_decomp.py 238*da0073e9SAndroid Build Coastguard Workerdef expectedFailureTrainingIRToRunDecomp(fn): 239*da0073e9SAndroid Build Coastguard Worker fn._expected_failure_training_ir_to_run_decomp = True 240*da0073e9SAndroid Build Coastguard Worker return fn 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker# Controls tests generated in test/export/test_export_training_ir_to_run_decomp.py 244*da0073e9SAndroid Build Coastguard Workerdef expectedFailureTrainingIRToRunDecompNonStrict(fn): 245*da0073e9SAndroid Build Coastguard Worker fn._expected_failure_training_ir_to_run_decomp_non_strict = True 246*da0073e9SAndroid Build Coastguard Worker return fn 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker# Controls tests generated in test/export/test_export_nonstrict.py 250*da0073e9SAndroid Build Coastguard Workerdef expectedFailureNonStrict(fn): 251*da0073e9SAndroid Build Coastguard Worker fn._expected_failure_non_strict = True 252*da0073e9SAndroid Build Coastguard Worker return fn 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker# Controls tests generated in test/export/test_retraceability.py 256*da0073e9SAndroid Build Coastguard Workerdef expectedFailureRetraceability(fn): 257*da0073e9SAndroid Build Coastguard Worker fn._expected_failure_retrace = True 258*da0073e9SAndroid Build Coastguard Worker return fn 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker# Controls tests generated in test/export/test_serdes.py 262*da0073e9SAndroid Build Coastguard Workerdef expectedFailureSerDer(fn): 263*da0073e9SAndroid Build Coastguard Worker fn._expected_failure_serdes = True 264*da0073e9SAndroid Build Coastguard Worker return fn 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Workerdef expectedFailureSerDerPreDispatch(fn): 268*da0073e9SAndroid Build Coastguard Worker fn._expected_failure_serdes_pre_dispatch = True 269*da0073e9SAndroid Build Coastguard Worker return fn 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Workerdef expectedFailurePreDispatchRunDecomp(fn): 273*da0073e9SAndroid Build Coastguard Worker fn._expected_failure_pre_dispatch = True 274*da0073e9SAndroid Build Coastguard Worker return fn 275