import contextlib import threading from typing import Callable, Generator, Iterable, Optional, Union from .custom_ops import custom_op from .infer_schema import infer_schema def triton_op( name: str, fn: Optional[Callable] = None, /, *, mutates_args: Union[str, Iterable[str]], schema: Optional[str] = None, ) -> Callable: """Create a custom operator whose implementation is backed by 1+ triton kernels. Use this instead of :func:`torch.library.custom_op` when the implementation consists of 1+ triton kernels. :func:`torch.library.custom_op` treats custom operators as opaque (:func:`torch.compile` and :func:`torch.export.export` will never trace into them), but ``triton_op`` makes the implementation visible to these subsystems, allowing them to optimize the triton kernel(s). Note that ``fn`` must only consist of calls to PyTorch-understood operators and triton kernels. Any triton kernels called inside ``fn`` must be wrapped in a call to :func:`torch._library.capture_triton``. Args: name (str): A name for the custom op that looks like "{namespace}::{name}", e.g. "mylib::my_linear". The name is used as the op's stable identifier in PyTorch subsystems (e.g. torch.export, FX graphs). To avoid name collisions, please use your project name as the namespace; e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace. mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates. This MUST be accurate, otherwise, the behavior is undefined. If "unknown", it pessimistically assumes that all inputs to the operator are being mutated. schema (None | str): A schema string for the operator. If None (recommended) we'll infer a schema for the operator from its type annotations. We recommend letting us infer a schema unless you have a specific reason not to. Example: "(Tensor x, int y) -> (Tensor, Tensor)". Example:: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) >>> import torch >>> from torch._library import triton_op, capture_triton >>> >>> import triton >>> from triton import language as tl >>> >>> @triton.jit >>> def add_kernel( >>> in_ptr0, >>> in_ptr1, >>> out_ptr, >>> n_elements, >>> BLOCK_SIZE: "tl.constexpr", >>> ): >>> pid = tl.program_id(axis=0) >>> block_start = pid * BLOCK_SIZE >>> offsets = block_start + tl.arange(0, BLOCK_SIZE) >>> mask = offsets < n_elements >>> x = tl.load(in_ptr0 + offsets, mask=mask) >>> y = tl.load(in_ptr1 + offsets, mask=mask) >>> output = x + y >>> tl.store(out_ptr + offsets, output, mask=mask) >>> >>> @triton_op("mylib::add", mutates_args={}) >>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: >>> output = torch.empty_like(x) >>> n_elements = output.numel() >>> >>> def grid(meta): >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) >>> >>> # NB: we need to wrap the triton kernel in a call to capture_triton >>> capture_triton(add_kernel)[grid](x, y, output, n_elements, 16) >>> return output >>> >>> @torch.compile >>> def f(x, y): >>> return add(x, y) >>> >>> x = torch.randn(3, device="cuda") >>> y = torch.randn(3, device="cuda") >>> >>> z = f(x, y) >>> assert torch.allclose(z, x + y) """ def dec(fn: Callable) -> Callable: def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def] # Optimization: we're passing regular Tensors into the triton kernel, so # no need to go through HOP dispatch with set_capture_triton_enabled(False): return fn(*args, **kwargs) result = custom_op( name, backend_fn, mutates_args=mutates_args, schema=infer_schema(fn, mutates_args=mutates_args), ) from .._subclasses.functional_tensor import FunctionalTensorMode # We require that the user pass us a function that is make_fx traceable, # so we can just register it as the Fake/meta kernel. result.register_fake(fn) # We decompose the operator when FunctionalTensorMode is active. # The goal is to decompose the operator in AOTDispatcher. # - With torch.compile, this means that the backend (usually Inductor) # can see a call to the triton kernel(s) and so it can directly optimize # them by inlining them into the lowering process. # - With post-dispatch torch.export, this means that there will # be a call(s) to the triton_kernel_wrapper_functional HOP in the # graph (that we have yet to figure out how to serialize). def functional_decomp( # type: ignore[no-untyped-def] mode, _, types, args, kwargs ): with mode: return fn(*args, **kwargs) result.register_torch_dispatch(FunctionalTensorMode, functional_decomp) return result if fn is None: return dec else: return dec(fn) capture_triton_enabled = threading.local() capture_triton_enabled_default = True @contextlib.contextmanager def set_capture_triton_enabled(enabled: bool) -> Generator[None, None, None]: """If triton kernels annotated with @capture_triton should dispatch via HOP or go straight to the triton kernel execution. We have this switch because eager-mode performance of HOP dispatch is slow enough to matter (~1ms) and we know that capture_triton isn't necessary in some situations (eager-mode with regular Tensors) """ try: prev = is_capture_triton_enabled() capture_triton_enabled.value = enabled yield finally: capture_triton_enabled.value = prev def is_capture_triton_enabled() -> bool: return getattr(capture_triton_enabled, "value", capture_triton_enabled_default) def capture_triton(triton_kernel: Callable, /) -> Callable: """Allows capture of a triton kernel into a graph via make_fx or non-strict export (coming soon). These technologies perform Dispatcher-based tracing (via ``__torch_dispatch__``) and cannot see calls to raw triton kernels. The ``capture_triton`` API returns a new callable that can actually be traced into a graph. Examples: >>> # xdoctest: +SKIP >>> import torch >>> import triton >>> from triton import language as tl >>> from torch.fx.experimental.proxy_tensor import make_fx >>> from torch._higher_order_ops.triton_kernel_wrap import capture_triton >>> >>> @triton.jit >>> def add_kernel( >>> in_ptr0, >>> in_ptr1, >>> out_ptr, >>> n_elements, >>> BLOCK_SIZE: "tl.constexpr", >>> ): >>> pid = tl.program_id(axis=0) >>> block_start = pid * BLOCK_SIZE >>> offsets = block_start + tl.arange(0, BLOCK_SIZE) >>> mask = offsets < n_elements >>> x = tl.load(in_ptr0 + offsets, mask=mask) >>> y = tl.load(in_ptr1 + offsets, mask=mask) >>> output = x + y >>> tl.store(out_ptr + offsets, output, mask=mask) >>> >>> def add(x, y): >>> output = torch.empty_like(x) >>> n_elements = output.numel() >>> >>> def grid_fn(meta): >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) >>> >>> capture_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16) >>> return output >>> >>> x = torch.randn(3, device="cuda") >>> y = torch.randn(3, device="cuda") >>> gm = make_fx(add)(x, y) >>> print(gm.code) >>> # def forward(self, x_1, y_1): >>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False) >>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation( >>> # kernel_idx = 0, constant_args_idx = 0, >>> # grid = [(1, 1, 1)], kwargs = { >>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like, >>> # 'n_elements': 3, 'BLOCK_SIZE': 16 >>> # }) >>> # return empty_like """ from triton.runtime.autotuner import Autotuner from triton.runtime.jit import JITFunction from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper if not isinstance(triton_kernel, (JITFunction, Autotuner)): raise RuntimeError( "capture_triton only works on functions annotated with triton.jit or triton.autotune" ) if not is_capture_triton_enabled(): return triton_kernel return TraceableTritonKernelWrapper(triton_kernel, None, None)