xref: /aosp_15_r20/external/executorch/backends/cadence/aot/memory_planning.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import collections
8import itertools
9import logging
10from functools import partial
11from typing import Iterable, List, Optional, Tuple
12
13import torch
14from executorch.backends.cadence.aot.utils import MemoryConfig
15
16from executorch.exir import ExecutorchProgramManager
17from executorch.exir.memory_planning import collect_specs_from_nodes, Verifier
18from executorch.exir.passes import MemoryPlanningPass
19from executorch.exir.tensor import TensorSpec
20from tabulate import tabulate
21from torch.export.exported_program import ExportGraphSignature
22from torch.fx.passes.infra.pass_base import PassResult
23
24
25# get num memories indexed from 1..N, compatible with EXIR's spec.mem_id
26def get_num_memories(memory_config: MemoryConfig) -> int:
27    return len(memory_config.memory_sizes) + 1
28
29
30# memory_space module provides num_memories indexed 0..num_memories-1.
31def get_size(memory_config: MemoryConfig, exir_id: int) -> int:
32    return memory_config.memory_sizes[exir_id - 1]
33
34
35def collect_specs_from_graph_module(
36    graph_module: torch.fx.GraphModule,
37    alloc_graph_input: bool,
38    alloc_graph_output: bool,
39) -> Iterable[TensorSpec]:
40    """
41    Return the specs for all the nodes in the graph module in
42    topological order.
43    """
44    # Collect the specs from all the nodes in the graph module, and return it
45    return collect_specs_from_nodes(
46        graph_module.graph.nodes,
47        ignore_graph_input=not alloc_graph_input,
48        ignore_graph_output=not alloc_graph_output,
49    )
50
51
52# baseline tensor placement algorithm, that greedily tries to place the tensor in
53# the fastest memory available
54def position_based_greedy_with_hierarchy(
55    graph_module: torch.fx.GraphModule,
56    alignment: int,
57    graph_signature: ExportGraphSignature,
58    alloc_graph_input: bool,
59    alloc_graph_output: bool,
60    *,
61    memory_config: MemoryConfig,
62) -> List[int]:
63    num_memories = get_num_memories(memory_config)
64    bufsizes = [0] * num_memories
65    allocated_buffers: List[List[TensorSpec]] = [[] for _ in range(num_memories)]
66
67    def overlap(spec: TensorSpec) -> Optional[TensorSpec]:
68        for allocated_spec in allocated_buffers[spec.mem_id]:
69            if Verifier.lifetime_overlap(
70                spec, allocated_spec
71            ) and Verifier.storage_overlap(spec, allocated_spec):
72                return allocated_spec
73        return None
74
75    def memory_available(spec: TensorSpec) -> bool:
76        return spec.mem_offset + spec.allocated_memory <= get_size(
77            memory_config, spec.mem_id
78        )
79
80    # Iterate over all the specs in sorted order
81    for spec in sorted(
82        collect_specs_from_graph_module(
83            graph_module, alloc_graph_input, alloc_graph_output
84        ),
85        key=lambda spec: spec.allocated_memory,
86        reverse=True,
87    ):
88        for spec.mem_id in range(1, num_memories):
89            spec.mem_offset = 0
90            while memory_available(spec) and (overlapped := overlap(spec)):
91                spec.mem_offset = overlapped.mem_offset + overlapped.allocated_memory
92            if memory_available(spec):
93                allocated_buffers[spec.mem_id].append(spec)
94                bufsizes[spec.mem_id] = max(
95                    spec.mem_offset + spec.allocated_memory, bufsizes[spec.mem_id]
96                )
97                break
98        if (
99            not allocated_buffers[spec.mem_id]
100            or allocated_buffers[spec.mem_id][-1] is not spec
101        ):
102            raise MemoryError(f"Cannot fit {spec} in any memory hierarchy")
103
104    logging.debug(
105        f"position based greedy algorithm with hierarchy returns bufsizes: {bufsizes}"
106    )
107    return bufsizes
108
109
110# Greedy tensor placement with the heuristics from arxiv.org/pdf/2001.03288.pdf
111def greedy_by_size_for_offset_calculation_with_hierarchy(
112    graph_module: torch.fx.GraphModule,
113    alignment: int,
114    graph_signature: ExportGraphSignature,
115    alloc_graph_input: bool,
116    alloc_graph_output: bool,
117    *,
118    memory_config: MemoryConfig,
119) -> List[int]:
120    num_memories = get_num_memories(memory_config)
121    bufsizes = [0] * num_memories
122    allocated_buffers = [[] for _ in range(num_memories)]
123
124    # Iterate over all the specs in sorted order
125    for spec in sorted(
126        collect_specs_from_graph_module(
127            graph_module, alloc_graph_input, alloc_graph_output
128        ),
129        key=lambda spec: spec.allocated_memory,
130        reverse=True,
131    ):
132        for spec.mem_id in range(1, num_memories):
133            prev_offset, smallest_gap = 0, float("inf")
134            for allocated_spec in allocated_buffers[spec.mem_id]:
135                if Verifier.lifetime_overlap(spec, allocated_spec):
136                    if (
137                        gap := allocated_spec.mem_offset - prev_offset
138                    ) >= spec.allocated_memory and gap < smallest_gap:
139                        smallest_gap = gap
140                        spec.mem_offset = prev_offset
141                    # Note that different from the paper, which updates prev_offset for all
142                    # allocated tensors, we only update tensors with overlapping lifetime.
143                    # Updating prev_offset outside the if statement will include tensors without
144                    # overlapping lifetime, causing unnecessary waste of memory and make the
145                    # calculation of gap incorrect. Moving it out will make the algorithm degenerate
146                    # to the naive one, reusing 0 tensor. The paper may have a typo here.
147                    prev_offset = max(
148                        allocated_spec.mem_offset + allocated_spec.allocated_memory,
149                        prev_offset,
150                    )
151            if spec.mem_offset is None:
152                if prev_offset + spec.allocated_memory > get_size(
153                    memory_config, spec.mem_id
154                ):
155                    continue
156                else:
157                    spec.mem_offset = prev_offset
158            bufsizes[spec.mem_id] = max(
159                spec.mem_offset + spec.allocated_memory, bufsizes[spec.mem_id]
160            )
161            allocated_buffers[spec.mem_id].append(spec)
162            allocated_buffers[spec.mem_id].sort(key=lambda spec: spec.mem_offset)
163            # A data structure used for maintaining the tensor order
164            # by offset, named ordered_allocated_ids in the paper
165            break
166        if spec not in allocated_buffers[spec.mem_id]:
167            raise MemoryError(f"Cannot fit {spec} in any memory hierarchy")
168
169    logging.debug(
170        f"greedy by size for offset calculation with hierarchy returns bufsizes: {bufsizes}"
171    )
172    return bufsizes
173
174
175def find_peak_memory_usages_per_memory(
176    graph_module: torch.fx.GraphModule,
177    alloc_graph_input: bool,
178    alloc_graph_output: bool,
179) -> List[int]:
180    """
181    Given a GraphModule with a memory plan, find the peak memory usages for each memory
182    in the memory hierarchy.
183    """
184    # Create a defaultdict to keep track of memory usages: {mem_id: mem_usage}
185    # Use a defaultdict here because we don't know how many unique memory_id in
186    # the memory hierarchy used in memory planning.
187    usages = collections.defaultdict(int)
188
189    # go through all nodes in the graph, collect memory usage per spec.mem_id
190    for spec in collect_specs_from_graph_module(
191        graph_module, alloc_graph_input, alloc_graph_output
192    ):
193        usages[spec.mem_id] = max(
194            usages[spec.mem_id], spec.mem_offset + spec.allocated_memory
195        )
196
197    # Convert usages dictionary into list of len of max memory id
198    # Ex: {1: 20, 3:30} -> [0, 20, 0, 30].
199    #                       ^   ^  ^   ^
200    #                       |   |  |   |_  mem_id 3
201    #                       |   |  |_ mem_id 2
202    #                       |   |_ mem_id 1
203    #                       |_ mem_id 0
204    max_mem_id = max(usages.keys(), default=0)
205    usages = [usages[i] for i in range(1, max_mem_id + 1)]
206
207    return usages
208
209
210def find_peak_memory_usage(
211    graph_module: torch.fx.GraphModule,
212    alloc_graph_input: bool,
213    alloc_graph_output: bool,
214) -> Tuple[int, int]:
215    """
216    Given a GraphModule with a memory plan, find the peak usage over time across all
217    memories in the memory hierarchy. The resulting peak memory usage should be:
218    1. >= min(find_peak_memory_usages_per_memory(graph_module))
219    2. <= sum(find_peak_memory_usages_per_memory(graph_module))
220    """
221    # memory allocations over time (measured in nodex index)
222    byte_allocated = [0] * (len(graph_module.graph.nodes) + 1)
223
224    # Iterate over all the node specs
225    for spec in collect_specs_from_graph_module(
226        graph_module, alloc_graph_input, alloc_graph_output
227    ):
228        if spec.lifetime[0] is None:
229            continue
230
231        # lifetime is [start, end], both ends inclusive
232        start, end = spec.lifetime
233        byte_allocated[start] += spec.allocated_memory
234        byte_allocated[end + 1] -= spec.allocated_memory
235
236    # accumulate the bytes allocated/deallocated to get memory usages
237    memory_usages = list(itertools.accumulate(byte_allocated))
238
239    # find the peak memory usage and the index
240    peak_memory_usage = max(memory_usages, default=0)
241    peak_memory_usage_node_idx = (
242        memory_usages.index(peak_memory_usage) if memory_usages else 0
243    )
244
245    return peak_memory_usage, peak_memory_usage_node_idx
246
247
248# Print two tables with relevant memory planning information
249#
250# Per Memory Space Usage Table:
251# +--------------------------------------+----------------+-----------------------+-----------------------------+
252# | Memory Space                         |   Base Address |   Memory Size (Bytes) |   Peak Memory Usage (Bytes) |
253# +======================================+================+=======================+=============================+
254# | MEMORY SPACE A                       |     0x57be0000 |                 65213 |                       64544 |
255# | MEMORY SPACE B                       |     0x57bf0000 |                 65521 |                       36864 |
256# | MEMORY SPACE ...                     |            ... |                   ... |                         ... |
257# +--------------------------------------+----------------+-----------------------+-----------------------------+
258#
259# Total Memory Space Usage Table:
260# +-------------------------------------+---------------+---------+
261# | Peak memory usage across all spaces | 2380032 bytes | Node 86 |
262# +-------------------------------------+---------------+---------+
263def print_memory_planning_info(
264    # pyre-fixme[11]: Annotation `ExecutorchProgramManager` is not defined as a type.
265    executorch_prog: ExecutorchProgramManager,
266    memory_config: MemoryConfig,
267    alloc_graph_input: bool,
268    alloc_graph_output: bool,
269) -> None:
270    # Get the peak memory usages per memory space
271    peak_memory_usages_per_memory = find_peak_memory_usages_per_memory(
272        executorch_prog.exported_program().graph_module,
273        alloc_graph_input,
274        alloc_graph_output,
275    )
276
277    # Create a table of memory spaces and their base addresses, total memory sizes, and peak memory usage
278    memory_names, base_addrs = memory_config.memory_names, memory_config.base_addrs
279    memory_usage_table = [
280        [
281            f"{(i + 1) if memory_names is None else memory_names[i]}",
282            None if base_addrs is None else hex(base_addrs[i]),
283            memory_config.memory_sizes[i],
284            peak_memory_usages_per_memory[i],
285        ]
286        for i in range(len(peak_memory_usages_per_memory))
287    ]
288
289    # Print the memory usage per memory space as a table
290    logging.info(
291        tabulate(
292            memory_usage_table,
293            headers=[
294                "Memory Space",
295                "Base Address",
296                "Memory Size (Bytes)",
297                "Peak Memory Usage (Bytes)",
298            ],
299            tablefmt="outline",
300        )
301    )
302
303    # Get the total peak memory usage across all memory spaces
304    total_peak_memory_usage = find_peak_memory_usage(
305        executorch_prog.exported_program().graph_module,
306        alloc_graph_input,
307        alloc_graph_output,
308    )
309
310    # Create a table with total peak memory usage and node at which this occurs
311    total_memory_usage_table = [
312        [
313            "Peak memory usage across all spaces",
314            f"{total_peak_memory_usage[0]} bytes",
315            f"Node {total_peak_memory_usage[1]}",
316        ]
317    ]
318
319    # Print the total memory usage as a table
320    logging.info(
321        tabulate(
322            total_memory_usage_table,
323            tablefmt="outline",
324        )
325    )
326
327
328class CadenceMemoryPlanning:
329    def __init__(
330        self,
331        memory_config: MemoryConfig,
332        mem_algo: int,
333        alloc_graph_input: bool = True,
334        alloc_graph_output: bool = True,
335    ) -> None:
336        self._init_mem_algos()
337
338        self.memory_config = memory_config
339        self.mem_algo = mem_algo
340        self.alloc_graph_input = alloc_graph_input
341        self.alloc_graph_output = alloc_graph_output
342
343    def _init_mem_algos(self) -> None:
344        self.available_mem_algos = [
345            position_based_greedy_with_hierarchy,
346            greedy_by_size_for_offset_calculation_with_hierarchy,
347        ]
348
349    def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
350        algo = partial(
351            self.available_mem_algos[self.mem_algo],
352            memory_config=self.memory_config,
353        )
354        # Create the memory planning pass. We allocate memory for input
355        # (output) tensors if alloc_graph_input (alloc_graph_output) is
356        # True.
357        mem_planning = MemoryPlanningPass(
358            algo,
359            allow_lifetime_and_storage_overlap=False,
360            alloc_graph_input=self.alloc_graph_input,
361            alloc_graph_output=self.alloc_graph_output,
362        )
363        mem_planning(graph_module)
364
365        return PassResult(graph_module, True)
366