Searched refs:is_bmm (Results 1 – 4 of 4) sorted by relevance
/aosp_15_r20/external/pytorch/torch/_inductor/fx_passes/ |
H A D | pad_mm.py | 455 is_bmm = op is torch.ops.aten.bmm 464 is_bmm=is_bmm, 468 if is_bmm: 481 is_bmm=is_bmm, 485 if is_bmm: 685 def pad_mat1(mat1, *, m_padded_length, k_padded_length, is_bmm=False): argument 691 if is_bmm: 695 return pad_dim(mat1, m_padded_length, 0 if not is_bmm else 1) 698 return pad_dim(mat1, k_padded_length, 1 if not is_bmm else 2) 701 def pad_mat2(mat2, *, k_padded_length, n_padded_length, is_bmm=False): argument [all …]
|
/aosp_15_r20/external/pytorch/test/inductor/ |
H A D | test_unbacked_symints.py | 235 def fn(x, w, repeats, is_bmm): argument 241 if is_bmm:
|
/aosp_15_r20/external/pytorch/aten/src/ATen/native/ |
H A D | LinearAlgebra.cpp | 284 …ch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, bool is_bmm, const std::option… in common_checks_baddbmm_bmm() argument 310 if (!is_bmm) { in common_checks_baddbmm_bmm() 1637 template <typename scalar_t, bool is_bmm> 1668 if (is_bmm) { in baddbmm_cpu_kernel()
|
/aosp_15_r20/external/pytorch/torch/ |
H A D | _meta_registrations.py | 3686 def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None): argument 3709 if not is_bmm and self_baddbmm is not None:
|