xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/hlo/lib/Transforms/buffer_reuse.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 
18 #include "mlir-hlo/Analysis/userange_analysis.h"
19 #include "mlir-hlo/Transforms/PassDetail.h"
20 #include "mlir-hlo/Transforms/passes.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/IR/Operation.h"
23 #include "mlir/Interfaces/LoopLikeInterface.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/BufferUtils.h"
26 
27 namespace mlir {
28 
29 namespace {
30 
31 /// Reuses already allocated buffer to save allocation operations.
32 class BufferReuse : BufferPlacementTransformationBase {
33   using ValueSetMap = llvm::MapVector<Value, DenseSet<Value>>;
34   using ValueVectorMap = llvm::MapVector<Value, SmallVector<Value, 4>>;
35 
36  public:
BufferReuse(Operation * op)37   explicit BufferReuse(Operation *op)
38       : BufferPlacementTransformationBase(op),
39         dominators(op),
40         postDominators(op),
41         userange(op, allocs, aliases) {}
42 
43   /// Reuses already allocated buffers to save allocation operations.
reuse(Operation * operation)44   void reuse(Operation *operation) {
45     // Create a list of values that can potentially be replaced for each value
46     // in the useRangeMap. The potentialReuseMap maps each value to the
47     // respective list.
48     ValueVectorMap potentialReuseMap;
49     for (BufferPlacementAllocs::AllocEntry entry : allocs) {
50       Value itemA = std::get<0>(entry);
51       SmallVector<Value, 4> potReuseVector;
52       for (BufferPlacementAllocs::AllocEntry entry : allocs) {
53         Value itemB = std::get<0>(entry);
54         // Do not compare an item to itself and make sure that the value of item
55         // B is not a BlockArgument. BlockArguments cannot be reused. Also
56         // perform a reuse compatibility check.
57         if (itemA == itemB || !checkReuseCompatibility(itemA, itemB)) continue;
58 
59         // Check if itemA interferes with itemB. If this is the case no reuse is
60         // possible.
61         if (userange.rangesInterfere(itemA, itemB)) continue;
62 
63         // The defining block of itemA has to dominate all uses of itemB.
64         if (!dominatesAllUses(itemA.getParentBlock(), itemB)) continue;
65 
66         // Insert itemB into the right place of the potReuseVector. The order of
67         // the vector is defined via the program order of the first use of each
68         // item.
69         auto insertionPoint = potReuseVector.begin();
70         while (insertionPoint != potReuseVector.end()) {
71           if (userange.getFirstUseIndex(itemB) <
72               userange.getFirstUseIndex(*insertionPoint))
73             break;
74           ++insertionPoint;
75         }
76         potReuseVector.insert(insertionPoint, itemB);
77       }
78 
79       potentialReuseMap.insert({itemA, potReuseVector});
80     }
81 
82     // Replace all uses of the value that is replaced and
83     // delete the DefiningOp.
84     for (auto &reuse : computeActualReuse(potentialReuseMap)) {
85       for (Value reuseValue : reuse.second) {
86         reuseValue.replaceAllUsesWith(reuse.first);
87         reuseValue.getDefiningOp()->erase();
88       }
89     }
90   }
91 
92  private:
93   /// Check if all uses of item are dominated by the given block.
dominatesAllUses(Block * block,Value item)94   bool dominatesAllUses(Block *block, Value item) {
95     for (OpOperand &operand : item.getUses()) {
96       if (!dominators.dominates(block, operand.getOwner()->getBlock()))
97         return false;
98     }
99     return true;
100   }
101 
102   /// Checks if there is a transitive interference between potReuseValue and the
103   /// value that may replace it, we call this value V. potReuses is the vector
104   /// of all values that can potentially be replaced by V. If potReuseValue
105   /// already replaces any other value that is not part of the potReuses vector
106   /// it cannot be replaced by V anymore.
transitiveInterference(Value potReuseValue,SmallVector<Value,4> & potReuses,ValueSetMap & actualReuseMap)107   bool transitiveInterference(Value potReuseValue,
108                               SmallVector<Value, 4> &potReuses,
109                               ValueSetMap &actualReuseMap) {
110     auto actualReuser = actualReuseMap.find(potReuseValue);
111     return actualReuser != actualReuseMap.end() &&
112            llvm::any_of(actualReuser->second, [&](Value vReuse) {
113              return std::find(potReuses.begin(), potReuses.end(), vReuse) ==
114                     potReuses.end();
115            });
116   }
117 
118   /// Checks if the types of the given values are compatible for a
119   /// replacement.
checkReuseCompatibility(Value a,Value b)120   bool checkReuseCompatibility(Value a, Value b) {
121     auto shapedA = a.getType().cast<ShapedType>();
122     auto shapedB = b.getType().cast<ShapedType>();
123 
124     // If both types are shaped we can check for equality.
125     if (shapedA.hasStaticShape() && shapedB.hasStaticShape())
126       return a.getType() == b.getType();
127     // If only one of the types is shaped we cannot detect compatibility since
128     // we do not know how the allocation operation behaves on its operands.
129     if (shapedA.hasStaticShape() != shapedB.hasStaticShape()) return false;
130 
131     // Compare the element Types of both shapes.
132     if (shapedA.getElementType() != shapedB.getElementType()) return false;
133 
134     // If the shapes have different ranks, we cannot reuse them.
135     if (shapedA.getRank() != shapedB.getRank()) return false;
136 
137     // Compare each dimension. If the dimensions are not equal no reuse is
138     // possible.
139     for (unsigned idx = 0, e = shapedA.getRank(); idx < e; ++idx) {
140       if (shapedA.getDimSize(idx) != shapedB.getDimSize(idx)) return false;
141     }
142 
143     // We need the actual alloc operation of both types. For aliases we need
144     // to check for the defining OP of the alias' origin.
145     Operation *defOpA = a.getDefiningOp();
146     Operation *defOpB = b.getDefiningOp();
147 
148     // If the alloc method or the number of operands is not the same the types
149     // might not be compatible.
150     if (defOpA->getName() != defOpB->getName() ||
151         defOpA->getNumOperands() != defOpB->getNumOperands())
152       return false;
153 
154     // If all operands are equal the types are compatible.
155     auto operandsA = defOpA->getOperands();
156     auto operandsB = defOpB->getOperands();
157     return std::equal(operandsA.begin(), operandsA.end(), operandsB.begin(),
158                       operandsB.end());
159   }
160 
161   /// A Fixpoint iteration over the potential reuses to compute the actual
162   /// reuses.
computeActualReuse(ValueVectorMap & potentialReuseMap)163   ValueSetMap computeActualReuse(ValueVectorMap &potentialReuseMap) {
164     // The replacedSet contains all values that are going to be replaced.
165     DenseSet<Value> replacedSet;
166 
167     // The currentReuserSet contains all values that are replacing another
168     // value in the current iteration. Note: This is necessary because the
169     // replacing property is not transitive.
170     DenseSet<Value> currentReuserSet;
171 
172     /// Maps a value to the set of values that it replaces.
173     ValueSetMap actualReuseMap;
174 
175     for (;;) {
176       // Clear the currentReuserSet for this iteration.
177       currentReuserSet.clear();
178       // Step 1 of the fixpoint iteration: Choose a value to be replaced for
179       // each value in the potentialReuseMap.
180       choosePotentialReuses(replacedSet, currentReuserSet, potentialReuseMap,
181                             actualReuseMap);
182 
183       // If the currentReuseSet is empty we can terminate the fixpoint
184       // iteration.
185       if (currentReuserSet.empty()) break;
186 
187       // Step 2 of the fixpoint iteration: Update the potentialReuseVectors for
188       // each value in the potentialReuseMap. Due to the chosen replacements in
189       // step 1 some values might not be replaceable anymore. Also remove all
190       // replaced values from the potentialReuseMap.
191       updatePotentialReuses(replacedSet, potentialReuseMap, actualReuseMap);
192     }
193     return actualReuseMap;
194   }
195 
196   /// For each value in the potentialReuseMap, check if another value tries to
197   /// reuse it or if it is already replaced by another value. If neither is the
198   /// case add the value and its reuses (if any) to the actualReuseMap.
choosePotentialReuses(DenseSet<Value> & replacedSet,DenseSet<Value> & currentReuserSet,ValueVectorMap & potentialReuseMap,ValueSetMap & actualReuseMap)199   void choosePotentialReuses(DenseSet<Value> &replacedSet,
200                              DenseSet<Value> &currentReuserSet,
201                              ValueVectorMap &potentialReuseMap,
202                              ValueSetMap &actualReuseMap) {
203     for (auto &potReuser : potentialReuseMap) {
204       Value item = potReuser.first;
205       SmallVector<Value, 4> &potReuses = potReuser.second;
206 
207       // If the current value is replaced already we have to skip it.
208       if (replacedSet.contains(item)) continue;
209 
210       // Find a value that can be reused. If the value is already in the
211       // currentReuserSet then we have to break. Due to the order of the
212       // values we must not skip it, because it can potentially be replaced in
213       // the next iteration. However, we may skip the value if it is replaced
214       // by another value.
215       for (Value v : potReuses) {
216         if (currentReuserSet.contains(v)) break;
217         if (replacedSet.contains(v)) continue;
218 
219         // Update the actualReuseMap.
220         actualReuseMap[item].insert(v);
221 
222         // Check if the replaced value already replaces other values and also
223         // add them to the reused set.
224         auto alreadyReplaced = actualReuseMap.find(v);
225         if (alreadyReplaced != actualReuseMap.end()) {
226           actualReuseMap[item].insert(alreadyReplaced->second.begin(),
227                                       alreadyReplaced->second.end());
228           actualReuseMap.erase(v);
229         }
230 
231         // Merge the userange of v into the userange of item.
232         userange.unionRanges(item, v);
233 
234         currentReuserSet.insert(item);
235         replacedSet.insert(v);
236         break;
237       }
238     }
239   }
240 
241   /// Update the potentialReuseVectors for each value in the potentialReuseMap.
updatePotentialReuses(DenseSet<Value> & replacedSet,ValueVectorMap & potentialReuseMap,ValueSetMap & actualReuseMap)242   void updatePotentialReuses(DenseSet<Value> &replacedSet,
243                              ValueVectorMap &potentialReuseMap,
244                              ValueSetMap &actualReuseMap) {
245     for (auto itReuseMap = potentialReuseMap.begin();
246          itReuseMap != potentialReuseMap.end();) {
247       Value item = itReuseMap->first;
248       SmallVector<Value, 4> &potReuses = itReuseMap->second;
249 
250       // If the item is already reused, we can remove it from the
251       // potentialReuseMap.
252       if (replacedSet.contains(item)) {
253         potentialReuseMap.erase(itReuseMap);
254         continue;
255       }
256 
257       // Remove all potential reuses that cannot be reused for this value.
258       potReuses.erase(
259           std::remove_if(potReuses.begin(), potReuses.end(),
260                          [&](Value potReuseValue) {
261                            return replacedSet.contains(potReuseValue) ||
262                                   transitiveInterference(potReuseValue,
263                                                          potReuses,
264                                                          actualReuseMap) ||
265                                   userange.rangesInterfere(item, potReuseValue);
266                          }),
267           potReuses.end());
268       ++itReuseMap;
269     }
270   }
271 
272   /// The current dominance info.
273   DominanceInfo dominators;
274 
275   /// The current postdominance info.
276   PostDominanceInfo postDominators;
277 
278   /// The current userange info.
279   UserangeAnalysis userange;
280 };
281 
282 /// The buffer reuse pass that uses already allocated buffers if all critera
283 /// are met.
284 struct BufferReusePass : BufferReuseBase<BufferReusePass> {
runOnFunctionmlir::__anon1754083f0111::BufferReusePass285   void runOnFunction() override {
286     // Reuse allocated buffer instead of new allocation.
287     Operation *funcOp = getFunction();
288     BufferReuse optimizer(funcOp);
289     optimizer.reuse(funcOp);
290   }
291 };
292 
293 }  // end namespace
294 
createBufferReusePass()295 std::unique_ptr<FunctionPass> createBufferReusePass() {
296   return std::make_unique<BufferReusePass>();
297 }
298 
299 }  // end namespace mlir
300