xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/sharding_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/tf2xla/sharding_util.h"
16 
17 #include "absl/strings/match.h"
18 #include "tensorflow/core/framework/node_def.pb.h"
19 #include "tensorflow/core/lib/core/errors.h"
20 #include "tensorflow/core/util/device_name_utils.h"
21 
22 namespace tensorflow {
23 namespace {
24 const char kDeviceSuffixReplicatedCore[] = "REPLICATED_CORE";
25 const char kShardingAttribute[] = "_XlaSharding";
26 const char kShardingOpAttribute[] = "sharding";
27 }  // namespace
28 
29 namespace {
CreateOpMetadata(const std::string & op_type,const std::string & op_name)30 xla::OpMetadata CreateOpMetadata(const std::string& op_type,
31                                  const std::string& op_name) {
32   xla::OpMetadata metadata;
33   metadata.set_op_type(op_type);
34   metadata.set_op_name(op_name);
35   return metadata;
36 }
37 
AssignOpMetadataToSharding(xla::OpSharding & sharding,const string & op_type,const string & op_name)38 void AssignOpMetadataToSharding(xla::OpSharding& sharding,
39                                 const string& op_type, const string& op_name) {
40   auto metadata = CreateOpMetadata(op_type, op_name);
41   if (sharding.type() == xla::OpSharding::TUPLE) {
42     for (auto& sharding_element : *sharding.mutable_tuple_shardings()) {
43       *sharding_element.add_metadata() = metadata;
44     }
45   } else {
46     *sharding.add_metadata() = metadata;
47   }
48 }
49 
CoreOutOfRangeError(int core,int num_cores_per_replica)50 Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
51   return errors::InvalidArgument(
52       "Invalid replicated core id: ", core,
53       "; num_cores_per_replica=", num_cores_per_replica);
54 }
55 }  // namespace
56 
ParseShardingFromDevice(const string & device_name,int num_cores_per_replica,std::optional<xla::OpSharding> explicit_sharding,std::optional<xla::OpMetadata> metadata)57 StatusOr<std::optional<xla::OpSharding>> ParseShardingFromDevice(
58     const string& device_name, int num_cores_per_replica,
59     std::optional<xla::OpSharding> explicit_sharding,
60     std::optional<xla::OpMetadata> metadata) {
61   if (device_name.empty()) {
62     return explicit_sharding;
63   }
64   DeviceNameUtils::ParsedName parsed_device;
65   if (!DeviceNameUtils::ParseFullName(device_name, &parsed_device)) {
66     return errors::InvalidArgument("Malformed assigned device '", device_name,
67                                    "'");
68   }
69 
70   if (explicit_sharding.has_value()) {
71     return explicit_sharding;
72   } else if (!parsed_device.has_type || !parsed_device.has_id ||
73              !absl::StrContains(parsed_device.type,
74                                 kDeviceSuffixReplicatedCore)) {
75     return std::optional<xla::OpSharding>();
76   } else {
77     const int core = parsed_device.id;
78     if (core < 0 || core >= num_cores_per_replica) {
79       return CoreOutOfRangeError(core, num_cores_per_replica);
80     }
81     auto sharding = xla::sharding_builder::AssignDevice(core);
82     if (metadata.has_value()) {
83       *sharding.add_metadata() = metadata.value();
84     }
85     return std::optional<xla::OpSharding>(sharding);
86   }
87 }
88 
ParseShardingFromDevice(const NodeDef & node_def,int num_cores_per_replica,bool add_metadata)89 StatusOr<std::optional<xla::OpSharding>> ParseShardingFromDevice(
90     const NodeDef& node_def, int num_cores_per_replica, bool add_metadata) {
91   const string& device_name = node_def.device();
92   TF_ASSIGN_OR_RETURN(std::optional<xla::OpSharding> sharding,
93                       GetShardingFromNodeDef(node_def, add_metadata));
94   return ParseShardingFromDevice(
95       device_name, num_cores_per_replica, sharding,
96       add_metadata ? std::optional<xla::OpMetadata>(
97                          CreateOpMetadata(node_def.op(), node_def.name()))
98                    : std::nullopt);
99 }
100 
ParseShardingFromDevice(const Node & node,int num_cores_per_replica,bool add_metadata)101 StatusOr<std::optional<xla::OpSharding>> ParseShardingFromDevice(
102     const Node& node, int num_cores_per_replica, bool add_metadata) {
103   string device_name = node.assigned_device_name();
104   if (device_name.empty()) {
105     device_name = node.requested_device();
106   }
107   TF_ASSIGN_OR_RETURN(std::optional<xla::OpSharding> sharding,
108                       GetShardingFromNodeDef(node.def(), add_metadata));
109   return ParseShardingFromDevice(
110       device_name, num_cores_per_replica, sharding,
111       add_metadata ? std::optional<xla::OpMetadata>(
112                          CreateOpMetadata(node.type_string(), node.name()))
113                    : std::nullopt);
114 }
115 
ParseShardingFromEdgeSource(const Edge & edge,int num_cores_per_replica,bool add_metadata)116 StatusOr<std::optional<xla::OpSharding>> ParseShardingFromEdgeSource(
117     const Edge& edge, int num_cores_per_replica, bool add_metadata) {
118   if (edge.src() == nullptr) {
119     return tensorflow::errors::InvalidArgument(
120         "Null src for ParseShardingFromEdgeSource edge=", edge.DebugString());
121   }
122   TF_ASSIGN_OR_RETURN(std::optional<xla::OpSharding> sharding,
123                       ParseShardingFromDevice(
124                           *edge.src(), num_cores_per_replica, add_metadata));
125   if (sharding.has_value() &&
126       sharding.value().type() == xla::OpSharding::TUPLE) {
127     if (edge.src_output() < 0 ||
128         edge.src_output() >= sharding.value().tuple_shardings_size()) {
129       return tensorflow::errors::InvalidArgument(
130           "Tuple index out of bound: edge=", edge.DebugString(),
131           " sharding=", sharding->DebugString());
132     }
133     std::optional<xla::OpSharding> subsharding =
134         sharding.value().tuple_shardings(edge.src_output());
135     return subsharding;
136   }
137   return sharding;
138 }
139 
SetShardingDeviceAssignmentFromNode(const Node & src,Node * dst)140 void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) {
141   string device_name = src.assigned_device_name();
142   if (device_name.empty()) {
143     device_name = src.requested_device();
144   }
145   dst->set_assigned_device_name(device_name);
146   if (const AttrValue* attr = src.attrs().Find(kShardingAttribute)) {
147     dst->AddAttr(kShardingAttribute, *attr);
148   }
149 }
150 
151 namespace {
152 
GetShardingFromNodeDefInternal(const NodeDef & node_def,bool add_metadata,const char * attribute)153 StatusOr<std::optional<xla::OpSharding>> GetShardingFromNodeDefInternal(
154     const NodeDef& node_def, bool add_metadata, const char* attribute) {
155   if (!HasNodeAttr(node_def, attribute)) {
156     return std::optional<xla::OpSharding>();
157   }
158   string value;
159   xla::OpSharding sharding;
160   TF_RETURN_IF_ERROR(GetNodeAttr(node_def, attribute, &value));
161   if (!sharding.ParseFromString(value)) {
162     return xla::InvalidArgument(
163         "Experimental %s attribute was not a valid encoded xla::OpSharding "
164         "proto.",
165         attribute);
166   }
167   if (add_metadata) {
168     AssignOpMetadataToSharding(sharding, node_def.op(), node_def.name());
169   }
170   return std::optional<xla::OpSharding>(sharding);
171 }
172 
173 }  // namespace
174 
GetShardingFromNodeDef(const NodeDef & node_def,bool add_metadata)175 xla::StatusOr<std::optional<xla::OpSharding>> GetShardingFromNodeDef(
176     const NodeDef& node_def, bool add_metadata) {
177   if (node_def.op() == "XlaSharding") {
178     TF_ASSIGN_OR_RETURN(auto sharding,
179                         GetShardingFromNodeDefInternal(node_def, add_metadata,
180                                                        kShardingOpAttribute));
181     if (sharding.has_value()) {
182       return sharding;
183     }
184   }
185   return GetShardingFromNodeDefInternal(node_def, add_metadata,
186                                         kShardingAttribute);
187 }
188 
189 }  // namespace tensorflow
190