1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17
18 #include "mlir-hlo/Analysis/userange_analysis.h"
19 #include "mlir-hlo/Transforms/PassDetail.h"
20 #include "mlir-hlo/Transforms/passes.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/IR/Operation.h"
23 #include "mlir/Interfaces/LoopLikeInterface.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/BufferUtils.h"
26
27 namespace mlir {
28
29 namespace {
30
31 /// Reuses already allocated buffer to save allocation operations.
32 class BufferReuse : BufferPlacementTransformationBase {
33 using ValueSetMap = llvm::MapVector<Value, DenseSet<Value>>;
34 using ValueVectorMap = llvm::MapVector<Value, SmallVector<Value, 4>>;
35
36 public:
BufferReuse(Operation * op)37 explicit BufferReuse(Operation *op)
38 : BufferPlacementTransformationBase(op),
39 dominators(op),
40 postDominators(op),
41 userange(op, allocs, aliases) {}
42
43 /// Reuses already allocated buffers to save allocation operations.
reuse(Operation * operation)44 void reuse(Operation *operation) {
45 // Create a list of values that can potentially be replaced for each value
46 // in the useRangeMap. The potentialReuseMap maps each value to the
47 // respective list.
48 ValueVectorMap potentialReuseMap;
49 for (BufferPlacementAllocs::AllocEntry entry : allocs) {
50 Value itemA = std::get<0>(entry);
51 SmallVector<Value, 4> potReuseVector;
52 for (BufferPlacementAllocs::AllocEntry entry : allocs) {
53 Value itemB = std::get<0>(entry);
54 // Do not compare an item to itself and make sure that the value of item
55 // B is not a BlockArgument. BlockArguments cannot be reused. Also
56 // perform a reuse compatibility check.
57 if (itemA == itemB || !checkReuseCompatibility(itemA, itemB)) continue;
58
59 // Check if itemA interferes with itemB. If this is the case no reuse is
60 // possible.
61 if (userange.rangesInterfere(itemA, itemB)) continue;
62
63 // The defining block of itemA has to dominate all uses of itemB.
64 if (!dominatesAllUses(itemA.getParentBlock(), itemB)) continue;
65
66 // Insert itemB into the right place of the potReuseVector. The order of
67 // the vector is defined via the program order of the first use of each
68 // item.
69 auto insertionPoint = potReuseVector.begin();
70 while (insertionPoint != potReuseVector.end()) {
71 if (userange.getFirstUseIndex(itemB) <
72 userange.getFirstUseIndex(*insertionPoint))
73 break;
74 ++insertionPoint;
75 }
76 potReuseVector.insert(insertionPoint, itemB);
77 }
78
79 potentialReuseMap.insert({itemA, potReuseVector});
80 }
81
82 // Replace all uses of the value that is replaced and
83 // delete the DefiningOp.
84 for (auto &reuse : computeActualReuse(potentialReuseMap)) {
85 for (Value reuseValue : reuse.second) {
86 reuseValue.replaceAllUsesWith(reuse.first);
87 reuseValue.getDefiningOp()->erase();
88 }
89 }
90 }
91
92 private:
93 /// Check if all uses of item are dominated by the given block.
dominatesAllUses(Block * block,Value item)94 bool dominatesAllUses(Block *block, Value item) {
95 for (OpOperand &operand : item.getUses()) {
96 if (!dominators.dominates(block, operand.getOwner()->getBlock()))
97 return false;
98 }
99 return true;
100 }
101
102 /// Checks if there is a transitive interference between potReuseValue and the
103 /// value that may replace it, we call this value V. potReuses is the vector
104 /// of all values that can potentially be replaced by V. If potReuseValue
105 /// already replaces any other value that is not part of the potReuses vector
106 /// it cannot be replaced by V anymore.
transitiveInterference(Value potReuseValue,SmallVector<Value,4> & potReuses,ValueSetMap & actualReuseMap)107 bool transitiveInterference(Value potReuseValue,
108 SmallVector<Value, 4> &potReuses,
109 ValueSetMap &actualReuseMap) {
110 auto actualReuser = actualReuseMap.find(potReuseValue);
111 return actualReuser != actualReuseMap.end() &&
112 llvm::any_of(actualReuser->second, [&](Value vReuse) {
113 return std::find(potReuses.begin(), potReuses.end(), vReuse) ==
114 potReuses.end();
115 });
116 }
117
118 /// Checks if the types of the given values are compatible for a
119 /// replacement.
checkReuseCompatibility(Value a,Value b)120 bool checkReuseCompatibility(Value a, Value b) {
121 auto shapedA = a.getType().cast<ShapedType>();
122 auto shapedB = b.getType().cast<ShapedType>();
123
124 // If both types are shaped we can check for equality.
125 if (shapedA.hasStaticShape() && shapedB.hasStaticShape())
126 return a.getType() == b.getType();
127 // If only one of the types is shaped we cannot detect compatibility since
128 // we do not know how the allocation operation behaves on its operands.
129 if (shapedA.hasStaticShape() != shapedB.hasStaticShape()) return false;
130
131 // Compare the element Types of both shapes.
132 if (shapedA.getElementType() != shapedB.getElementType()) return false;
133
134 // If the shapes have different ranks, we cannot reuse them.
135 if (shapedA.getRank() != shapedB.getRank()) return false;
136
137 // Compare each dimension. If the dimensions are not equal no reuse is
138 // possible.
139 for (unsigned idx = 0, e = shapedA.getRank(); idx < e; ++idx) {
140 if (shapedA.getDimSize(idx) != shapedB.getDimSize(idx)) return false;
141 }
142
143 // We need the actual alloc operation of both types. For aliases we need
144 // to check for the defining OP of the alias' origin.
145 Operation *defOpA = a.getDefiningOp();
146 Operation *defOpB = b.getDefiningOp();
147
148 // If the alloc method or the number of operands is not the same the types
149 // might not be compatible.
150 if (defOpA->getName() != defOpB->getName() ||
151 defOpA->getNumOperands() != defOpB->getNumOperands())
152 return false;
153
154 // If all operands are equal the types are compatible.
155 auto operandsA = defOpA->getOperands();
156 auto operandsB = defOpB->getOperands();
157 return std::equal(operandsA.begin(), operandsA.end(), operandsB.begin(),
158 operandsB.end());
159 }
160
161 /// A Fixpoint iteration over the potential reuses to compute the actual
162 /// reuses.
computeActualReuse(ValueVectorMap & potentialReuseMap)163 ValueSetMap computeActualReuse(ValueVectorMap &potentialReuseMap) {
164 // The replacedSet contains all values that are going to be replaced.
165 DenseSet<Value> replacedSet;
166
167 // The currentReuserSet contains all values that are replacing another
168 // value in the current iteration. Note: This is necessary because the
169 // replacing property is not transitive.
170 DenseSet<Value> currentReuserSet;
171
172 /// Maps a value to the set of values that it replaces.
173 ValueSetMap actualReuseMap;
174
175 for (;;) {
176 // Clear the currentReuserSet for this iteration.
177 currentReuserSet.clear();
178 // Step 1 of the fixpoint iteration: Choose a value to be replaced for
179 // each value in the potentialReuseMap.
180 choosePotentialReuses(replacedSet, currentReuserSet, potentialReuseMap,
181 actualReuseMap);
182
183 // If the currentReuseSet is empty we can terminate the fixpoint
184 // iteration.
185 if (currentReuserSet.empty()) break;
186
187 // Step 2 of the fixpoint iteration: Update the potentialReuseVectors for
188 // each value in the potentialReuseMap. Due to the chosen replacements in
189 // step 1 some values might not be replaceable anymore. Also remove all
190 // replaced values from the potentialReuseMap.
191 updatePotentialReuses(replacedSet, potentialReuseMap, actualReuseMap);
192 }
193 return actualReuseMap;
194 }
195
196 /// For each value in the potentialReuseMap, check if another value tries to
197 /// reuse it or if it is already replaced by another value. If neither is the
198 /// case add the value and its reuses (if any) to the actualReuseMap.
choosePotentialReuses(DenseSet<Value> & replacedSet,DenseSet<Value> & currentReuserSet,ValueVectorMap & potentialReuseMap,ValueSetMap & actualReuseMap)199 void choosePotentialReuses(DenseSet<Value> &replacedSet,
200 DenseSet<Value> ¤tReuserSet,
201 ValueVectorMap &potentialReuseMap,
202 ValueSetMap &actualReuseMap) {
203 for (auto &potReuser : potentialReuseMap) {
204 Value item = potReuser.first;
205 SmallVector<Value, 4> &potReuses = potReuser.second;
206
207 // If the current value is replaced already we have to skip it.
208 if (replacedSet.contains(item)) continue;
209
210 // Find a value that can be reused. If the value is already in the
211 // currentReuserSet then we have to break. Due to the order of the
212 // values we must not skip it, because it can potentially be replaced in
213 // the next iteration. However, we may skip the value if it is replaced
214 // by another value.
215 for (Value v : potReuses) {
216 if (currentReuserSet.contains(v)) break;
217 if (replacedSet.contains(v)) continue;
218
219 // Update the actualReuseMap.
220 actualReuseMap[item].insert(v);
221
222 // Check if the replaced value already replaces other values and also
223 // add them to the reused set.
224 auto alreadyReplaced = actualReuseMap.find(v);
225 if (alreadyReplaced != actualReuseMap.end()) {
226 actualReuseMap[item].insert(alreadyReplaced->second.begin(),
227 alreadyReplaced->second.end());
228 actualReuseMap.erase(v);
229 }
230
231 // Merge the userange of v into the userange of item.
232 userange.unionRanges(item, v);
233
234 currentReuserSet.insert(item);
235 replacedSet.insert(v);
236 break;
237 }
238 }
239 }
240
241 /// Update the potentialReuseVectors for each value in the potentialReuseMap.
updatePotentialReuses(DenseSet<Value> & replacedSet,ValueVectorMap & potentialReuseMap,ValueSetMap & actualReuseMap)242 void updatePotentialReuses(DenseSet<Value> &replacedSet,
243 ValueVectorMap &potentialReuseMap,
244 ValueSetMap &actualReuseMap) {
245 for (auto itReuseMap = potentialReuseMap.begin();
246 itReuseMap != potentialReuseMap.end();) {
247 Value item = itReuseMap->first;
248 SmallVector<Value, 4> &potReuses = itReuseMap->second;
249
250 // If the item is already reused, we can remove it from the
251 // potentialReuseMap.
252 if (replacedSet.contains(item)) {
253 potentialReuseMap.erase(itReuseMap);
254 continue;
255 }
256
257 // Remove all potential reuses that cannot be reused for this value.
258 potReuses.erase(
259 std::remove_if(potReuses.begin(), potReuses.end(),
260 [&](Value potReuseValue) {
261 return replacedSet.contains(potReuseValue) ||
262 transitiveInterference(potReuseValue,
263 potReuses,
264 actualReuseMap) ||
265 userange.rangesInterfere(item, potReuseValue);
266 }),
267 potReuses.end());
268 ++itReuseMap;
269 }
270 }
271
272 /// The current dominance info.
273 DominanceInfo dominators;
274
275 /// The current postdominance info.
276 PostDominanceInfo postDominators;
277
278 /// The current userange info.
279 UserangeAnalysis userange;
280 };
281
282 /// The buffer reuse pass that uses already allocated buffers if all critera
283 /// are met.
284 struct BufferReusePass : BufferReuseBase<BufferReusePass> {
runOnFunctionmlir::__anon1754083f0111::BufferReusePass285 void runOnFunction() override {
286 // Reuse allocated buffer instead of new allocation.
287 Operation *funcOp = getFunction();
288 BufferReuse optimizer(funcOp);
289 optimizer.reuse(funcOp);
290 }
291 };
292
293 } // end namespace
294
createBufferReusePass()295 std::unique_ptr<FunctionPass> createBufferReusePass() {
296 return std::make_unique<BufferReusePass>();
297 }
298
299 } // end namespace mlir
300