1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/tf2xla/sharding_util.h"
16
17 #include "absl/strings/match.h"
18 #include "tensorflow/core/framework/node_def.pb.h"
19 #include "tensorflow/core/lib/core/errors.h"
20 #include "tensorflow/core/util/device_name_utils.h"
21
22 namespace tensorflow {
23 namespace {
24 const char kDeviceSuffixReplicatedCore[] = "REPLICATED_CORE";
25 const char kShardingAttribute[] = "_XlaSharding";
26 const char kShardingOpAttribute[] = "sharding";
27 } // namespace
28
29 namespace {
CreateOpMetadata(const std::string & op_type,const std::string & op_name)30 xla::OpMetadata CreateOpMetadata(const std::string& op_type,
31 const std::string& op_name) {
32 xla::OpMetadata metadata;
33 metadata.set_op_type(op_type);
34 metadata.set_op_name(op_name);
35 return metadata;
36 }
37
AssignOpMetadataToSharding(xla::OpSharding & sharding,const string & op_type,const string & op_name)38 void AssignOpMetadataToSharding(xla::OpSharding& sharding,
39 const string& op_type, const string& op_name) {
40 auto metadata = CreateOpMetadata(op_type, op_name);
41 if (sharding.type() == xla::OpSharding::TUPLE) {
42 for (auto& sharding_element : *sharding.mutable_tuple_shardings()) {
43 *sharding_element.add_metadata() = metadata;
44 }
45 } else {
46 *sharding.add_metadata() = metadata;
47 }
48 }
49
CoreOutOfRangeError(int core,int num_cores_per_replica)50 Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
51 return errors::InvalidArgument(
52 "Invalid replicated core id: ", core,
53 "; num_cores_per_replica=", num_cores_per_replica);
54 }
55 } // namespace
56
ParseShardingFromDevice(const string & device_name,int num_cores_per_replica,std::optional<xla::OpSharding> explicit_sharding,std::optional<xla::OpMetadata> metadata)57 StatusOr<std::optional<xla::OpSharding>> ParseShardingFromDevice(
58 const string& device_name, int num_cores_per_replica,
59 std::optional<xla::OpSharding> explicit_sharding,
60 std::optional<xla::OpMetadata> metadata) {
61 if (device_name.empty()) {
62 return explicit_sharding;
63 }
64 DeviceNameUtils::ParsedName parsed_device;
65 if (!DeviceNameUtils::ParseFullName(device_name, &parsed_device)) {
66 return errors::InvalidArgument("Malformed assigned device '", device_name,
67 "'");
68 }
69
70 if (explicit_sharding.has_value()) {
71 return explicit_sharding;
72 } else if (!parsed_device.has_type || !parsed_device.has_id ||
73 !absl::StrContains(parsed_device.type,
74 kDeviceSuffixReplicatedCore)) {
75 return std::optional<xla::OpSharding>();
76 } else {
77 const int core = parsed_device.id;
78 if (core < 0 || core >= num_cores_per_replica) {
79 return CoreOutOfRangeError(core, num_cores_per_replica);
80 }
81 auto sharding = xla::sharding_builder::AssignDevice(core);
82 if (metadata.has_value()) {
83 *sharding.add_metadata() = metadata.value();
84 }
85 return std::optional<xla::OpSharding>(sharding);
86 }
87 }
88
ParseShardingFromDevice(const NodeDef & node_def,int num_cores_per_replica,bool add_metadata)89 StatusOr<std::optional<xla::OpSharding>> ParseShardingFromDevice(
90 const NodeDef& node_def, int num_cores_per_replica, bool add_metadata) {
91 const string& device_name = node_def.device();
92 TF_ASSIGN_OR_RETURN(std::optional<xla::OpSharding> sharding,
93 GetShardingFromNodeDef(node_def, add_metadata));
94 return ParseShardingFromDevice(
95 device_name, num_cores_per_replica, sharding,
96 add_metadata ? std::optional<xla::OpMetadata>(
97 CreateOpMetadata(node_def.op(), node_def.name()))
98 : std::nullopt);
99 }
100
ParseShardingFromDevice(const Node & node,int num_cores_per_replica,bool add_metadata)101 StatusOr<std::optional<xla::OpSharding>> ParseShardingFromDevice(
102 const Node& node, int num_cores_per_replica, bool add_metadata) {
103 string device_name = node.assigned_device_name();
104 if (device_name.empty()) {
105 device_name = node.requested_device();
106 }
107 TF_ASSIGN_OR_RETURN(std::optional<xla::OpSharding> sharding,
108 GetShardingFromNodeDef(node.def(), add_metadata));
109 return ParseShardingFromDevice(
110 device_name, num_cores_per_replica, sharding,
111 add_metadata ? std::optional<xla::OpMetadata>(
112 CreateOpMetadata(node.type_string(), node.name()))
113 : std::nullopt);
114 }
115
ParseShardingFromEdgeSource(const Edge & edge,int num_cores_per_replica,bool add_metadata)116 StatusOr<std::optional<xla::OpSharding>> ParseShardingFromEdgeSource(
117 const Edge& edge, int num_cores_per_replica, bool add_metadata) {
118 if (edge.src() == nullptr) {
119 return tensorflow::errors::InvalidArgument(
120 "Null src for ParseShardingFromEdgeSource edge=", edge.DebugString());
121 }
122 TF_ASSIGN_OR_RETURN(std::optional<xla::OpSharding> sharding,
123 ParseShardingFromDevice(
124 *edge.src(), num_cores_per_replica, add_metadata));
125 if (sharding.has_value() &&
126 sharding.value().type() == xla::OpSharding::TUPLE) {
127 if (edge.src_output() < 0 ||
128 edge.src_output() >= sharding.value().tuple_shardings_size()) {
129 return tensorflow::errors::InvalidArgument(
130 "Tuple index out of bound: edge=", edge.DebugString(),
131 " sharding=", sharding->DebugString());
132 }
133 std::optional<xla::OpSharding> subsharding =
134 sharding.value().tuple_shardings(edge.src_output());
135 return subsharding;
136 }
137 return sharding;
138 }
139
SetShardingDeviceAssignmentFromNode(const Node & src,Node * dst)140 void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) {
141 string device_name = src.assigned_device_name();
142 if (device_name.empty()) {
143 device_name = src.requested_device();
144 }
145 dst->set_assigned_device_name(device_name);
146 if (const AttrValue* attr = src.attrs().Find(kShardingAttribute)) {
147 dst->AddAttr(kShardingAttribute, *attr);
148 }
149 }
150
151 namespace {
152
GetShardingFromNodeDefInternal(const NodeDef & node_def,bool add_metadata,const char * attribute)153 StatusOr<std::optional<xla::OpSharding>> GetShardingFromNodeDefInternal(
154 const NodeDef& node_def, bool add_metadata, const char* attribute) {
155 if (!HasNodeAttr(node_def, attribute)) {
156 return std::optional<xla::OpSharding>();
157 }
158 string value;
159 xla::OpSharding sharding;
160 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, attribute, &value));
161 if (!sharding.ParseFromString(value)) {
162 return xla::InvalidArgument(
163 "Experimental %s attribute was not a valid encoded xla::OpSharding "
164 "proto.",
165 attribute);
166 }
167 if (add_metadata) {
168 AssignOpMetadataToSharding(sharding, node_def.op(), node_def.name());
169 }
170 return std::optional<xla::OpSharding>(sharding);
171 }
172
173 } // namespace
174
GetShardingFromNodeDef(const NodeDef & node_def,bool add_metadata)175 xla::StatusOr<std::optional<xla::OpSharding>> GetShardingFromNodeDef(
176 const NodeDef& node_def, bool add_metadata) {
177 if (node_def.op() == "XlaSharding") {
178 TF_ASSIGN_OR_RETURN(auto sharding,
179 GetShardingFromNodeDefInternal(node_def, add_metadata,
180 kShardingOpAttribute));
181 if (sharding.has_value()) {
182 return sharding;
183 }
184 }
185 return GetShardingFromNodeDefInternal(node_def, add_metadata,
186 kShardingAttribute);
187 }
188
189 } // namespace tensorflow
190