1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import collections 8import itertools 9import logging 10from functools import partial 11from typing import Iterable, List, Optional, Tuple 12 13import torch 14from executorch.backends.cadence.aot.utils import MemoryConfig 15 16from executorch.exir import ExecutorchProgramManager 17from executorch.exir.memory_planning import collect_specs_from_nodes, Verifier 18from executorch.exir.passes import MemoryPlanningPass 19from executorch.exir.tensor import TensorSpec 20from tabulate import tabulate 21from torch.export.exported_program import ExportGraphSignature 22from torch.fx.passes.infra.pass_base import PassResult 23 24 25# get num memories indexed from 1..N, compatible with EXIR's spec.mem_id 26def get_num_memories(memory_config: MemoryConfig) -> int: 27 return len(memory_config.memory_sizes) + 1 28 29 30# memory_space module provides num_memories indexed 0..num_memories-1. 31def get_size(memory_config: MemoryConfig, exir_id: int) -> int: 32 return memory_config.memory_sizes[exir_id - 1] 33 34 35def collect_specs_from_graph_module( 36 graph_module: torch.fx.GraphModule, 37 alloc_graph_input: bool, 38 alloc_graph_output: bool, 39) -> Iterable[TensorSpec]: 40 """ 41 Return the specs for all the nodes in the graph module in 42 topological order. 43 """ 44 # Collect the specs from all the nodes in the graph module, and return it 45 return collect_specs_from_nodes( 46 graph_module.graph.nodes, 47 ignore_graph_input=not alloc_graph_input, 48 ignore_graph_output=not alloc_graph_output, 49 ) 50 51 52# baseline tensor placement algorithm, that greedily tries to place the tensor in 53# the fastest memory available 54def position_based_greedy_with_hierarchy( 55 graph_module: torch.fx.GraphModule, 56 alignment: int, 57 graph_signature: ExportGraphSignature, 58 alloc_graph_input: bool, 59 alloc_graph_output: bool, 60 *, 61 memory_config: MemoryConfig, 62) -> List[int]: 63 num_memories = get_num_memories(memory_config) 64 bufsizes = [0] * num_memories 65 allocated_buffers: List[List[TensorSpec]] = [[] for _ in range(num_memories)] 66 67 def overlap(spec: TensorSpec) -> Optional[TensorSpec]: 68 for allocated_spec in allocated_buffers[spec.mem_id]: 69 if Verifier.lifetime_overlap( 70 spec, allocated_spec 71 ) and Verifier.storage_overlap(spec, allocated_spec): 72 return allocated_spec 73 return None 74 75 def memory_available(spec: TensorSpec) -> bool: 76 return spec.mem_offset + spec.allocated_memory <= get_size( 77 memory_config, spec.mem_id 78 ) 79 80 # Iterate over all the specs in sorted order 81 for spec in sorted( 82 collect_specs_from_graph_module( 83 graph_module, alloc_graph_input, alloc_graph_output 84 ), 85 key=lambda spec: spec.allocated_memory, 86 reverse=True, 87 ): 88 for spec.mem_id in range(1, num_memories): 89 spec.mem_offset = 0 90 while memory_available(spec) and (overlapped := overlap(spec)): 91 spec.mem_offset = overlapped.mem_offset + overlapped.allocated_memory 92 if memory_available(spec): 93 allocated_buffers[spec.mem_id].append(spec) 94 bufsizes[spec.mem_id] = max( 95 spec.mem_offset + spec.allocated_memory, bufsizes[spec.mem_id] 96 ) 97 break 98 if ( 99 not allocated_buffers[spec.mem_id] 100 or allocated_buffers[spec.mem_id][-1] is not spec 101 ): 102 raise MemoryError(f"Cannot fit {spec} in any memory hierarchy") 103 104 logging.debug( 105 f"position based greedy algorithm with hierarchy returns bufsizes: {bufsizes}" 106 ) 107 return bufsizes 108 109 110# Greedy tensor placement with the heuristics from arxiv.org/pdf/2001.03288.pdf 111def greedy_by_size_for_offset_calculation_with_hierarchy( 112 graph_module: torch.fx.GraphModule, 113 alignment: int, 114 graph_signature: ExportGraphSignature, 115 alloc_graph_input: bool, 116 alloc_graph_output: bool, 117 *, 118 memory_config: MemoryConfig, 119) -> List[int]: 120 num_memories = get_num_memories(memory_config) 121 bufsizes = [0] * num_memories 122 allocated_buffers = [[] for _ in range(num_memories)] 123 124 # Iterate over all the specs in sorted order 125 for spec in sorted( 126 collect_specs_from_graph_module( 127 graph_module, alloc_graph_input, alloc_graph_output 128 ), 129 key=lambda spec: spec.allocated_memory, 130 reverse=True, 131 ): 132 for spec.mem_id in range(1, num_memories): 133 prev_offset, smallest_gap = 0, float("inf") 134 for allocated_spec in allocated_buffers[spec.mem_id]: 135 if Verifier.lifetime_overlap(spec, allocated_spec): 136 if ( 137 gap := allocated_spec.mem_offset - prev_offset 138 ) >= spec.allocated_memory and gap < smallest_gap: 139 smallest_gap = gap 140 spec.mem_offset = prev_offset 141 # Note that different from the paper, which updates prev_offset for all 142 # allocated tensors, we only update tensors with overlapping lifetime. 143 # Updating prev_offset outside the if statement will include tensors without 144 # overlapping lifetime, causing unnecessary waste of memory and make the 145 # calculation of gap incorrect. Moving it out will make the algorithm degenerate 146 # to the naive one, reusing 0 tensor. The paper may have a typo here. 147 prev_offset = max( 148 allocated_spec.mem_offset + allocated_spec.allocated_memory, 149 prev_offset, 150 ) 151 if spec.mem_offset is None: 152 if prev_offset + spec.allocated_memory > get_size( 153 memory_config, spec.mem_id 154 ): 155 continue 156 else: 157 spec.mem_offset = prev_offset 158 bufsizes[spec.mem_id] = max( 159 spec.mem_offset + spec.allocated_memory, bufsizes[spec.mem_id] 160 ) 161 allocated_buffers[spec.mem_id].append(spec) 162 allocated_buffers[spec.mem_id].sort(key=lambda spec: spec.mem_offset) 163 # A data structure used for maintaining the tensor order 164 # by offset, named ordered_allocated_ids in the paper 165 break 166 if spec not in allocated_buffers[spec.mem_id]: 167 raise MemoryError(f"Cannot fit {spec} in any memory hierarchy") 168 169 logging.debug( 170 f"greedy by size for offset calculation with hierarchy returns bufsizes: {bufsizes}" 171 ) 172 return bufsizes 173 174 175def find_peak_memory_usages_per_memory( 176 graph_module: torch.fx.GraphModule, 177 alloc_graph_input: bool, 178 alloc_graph_output: bool, 179) -> List[int]: 180 """ 181 Given a GraphModule with a memory plan, find the peak memory usages for each memory 182 in the memory hierarchy. 183 """ 184 # Create a defaultdict to keep track of memory usages: {mem_id: mem_usage} 185 # Use a defaultdict here because we don't know how many unique memory_id in 186 # the memory hierarchy used in memory planning. 187 usages = collections.defaultdict(int) 188 189 # go through all nodes in the graph, collect memory usage per spec.mem_id 190 for spec in collect_specs_from_graph_module( 191 graph_module, alloc_graph_input, alloc_graph_output 192 ): 193 usages[spec.mem_id] = max( 194 usages[spec.mem_id], spec.mem_offset + spec.allocated_memory 195 ) 196 197 # Convert usages dictionary into list of len of max memory id 198 # Ex: {1: 20, 3:30} -> [0, 20, 0, 30]. 199 # ^ ^ ^ ^ 200 # | | | |_ mem_id 3 201 # | | |_ mem_id 2 202 # | |_ mem_id 1 203 # |_ mem_id 0 204 max_mem_id = max(usages.keys(), default=0) 205 usages = [usages[i] for i in range(1, max_mem_id + 1)] 206 207 return usages 208 209 210def find_peak_memory_usage( 211 graph_module: torch.fx.GraphModule, 212 alloc_graph_input: bool, 213 alloc_graph_output: bool, 214) -> Tuple[int, int]: 215 """ 216 Given a GraphModule with a memory plan, find the peak usage over time across all 217 memories in the memory hierarchy. The resulting peak memory usage should be: 218 1. >= min(find_peak_memory_usages_per_memory(graph_module)) 219 2. <= sum(find_peak_memory_usages_per_memory(graph_module)) 220 """ 221 # memory allocations over time (measured in nodex index) 222 byte_allocated = [0] * (len(graph_module.graph.nodes) + 1) 223 224 # Iterate over all the node specs 225 for spec in collect_specs_from_graph_module( 226 graph_module, alloc_graph_input, alloc_graph_output 227 ): 228 if spec.lifetime[0] is None: 229 continue 230 231 # lifetime is [start, end], both ends inclusive 232 start, end = spec.lifetime 233 byte_allocated[start] += spec.allocated_memory 234 byte_allocated[end + 1] -= spec.allocated_memory 235 236 # accumulate the bytes allocated/deallocated to get memory usages 237 memory_usages = list(itertools.accumulate(byte_allocated)) 238 239 # find the peak memory usage and the index 240 peak_memory_usage = max(memory_usages, default=0) 241 peak_memory_usage_node_idx = ( 242 memory_usages.index(peak_memory_usage) if memory_usages else 0 243 ) 244 245 return peak_memory_usage, peak_memory_usage_node_idx 246 247 248# Print two tables with relevant memory planning information 249# 250# Per Memory Space Usage Table: 251# +--------------------------------------+----------------+-----------------------+-----------------------------+ 252# | Memory Space | Base Address | Memory Size (Bytes) | Peak Memory Usage (Bytes) | 253# +======================================+================+=======================+=============================+ 254# | MEMORY SPACE A | 0x57be0000 | 65213 | 64544 | 255# | MEMORY SPACE B | 0x57bf0000 | 65521 | 36864 | 256# | MEMORY SPACE ... | ... | ... | ... | 257# +--------------------------------------+----------------+-----------------------+-----------------------------+ 258# 259# Total Memory Space Usage Table: 260# +-------------------------------------+---------------+---------+ 261# | Peak memory usage across all spaces | 2380032 bytes | Node 86 | 262# +-------------------------------------+---------------+---------+ 263def print_memory_planning_info( 264 # pyre-fixme[11]: Annotation `ExecutorchProgramManager` is not defined as a type. 265 executorch_prog: ExecutorchProgramManager, 266 memory_config: MemoryConfig, 267 alloc_graph_input: bool, 268 alloc_graph_output: bool, 269) -> None: 270 # Get the peak memory usages per memory space 271 peak_memory_usages_per_memory = find_peak_memory_usages_per_memory( 272 executorch_prog.exported_program().graph_module, 273 alloc_graph_input, 274 alloc_graph_output, 275 ) 276 277 # Create a table of memory spaces and their base addresses, total memory sizes, and peak memory usage 278 memory_names, base_addrs = memory_config.memory_names, memory_config.base_addrs 279 memory_usage_table = [ 280 [ 281 f"{(i + 1) if memory_names is None else memory_names[i]}", 282 None if base_addrs is None else hex(base_addrs[i]), 283 memory_config.memory_sizes[i], 284 peak_memory_usages_per_memory[i], 285 ] 286 for i in range(len(peak_memory_usages_per_memory)) 287 ] 288 289 # Print the memory usage per memory space as a table 290 logging.info( 291 tabulate( 292 memory_usage_table, 293 headers=[ 294 "Memory Space", 295 "Base Address", 296 "Memory Size (Bytes)", 297 "Peak Memory Usage (Bytes)", 298 ], 299 tablefmt="outline", 300 ) 301 ) 302 303 # Get the total peak memory usage across all memory spaces 304 total_peak_memory_usage = find_peak_memory_usage( 305 executorch_prog.exported_program().graph_module, 306 alloc_graph_input, 307 alloc_graph_output, 308 ) 309 310 # Create a table with total peak memory usage and node at which this occurs 311 total_memory_usage_table = [ 312 [ 313 "Peak memory usage across all spaces", 314 f"{total_peak_memory_usage[0]} bytes", 315 f"Node {total_peak_memory_usage[1]}", 316 ] 317 ] 318 319 # Print the total memory usage as a table 320 logging.info( 321 tabulate( 322 total_memory_usage_table, 323 tablefmt="outline", 324 ) 325 ) 326 327 328class CadenceMemoryPlanning: 329 def __init__( 330 self, 331 memory_config: MemoryConfig, 332 mem_algo: int, 333 alloc_graph_input: bool = True, 334 alloc_graph_output: bool = True, 335 ) -> None: 336 self._init_mem_algos() 337 338 self.memory_config = memory_config 339 self.mem_algo = mem_algo 340 self.alloc_graph_input = alloc_graph_input 341 self.alloc_graph_output = alloc_graph_output 342 343 def _init_mem_algos(self) -> None: 344 self.available_mem_algos = [ 345 position_based_greedy_with_hierarchy, 346 greedy_by_size_for_offset_calculation_with_hierarchy, 347 ] 348 349 def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult: 350 algo = partial( 351 self.available_mem_algos[self.mem_algo], 352 memory_config=self.memory_config, 353 ) 354 # Create the memory planning pass. We allocate memory for input 355 # (output) tensors if alloc_graph_input (alloc_graph_output) is 356 # True. 357 mem_planning = MemoryPlanningPass( 358 algo, 359 allow_lifetime_and_storage_overlap=False, 360 alloc_graph_input=self.alloc_graph_input, 361 alloc_graph_output=self.alloc_graph_output, 362 ) 363 mem_planning(graph_module) 364 365 return PassResult(graph_module, True) 366