1from __future__ import annotations 2 3from dataclasses import dataclass 4from typing import Sequence, TYPE_CHECKING 5 6import torchgen.api.ufunc as ufunc 7from torchgen.api.translate import translate 8from torchgen.api.types import ( 9 BaseCType, 10 Binding, 11 CType, 12 Expr, 13 NamedCType, 14 opmath_t, 15 scalar_t, 16 StructuredImplSignature, 17 VectorizedCType, 18) 19from torchgen.context import with_native_function 20from torchgen.model import ( 21 Argument, 22 BaseTy, 23 BaseType, 24 DispatchKey, 25 NativeFunctionsGroup, 26 ScalarType, 27 UfuncKey, 28) 29from torchgen.utils import OrderedSet 30 31 32if TYPE_CHECKING: 33 from torchgen.api.ufunc import UfunctorBindings 34 35 36# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 37# 38# CUDA STUFF 39# 40# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 41 42# NB: not bothering to generate dispatch stub forward declaration in header, 43# we can just paste it whereever necessary 44 45# TODO: use BackendIndex 46# dispatch_key: DispatchKey # only CPU/CUDA right now 47 48 49# Represents functors for implementing CUDA ufuncs. 50# Functors are templated by scalar_t because when USERS instantiate functors 51# they are templated. A functor looks something like this: 52# 53# template <typename scalar_t> 54# struct CUDAFunctorOnSelf_add { 55# using opmath_t = at::opmath_type<scalar_t>; 56# opmath_t other_; 57# opmath_t alpha_; 58# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) 59# : other_(other), alpha_(alpha) {} 60# __device__ scalar_t operator()(scalar_t self) { 61# return ufunc::add(static_cast<opmath_t>(self), other_, alpha_); 62# } 63# }; 64# 65@dataclass(frozen=True) 66class UfunctorSignature: 67 g: NativeFunctionsGroup 68 scalar_tensor_idx: int | None 69 name: str 70 71 def arguments(self) -> UfunctorBindings: 72 return ufunc.ufunctor_arguments( 73 self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t 74 ) 75 76 def fields(self) -> list[Binding]: 77 # fields are renamed to have a trailing underscore, as is conventional 78 return [b.rename(f"{b.name}_") for b in self.arguments().ctor] 79 80 def returns_type(self) -> CType: 81 # TODO: don't hardcode; return type will be inferred based on tags on 82 # the native function 83 return BaseCType(scalar_t) 84 85 def decl_fields(self) -> str: 86 return "\n".join(f"{f.type} {f.name};" for f in self.fields()) 87 88 def inline_defn_ctor(self) -> str: 89 args_str = ", ".join(a.decl() for a in self.arguments().ctor) 90 # NB: hypothetically could do this with translate but the 91 # transition here is very regular 92 init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor) 93 return f"{self.name}({args_str}) : {init_str} {{}}" 94 95 def decl_apply(self) -> str: 96 args_str = ", ".join(a.decl() for a in self.arguments().apply) 97 return f"{self.returns_type().cpp_type()} operator()({args_str}) const" 98 99 100@dataclass(frozen=True) 101class UfuncSignature: 102 g: NativeFunctionsGroup 103 name: str 104 compute_t: CType 105 106 def arguments(self) -> list[Binding]: 107 return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t) 108 109 def call(self, ctx: Sequence[Binding | Expr]) -> str: 110 return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})" 111 112 113# steps: 114# 1. take the functional signature 115# 2. use api.ufunc to convert it to template signature. this establishes 116# the type of the template function 117# 3. use api.ufunc (II) to generate a split struct / operator() signature. 118# this establish context in which we call the template signature 119# 120# StructuredImplSignature context 121# ~> functor constructor sig 122# 123# Functor constructor context 124# ~> functor fields sig 125# 126# Functor apply context (functor fields + functor apply sig) 127# ~> template sig 128# 129 130 131def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool: 132 num_tensors = sum( 133 1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like() 134 ) 135 return num_tensors == 2 136 137 138def compute_ufunc_cuda_functors( 139 g: NativeFunctionsGroup, 140) -> tuple[dict[ScalarType, dict[UfuncKey, UfunctorSignature]], str]: 141 # First, build the functors. 142 ufunctor_sigs: dict[ScalarType, dict[UfuncKey, UfunctorSignature]] = {} 143 ufunctors: list[str] = [] 144 loops = g.out.ufunc_inner_loop 145 scalar_tensor_idx_lookup = { 146 UfuncKey.CUDAFunctorOnSelf: 1, 147 UfuncKey.CUDAFunctorOnOther: 0, 148 UfuncKey.CUDAFunctor: None, 149 } 150 if eligible_for_binary_scalar_specialization(g): 151 keys = [ 152 UfuncKey.CUDAFunctorOnSelf, 153 UfuncKey.CUDAFunctorOnOther, 154 UfuncKey.CUDAFunctor, 155 ] 156 else: 157 keys = [UfuncKey.CUDAFunctor] 158 for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]: 159 assert k not in loops, f"cannot use {k} on non-binary function" 160 for k in keys: 161 # If the key was directly defined, skip functor codegen; we assume the 162 # user already done it for us 163 if k in loops: 164 ufunctor_sig = UfunctorSignature( 165 g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name 166 ) 167 for dtype in loops[k].supported_dtypes: 168 ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig 169 continue 170 171 # Note [ScalarOnly and Generic must match names for CUDA] 172 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 173 # Otherwise, look in ANY of the generic entries. For simplicity of 174 # codegen, both ScalarOnly and Generic are defined, the ufunc name 175 # must match (if they didn't match, we'd have to generate distinct 176 # functors per dtype, which is awful, so we're not going to do it unless 177 # someone really forces us to) 178 ufunc_name = None 179 supported_dtypes: OrderedSet[ScalarType] = OrderedSet() 180 for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]: 181 if lk not in loops: 182 continue 183 if ufunc_name is None: 184 ufunc_name = loops[lk].name 185 else: 186 # See Note [ScalarOnly and Generic must match names for CUDA] 187 assert ( 188 ufunc_name == loops[lk].name 189 ), "ScalarOnly and Generic must have same ufunc name" 190 supported_dtypes |= loops[lk].supported_dtypes 191 assert ufunc_name is not None 192 193 name = f"{k}_{ufunc_name}" 194 ufunctor_sig = UfunctorSignature( 195 g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name 196 ) 197 for dtype in supported_dtypes: 198 ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig 199 200 ufunc_sig = UfuncSignature( 201 g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t) 202 ) 203 apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply 204 ufunctors.append( 205 f""" 206template <typename scalar_t> 207struct {ufunctor_sig.name} {{ 208 using opmath_t = at::opmath_type<scalar_t>; 209 {ufunctor_sig.decl_fields()} 210 {ufunctor_sig.inline_defn_ctor()} 211 __device__ {ufunctor_sig.decl_apply()} {{ 212 return {ufunc_sig.call(apply_ctx)}; 213 }} 214}}; 215""" 216 ) 217 218 return ufunctor_sigs, "\n".join(ufunctors) 219 220 221@dataclass(frozen=True) 222class BinaryScalarSpecializationConfig: 223 scalar_idx: int 224 ctor_tensor: str 225 ufunc_key: UfuncKey 226 227 228BinaryScalarSpecializationConfigs = [ 229 BinaryScalarSpecializationConfig( 230 scalar_idx=0, 231 ctor_tensor="self", 232 ufunc_key=UfuncKey.CUDAFunctorOnOther, 233 ), 234 BinaryScalarSpecializationConfig( 235 scalar_idx=1, 236 ctor_tensor="other", 237 ufunc_key=UfuncKey.CUDAFunctorOnSelf, 238 ), 239] 240 241 242def compute_ufunc_cuda_dtype_body( 243 g: NativeFunctionsGroup, 244 dtype: ScalarType, 245 inner_loops: dict[UfuncKey, UfunctorSignature], 246 parent_ctx: Sequence[Binding], 247) -> str: 248 body = "using opmath_t = at::opmath_type<scalar_t>;" 249 body += "if (false) {}\n" # for ease of codegen 250 for config in BinaryScalarSpecializationConfigs: 251 if config.ufunc_key not in inner_loops: 252 continue 253 ufunctor_sig = inner_loops[config.ufunc_key] 254 scalar_idx = config.scalar_idx + 1 255 # Make a copy and at the same time widen the type (not permissible 256 # without copy; we don't want to mutate the input argument anyway) 257 ctx: list[Expr | Binding] = list(parent_ctx) 258 ctx.append( 259 Expr( 260 expr=f"iter.scalar_value<opmath_t>({scalar_idx})", 261 type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)), 262 ) 263 ) 264 ufunctor_ctor_exprs_str = ", ".join( 265 a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor) 266 ) 267 268 # NB: ufunctor must be allocated before iter.remove_operand is called, 269 # as it relies on iter 270 body += f"""\ 271else if (iter.is_cpu_scalar({scalar_idx})) {{ 272 {ufunctor_sig.name}<scalar_t> ufunctor({ufunctor_ctor_exprs_str}); 273 iter.remove_operand({scalar_idx}); 274 gpu_kernel(iter, ufunctor); 275}}""" 276 277 ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor] 278 ufunctor_ctor_exprs_str = ", ".join( 279 a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor) 280 ) 281 body += f""" 282else {{ 283 gpu_kernel(iter, {ufunctor_sig.name}<scalar_t>({ufunctor_ctor_exprs_str})); 284}} 285 """ 286 return body 287 288 289@with_native_function 290def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str: 291 # First, build the functors, indexing them by dtype 292 ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g) 293 294 # Next, build the conditionals 295 sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA)) 296 dtype_cases = [] 297 for dtype, inner_ufunc_sigs in ufunctor_sigs.items(): 298 dtype_cases.append( 299 f""" 300AT_DISPATCH_CASE(at::ScalarType::{dtype}, 301 [&]() {{ 302 {compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())} 303 }} 304) 305""" 306 ) 307 308 dtype_cases_str = "\n".join(dtype_cases) 309 310 stub_sig = StubSignature(g) 311 312 return f""" 313{ufunctors} 314 315{stub_sig.type_defn()}; 316{stub_sig.dispatch_decl()}; 317 318{stub_sig.kernel_defn()} {{ 319 AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}", 320 {dtype_cases_str} 321 ); 322}} 323REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name}); 324 325{sig.defn()} {{ 326 {stub_sig.direct_call(sig.arguments())}; 327}} 328""" 329 330 331# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 332# 333# CPU STUFF 334# 335# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 336 337 338@dataclass(frozen=True) 339class StubSignature: 340 g: NativeFunctionsGroup 341 342 @property 343 def name(self) -> str: 344 return f"{str(self.g.functional.func.name.name)}_stub" 345 346 @property 347 def kernel_name(self) -> str: 348 return f"{str(self.g.functional.func.name.name)}_kernel" 349 350 @property 351 def type_name(self) -> str: 352 return f"{str(self.g.functional.func.name.name)}_fn" 353 354 def arguments(self) -> list[Binding]: 355 return ufunc.stub_arguments(self.g) 356 357 def type(self) -> str: 358 cpp_args = self.arguments() 359 return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})" 360 361 def dispatch_decl(self) -> str: 362 return f"DECLARE_DISPATCH({self.type_name}, {self.name})" 363 364 def dispatch_defn(self) -> str: 365 return f"DEFINE_DISPATCH({self.name})" 366 367 def kernel_defn(self) -> str: 368 return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})" 369 370 def type_defn(self) -> str: 371 return f"using {self.type_name} = {self.type()}" 372 373 # must be called from context where this is TensorIteratorBase* 374 def call(self, ctx: Sequence[Binding]) -> str: 375 return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})" 376 377 # used in CUDA to skip the unnecessary dynamic dispatch 378 def direct_call(self, ctx: Sequence[Binding]) -> str: 379 return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})" 380 381 382@with_native_function 383def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str: 384 stub_sig = StubSignature(g) 385 sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU)) 386 387 return f""" 388{stub_sig.type_defn()}; 389{stub_sig.dispatch_decl()}; 390{stub_sig.dispatch_defn()}; 391 392{sig.defn()} {{ 393 {stub_sig.call(sig.arguments())}; 394}} 395""" 396 397 398def compute_ufunc_cpu_dtype_body( 399 g: NativeFunctionsGroup, 400 dtype: ScalarType, 401 inner_loops: dict[UfuncKey, UfuncSignature], 402 parent_ctx: Sequence[Binding], 403) -> str: 404 assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}" 405 assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector} 406 scalar_loop = inner_loops[UfuncKey.CPUScalar] 407 vec_loop = None 408 if UfuncKey.CPUVector in inner_loops: 409 vec_loop = inner_loops[UfuncKey.CPUVector] 410 411 # NB: We DON'T use translate here, because translate is 412 # incapable of CSE'ing the scalar accesses in case it is also 413 # used by Vectorized; also, the unpacking here is very simple 414 # and only affects Scalar; everything else is implicitly captured 415 # by the lambda 416 417 # Setup scalar in scope 418 body = [] 419 ctx = [] 420 for b in parent_ctx: 421 if isinstance(b.argument, Argument) and b.argument.type != BaseType( 422 BaseTy.Scalar 423 ): 424 continue 425 body.append(f"auto _s_{b.name} = {b.name}.to<scalar_t>();") 426 ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t)))) 427 if vec_loop is not None: 428 for b in parent_ctx: 429 if isinstance(b.argument, Argument) and b.argument.type != BaseType( 430 BaseTy.Scalar 431 ): 432 continue 433 body.append( 434 f"auto _v_{b.name} = at::vec::Vectorized<scalar_t>(_s_{b.name});" 435 ) 436 ctx.append( 437 Expr( 438 f"_v_{b.name}", 439 NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))), 440 ) 441 ) 442 443 # Setup lambda signature 444 # NB: simplified version of ufunctor_arguments 445 scalar_bindings = [] 446 vec_bindings = [] 447 for a in g.functional.func.arguments.flat_non_out: 448 if not a.type.is_tensor_like(): 449 continue 450 assert a.type == BaseType(BaseTy.Tensor) 451 scalar_bindings.append( 452 Binding( 453 name=a.name, 454 nctype=NamedCType(a.name, BaseCType(scalar_t)), 455 argument=a, 456 ) 457 ) 458 if vec_loop is not None: 459 vec_bindings.append( 460 Binding( 461 name=a.name, 462 nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))), 463 argument=a, 464 ) 465 ) 466 467 def with_ctx(b: Sequence[Binding]) -> list[Expr | Binding]: 468 r: list[Expr | Binding] = [] 469 r.extend(ctx) 470 r.extend(b) 471 return r 472 473 body_str = "\n".join(body) 474 if vec_loop is not None: 475 return f""" 476{body_str} 477cpu_kernel_vec(iter, 478 [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}, 479 [=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }} 480); 481""" 482 else: 483 return f""" 484{body_str} 485cpu_kernel(iter, 486 [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }} 487); 488""" 489 490 491@with_native_function 492def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str: 493 stub_sig = StubSignature(g) 494 495 # Reindex the ufunc by dtypes; processing generic/scalaronly as well 496 loops = g.out.ufunc_inner_loop 497 ufunc_sigs: dict[ScalarType, dict[UfuncKey, UfuncSignature]] = {} 498 for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]: 499 lks = [] 500 # ORDER MATTERS: this specifies overriding precedence 501 if k in loops: # should happen rarely 502 lks.append(k) 503 if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar: 504 lks.append(UfuncKey.ScalarOnly) 505 if UfuncKey.Generic in loops: 506 lks.append(UfuncKey.Generic) 507 # TODO: don't hardcode ufunc:: namespace here, should be centralized smh 508 for lk in lks: 509 for dtype in loops[lk].supported_dtypes: 510 compute_t: CType 511 if k is UfuncKey.CPUScalar: 512 compute_t = BaseCType(scalar_t) 513 elif k is UfuncKey.CPUVector: 514 compute_t = VectorizedCType(BaseCType(scalar_t)) 515 else: 516 raise AssertionError 517 inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {}) 518 if k not in inner_ufunc_sigs: 519 inner_ufunc_sigs[k] = UfuncSignature( 520 g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t 521 ) 522 523 # Build the conditionals 524 dtype_cases = [] 525 for dtype, inner_ufunc_sigs in ufunc_sigs.items(): 526 dtype_cases.append( 527 f""" 528AT_DISPATCH_CASE(at::ScalarType::{dtype}, 529 [&]() {{ 530 {compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())} 531 }} 532) 533""" 534 ) 535 536 dtype_cases_str = "\n".join(dtype_cases) 537 return f""" 538namespace {{ 539 540{stub_sig.kernel_defn()} {{ 541 AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}", 542 {dtype_cases_str} 543 ); 544}} 545 546}} // anonymous namespace 547 548{stub_sig.type_defn()}; 549{stub_sig.dispatch_decl()}; 550REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name}); 551""" 552