xref: /aosp_15_r20/external/pytorch/functorch/notebooks/minifier.ipynb (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker{
2*da0073e9SAndroid Build Coastguard Worker "cells": [
3*da0073e9SAndroid Build Coastguard Worker  {
4*da0073e9SAndroid Build Coastguard Worker   "cell_type": "markdown",
5*da0073e9SAndroid Build Coastguard Worker   "metadata": {},
6*da0073e9SAndroid Build Coastguard Worker   "source": [
7*da0073e9SAndroid Build Coastguard Worker    "# Using the Minifier\n",
8*da0073e9SAndroid Build Coastguard Worker    "We have a pretty convenient test case minifier with this interface\n",
9*da0073e9SAndroid Build Coastguard Worker    "```\n",
10*da0073e9SAndroid Build Coastguard Worker    "def minifier(fail_f: fx.GraphModule, inps, module_fails):\n",
11*da0073e9SAndroid Build Coastguard Worker    "    \"\"\"\n",
12*da0073e9SAndroid Build Coastguard Worker    "    Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails.\n",
13*da0073e9SAndroid Build Coastguard Worker    "\n",
14*da0073e9SAndroid Build Coastguard Worker    "    Does 2 main strategies:\n",
15*da0073e9SAndroid Build Coastguard Worker    "    1. Truncates suffix: Removes some suffix from the graph and sets a new output.\n",
16*da0073e9SAndroid Build Coastguard Worker    "    2. Delta Debugging: Tries replacing half of the graph with inputs. If fails,\n",
17*da0073e9SAndroid Build Coastguard Worker    "        tries replacing quarter of the graph, etc.\n",
18*da0073e9SAndroid Build Coastguard Worker    "\n",
19*da0073e9SAndroid Build Coastguard Worker    "    >>> failing_function = fx.symbolic_trace(f)\n",
20*da0073e9SAndroid Build Coastguard Worker    "    >>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps))\n",
21*da0073e9SAndroid Build Coastguard Worker    "\n",
22*da0073e9SAndroid Build Coastguard Worker    "    note: module_fails returns True if it fails.\n",
23*da0073e9SAndroid Build Coastguard Worker    "    ...\n",
24*da0073e9SAndroid Build Coastguard Worker    "```\n",
25*da0073e9SAndroid Build Coastguard Worker    "\n",
26*da0073e9SAndroid Build Coastguard Worker    "Specifically, it takes your FX graph, and tries to minify it with the following 4 strategies (while checking that the resulting graph still returns True for `module_fails`), until it can't minify it anymore.\n",
27*da0073e9SAndroid Build Coastguard Worker    "\n",
28*da0073e9SAndroid Build Coastguard Worker    "1. Truncates Suffix: Given a FX graph, it tries to remove some suffix from the graph. For example, given this:\n",
29*da0073e9SAndroid Build Coastguard Worker    "\n",
30*da0073e9SAndroid Build Coastguard Worker    "```\n",
31*da0073e9SAndroid Build Coastguard Worker    "def f(a):\n",
32*da0073e9SAndroid Build Coastguard Worker    "    b = x * 2\n",
33*da0073e9SAndroid Build Coastguard Worker    "    c = b + 3\n",
34*da0073e9SAndroid Build Coastguard Worker    "    d = c / 4\n",
35*da0073e9SAndroid Build Coastguard Worker    "    return d\n",
36*da0073e9SAndroid Build Coastguard Worker    "```\n",
37*da0073e9SAndroid Build Coastguard Worker    "It might try truncating the suffix, and get\n",
38*da0073e9SAndroid Build Coastguard Worker    "```\n",
39*da0073e9SAndroid Build Coastguard Worker    "def f(a):\n",
40*da0073e9SAndroid Build Coastguard Worker    "    b = x * 2\n",
41*da0073e9SAndroid Build Coastguard Worker    "    c = b + 3\n",
42*da0073e9SAndroid Build Coastguard Worker    "    return c\n",
43*da0073e9SAndroid Build Coastguard Worker    "```\n",
44*da0073e9SAndroid Build Coastguard Worker    "It tries this in a binary search manner, trying to remove the last 1/2, then 3/4, 1/4 then 7/8, 5/8, 3/8...\n",
45*da0073e9SAndroid Build Coastguard Worker    "\n",
46*da0073e9SAndroid Build Coastguard Worker    "2. [Delta Debugging](https://en.wikipedia.org/wiki/Delta_debugging): Of course, removing the suffix isn't always sufficient to minify a graph. What if the error is caused by the first instruction? So, we take an approach inspired by delta debugging - we try removing intermediate nodes of the graph. Unlike with suffixes, there are still dependencies on the removed nodes. So, instead of removing them entirely, we promote them to inputs. For example, given the above example:\n",
47*da0073e9SAndroid Build Coastguard Worker    "\n",
48*da0073e9SAndroid Build Coastguard Worker    "```\n",
49*da0073e9SAndroid Build Coastguard Worker    "def f(a):\n",
50*da0073e9SAndroid Build Coastguard Worker    "    b = x * 2\n",
51*da0073e9SAndroid Build Coastguard Worker    "    c = b + 3\n",
52*da0073e9SAndroid Build Coastguard Worker    "    d = c / 4\n",
53*da0073e9SAndroid Build Coastguard Worker    "    return d\n",
54*da0073e9SAndroid Build Coastguard Worker    "```\n",
55*da0073e9SAndroid Build Coastguard Worker    "We might remove a middle node (say, c, in this case).\n",
56*da0073e9SAndroid Build Coastguard Worker    "```\n",
57*da0073e9SAndroid Build Coastguard Worker    "def f(a, c):\n",
58*da0073e9SAndroid Build Coastguard Worker    "    b = x * 2\n",
59*da0073e9SAndroid Build Coastguard Worker    "    d = c / 4\n",
60*da0073e9SAndroid Build Coastguard Worker    "    return d\n",
61*da0073e9SAndroid Build Coastguard Worker    "```\n",
62*da0073e9SAndroid Build Coastguard Worker    "\n",
63*da0073e9SAndroid Build Coastguard Worker    "Finally, there are 2 auxiliary strategies - eliminating dead code and removing unused inputs. These are somewhat self-explanatory."
64*da0073e9SAndroid Build Coastguard Worker   ]
65*da0073e9SAndroid Build Coastguard Worker  },
66*da0073e9SAndroid Build Coastguard Worker  {
67*da0073e9SAndroid Build Coastguard Worker   "cell_type": "markdown",
68*da0073e9SAndroid Build Coastguard Worker   "metadata": {},
69*da0073e9SAndroid Build Coastguard Worker   "source": [
70*da0073e9SAndroid Build Coastguard Worker    "So, let's take a look at a toy example. Let's pretend that our graph fails if it has a \"multiply\" in it. Let's create a failing graph."
71*da0073e9SAndroid Build Coastguard Worker   ]
72*da0073e9SAndroid Build Coastguard Worker  },
73*da0073e9SAndroid Build Coastguard Worker  {
74*da0073e9SAndroid Build Coastguard Worker   "cell_type": "code",
75*da0073e9SAndroid Build Coastguard Worker   "execution_count": 1,
76*da0073e9SAndroid Build Coastguard Worker   "metadata": {},
77*da0073e9SAndroid Build Coastguard Worker   "outputs": [
78*da0073e9SAndroid Build Coastguard Worker    {
79*da0073e9SAndroid Build Coastguard Worker     "name": "stderr",
80*da0073e9SAndroid Build Coastguard Worker     "output_type": "stream",
81*da0073e9SAndroid Build Coastguard Worker     "text": [
82*da0073e9SAndroid Build Coastguard Worker      "[W OperatorEntry.cpp:133] Warning: Overriding a previously registered kernel for the same operator and the same dispatch key\n",
83*da0073e9SAndroid Build Coastguard Worker      "  operator: aten::multiply.Tensor(Tensor self, Tensor other) -> (Tensor)\n",
84*da0073e9SAndroid Build Coastguard Worker      "    registered at aten/src/ATen/RegisterSchema.cpp:6\n",
85*da0073e9SAndroid Build Coastguard Worker      "  dispatch key: FuncTorchBatched\n",
86*da0073e9SAndroid Build Coastguard Worker      "  previous kernel: registered at aten/src/ATen/RegisterCompositeImplicitAutograd.cpp:10338\n",
87*da0073e9SAndroid Build Coastguard Worker      "       new kernel: registered at /fsx/users/chilli/work/functorch/functorch/csrc/BatchRulesDecompositions.cpp:108 (function registerKernel)\n"
88*da0073e9SAndroid Build Coastguard Worker     ]
89*da0073e9SAndroid Build Coastguard Worker    },
90*da0073e9SAndroid Build Coastguard Worker    {
91*da0073e9SAndroid Build Coastguard Worker     "name": "stdout",
92*da0073e9SAndroid Build Coastguard Worker     "output_type": "stream",
93*da0073e9SAndroid Build Coastguard Worker     "text": [
94*da0073e9SAndroid Build Coastguard Worker      "Started off with 7 nodes\n",
95*da0073e9SAndroid Build Coastguard Worker      "###################\n",
96*da0073e9SAndroid Build Coastguard Worker      "Current size: 7\n",
97*da0073e9SAndroid Build Coastguard Worker      "###################\n",
98*da0073e9SAndroid Build Coastguard Worker      "Strategy: Remove suffix\n",
99*da0073e9SAndroid Build Coastguard Worker      "\n",
100*da0073e9SAndroid Build Coastguard Worker      "SUCCESS: Removed [4:7)\n",
101*da0073e9SAndroid Build Coastguard Worker      "\n",
102*da0073e9SAndroid Build Coastguard Worker      "###################\n",
103*da0073e9SAndroid Build Coastguard Worker      "Current size: 6\n",
104*da0073e9SAndroid Build Coastguard Worker      "###################\n",
105*da0073e9SAndroid Build Coastguard Worker      "Strategy: Delta Debugging\n",
106*da0073e9SAndroid Build Coastguard Worker      "SUCCESS: Removed (0:4] - Went from 2 placeholders to 4\n",
107*da0073e9SAndroid Build Coastguard Worker      "\n",
108*da0073e9SAndroid Build Coastguard Worker      "###################\n",
109*da0073e9SAndroid Build Coastguard Worker      "Current size: 6\n",
110*da0073e9SAndroid Build Coastguard Worker      "###################\n",
111*da0073e9SAndroid Build Coastguard Worker      "Strategy: Remove unused inputs\n",
112*da0073e9SAndroid Build Coastguard Worker      "SUCCESS: Went from 4 inputs to 2 inputs\n",
113*da0073e9SAndroid Build Coastguard Worker      "\n",
114*da0073e9SAndroid Build Coastguard Worker      "###################\n",
115*da0073e9SAndroid Build Coastguard Worker      "Current size: 4\n",
116*da0073e9SAndroid Build Coastguard Worker      "###################\n",
117*da0073e9SAndroid Build Coastguard Worker      "Strategy: Remove suffix\n",
118*da0073e9SAndroid Build Coastguard Worker      "FAIL: Could not remove suffix\n",
119*da0073e9SAndroid Build Coastguard Worker      "Strategy: Delta Debugging\n",
120*da0073e9SAndroid Build Coastguard Worker      "FAIL: Could not remove prefix\n",
121*da0073e9SAndroid Build Coastguard Worker      "\n",
122*da0073e9SAndroid Build Coastguard Worker      "inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]\n",
123*da0073e9SAndroid Build Coastguard Worker      "inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]\n",
124*da0073e9SAndroid Build Coastguard Worker      "\n",
125*da0073e9SAndroid Build Coastguard Worker      "\n",
126*da0073e9SAndroid Build Coastguard Worker      "\n",
127*da0073e9SAndroid Build Coastguard Worker      "def forward(self, div, add):\n",
128*da0073e9SAndroid Build Coastguard Worker      "    mul = torch.ops.aten.mul(add, div);  add = div = None\n",
129*da0073e9SAndroid Build Coastguard Worker      "    return (mul,)\n",
130*da0073e9SAndroid Build Coastguard Worker      "    \n",
131*da0073e9SAndroid Build Coastguard Worker      "f = torch.jit.script(forward)\n",
132*da0073e9SAndroid Build Coastguard Worker      "with torch.jit.fuser(\"fuser2\"):\n",
133*da0073e9SAndroid Build Coastguard Worker      "  for _ in range(5):\n",
134*da0073e9SAndroid Build Coastguard Worker      "    f(*inps)\n"
135*da0073e9SAndroid Build Coastguard Worker     ]
136*da0073e9SAndroid Build Coastguard Worker    }
137*da0073e9SAndroid Build Coastguard Worker   ],
138*da0073e9SAndroid Build Coastguard Worker   "source": [
139*da0073e9SAndroid Build Coastguard Worker    "import torch\n",
140*da0073e9SAndroid Build Coastguard Worker    "import torch.fx as fx\n",
141*da0073e9SAndroid Build Coastguard Worker    "from functorch.compile import minifier\n",
142*da0073e9SAndroid Build Coastguard Worker    "\n",
143*da0073e9SAndroid Build Coastguard Worker    "def failing_f(x, y):\n",
144*da0073e9SAndroid Build Coastguard Worker    "    y = torch.ops.aten.div(x, y)\n",
145*da0073e9SAndroid Build Coastguard Worker    "    x = torch.ops.aten.add(x, 3)\n",
146*da0073e9SAndroid Build Coastguard Worker    "    x = torch.ops.aten.mul(x, y)\n",
147*da0073e9SAndroid Build Coastguard Worker    "    return torch.ops.aten.sub(x, y)\n",
148*da0073e9SAndroid Build Coastguard Worker    "\n",
149*da0073e9SAndroid Build Coastguard Worker    "inps = [torch.randn(3), torch.randn(3)]\n",
150*da0073e9SAndroid Build Coastguard Worker    "\n",
151*da0073e9SAndroid Build Coastguard Worker    "def pass_checker(fx_g, inps):\n",
152*da0073e9SAndroid Build Coastguard Worker    "    return (torch.ops.aten.mul in {i.target for i in fx_g.graph.nodes})\n",
153*da0073e9SAndroid Build Coastguard Worker    "\n",
154*da0073e9SAndroid Build Coastguard Worker    "min_f, inps = minifier(fx.symbolic_trace(failing_f), inps, pass_checker)"
155*da0073e9SAndroid Build Coastguard Worker   ]
156*da0073e9SAndroid Build Coastguard Worker  },
157*da0073e9SAndroid Build Coastguard Worker  {
158*da0073e9SAndroid Build Coastguard Worker   "cell_type": "markdown",
159*da0073e9SAndroid Build Coastguard Worker   "metadata": {},
160*da0073e9SAndroid Build Coastguard Worker   "source": [
161*da0073e9SAndroid Build Coastguard Worker    "Tada! Our graph is now a minimal example that still fails.\n",
162*da0073e9SAndroid Build Coastguard Worker    "\n",
163*da0073e9SAndroid Build Coastguard Worker    "Since the primary use case of this minifier (for now) is for NVFuser repros, we print out a string for convenience that creates a self-contained repro to run the minified graph with NVFuser.\n",
164*da0073e9SAndroid Build Coastguard Worker    "\n",
165*da0073e9SAndroid Build Coastguard Worker    "Note that in practice, we provide 2 main \"graph checkers\" - `check_nvfuser_subprocess` and `check_nvfuser_correctness_subprocess`. These are used to check for errors and correctness (i.e. do the results match eager) respectively. These can be used like\n",
166*da0073e9SAndroid Build Coastguard Worker    "\n",
167*da0073e9SAndroid Build Coastguard Worker    "```\n",
168*da0073e9SAndroid Build Coastguard Worker    "from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess\n",
169*da0073e9SAndroid Build Coastguard Worker    "minifier(failing_graph, inps, check_nvfuser_subprocess)\n",
170*da0073e9SAndroid Build Coastguard Worker    "```"
171*da0073e9SAndroid Build Coastguard Worker   ]
172*da0073e9SAndroid Build Coastguard Worker  },
173*da0073e9SAndroid Build Coastguard Worker  {
174*da0073e9SAndroid Build Coastguard Worker   "cell_type": "markdown",
175*da0073e9SAndroid Build Coastguard Worker   "metadata": {},
176*da0073e9SAndroid Build Coastguard Worker   "source": [
177*da0073e9SAndroid Build Coastguard Worker    "However, assuming you're using AOTAutograd, there's another problem - how do you obtain the FX graph in the first place to pass to the minifier? One possible way is simply to use `print_compile`."
178*da0073e9SAndroid Build Coastguard Worker   ]
179*da0073e9SAndroid Build Coastguard Worker  },
180*da0073e9SAndroid Build Coastguard Worker  {
181*da0073e9SAndroid Build Coastguard Worker   "cell_type": "code",
182*da0073e9SAndroid Build Coastguard Worker   "execution_count": 2,
183*da0073e9SAndroid Build Coastguard Worker   "metadata": {},
184*da0073e9SAndroid Build Coastguard Worker   "outputs": [
185*da0073e9SAndroid Build Coastguard Worker    {
186*da0073e9SAndroid Build Coastguard Worker     "name": "stdout",
187*da0073e9SAndroid Build Coastguard Worker     "output_type": "stream",
188*da0073e9SAndroid Build Coastguard Worker     "text": [
189*da0073e9SAndroid Build Coastguard Worker      "\n",
190*da0073e9SAndroid Build Coastguard Worker      "\n",
191*da0073e9SAndroid Build Coastguard Worker      "\n",
192*da0073e9SAndroid Build Coastguard Worker      "def forward(self, primals_1):\n",
193*da0073e9SAndroid Build Coastguard Worker      "    cos = torch.ops.aten.cos(primals_1)\n",
194*da0073e9SAndroid Build Coastguard Worker      "    cos_1 = torch.ops.aten.cos(cos)\n",
195*da0073e9SAndroid Build Coastguard Worker      "    return [cos_1, primals_1, cos]\n",
196*da0073e9SAndroid Build Coastguard Worker      "    \n",
197*da0073e9SAndroid Build Coastguard Worker      "\n",
198*da0073e9SAndroid Build Coastguard Worker      "\n",
199*da0073e9SAndroid Build Coastguard Worker      "\n",
200*da0073e9SAndroid Build Coastguard Worker      "def forward(self, primals_1, cos, tangents_1):\n",
201*da0073e9SAndroid Build Coastguard Worker      "    sin = torch.ops.aten.sin(cos);  cos = None\n",
202*da0073e9SAndroid Build Coastguard Worker      "    neg = torch.ops.aten.neg(sin);  sin = None\n",
203*da0073e9SAndroid Build Coastguard Worker      "    mul = torch.ops.aten.mul(tangents_1, neg);  tangents_1 = neg = None\n",
204*da0073e9SAndroid Build Coastguard Worker      "    sin_1 = torch.ops.aten.sin(primals_1);  primals_1 = None\n",
205*da0073e9SAndroid Build Coastguard Worker      "    neg_1 = torch.ops.aten.neg(sin_1);  sin_1 = None\n",
206*da0073e9SAndroid Build Coastguard Worker      "    mul_1 = torch.ops.aten.mul(mul, neg_1);  mul = neg_1 = None\n",
207*da0073e9SAndroid Build Coastguard Worker      "    return [mul_1]\n",
208*da0073e9SAndroid Build Coastguard Worker      "    \n"
209*da0073e9SAndroid Build Coastguard Worker     ]
210*da0073e9SAndroid Build Coastguard Worker    },
211*da0073e9SAndroid Build Coastguard Worker    {
212*da0073e9SAndroid Build Coastguard Worker     "data": {
213*da0073e9SAndroid Build Coastguard Worker      "text/plain": [
214*da0073e9SAndroid Build Coastguard Worker       "tensor([0.6062, 0.9982, 0.6474], grad_fn=<CompiledFunctionBackward>)"
215*da0073e9SAndroid Build Coastguard Worker      ]
216*da0073e9SAndroid Build Coastguard Worker     },
217*da0073e9SAndroid Build Coastguard Worker     "execution_count": 2,
218*da0073e9SAndroid Build Coastguard Worker     "metadata": {},
219*da0073e9SAndroid Build Coastguard Worker     "output_type": "execute_result"
220*da0073e9SAndroid Build Coastguard Worker    }
221*da0073e9SAndroid Build Coastguard Worker   ],
222*da0073e9SAndroid Build Coastguard Worker   "source": [
223*da0073e9SAndroid Build Coastguard Worker    "from functorch.compile import aot_function\n",
224*da0073e9SAndroid Build Coastguard Worker    "\n",
225*da0073e9SAndroid Build Coastguard Worker    "from functorch.compile import print_compile\n",
226*da0073e9SAndroid Build Coastguard Worker    "# Or...\n",
227*da0073e9SAndroid Build Coastguard Worker    "def print_compile(fx_g, _):\n",
228*da0073e9SAndroid Build Coastguard Worker    "    print(fx_g.code)\n",
229*da0073e9SAndroid Build Coastguard Worker    "    return fx_g\n",
230*da0073e9SAndroid Build Coastguard Worker    "\n",
231*da0073e9SAndroid Build Coastguard Worker    "def foo(x):\n",
232*da0073e9SAndroid Build Coastguard Worker    "    return x.cos().cos()\n",
233*da0073e9SAndroid Build Coastguard Worker    "inp = torch.randn(3, requires_grad=True)\n",
234*da0073e9SAndroid Build Coastguard Worker    "aot_function(foo, print_compile)(inp)"
235*da0073e9SAndroid Build Coastguard Worker   ]
236*da0073e9SAndroid Build Coastguard Worker  },
237*da0073e9SAndroid Build Coastguard Worker  {
238*da0073e9SAndroid Build Coastguard Worker   "cell_type": "markdown",
239*da0073e9SAndroid Build Coastguard Worker   "metadata": {},
240*da0073e9SAndroid Build Coastguard Worker   "source": [
241*da0073e9SAndroid Build Coastguard Worker    "However, this doesn't provide the inputs, nor does it handle any tensor constants that might be saved in the graph. To resolve this, we have another \"compiler\" called `debug_compile`. It simply prints out a string that can be copy pasted and run from another file. It leverages FX's `to_folder` feature to serialize the graph to disk, along with any constants.\n",
242*da0073e9SAndroid Build Coastguard Worker    "\n",
243*da0073e9SAndroid Build Coastguard Worker    "You can apply it to either the `fw_compiler` to dump the forwards graph or `bw_compiler` to dump the backwards graph."
244*da0073e9SAndroid Build Coastguard Worker   ]
245*da0073e9SAndroid Build Coastguard Worker  },
246*da0073e9SAndroid Build Coastguard Worker  {
247*da0073e9SAndroid Build Coastguard Worker   "cell_type": "code",
248*da0073e9SAndroid Build Coastguard Worker   "execution_count": 3,
249*da0073e9SAndroid Build Coastguard Worker   "metadata": {},
250*da0073e9SAndroid Build Coastguard Worker   "outputs": [
251*da0073e9SAndroid Build Coastguard Worker    {
252*da0073e9SAndroid Build Coastguard Worker     "name": "stdout",
253*da0073e9SAndroid Build Coastguard Worker     "output_type": "stream",
254*da0073e9SAndroid Build Coastguard Worker     "text": [
255*da0073e9SAndroid Build Coastguard Worker      "\n",
256*da0073e9SAndroid Build Coastguard Worker      "##############################################################\n",
257*da0073e9SAndroid Build Coastguard Worker      "# To minimize FX graph, copy and paste the below and run it  #\n",
258*da0073e9SAndroid Build Coastguard Worker      "##############################################################\n",
259*da0073e9SAndroid Build Coastguard Worker      "\n",
260*da0073e9SAndroid Build Coastguard Worker      "import torch\n",
261*da0073e9SAndroid Build Coastguard Worker      "import torch.fx as fx\n",
262*da0073e9SAndroid Build Coastguard Worker      "from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess\n",
263*da0073e9SAndroid Build Coastguard Worker      "\n",
264*da0073e9SAndroid Build Coastguard Worker      "inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]\n",
265*da0073e9SAndroid Build Coastguard Worker      "inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]\n",
266*da0073e9SAndroid Build Coastguard Worker      "from foo import FxModule\n",
267*da0073e9SAndroid Build Coastguard Worker      "mod = FxModule().cuda()\n",
268*da0073e9SAndroid Build Coastguard Worker      "\n",
269*da0073e9SAndroid Build Coastguard Worker      "with torch.jit.fuser(\"fuser2\"):\n",
270*da0073e9SAndroid Build Coastguard Worker      "  # check_nvfuser_subprocess can be replaced with check_nvfuser_correctness_subprocess\n",
271*da0073e9SAndroid Build Coastguard Worker      "  minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess)\n",
272*da0073e9SAndroid Build Coastguard Worker      "\n"
273*da0073e9SAndroid Build Coastguard Worker     ]
274*da0073e9SAndroid Build Coastguard Worker    },
275*da0073e9SAndroid Build Coastguard Worker    {
276*da0073e9SAndroid Build Coastguard Worker     "data": {
277*da0073e9SAndroid Build Coastguard Worker      "text/plain": [
278*da0073e9SAndroid Build Coastguard Worker       "tensor([0.6062, 0.9982, 0.6474], grad_fn=<CompiledFunctionBackward>)"
279*da0073e9SAndroid Build Coastguard Worker      ]
280*da0073e9SAndroid Build Coastguard Worker     },
281*da0073e9SAndroid Build Coastguard Worker     "execution_count": 3,
282*da0073e9SAndroid Build Coastguard Worker     "metadata": {},
283*da0073e9SAndroid Build Coastguard Worker     "output_type": "execute_result"
284*da0073e9SAndroid Build Coastguard Worker    }
285*da0073e9SAndroid Build Coastguard Worker   ],
286*da0073e9SAndroid Build Coastguard Worker   "source": [
287*da0073e9SAndroid Build Coastguard Worker    "from functorch.compile import memory_efficient_fusion, debug_compile\n",
288*da0073e9SAndroid Build Coastguard Worker    "\n",
289*da0073e9SAndroid Build Coastguard Worker    "memory_efficient_fusion(foo, bw_compiler=debug_compile)(inp)\n"
290*da0073e9SAndroid Build Coastguard Worker   ]
291*da0073e9SAndroid Build Coastguard Worker  },
292*da0073e9SAndroid Build Coastguard Worker  {
293*da0073e9SAndroid Build Coastguard Worker   "cell_type": "markdown",
294*da0073e9SAndroid Build Coastguard Worker   "metadata": {},
295*da0073e9SAndroid Build Coastguard Worker   "source": [
296*da0073e9SAndroid Build Coastguard Worker    "So, let's copy paste it and see how it works - note that I made a couple minor modifications to run on CPU and use the previous \"graph fails if there's a multiply in it\" checker."
297*da0073e9SAndroid Build Coastguard Worker   ]
298*da0073e9SAndroid Build Coastguard Worker  },
299*da0073e9SAndroid Build Coastguard Worker  {
300*da0073e9SAndroid Build Coastguard Worker   "cell_type": "code",
301*da0073e9SAndroid Build Coastguard Worker   "execution_count": 4,
302*da0073e9SAndroid Build Coastguard Worker   "metadata": {},
303*da0073e9SAndroid Build Coastguard Worker   "outputs": [
304*da0073e9SAndroid Build Coastguard Worker    {
305*da0073e9SAndroid Build Coastguard Worker     "name": "stdout",
306*da0073e9SAndroid Build Coastguard Worker     "output_type": "stream",
307*da0073e9SAndroid Build Coastguard Worker     "text": [
308*da0073e9SAndroid Build Coastguard Worker      "Started off with 10 nodes\n",
309*da0073e9SAndroid Build Coastguard Worker      "###################\n",
310*da0073e9SAndroid Build Coastguard Worker      "Current size: 10\n",
311*da0073e9SAndroid Build Coastguard Worker      "###################\n",
312*da0073e9SAndroid Build Coastguard Worker      "Strategy: Remove suffix\n",
313*da0073e9SAndroid Build Coastguard Worker      "\n",
314*da0073e9SAndroid Build Coastguard Worker      "SUCCESS: Removed [6:10)\n",
315*da0073e9SAndroid Build Coastguard Worker      "\n",
316*da0073e9SAndroid Build Coastguard Worker      "###################\n",
317*da0073e9SAndroid Build Coastguard Worker      "Current size: 8\n",
318*da0073e9SAndroid Build Coastguard Worker      "###################\n",
319*da0073e9SAndroid Build Coastguard Worker      "Strategy: Delta Debugging\n",
320*da0073e9SAndroid Build Coastguard Worker      "SUCCESS: Removed (0:4] - Went from 2 placeholders to 4\n",
321*da0073e9SAndroid Build Coastguard Worker      "\n",
322*da0073e9SAndroid Build Coastguard Worker      "###################\n",
323*da0073e9SAndroid Build Coastguard Worker      "Current size: 8\n",
324*da0073e9SAndroid Build Coastguard Worker      "###################\n",
325*da0073e9SAndroid Build Coastguard Worker      "Strategy: Remove unused inputs\n",
326*da0073e9SAndroid Build Coastguard Worker      "SUCCESS: Went from 4 inputs to 3 inputs\n",
327*da0073e9SAndroid Build Coastguard Worker      "\n",
328*da0073e9SAndroid Build Coastguard Worker      "###################\n",
329*da0073e9SAndroid Build Coastguard Worker      "Current size: 7\n",
330*da0073e9SAndroid Build Coastguard Worker      "###################\n",
331*da0073e9SAndroid Build Coastguard Worker      "Strategy: Remove suffix\n",
332*da0073e9SAndroid Build Coastguard Worker      "\n",
333*da0073e9SAndroid Build Coastguard Worker      "SUCCESS: Removed [4:7)\n",
334*da0073e9SAndroid Build Coastguard Worker      "\n",
335*da0073e9SAndroid Build Coastguard Worker      "###################\n",
336*da0073e9SAndroid Build Coastguard Worker      "Current size: 6\n",
337*da0073e9SAndroid Build Coastguard Worker      "###################\n",
338*da0073e9SAndroid Build Coastguard Worker      "Strategy: Remove unused inputs\n",
339*da0073e9SAndroid Build Coastguard Worker      "SUCCESS: Went from 3 inputs to 2 inputs\n",
340*da0073e9SAndroid Build Coastguard Worker      "\n",
341*da0073e9SAndroid Build Coastguard Worker      "###################\n",
342*da0073e9SAndroid Build Coastguard Worker      "Current size: 5\n",
343*da0073e9SAndroid Build Coastguard Worker      "###################\n",
344*da0073e9SAndroid Build Coastguard Worker      "Strategy: Delta Debugging\n",
345*da0073e9SAndroid Build Coastguard Worker      "SUCCESS: Removed (2:3] - Went from 2 placeholders to 3\n",
346*da0073e9SAndroid Build Coastguard Worker      "\n",
347*da0073e9SAndroid Build Coastguard Worker      "###################\n",
348*da0073e9SAndroid Build Coastguard Worker      "Current size: 5\n",
349*da0073e9SAndroid Build Coastguard Worker      "###################\n",
350*da0073e9SAndroid Build Coastguard Worker      "Strategy: Remove unused inputs\n",
351*da0073e9SAndroid Build Coastguard Worker      "SUCCESS: Went from 3 inputs to 2 inputs\n",
352*da0073e9SAndroid Build Coastguard Worker      "\n",
353*da0073e9SAndroid Build Coastguard Worker      "###################\n",
354*da0073e9SAndroid Build Coastguard Worker      "Current size: 4\n",
355*da0073e9SAndroid Build Coastguard Worker      "###################\n",
356*da0073e9SAndroid Build Coastguard Worker      "Strategy: Remove suffix\n",
357*da0073e9SAndroid Build Coastguard Worker      "FAIL: Could not remove suffix\n",
358*da0073e9SAndroid Build Coastguard Worker      "Strategy: Delta Debugging\n",
359*da0073e9SAndroid Build Coastguard Worker      "FAIL: Could not remove prefix\n",
360*da0073e9SAndroid Build Coastguard Worker      "\n",
361*da0073e9SAndroid Build Coastguard Worker      "inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]\n",
362*da0073e9SAndroid Build Coastguard Worker      "inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]\n",
363*da0073e9SAndroid Build Coastguard Worker      "\n",
364*da0073e9SAndroid Build Coastguard Worker      "\n",
365*da0073e9SAndroid Build Coastguard Worker      "\n",
366*da0073e9SAndroid Build Coastguard Worker      "def forward(self, tangents_1, neg):\n",
367*da0073e9SAndroid Build Coastguard Worker      "    mul = torch.ops.aten.mul(tangents_1, neg);  tangents_1 = neg = None\n",
368*da0073e9SAndroid Build Coastguard Worker      "    return (mul,)\n",
369*da0073e9SAndroid Build Coastguard Worker      "    \n",
370*da0073e9SAndroid Build Coastguard Worker      "f = torch.jit.script(forward)\n",
371*da0073e9SAndroid Build Coastguard Worker      "with torch.jit.fuser(\"fuser2\"):\n",
372*da0073e9SAndroid Build Coastguard Worker      "  for _ in range(5):\n",
373*da0073e9SAndroid Build Coastguard Worker      "    f(*inps)\n"
374*da0073e9SAndroid Build Coastguard Worker     ]
375*da0073e9SAndroid Build Coastguard Worker    },
376*da0073e9SAndroid Build Coastguard Worker    {
377*da0073e9SAndroid Build Coastguard Worker     "data": {
378*da0073e9SAndroid Build Coastguard Worker      "text/plain": [
379*da0073e9SAndroid Build Coastguard Worker       "(GraphModule(), [tensor([1., 1., 1.]), tensor([-0.5144, -0.5144, -0.5144])])"
380*da0073e9SAndroid Build Coastguard Worker      ]
381*da0073e9SAndroid Build Coastguard Worker     },
382*da0073e9SAndroid Build Coastguard Worker     "execution_count": 4,
383*da0073e9SAndroid Build Coastguard Worker     "metadata": {},
384*da0073e9SAndroid Build Coastguard Worker     "output_type": "execute_result"
385*da0073e9SAndroid Build Coastguard Worker    }
386*da0073e9SAndroid Build Coastguard Worker   ],
387*da0073e9SAndroid Build Coastguard Worker   "source": [
388*da0073e9SAndroid Build Coastguard Worker    "import torch\n",
389*da0073e9SAndroid Build Coastguard Worker    "import torch.fx as fx\n",
390*da0073e9SAndroid Build Coastguard Worker    "from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess\n",
391*da0073e9SAndroid Build Coastguard Worker    "\n",
392*da0073e9SAndroid Build Coastguard Worker    "inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]\n",
393*da0073e9SAndroid Build Coastguard Worker    "inps = [torch.ones(shape, dtype=dtype) for (shape, dtype) in inps]\n",
394*da0073e9SAndroid Build Coastguard Worker    "from foo import FxModule\n",
395*da0073e9SAndroid Build Coastguard Worker    "mod = FxModule()\n",
396*da0073e9SAndroid Build Coastguard Worker    "\n",
397*da0073e9SAndroid Build Coastguard Worker    "minifier(fx.symbolic_trace(mod), inps, pass_checker)"
398*da0073e9SAndroid Build Coastguard Worker   ]
399*da0073e9SAndroid Build Coastguard Worker  },
400*da0073e9SAndroid Build Coastguard Worker  {
401*da0073e9SAndroid Build Coastguard Worker   "cell_type": "markdown",
402*da0073e9SAndroid Build Coastguard Worker   "metadata": {},
403*da0073e9SAndroid Build Coastguard Worker   "source": [
404*da0073e9SAndroid Build Coastguard Worker    "Hopefully that was useful :)"
405*da0073e9SAndroid Build Coastguard Worker   ]
406*da0073e9SAndroid Build Coastguard Worker  }
407*da0073e9SAndroid Build Coastguard Worker ],
408*da0073e9SAndroid Build Coastguard Worker "metadata": {
409*da0073e9SAndroid Build Coastguard Worker  "interpreter": {
410*da0073e9SAndroid Build Coastguard Worker   "hash": "a1cf69278e4496ab232105d2fffcc75678d2dcbec1c795483197519eb80161c7"
411*da0073e9SAndroid Build Coastguard Worker  },
412*da0073e9SAndroid Build Coastguard Worker  "kernelspec": {
413*da0073e9SAndroid Build Coastguard Worker   "display_name": "Python 3.8.12 ('py38')",
414*da0073e9SAndroid Build Coastguard Worker   "language": "python",
415*da0073e9SAndroid Build Coastguard Worker   "name": "python3"
416*da0073e9SAndroid Build Coastguard Worker  },
417*da0073e9SAndroid Build Coastguard Worker  "language_info": {
418*da0073e9SAndroid Build Coastguard Worker   "codemirror_mode": {
419*da0073e9SAndroid Build Coastguard Worker    "name": "ipython",
420*da0073e9SAndroid Build Coastguard Worker    "version": 3
421*da0073e9SAndroid Build Coastguard Worker   },
422*da0073e9SAndroid Build Coastguard Worker   "file_extension": ".py",
423*da0073e9SAndroid Build Coastguard Worker   "mimetype": "text/x-python",
424*da0073e9SAndroid Build Coastguard Worker   "name": "python",
425*da0073e9SAndroid Build Coastguard Worker   "nbconvert_exporter": "python",
426*da0073e9SAndroid Build Coastguard Worker   "pygments_lexer": "ipython3",
427*da0073e9SAndroid Build Coastguard Worker   "version": "3.8.12"
428*da0073e9SAndroid Build Coastguard Worker  },
429*da0073e9SAndroid Build Coastguard Worker  "orig_nbformat": 4
430*da0073e9SAndroid Build Coastguard Worker },
431*da0073e9SAndroid Build Coastguard Worker "nbformat": 4,
432*da0073e9SAndroid Build Coastguard Worker "nbformat_minor": 2
433*da0073e9SAndroid Build Coastguard Worker}
434