Name Date Size #Lines LOC

..--

LlgaTensorImpl.cppH A D25-Apr-20255.3 KiB161122

LlgaTensorImpl.hH A D25-Apr-20257.5 KiB277208

README.mdH A D25-Apr-20255.4 KiB13289

decompose_silu.cppH A D25-Apr-20251.6 KiB6653

decompose_silu.hH A D25-Apr-2025263 1611

defer_size_check.cppH A D25-Apr-20252.4 KiB8963

defer_size_check.hH A D25-Apr-2025257 1611

graph_fuser.cppH A D25-Apr-20251.1 KiB3222

graph_fuser.hH A D25-Apr-20251.3 KiB5441

graph_helper.cppH A D25-Apr-202524.7 KiB622521

graph_helper.hH A D25-Apr-20252.5 KiB10579

graph_rewriter.cppH A D25-Apr-20255.2 KiB14591

guard_shape.cppH A D25-Apr-20251.4 KiB4630

guard_shape.hH A D25-Apr-2025259 1611

interface.cppH A D25-Apr-20256.2 KiB183149

interface.hH A D25-Apr-20251.4 KiB6350

kernel.cppH A D25-Apr-202510.3 KiB300253

kernel.hH A D25-Apr-20252.7 KiB9667

layout_propagation.cppH A D25-Apr-20251.4 KiB5445

layout_propagation.hH A D25-Apr-2025264 1611

operator.hH A D25-Apr-20253.9 KiB153121

prepare_binary.cppH A D25-Apr-20257.2 KiB186137

prepare_binary.hH A D25-Apr-2025574 2711

register_interface.cppH A D25-Apr-20251 KiB5548

README.md

1# Pytorch - oneDNN Graph API Bridge
2This is a PyTorch JIT graph fuser based on [oneDNN Graph API](https://spec.oneapi.io/onednn-graph/latest/programming_model.html), which provides a flexible API for aggressive fusion. Float & BFloat16 inference is supported. However, BFloat16 only performs well on Intel Xeon Cooper Lake platform & beyond, as they have native BFloat16 support. Also, currently, PyTorch has divergent AMP support in JIT & eager modes, so one should disable JIT AMP support & leverage eager mode AMP support to use BFloat16. Please refer to the BFloat16 example below.
3
4Currently, speedup is achieved only for static shapes, although we'd soon add dynamic-shape support. When oneDNN Graph is enabled, weights are cached, as they're constant during inference.
5
6## Graph Optimization
7We have registered optimization passes in the custom pre-passes set of PyTorch:
8
91. Alias and mutation reduction
10
11    The operators of oneDNN graph are pure functional while PyTorch has operators in in-place forms or create views for buffer sharing.
12    Due to the semantic gaps between the backend operators and the PyTorch operators, we have a pass to reduce mutation with best effort at the beginning.
13
142. Graph passing
15
16    With a PyTorch TorchScript graph, the integration maps PyTorch operators on the graph to the corresponding oneDNN Graph operators to form a backend graph.
17
183. Partitioning
19
20    The backend selects regions to be fused in the graph and returns a list of partitions. Each partition corresponds to a set of fused operators.
21
224. Graph rewriting
23
24    The original PyTorch JIT graph will be re-written based on the partitions returned from the backend. The operators in one partition will be grouped together to form a JIT operator, referred to as a oneDNN Graph fusion group.
25
265. Layout propagation
27
28    This pass is to eliminate unnecessary layout conversions at partition boundaries. We set different formats to the output of a partition so that the backend could perform layout conversion internally. When `ANY` is set, the layout at boundaries will be fully decided by the backend. Otherwise, the backend should follow the layout set by PyTorch. Currently, we set `ANY` layout for a tensor that's an output of a oneDNN Graph partition, and an input to another.
29
30## Graph Executor
31During runtime execution of a (re-written) PyTorch JIT graph, oneDNN graph partitions will be dispatched to the oneDNN graph JIT variadic Operator.
32Inside the oneDNN graph JIT Op, input PyTorch tensors of each partition will be mapped to oneDNN graph tensors. The partition will then be [compiled](https://spec.oneapi.io/onednn-graph/latest/programming_model.html#partition) and [executed](https://spec.oneapi.io/onednn-graph/latest/programming_model.html#compiled-partition). The output oneDNN graph tensor will be mapped back to PyTorch tensors to be fed to the next operator on the PyTorch JIT graph.
33
34
35## Tests
36
37```bash
38pytest test/test_jit_llga_fuser.py
39```
40
41## Quick Start
42
43A simple cascaded Conv-Relu example is provided in test. Please consider enabling log outputs to familiarize yourself with the whole pipeline:
44
45**Mutation Removal -> Prepare Binary -> Defer Size Check -> Graph Fuser -> Layout Propagation -> Type Guard -> Kernel Execution**
46
47oneDNN Graph was formerly known as LLGA (Low Level Graph API),
48and thus LLGA in the codebase corresponds to oneDNN Graph.
49
50```bash
51DNNL_VERBOSE=1 PYTORCH_JIT_LOG_LEVEL=">>graph_helper:>>graph_fuser:>>kernel:>>interface" python -u test/test_jit_llga_fuser.py -k test_conv2d_eltwise
52```
53
54## Codebase structure
55
56Most of the source code is placed in
57
58```bash
59torch/csrc/jit/codegen/onednn/*
60```
61
62Tensor related code is located at
63
64```bash
65torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h
66torch/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp
67```
68
69CMake files where bridge code is included:
70
71```bash
72caffe2/CMakeLists.txt
73```
74
75CMake files where oneDNN Graph submodule are included:
76
77```bash
78third_party/ideep/mkl-dnn
79cmake/public/mkldnn.cmake
80cmake/Modules/FindMKLDNN.cmake
81cmake/Dependencies.cmake
82```
83
84To map another op to oneDNN Graph, you should add an entry for it in in createOperator in torch/csrc/jit/codegen/onednn/graph_helper.cpp.
85If it has an inplace variant, you should add it in the lambda being passed to RemoveTensorMutation in
86torch/csrc/jit/codegen/onednn/interface.cpp. You might also want to add it to canFuseNode in torch/csrc/jit/codegen/onednn/register_interface.cpp.
87
88## Example with Float
89
90
91```python
92# enable oneDNN graph fusion globally
93torch.jit.enable_onednn_fusion(True)
94
95# define the model
96def MyModel(torch.nn.Module):
97    ...
98
99# construct the model
100model = MyModel(…)
101with torch.no_grad():
102    model.eval()
103    model = torch.jit.trace(model, torch.rand(args.batch_size, 3, 224, 224))
104
105# run the model
106with torch.no_grad():
107    # oneDNN graph fusion will be triggered during runtime
108    output = model(images)
109```
110
111## Example with BFloat16
112
113```python
114# Assuming we have a model of the name 'model'
115
116example_input = torch.rand(1, 3, 224, 224)
117
118# enable oneDNN Graph
119torch.jit.enable_onednn_fusion(True)
120# Disable AMP for JIT
121torch._C._jit_set_autocast_mode(False)
122with torch.no_grad(), torch.cpu.amp.autocast():
123    model = torch.jit.trace(model, (example_input))
124    model = torch.jit.freeze(model)
125     # 2 warm-ups (2 for tracing/scripting with an example, 3 without an example)
126    model(example_input)
127    model(example_input)
128
129    # speedup would be observed in subsequent runs.
130    model(example_input)
131```
132