1# Owner(s): ["module: onnx"] 2 3""" 4Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data] 5 --no-onnx: no onnx python dependency 6 --produce-onnx-test-data: generate onnx test data 7 --accept: accept onnx updates and overwrite models 8""" 9 10import glob 11import inspect 12import io 13import itertools 14import operator 15import os 16import shutil 17import tempfile 18 19# Full diff for expect files 20import unittest 21 22from pytorch_test_common import ( 23 BATCH_SIZE, 24 flatten, 25 RNN_HIDDEN_SIZE, 26 RNN_INPUT_SIZE, 27 RNN_SEQUENCE_LENGTH, 28) 29 30import torch 31import torch.nn as nn 32import torch.nn.functional as F 33import torch.onnx 34from torch.autograd import Function, Variable 35from torch.nn import functional, Module 36from torch.onnx._internal import diagnostics 37from torch.onnx.symbolic_helper import ( 38 _get_tensor_dim_size, 39 _get_tensor_sizes, 40 parse_args, 41) 42from torch.testing._internal import common_utils 43from torch.testing._internal.common_utils import skipIfNoLapack 44 45 46unittest.TestCase.maxDiff = None 47 48_onnx_test = False # flag to produce onnx test cases. 49_onnx_dep = True # flag to import onnx package. 50 51 52def export_to_pbtxt(model, inputs, *args, **kwargs): 53 return torch.onnx.export_to_pretty_string( 54 model, inputs, *args, google_printer=True, **kwargs 55 ) 56 57 58def export_to_pb(model, inputs, *args, **kwargs): 59 f = io.BytesIO() 60 with torch.no_grad(): 61 torch.onnx.export(model, inputs, f, *args, **kwargs) 62 return f.getvalue() 63 64 65class FuncModule(Module): 66 def __init__(self, f, params=None): 67 if params is None: 68 params = () 69 super().__init__() 70 self.f = f 71 self.params = nn.ParameterList(list(params)) 72 73 def forward(self, *args): 74 return self.f(*itertools.chain(args, self.params)) 75 76 77class TestOperators(common_utils.TestCase): 78 def setUp(self): 79 super().setUp() 80 diagnostics.engine.clear() 81 82 def assertONNX(self, f, args, params=None, **kwargs): 83 if params is None: 84 params = () 85 if isinstance(f, nn.Module): 86 m = f 87 else: 88 m = FuncModule(f, params) 89 m.eval() 90 onnx_model_pbtxt = export_to_pbtxt(m, args, **kwargs) 91 subname = kwargs.pop("subname", None) 92 self.assertExpected(onnx_model_pbtxt, subname) 93 if _onnx_dep: 94 onnx_model_pb = export_to_pb(m, args, **kwargs) 95 import onnx 96 import onnx.checker 97 import onnx.numpy_helper 98 import onnx_test_common 99 100 model_def = onnx.ModelProto.FromString(onnx_model_pb) 101 onnx.checker.check_model(model_def) 102 if _onnx_test: 103 test_function = inspect.stack()[1][0].f_code.co_name 104 test_name = test_function[0:4] + "_operator" + test_function[4:] 105 output_dir = os.path.join( 106 onnx_test_common.pytorch_operator_dir, test_name 107 ) 108 # Assume: 109 # 1) the old test should be delete before the test. 110 # 2) only one assertONNX in each test, otherwise will override the data. 111 assert not os.path.exists(output_dir), f"{output_dir} should not exist!" 112 os.makedirs(output_dir) 113 with open(os.path.join(output_dir, "model.onnx"), "wb") as file: 114 file.write(model_def.SerializeToString()) 115 data_dir = os.path.join(output_dir, "test_data_set_0") 116 os.makedirs(data_dir) 117 if isinstance(args, Variable): 118 args = (args,) 119 for index, var in enumerate(flatten(args)): 120 tensor = onnx.numpy_helper.from_array(var.data.numpy()) 121 with open( 122 os.path.join(data_dir, f"input_{index}.pb"), "wb" 123 ) as file: 124 file.write(tensor.SerializeToString()) 125 outputs = m(*args) 126 if isinstance(outputs, Variable): 127 outputs = (outputs,) 128 for index, var in enumerate(flatten(outputs)): 129 tensor = onnx.numpy_helper.from_array(var.data.numpy()) 130 with open( 131 os.path.join(data_dir, f"output_{index}.pb"), "wb" 132 ) as file: 133 file.write(tensor.SerializeToString()) 134 135 def assertONNXRaises(self, err, f, args, params=None, **kwargs): 136 if params is None: 137 params = () 138 if isinstance(f, nn.Module): 139 m = f 140 else: 141 m = FuncModule(f, params) 142 self.assertExpectedRaises(err, lambda: export_to_pbtxt(m, args, **kwargs)) 143 144 def assertONNXRaisesRegex(self, err, reg, f, args, params=None, **kwargs): 145 if params is None: 146 params = () 147 if isinstance(f, nn.Module): 148 m = f 149 else: 150 m = FuncModule(f, params) 151 with self.assertRaisesRegex(err, reg): 152 export_to_pbtxt(m, args, **kwargs) 153 154 def test_basic(self): 155 x = torch.tensor([0.4], requires_grad=True) 156 y = torch.tensor([0.7], requires_grad=True) 157 self.assertONNX(lambda x, y: -torch.sigmoid(torch.tanh(x * (x + y))), (x, y)) 158 159 def test_view(self): 160 x = torch.tensor([0.0], requires_grad=True) 161 self.assertONNX(lambda x: x.view(1, 1), x) 162 163 def test_index(self): 164 x = torch.tensor([[0.0]], requires_grad=True) 165 self.assertONNX(lambda x: x[0], x) 166 167 def test_type_as(self): 168 x = torch.tensor([0.0], requires_grad=True) 169 self.assertONNX(lambda x: x.type_as(x), x) 170 171 def test_addconstant(self): 172 x = torch.randn(2, 3, requires_grad=True).double() 173 self.assertONNX(lambda x: x + 1, x) 174 175 def test_add_broadcast(self): 176 x = torch.randn(2, 3, requires_grad=True).double() 177 y = torch.randn(3, requires_grad=True).double() 178 self.assertONNX(operator.add, (x, y)) 179 180 def test_add_left_broadcast(self): 181 x = torch.randn(3, requires_grad=True).double() 182 y = torch.randn(2, 3, requires_grad=True).double() 183 self.assertONNX(operator.add, (x, y)) 184 185 def test_add_size1_broadcast(self): 186 x = torch.randn(2, 3, requires_grad=True).double() 187 y = torch.randn(2, 1, requires_grad=True).double() 188 self.assertONNX(operator.add, (x, y)) 189 190 def test_add_size1_right_broadcast(self): 191 x = torch.randn(2, 3, requires_grad=True).double() 192 y = torch.randn(3, requires_grad=True).double() 193 self.assertONNX(operator.add, (x, y)) 194 195 def test_add_size1_singleton_broadcast(self): 196 x = torch.randn(2, 3, requires_grad=True).double() 197 y = torch.randn(1, 3, requires_grad=True).double() 198 self.assertONNX(operator.add, (x, y)) 199 200 def test_rsub(self): 201 x = torch.randn(2, 3, requires_grad=True).double() 202 self.assertONNX(lambda x: 1 - x, (x,)) 203 204 def test_mul_bool(self): 205 x = torch.tensor([True, False, True, False]) 206 y = torch.tensor([True, True, False, False]) 207 self.assertONNX(lambda x, y: torch.mul(x, y), (x, y)) 208 209 def test_mul_fp_bool(self): 210 x = torch.tensor([9.4, 1.7, 3.6]) 211 y = torch.tensor([True, True, False]) 212 self.assertONNX(lambda x, y: torch.mul(x, y), (x, y)) 213 214 def test_transpose(self): 215 x = torch.tensor([[0.0, 1.0], [2.0, 3.0]], requires_grad=True) 216 self.assertONNX(lambda x: x.transpose(0, 1).transpose(1, 0), x) 217 218 def test_chunk(self): 219 x = torch.tensor([0.0, 1.0, 2.0], requires_grad=True) 220 self.assertONNX(lambda x: x.chunk(2), x) 221 222 def test_split(self): 223 x = torch.tensor( 224 [[0.0, 1.0, 1.0, 0.0, 2.0, 2.0], [2.0, 3.0, 3.0, 2.0, 1.0, 1.0]] 225 ) 226 self.assertONNX(lambda x: torch.split(x, 2, 1), x) 227 228 def test_split_with_sizes(self): 229 x = torch.tensor( 230 [[0.0, 1.0, 1.0, 0.0, 2.0, 2.0], [2.0, 3.0, 3.0, 2.0, 1.0, 1.0]] 231 ) 232 self.assertONNX(lambda x: torch.split(x, [2, 1, 3], 1), x) 233 234 def test_concat2(self): 235 x = torch.randn(2, 3) 236 y = torch.randn(2, 3) 237 self.assertONNX(lambda inputs: torch.cat(inputs, 1), ((x, y),)) 238 239 def test_mm(self): 240 m1 = torch.randn(2, 3, requires_grad=True) 241 m2 = torch.randn(3, 4, requires_grad=True) 242 self.assertONNX(torch.mm, (m1, m2)) 243 244 def test_addmm(self): 245 m1 = torch.randn(2, 3, requires_grad=True) 246 m2 = torch.randn(3, 4, requires_grad=True) 247 m3 = torch.randn(4, requires_grad=True) 248 self.assertONNX( 249 lambda x, y, z: torch.addmm(torch.addmm(z, x, y), x, y), (m1, m2, m3) 250 ) 251 252 def test_permute2(self): 253 x = torch.tensor([[[[[[0.0]]]]]], requires_grad=True) 254 self.assertONNX(lambda x: x.permute(0, 1, 4, 2, 5, 3), x) 255 256 def test_pad(self): 257 x = torch.tensor( 258 [[[[0.0, 1.0, 1.0, 1.0], [2.0, 3.0, 7.0, 7.0]]]], requires_grad=True 259 ) 260 self.assertONNX(nn.ReflectionPad2d((2, 3, 0, 1)), x) 261 262 def test_params(self): 263 x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True) 264 y = nn.Parameter(torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)) 265 self.assertONNX( 266 lambda x, y: -torch.sigmoid(torch.tanh(x * (x + y))), 267 x, 268 params=(y,), 269 keep_initializers_as_inputs=True, 270 ) 271 272 def test_params_onnx_irv4(self): 273 x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True) 274 y = nn.Parameter(torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)) 275 self.assertONNX( 276 lambda x, y: -torch.sigmoid(torch.tanh(x * (x + y))), 277 x, 278 params=(y,), 279 keep_initializers_as_inputs=False, 280 ) 281 282 def test_symbolic_mismatch(self): 283 class MyFun(Function): 284 @staticmethod 285 def symbolic(g, x): 286 # The inside of this function should never be invoked, because 287 # we will fail due to an argument mismatch first. 288 raise AssertionError 289 290 @staticmethod 291 def forward(ctx, x, y): 292 return x + y 293 294 x = torch.ones(2, 2) 295 y = torch.ones(2, 2) 296 # NB: Don't use expect test here, the type error wobbles depending 297 # on Python version 298 with self.assertRaisesRegex(TypeError, "occurred when translating MyFun"): 299 export_to_pbtxt(FuncModule(MyFun().apply), (x, y)) 300 301 # TODO: Do an nn style test for these 302 def test_batchnorm(self): 303 x = torch.ones(2, 2, 2, 2, requires_grad=True) 304 self.assertONNX(nn.BatchNorm2d(2), x, keep_initializers_as_inputs=True) 305 306 def test_batchnorm_onnx_irv4(self): 307 x = torch.ones(2, 2, 2, 2, requires_grad=True) 308 self.assertONNX(nn.BatchNorm2d(2), x) 309 310 def test_batchnorm_1d(self): 311 x = torch.ones(2, 2, requires_grad=True) 312 self.assertONNX(nn.BatchNorm1d(2), x, keep_initializers_as_inputs=True) 313 314 def test_batchnorm_training(self): 315 x = torch.ones(2, 2, 2, 2, requires_grad=True) 316 self.assertONNX( 317 nn.BatchNorm2d(2), 318 x, 319 training=torch.onnx.TrainingMode.TRAINING, 320 keep_initializers_as_inputs=True, 321 ) 322 323 def test_conv(self): 324 x = torch.ones(20, 16, 50, 40, requires_grad=True) 325 self.assertONNX( 326 nn.Conv2d(16, 13, 3, bias=False), x, keep_initializers_as_inputs=True 327 ) 328 329 def test_conv_onnx_irv4(self): 330 x = torch.ones(20, 16, 50, 40, requires_grad=True) 331 self.assertONNX(nn.Conv2d(16, 13, 3, bias=False), x) 332 333 def test_conv_onnx_irv4_opset8(self): 334 # This test point checks that for opset 8 (or lower), even if 335 # keep_initializers_as_inputs is set to False, it is ignored, 336 # and initializers are listed as ONNX graph input, in accordance 337 # with ONNX IR v3 semantics (which apply to opset version <= 8). 338 x = torch.ones(1, 2, 5, 7, requires_grad=True) 339 conv_node = nn.Conv2d(2, 4, 3, bias=False) 340 conv_node.weight.data.fill_(1.0) 341 self.assertONNX( 342 conv_node, x, opset_version=8, keep_initializers_as_inputs=False 343 ) 344 345 def test_conv_variable_length(self): 346 x = torch.ones(5, 3, 6, 6, requires_grad=True) 347 model = torch.nn.Conv2d(3, 2, 3) 348 349 dynamic_axes = { 350 "input_1": [0, 2, 3], 351 "output_1": {0: "output_1_variable_dim_0", 1: "output_1_variable_dim_1"}, 352 } 353 model_proto_file = tempfile.NamedTemporaryFile() 354 torch.onnx.export( 355 model, 356 x, 357 model_proto_file.name, 358 verbose=True, 359 input_names=["input_1"], 360 output_names=["output_1"], 361 dynamic_axes=dynamic_axes, 362 ) 363 364 import onnx 365 366 onnx_model = onnx.load(model_proto_file.name) 367 onnx.checker.check_model(onnx_model) 368 369 # Asserting the default dynamic axes names are generated when custom names are not provided 370 assert ( 371 onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_param 372 == "input_1_dynamic_axes_1" 373 ) 374 assert ( 375 onnx_model.graph.input[0].type.tensor_type.shape.dim[2].dim_param 376 == "input_1_dynamic_axes_2" 377 ) 378 assert ( 379 onnx_model.graph.input[0].type.tensor_type.shape.dim[3].dim_param 380 == "input_1_dynamic_axes_3" 381 ) 382 383 # Asserting the custom names are applied when provided 384 assert ( 385 onnx_model.graph.output[0].type.tensor_type.shape.dim[0].dim_param 386 == "output_1_variable_dim_0" 387 ) 388 assert ( 389 onnx_model.graph.output[0].type.tensor_type.shape.dim[1].dim_param 390 == "output_1_variable_dim_1" 391 ) 392 393 def test_convtranspose(self): 394 x = torch.ones(2, 3, 4, 5, requires_grad=True) 395 self.assertONNX( 396 nn.ConvTranspose2d( 397 3, 3, 3, stride=3, bias=False, padding=1, output_padding=2 398 ), 399 x, 400 keep_initializers_as_inputs=True, 401 ) 402 403 def test_maxpool(self): 404 x = torch.randn(20, 16, 50) 405 self.assertONNX(nn.MaxPool1d(3, stride=2), x) 406 407 def test_maxpool_dilations(self): 408 x = torch.randn(20, 16, 50) 409 self.assertONNX(nn.MaxPool1d(2, stride=1, dilation=2), x, opset_version=10) 410 411 def test_avg_pool2d(self): 412 x = torch.randn(20, 16, 50, 32) 413 self.assertONNX(nn.AvgPool2d(3, stride=2), x) 414 415 def test_maxpool_indices(self): 416 x = torch.randn(20, 16, 50) 417 self.assertONNX(nn.MaxPool1d(3, stride=2, return_indices=True), x) 418 419 def test_at_op(self): 420 x = torch.randn(3, 4) 421 422 class MyFun(Function): 423 @staticmethod 424 def symbolic(g, x): 425 return g.at("add", x, x) 426 427 @staticmethod 428 def forward(ctx, x): 429 return x + x 430 431 class MyModule(Module): 432 def forward(self, x): 433 return MyFun.apply(x) 434 435 self.assertONNX( 436 MyModule(), 437 x, 438 operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, 439 ) 440 441 def test_clip(self): 442 x = torch.randn(3, 4, requires_grad=True) 443 self.assertONNX(lambda x: torch.clamp(x, min=-0.5, max=0.5), x) 444 445 def test_clip_min(self): 446 x = torch.randn(1, 2, 3, 4, requires_grad=True) 447 self.assertONNX(lambda x: x.clamp(min=-0.1), x) 448 449 def test_clip_max(self): 450 x = torch.randn(1, 2, 3, 4, requires_grad=True) 451 self.assertONNX(lambda x: x.clamp(max=0.1), x) 452 453 def test_hardtanh(self): 454 x = torch.randn(3, 4, requires_grad=True) 455 self.assertONNX(lambda x: torch.nn.Hardtanh(-0.5, 0.5)(x), x) 456 457 def test_full(self): 458 x = torch.randn(3, 4, requires_grad=True) 459 self.assertONNX(lambda x: torch.full(x.shape, 2.0), x) 460 461 def test_full_like(self): 462 x = torch.randn(3, 4, requires_grad=True) 463 self.assertONNX(lambda x: torch.full_like(x, 2), x) 464 465 def test_max(self): 466 x = torch.randn(3, 4, requires_grad=True) 467 y = torch.randn(3, 4, requires_grad=True) 468 self.assertONNX(lambda x, y: torch.max(x, y), (x, y)) 469 470 def test_min(self): 471 x = torch.randn(3, 4, requires_grad=True) 472 y = torch.randn(3, 4, requires_grad=True) 473 self.assertONNX(lambda x, y: torch.min(x, y), (x, y)) 474 475 def test_mean(self): 476 x = torch.randn(1, 2, 3, 4, requires_grad=True) 477 self.assertONNX(lambda x: torch.mean(x), x) 478 479 def test_reduced_mean(self): 480 x = torch.randn(1, 2, 3, 4, requires_grad=True) 481 self.assertONNX(lambda x: torch.mean(x, dim=2), x) 482 483 def test_reduced_mean_keepdim(self): 484 x = torch.randn(1, 2, 3, 4, requires_grad=True) 485 self.assertONNX(lambda x: torch.mean(x, dim=(2, 3), keepdim=True), x) 486 487 def test_mean_dtype(self): 488 x = torch.randn(1, 2, 3, 4, requires_grad=True) 489 self.assertONNX(lambda x: torch.mean(x, dtype=torch.double), x) 490 491 def test_reduced_mean_dtype(self): 492 x = torch.randn(1, 2, 3, 4, requires_grad=True) 493 self.assertONNX(lambda x: torch.mean(x, dim=0, dtype=torch.double), x) 494 495 def test_sum(self): 496 x = torch.randn(1, 2, 3, 4, requires_grad=True) 497 self.assertONNX(lambda x: torch.sum(x), x) 498 499 def test_sum_dtype(self): 500 x = torch.randn(1, 2, 3, 4, requires_grad=True) 501 self.assertONNX(lambda x: torch.sum(x, dtype=torch.double), x) 502 503 def test_reduced_sum_dtype(self): 504 x = torch.randn(1, 2, 3, 4, requires_grad=True) 505 self.assertONNX(lambda x: torch.sum(x, dim=0, dtype=torch.double), x) 506 507 def test_reduced_sum(self): 508 x = torch.randn(1, 2, 3, 4, requires_grad=True) 509 self.assertONNX(lambda x: torch.sum(x, dim=(1, 2)), x) 510 511 def test_reduced_sum_keepdim(self): 512 x = torch.randn(1, 2, 3, 4, requires_grad=True) 513 self.assertONNX(lambda x: torch.sum(x, dim=2, keepdim=True), x) 514 515 def test_prod(self): 516 x = torch.randn(1, 2, 3, 4, requires_grad=True) 517 self.assertONNX(lambda x: torch.prod(x), x) 518 519 def test_reduced_prod(self): 520 x = torch.randn(1, 2, 3, 4, requires_grad=True) 521 self.assertONNX(lambda x: torch.prod(x, dim=2), x) 522 523 def test_reduced_prod_keepdim(self): 524 x = torch.randn(1, 2, 3, 4, requires_grad=True) 525 self.assertONNX(lambda x: torch.prod(x, dim=2, keepdim=True), x) 526 527 def test_prod_dtype(self): 528 x = torch.randn(1, 2, 3, 4, requires_grad=True) 529 self.assertONNX(lambda x: torch.prod(x, dtype=torch.double), x) 530 531 def test_reduced_prod_dtype(self): 532 x = torch.randn(1, 2, 3, 4, requires_grad=True) 533 self.assertONNX(lambda x: torch.prod(x, dim=0, dtype=torch.double), x) 534 535 def test_sqrt(self): 536 x = torch.randn(3, 4, requires_grad=True) 537 self.assertONNX(lambda x: torch.sqrt(x), x) 538 539 def test_rsqrt(self): 540 x = torch.randn(3, 4, requires_grad=True) 541 self.assertONNX(lambda x: torch.rsqrt(x), x) 542 543 def test_equal(self): 544 x = torch.randn(1, 2, 3, 1, requires_grad=False).int() 545 y = torch.randn(1, 4, requires_grad=False).int() 546 self.assertONNX(operator.eq, (x, y)) 547 548 def test_lt(self): 549 x = torch.randn(1, 2, 3, 1, requires_grad=False).int() 550 y = torch.randn(1, 4, requires_grad=False).int() 551 self.assertONNX(operator.lt, (x, y)) 552 553 def test_gt(self): 554 x = torch.randn(1, 2, 3, 1, requires_grad=False).int() 555 y = torch.randn(1, 4, requires_grad=False).int() 556 self.assertONNX(operator.gt, (x, y)) 557 558 def test_le(self): 559 x = torch.randn(3, 4, requires_grad=False).int() 560 y = torch.randn(3, 4, requires_grad=False).int() 561 self.assertONNX(operator.le, (x, y)) 562 563 def test_ge(self): 564 x = torch.randn(3, 4, requires_grad=False).int() 565 y = torch.randn(3, 4, requires_grad=False).int() 566 self.assertONNX(operator.ge, (x, y)) 567 568 def test_exp(self): 569 x = torch.randn(3, 4, requires_grad=True) 570 self.assertONNX(lambda x: x.exp(), x) 571 572 def test_sin(self): 573 x = torch.randn(3, 4, requires_grad=True) 574 self.assertONNX(lambda x: x.sin(), x) 575 576 def test_cos(self): 577 x = torch.randn(3, 4, requires_grad=True) 578 self.assertONNX(lambda x: x.cos(), x) 579 580 def test_tan(self): 581 x = torch.randn(3, 4, requires_grad=True) 582 self.assertONNX(lambda x: x.tan(), x) 583 584 def test_asin(self): 585 x = torch.rand(3, 4, requires_grad=True) 586 self.assertONNX(lambda x: x.asin(), x) 587 588 def test_acos(self): 589 x = torch.rand(3, 4, requires_grad=True) 590 self.assertONNX(lambda x: x.acos(), x) 591 592 def test_slice(self): 593 x = torch.rand(3, 4, requires_grad=True) 594 self.assertONNX(lambda x: x[:, 1:2], x) 595 596 def test_slice_dynamic(self): 597 x = torch.rand(3, 4, requires_grad=True) 598 self.assertONNX(lambda x: x[x.size(0) :, x.size(1) - 3], x, opset_version=10) 599 600 def test_sign(self): 601 x = torch.rand(3, 4, requires_grad=True) 602 self.assertONNX(lambda x: x.sign(), x) 603 604 def test_narrow(self): 605 x = torch.randn(3, 3, requires_grad=True) 606 self.assertONNX(lambda x: torch.narrow(x, 0, 0, 2), x) 607 608 def test_atan(self): 609 x = torch.randn(3, 4, requires_grad=True) 610 self.assertONNX(lambda x: x.atan(), x) 611 612 def test_view_flatten(self): 613 x = torch.randn(1, 2, 3, 4, requires_grad=True) 614 self.assertONNX(lambda x: x.view(x.size()[0], x.numel() // x.size()[0]), x) 615 616 def test_flatten(self): 617 x = torch.randn(1, 2, 3, 4, requires_grad=True) 618 self.assertONNX(lambda x: torch.flatten(x), x) 619 620 def test_flatten2D(self): 621 x = torch.randn(1, 2, 3, 4, requires_grad=True) 622 self.assertONNX(lambda x: torch.flatten(x, 1), x) 623 624 def test_isnan(self): 625 x = torch.tensor([1, float("nan"), 2]) 626 self.assertONNX(lambda x: torch.isnan(x), x) 627 628 def test_argmax(self): 629 x = torch.randn(4, 4, requires_grad=True) 630 self.assertONNX(lambda x: torch.argmax(x, dim=1), x) 631 632 def test_logsoftmax(self): 633 x = torch.randn(1, 2, 3, 4, requires_grad=True) 634 self.assertONNX(nn.LogSoftmax(dim=3), x) 635 636 def test_pow(self): 637 x = torch.randn(1, 2, 3, 4, requires_grad=True) 638 y = torch.randn(1, 2, 3, 4, requires_grad=True) 639 self.assertONNX(lambda x, y: x.pow(y), (x, y)) 640 641 def test_elu(self): 642 x = torch.randn(1, 2, 3, 4, requires_grad=True) 643 self.assertONNX(nn.ELU(), x) 644 645 def test_selu(self): 646 x = torch.randn(1, 2, 3, 4, requires_grad=True) 647 self.assertONNX(nn.SELU(), x) 648 649 def test_repeat(self): 650 x = torch.randn(1, 2, 3, 4, requires_grad=True) 651 self.assertONNX(lambda x: x.repeat(1, 2, 3, 4), x) 652 653 def test_repeat_dim_overflow(self): 654 x = torch.randn(1, 2, requires_grad=True) 655 self.assertONNX(lambda x: x.repeat(1, 2, 3, 4), x) 656 657 def test_norm_p1(self): 658 x = torch.randn(1, 2, 3, 4, requires_grad=True) 659 self.assertONNX(lambda x: x.norm(p=1, dim=2), (x)) 660 661 def test_norm_p2(self): 662 x = torch.randn(1, 2, 3, 4, requires_grad=True) 663 self.assertONNX(lambda x: x.norm(p=2, dim=2), (x)) 664 665 def test_upsample_nearest_scale(self): 666 x = torch.randn(1, 2, 3, 4, requires_grad=True) 667 self.assertONNX( 668 lambda x: nn.functional.interpolate( 669 x, scale_factor=2.0, mode="nearest", recompute_scale_factor=False 670 ), 671 x, 672 ) 673 674 def test_upsample_nearest_scale_default_scale_factor(self): 675 x = torch.randn(1, 2, 3, 4, requires_grad=True) 676 self.assertONNX( 677 lambda x: nn.functional.interpolate(x, scale_factor=2.0, mode="nearest"), x 678 ) 679 680 def test_upsample_nearest_size(self): 681 x = torch.randn(1, 2, 3, 4, requires_grad=True) 682 self.assertONNX( 683 lambda x: nn.functional.interpolate(x, size=16, mode="nearest"), x 684 ) 685 686 def test_unsqueeze(self): 687 x = torch.randn(3, 4, requires_grad=True) 688 self.assertONNX(lambda x: x.unsqueeze(len(x.shape)), x) 689 690 def test_batchnorm_noaffine(self): 691 x = torch.randn(128, 128, 1, 1, requires_grad=True) 692 self.assertONNX( 693 nn.BatchNorm2d(128, affine=False, momentum=0.3), 694 x, 695 keep_initializers_as_inputs=True, 696 ) 697 698 def test_embedding_bags(self): 699 emb_bag = nn.EmbeddingBag(10, 8) 700 input = torch.tensor([1, 2, 3, 4]).long() 701 offset = torch.tensor([0]).long() 702 self.assertONNX( 703 emb_bag, 704 (input, offset), 705 keep_initializers_as_inputs=True, 706 operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, 707 ) 708 709 def test_implicit_expand(self): 710 x = torch.randn(3, 4, requires_grad=True) 711 self.assertONNX(lambda x: x + 1, x) 712 713 def test_reduce_sum_negative_indices(self): 714 x = torch.randn(3, 4, requires_grad=True) 715 self.assertONNX(lambda x: x.sum(-1), x) 716 717 def test_randn(self): 718 x = torch.randn(1, 2, 3, 4) 719 self.assertONNX(lambda x: torch.randn(1, 2, 3, 4) + x, x) 720 721 def test_rand(self): 722 x = torch.rand(1, 2, 3, 4) 723 self.assertONNX(lambda x: torch.rand(1, 2, 3, 4) + x, x) 724 725 def test_rrelu(self): 726 x = torch.randn(1, 2, 3, 4) 727 self.assertONNX(torch.nn.RReLU(), x) 728 729 def test_prelu(self): 730 x = torch.randn(1, 2, 3, 4) 731 self.assertONNX(torch.nn.PReLU(2), x, keep_initializers_as_inputs=True) 732 733 def test_log_sigmoid(self): 734 x = torch.randn(1, 2, 3, 4) 735 self.assertONNX(torch.nn.LogSigmoid(), x) 736 737 def test_linear(self): 738 x = torch.randn(3, 4) 739 self.assertONNX( 740 torch.nn.Linear(4, 5, bias=True), x, keep_initializers_as_inputs=True 741 ) 742 743 def test_empty_like(self): 744 x = torch.randn(5, 8, requires_grad=True) 745 self.assertONNX(lambda x: torch.empty_like(x), x) 746 747 def test_zeros_like(self): 748 x = torch.randn(5, 8, requires_grad=True) 749 self.assertONNX(lambda x: torch.zeros_like(x), x) 750 751 def test_ones_like(self): 752 x = torch.randn(6, 10, requires_grad=True) 753 self.assertONNX(lambda x: torch.ones_like(x), x) 754 755 def test_expand(self): 756 x = torch.randn(6, 1, requires_grad=True) 757 self.assertONNX(lambda x: x.expand(4, 6, 2), x) 758 759 def test_ne(self): 760 x = torch.randn(1, 2, 3, 1, requires_grad=False).int() 761 y = torch.randn(1, 4, requires_grad=False).int() 762 self.assertONNX(lambda x, y: torch.ne(x, y), (x, y)) 763 764 def test_reducemax(self): 765 x = torch.randn(1, 2, 3, 4) 766 self.assertONNX(lambda x: torch.max(x), x) 767 768 def test_reducemin(self): 769 x = torch.randn(1, 2, 3, 4) 770 self.assertONNX(lambda x: torch.min(x), x) 771 772 def test_erf(self): 773 x = torch.randn(1, 2, 3, 4) 774 self.assertONNX(lambda x: x.erf(), x) 775 776 def test_dropout(self): 777 x = torch.randn(3, 4, requires_grad=True) 778 self.assertONNX(lambda x: torch.max(functional.dropout(x, training=False)), x) 779 780 def test_dropout_default(self): 781 x = torch.randn(3, 4, requires_grad=True) 782 self.assertONNX( 783 lambda x: torch.max( 784 functional.dropout( 785 x, 786 ) 787 ), 788 x, 789 ) 790 791 def test_dropout_training(self): 792 x = torch.randn(3, 4, requires_grad=True) 793 self.assertONNX( 794 lambda x: torch.max(functional.dropout(x)), 795 x, 796 training=torch.onnx.TrainingMode.TRAINING, 797 ) 798 799 def test_dropout_opset12(self): 800 x = torch.randn(3, 4, requires_grad=True) 801 self.assertONNX( 802 lambda x: torch.max(functional.dropout(x, training=False)), 803 x, 804 opset_version=12, 805 ) 806 807 def test_dropout_training_opset12(self): 808 x = torch.randn(3, 4, requires_grad=True) 809 self.assertONNX( 810 lambda x: torch.max(functional.dropout(x)), 811 x, 812 opset_version=12, 813 training=torch.onnx.TrainingMode.TRAINING, 814 ) 815 816 def test_nonzero(self): 817 x = torch.tensor( 818 [[[2.0, 2.0], [1.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]]], requires_grad=True 819 ) 820 self.assertONNX(lambda x: torch.nonzero(x), x) 821 822 def test_gather(self): 823 data = torch.randn(3, 4, 3, requires_grad=True) 824 index = torch.tensor([2, 0]).view(1, 2, 1).expand(3, 2, 3) 825 self.assertONNX(lambda data, index: data.gather(1, index), (data, index)) 826 827 def test_gather_opset11(self): 828 data = torch.randn(3, 4, 3, requires_grad=True) 829 index = torch.tensor([2, 0]).view(1, 2, 1).expand(3, 2, 3) 830 self.assertONNX( 831 lambda data, index: data.gather(1, index), (data, index), opset_version=11 832 ) 833 834 def test_scatter_add(self): 835 data = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) 836 indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64) 837 values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]) 838 self.assertONNX( 839 lambda data, index: data.scatter_add(1, indices, values), 840 (data, (indices, values)), 841 ) 842 843 def test_scatter_add_opset11(self): 844 data = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) 845 indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64) 846 values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]) 847 self.assertONNX( 848 lambda data, index: data.scatter_add(1, indices, values), 849 (data, (indices, values)), 850 opset_version=11, 851 ) 852 853 def test_scatter_add_opset16(self): 854 data = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) 855 indices = torch.tensor([[0, 0], [1, 1], [0, 1]], dtype=torch.int64) 856 values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]) 857 self.assertONNX( 858 lambda data, index: data.scatter_add(1, indices, values), 859 (data, (indices, values)), 860 opset_version=16, 861 ) 862 863 def test_master_opset(self): 864 x = torch.randn(2, 3).float() 865 y = torch.randn(2, 3).float() 866 self.assertONNX(operator.add, (x, y), opset_version=10) 867 868 def test_std(self): 869 x = torch.randn(2, 3, 4).float() 870 self.assertONNX( 871 lambda x: torch.std(x, dim=(0, 1), unbiased=True, keepdim=True), x 872 ) 873 874 def test_cumsum(self): 875 x = torch.randn(2, 3, 4, requires_grad=True) 876 self.assertONNX(lambda x: torch.cumsum(x, dim=1), x, opset_version=11) 877 878 def test_dict(self): 879 class MyModel(torch.nn.Module): 880 def forward(self, x_in): 881 x_out = {} 882 x_out["test_key_out"] = torch.add( 883 x_in[list(x_in.keys())[0]], # noqa: RUF015 884 list(x_in.keys())[0], # noqa: RUF015 885 ) 886 return x_out 887 888 x = {torch.tensor(1.0): torch.randn(1, 2, 3)} 889 self.assertONNX(MyModel(), (x, {})) 890 891 def test_dict_str(self): 892 class MyModel(torch.nn.Module): 893 def forward(self, x_in): 894 x_out = {} 895 x_out["test_key_out"] = torch.add(x_in["test_key_in"], 2.0) 896 return x_out 897 898 x = {"test_key_in": torch.randn(1, 2, 3)} 899 self.assertONNX(MyModel(), (x, {})) 900 901 def test_arange_dynamic(self): 902 class TestModel(torch.nn.Module): 903 def forward(self, input): 904 return torch.arange(input.shape[0], input.shape[0] + 5, 0.5) 905 906 input = torch.randn(5, 3, 2) 907 self.assertONNX(TestModel(), input, opset_version=11) 908 909 def test_bitshift(self): 910 class BitshiftModel(torch.nn.Module): 911 def forward(self, input): 912 return input >> 1, input >> 2 913 914 input = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2) 915 self.assertONNX(BitshiftModel(), input, opset_version=11) 916 917 def test_bitwise_and(self): 918 class BiwiseAndModel(torch.nn.Module): 919 def forward(self, input, other): 920 return torch.bitwise_and(input, other), input & 2 921 922 input = torch.randint(0, 100, (2, 3, 4), dtype=torch.uint8) 923 other = torch.randint(-50, 50, (2, 3, 4), dtype=torch.int8) 924 self.assertONNX(BiwiseAndModel(), (input, other), opset_version=18) 925 926 def test_layer_norm_aten(self): 927 model = torch.nn.LayerNorm([10, 10]) 928 x = torch.randn(20, 5, 10, 10) 929 self.assertONNX( 930 model, 931 x, 932 operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, 933 ) 934 935 def test_pixel_shuffle(self): 936 x = torch.randn(2, 8, 3, 4).float() 937 self.assertONNX( 938 lambda x: torch.pixel_shuffle(x, upscale_factor=2), x, opset_version=11 939 ) 940 941 def test_frobenius_norm(self): 942 x = torch.randn(2, 3, 4).float() 943 self.assertONNX(lambda x: torch.norm(x, p="fro", dim=(0, 1), keepdim=True), x) 944 945 def test_unfold(self): 946 x = torch.randn(2, 3, 4, requires_grad=True) 947 self.assertONNX(lambda x: x.unfold(dimension=2, size=2, step=2), x) 948 949 def test_remainder(self): 950 x = torch.randn(2, 3, 4) 951 y = torch.randn(2, 1, 4) 952 self.assertONNX(lambda x, y: torch.remainder(x, y), (x, y)) 953 954 def test_fmod(self): 955 x = torch.randn(2, 3, 4) 956 y = torch.randn(2, 1, 4) 957 self.assertONNX(lambda x, y: torch.fmod(x, y), (x, y), opset_version=10) 958 959 def test_gelu(self): 960 x = torch.randn(2, 3, 4, 5, requires_grad=True) 961 self.assertONNX(lambda x: torch.nn.functional.gelu(x), x) 962 963 def test_unique(self): 964 x = torch.randint(3, (2, 3, 4, 5)).float() 965 self.assertONNX( 966 lambda x: torch.unique( 967 x, dim=0, sorted=True, return_inverse=False, return_counts=True 968 ), 969 x, 970 opset_version=11, 971 ) 972 973 def test_meshgrid(self): 974 x = torch.ones(3, requires_grad=True) 975 y = torch.zeros(4, requires_grad=True) 976 z = torch.ones(5, requires_grad=True) 977 self.assertONNX(lambda x, y, z: torch.meshgrid(x, y, z), (x, y, z)) 978 979 def test_meshgrid_indexing(self): 980 x = torch.ones(3, requires_grad=True) 981 y = torch.zeros(4, requires_grad=True) 982 z = torch.ones(5, requires_grad=True) 983 self.assertONNX( 984 lambda x, y, z: torch.meshgrid(x, y, z, indexing="xy"), 985 (x, y, z), 986 opset_version=9, 987 ) 988 989 def test_topk(self): 990 x = torch.arange(1.0, 6.0, requires_grad=True) 991 k = torch.tensor(3) 992 self.assertONNX(lambda x, k: torch.topk(x, k), (x, k), opset_version=10) 993 994 def test_topk_smallest_unsorted(self): 995 x = torch.arange(1.0, 6.0, requires_grad=True) 996 k = torch.tensor(3) 997 self.assertONNX( 998 lambda x, k: torch.topk(x, k, largest=False, sorted=False), 999 (x, k), 1000 opset_version=11, 1001 ) 1002 1003 def test_baddbmm(self): 1004 x = torch.randn(10, 3, 5) 1005 b1 = torch.randn(10, 3, 4) 1006 b2 = torch.randn(10, 4, 5) 1007 self.assertONNX(lambda x, b1, b2: torch.baddbmm(x, b1, b2), (x, b1, b2)) 1008 1009 def test_round(self): 1010 x = torch.tensor([0.9920, -1.0362, -1.5000, 2.5000], requires_grad=True) 1011 self.assertONNX(lambda x: torch.round(x), x, opset_version=11) 1012 1013 def test_dim(self): 1014 x = torch.ones((2, 2), requires_grad=True) 1015 self.assertONNX(lambda x: torch.scalar_tensor(x.dim()), x) 1016 1017 @skipIfNoLapack 1018 def test_det(self): 1019 x = torch.randn(2, 3, 5, 5, device=torch.device("cpu")) 1020 self.assertONNX(lambda x: torch.det(x), x, opset_version=11) 1021 self.assertONNX(lambda x: torch.linalg.det(x), x, opset_version=11) 1022 1023 def test_softmaxcrossentropy(self): 1024 x = torch.randn(3, 5) 1025 y = torch.empty(3, dtype=torch.long).random_(5) 1026 self.assertONNX(torch.nn.CrossEntropyLoss(), (x, y), opset_version=12) 1027 1028 def test_softmaxcrossentropy_ignore_index(self): 1029 x = torch.randn(3, 5) 1030 y = torch.empty(3, dtype=torch.long).random_(5) 1031 self.assertONNX( 1032 torch.nn.CrossEntropyLoss(ignore_index=1), (x, y), opset_version=12 1033 ) 1034 1035 def test_softmaxcrossentropy_weights(self): 1036 x = torch.randn(3, 5) 1037 y = torch.empty(3, dtype=torch.long).random_(5) 1038 self.assertONNX( 1039 torch.nn.CrossEntropyLoss(weight=torch.randn(5)), (x, y), opset_version=12 1040 ) 1041 1042 def test_softmaxcrossentropy_3d(self): 1043 x = torch.randn(3, 5, 2) 1044 y = torch.empty(3, 2, dtype=torch.long).random_(5) 1045 self.assertONNX(torch.nn.CrossEntropyLoss(), (x, y), opset_version=12) 1046 1047 def test_softmaxcrossentropy_3d_none(self): 1048 x = torch.randn(3, 5, 2) 1049 y = torch.empty(3, 2, dtype=torch.long).random_(5) 1050 self.assertONNX( 1051 torch.nn.CrossEntropyLoss(reduction="none"), (x, y), opset_version=12 1052 ) 1053 1054 def test_softmaxcrossentropy_4d(self): 1055 x = torch.randn(3, 5, 2, 1) 1056 y = torch.empty(3, 2, 1, dtype=torch.long).random_(5) 1057 self.assertONNX(torch.nn.CrossEntropyLoss(), (x, y), opset_version=12) 1058 1059 def test_lstm_none_sequence_lens(self): 1060 """Test symbolic shape inference for LSTM when the input sequence_lens = None.""" 1061 input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE) 1062 h0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE) 1063 c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE) 1064 1065 class LSTMModel(torch.nn.Module): 1066 def __init__(self) -> None: 1067 super().__init__() 1068 self.rnn = torch.nn.LSTM( 1069 RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False 1070 ) 1071 1072 def forward(self, x, h0, c0): 1073 a, b = self.rnn(x, (h0, c0)) 1074 return torch.ones(b[0].shape) 1075 1076 self.assertONNX( 1077 LSTMModel(), 1078 (input, h0, c0), 1079 input_names=["x", "y"], 1080 dynamic_axes={"x": {0: "batch"}}, 1081 opset_version=12, 1082 ) 1083 1084 def test_dynamic_axes_add(self): 1085 m1 = torch.randn(2, 3, requires_grad=True) 1086 m2 = torch.randn(2, 1, requires_grad=True) 1087 self.assertONNX( 1088 lambda x, y: torch.add(x, y), 1089 (m1, m2), 1090 input_names=["input_1", "input_2"], 1091 dynamic_axes={"input_1": {1: "dim_1"}, "input_2": {1: "dim_2"}}, 1092 opset_version=12, 1093 ) 1094 1095 def test_dynamic_axes_add_inputs_same_symbolic_shape(self): 1096 m1 = torch.randn(2, 3, requires_grad=True) 1097 self.assertONNX( 1098 lambda x: torch.add(x, x), 1099 (m1,), 1100 input_names=["input_1"], 1101 dynamic_axes={"input_1": {1: "dim_1"}}, 1102 opset_version=12, 1103 ) 1104 1105 def test_dynamic_axes_matmul(self): 1106 m1 = torch.randn(2, 2, 4, requires_grad=True) 1107 m2 = torch.randn(2, 4, 3, requires_grad=True) 1108 self.assertONNX( 1109 lambda x, y: torch.matmul(x, y), 1110 (m1, m2), 1111 input_names=["input_1", "input_2"], 1112 dynamic_axes={"input_1": {1: "dim_0"}, "input_2": {2: "dim_1"}}, 1113 opset_version=12, 1114 ) 1115 1116 def test_dynamic_axes_reduce_mean(self): 1117 m1 = torch.randn(2, 3, 4, requires_grad=True) 1118 self.assertONNX( 1119 lambda x: torch.mean(x, dim=1), 1120 (m1), 1121 input_names=["input"], 1122 dynamic_axes={"input": {1: "dim_1", 2: "dim_2"}}, 1123 opset_version=12, 1124 ) 1125 1126 def test_dynamic_axes_unchange(self): 1127 """Test ProcessUnchangeNode in symbolic shape inference.""" 1128 m1 = torch.randn(2, 3, requires_grad=True) 1129 self.assertONNX( 1130 lambda x: torch.softmax(x, dim=0), 1131 (m1,), 1132 input_names=["input"], 1133 dynamic_axes={"input": {1: "dim_1"}}, 1134 opset_version=12, 1135 ) 1136 1137 def test_aten_embedding_1(self): 1138 _onnx_opset_version = 12 1139 1140 @parse_args("v", "v", "i", "b", "b") 1141 def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse): 1142 custom_attributes_json = ( 1143 "{" 1144 f'"padding_idx":{str(padding_idx)},' 1145 f'"scale_grad_by_freq":{str(scale_grad_by_freq).lower()},' 1146 f'"sparse":{str(sparse).lower()}' 1147 "}" 1148 ) 1149 output = g.at( 1150 "embedding", 1151 weight, 1152 indices, 1153 custom_attributes_json_s=custom_attributes_json, 1154 ) 1155 return output 1156 1157 torch.onnx.register_custom_op_symbolic( 1158 "::embedding", embedding, _onnx_opset_version 1159 ) 1160 1161 class Model(torch.nn.Module): 1162 def __init__(self) -> None: 1163 super().__init__() 1164 self.emb = torch.nn.Embedding(4, 8) 1165 1166 def forward(self, x, y): 1167 res = self.emb(x) 1168 res = res + y 1169 return torch.ones(res.shape[0]) 1170 1171 model = Model() 1172 x = torch.ones(32, dtype=torch.long) 1173 y = torch.randn(1, 8) 1174 self.assertONNX(model, (x, y), opset_version=_onnx_opset_version) 1175 1176 torch.onnx.unregister_custom_op_symbolic("::embedding", _onnx_opset_version) 1177 1178 # This is test_aten_embedding_1 with shape inference on custom symbolic aten::embedding. 1179 def test_aten_embedding_2(self): 1180 _onnx_opset_version = 12 1181 1182 @parse_args("v", "v", "i", "b", "b") 1183 def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse): 1184 custom_attributes_json = ( 1185 "{" 1186 f'"padding_idx":{str(padding_idx)},' 1187 f'"scale_grad_by_freq":{str(scale_grad_by_freq).lower()},' 1188 f'"sparse":{str(sparse).lower()}' 1189 "}" 1190 ) 1191 output = g.at( 1192 "embedding", 1193 weight, 1194 indices, 1195 custom_attributes_json_s=custom_attributes_json, 1196 ) 1197 1198 # do shape inference and set it via setType 1199 indices_shape = _get_tensor_sizes(indices) 1200 if indices_shape is not None and hasattr(weight.type(), "with_sizes"): 1201 output_type = weight.type().with_sizes( 1202 indices_shape + [_get_tensor_dim_size(weight, 1)] 1203 ) 1204 output.setType(output_type) 1205 return output 1206 1207 torch.onnx.register_custom_op_symbolic( 1208 "::embedding", embedding, _onnx_opset_version 1209 ) 1210 1211 class Model(torch.nn.Module): 1212 def __init__(self) -> None: 1213 super().__init__() 1214 self.emb = torch.nn.Embedding(4, 8) 1215 1216 def forward(self, x, y): 1217 res = self.emb(x) 1218 res = res + y 1219 return torch.ones(res.shape[0]) 1220 1221 model = Model() 1222 x = torch.ones(32, dtype=torch.long) 1223 y = torch.randn(1, 8) 1224 self.assertONNX( 1225 model, 1226 (x, y), 1227 opset_version=_onnx_opset_version, 1228 input_names=["input_1", "input_2"], 1229 dynamic_axes={"input_1": {0: "dim_0"}, "input_2": {0: "dim_1", 1: "dim_2"}}, 1230 keep_initializers_as_inputs=False, 1231 operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, 1232 ) 1233 1234 torch.onnx.unregister_custom_op_symbolic("::embedding", _onnx_opset_version) 1235 1236 # Without shapeValueMap, the onnx graph looks like: 1237 # graph(%0 : Float(*, 1, 128, 1, strides=[128, 128, 1, 1], requires_grad=0, device=cpu)): 1238 # %2 : Long(4, strides=[1], device=cpu) = onnx::Shape(%0) 1239 # %4 : Long(device=cpu) = onnx::Constant[value={0}]() 1240 # %5 : Long(device=cpu) = onnx::Gather[axis=0](%2, %4) 1241 # %6 : Long(device=cpu) = onnx::Constant[value={1}]() 1242 # %7 : Long(device=cpu) = onnx::Constant[value={2}]() 1243 # %8 : Long(device=cpu) = onnx::Constant[value={-1}]() 1244 # %9 : int[] = prim::ListConstruct(%5, %6, %7, %8) 1245 # %10 : Float(*, *, *, *, strides=[128, 128, 64, 1], requires_grad=0, device=cpu) = onnx::Reshape(%0, %9) 1246 # ... 1247 # With shapeValueMap, it becomes: 1248 # ... 1249 # %10 : Float(*, 1, 2, 64, strides=[128, 128, 64, 1], requires_grad=0, device=cpu) = onnx::Reshape(%0, %9) 1250 # ... 1251 def test_shape_value_map(self): 1252 class RSoftMax(torch.nn.Module): 1253 def __init__(self, radix, cardinality): 1254 super().__init__() 1255 self.radix = radix 1256 self.cardinality = cardinality 1257 1258 def forward(self, x): 1259 batch = x.size(0) 1260 x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 1261 x = F.softmax(x, dim=1) 1262 x = x.reshape(batch, -1) 1263 return x 1264 1265 radix = 2 1266 cardinality = 1 1267 x = torch.randn(10, 1, 128, 1) 1268 self.assertONNX( 1269 RSoftMax(radix, cardinality), 1270 (x,), 1271 input_names=["x"], 1272 dynamic_axes={"x": {0: "dim_0"}}, 1273 ) 1274 1275 1276if __name__ == "__main__": 1277 no_onnx_dep_flag = "--no-onnx" 1278 _onnx_dep = no_onnx_dep_flag not in common_utils.UNITTEST_ARGS 1279 if no_onnx_dep_flag in common_utils.UNITTEST_ARGS: 1280 common_utils.UNITTEST_ARGS.remove(no_onnx_dep_flag) 1281 onnx_test_flag = "--produce-onnx-test-data" 1282 _onnx_test = onnx_test_flag in common_utils.UNITTEST_ARGS 1283 if onnx_test_flag in common_utils.UNITTEST_ARGS: 1284 common_utils.UNITTEST_ARGS.remove(onnx_test_flag) 1285 if _onnx_test: 1286 _onnx_dep = True 1287 import onnx_test_common 1288 1289 for d in glob.glob( 1290 os.path.join(onnx_test_common.pytorch_operator_dir, "test_operator_*") 1291 ): 1292 shutil.rmtree(d) 1293 common_utils.run_tests() 1294