#include #include #include #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #else #include #include #endif namespace at { void FunctionalTensorWrapper::set_constructor_metadata() { TORCH_INTERNAL_ASSERT(value_.defined()); // Note: "level" is a concept that we don't know how to compute in core. // For now I'm retroactively setting this in functorch, // but once Open Multiple Dispatch lands we should be able to calculate this in core. level_ = -1; // mirror all of the generic tensor metadata onto the wrapper copy_generic_tensor_metadata(value_.getIntrusivePtr().get(), this); refresh_numel(); refresh_contiguous(); storage_access_should_throw_ = false; // In general, the sizes/stride metadata on a tensor can change as it is mutated, // and these changes need to be reflected in the metadata of the wrapper. set_allow_tensor_metadata_change(true); key_set_ = c10::DispatchKeySet(c10::DispatchKey::Functionalize) | value_.key_set(); // All of the keys corresponding to functorch transforms should not be copied over. // Functorch transforms all have their own wrapper tensors (e.g. BatchedTensorImpl) which expect // to participate in the functorch transforms. key_set_ = key_set_ - c10::functorch_transforms_ks - c10::python_ks; // We override a bunch of _custom(), so make sure they get called // TODO: metadata copying may not actually be necessary then set_custom_sizes_strides(SizesStridesPolicy::CustomSizes); set_custom_device(true); // E.g. when running torch.compile under inference mode, we need to make sure that // for any inputs that were created outside of inference mode (so they are not inference tensors), // then the functional wrappers that we wrap them with should also not be inference tensors. version_counter_ = value_.unsafeGetTensorImpl()->version_counter(); } FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& value) : c10::TensorImpl( c10::Storage(c10::make_intrusive(value)), c10::DispatchKeySet(DispatchKey::Functionalize) | value.key_set(), value.dtype() ), value_(value) { TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_)); TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); set_constructor_metadata(); } void FunctionalTensorWrapper::freeze_storage() const { functional_storage_impl()->freeze(); } // Note [Functionalization: Alias Removal] // When someone calls a view() op during the functionalization pass, e.g. 'b = a.view(...)', // we link `b` and `a` to a shared Alias object to preserve the aliasing relationship. // // How do we do that? // // Every FunctionalTensorWrapper contains a dummy FunctionalStorageImpl, which subclasses from c10::StorageImpl. // It doesn't contain any data (similar to MetaTensor storage), but it contains an Alias object that knows about the base tensor. // When a tensor is created through a view operation, both the new and old tensor point to the same FunctionalStorageImpl. // // As mutations are applied to any of the views, we also queue each mutation up on the Alias object, so we can replay them. // When the user requests a tensor that's had a view taken, we check if it's up to date. // If it's not up to date, we first replay all of the queued up mutations onto the alias, and then re-apply the current view // on top of the newly updated alias. // // Why do we queue up and lazily run mutations on the alias, instead of updating the alias eagerly? // This behavior was taken from pytorch/xla, which the alias-removal logic was inspired from. // One benefit of the laziness is that we save work in the cases where a user has multiple views and mutates one of them, // but never uses the other views later in the program (in which case we'll never update the alias). // It also has downsides though: repeatedly applying mutations to the same view without syncing // will silently use up more and more memory as more mutations are queued up. // // Corresponding diagram: // // b = a.view(...) // // a b // | | If the user asks for b and it’s out of date, // \/ \/ We regenerate b by replaying it’s views from the alias. // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - . // | FunctionalTensorWrapper | | FunctionalTensorWrapper | // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - . // | value | storage | | storage | Value | // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - . // | \ / | // | \ / | // | . - - - - - - - - - - - - . | // | | FunctionalStorageImpl | | // | . - - - - - - - - - - - - . | // | | Alias | | // | . - - - - - - - - - - - - . | // | / mutations to a or b | // | / are queued onto Alias | // | / | // \/ / \/ // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - . // | TensorImpl | | TensorImpl | // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - . // | value | storage | | storage | Value | // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - . // | | // | | // | | // | In this picture the two tensor views their own storages, | // | have their own storages, but backends like functorch | // \/ are allowed to re-alias underneath the pass \/ // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - . // | underyling_storage | | underyling_storage | // . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - . // // This constructor is only used by view ops. // - view_value: The output tensor that we need to wrap. // - base: The "base" of the view that `view_value` was generated from. // See Note [Functionalization: Alias Removal Part 2] for more details on the mutation replay logic. FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const FunctionalTensorWrapper* base, const functionalization::ViewMeta& meta) : c10::TensorImpl( c10::DispatchKeySet(DispatchKey::Functionalize), view_value.dtype(), view_value.device() ), value_(view_value), is_multi_output_view_(base->is_multi_output_view_ || meta.is_multi_output), was_storage_changed_(base->was_storage_changed_), is_symbolic_(base->is_symbolic_) { TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_)); TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); set_constructor_metadata(); // Copy the original tensor's ViewMeta vector and push the current one. if (!base->view_metas_.empty()) { view_metas_ = base->view_metas_; // copy } view_metas_.push_back(meta); maybe_mark_symbolic(meta); storage_ = base->storage_; // alias this tensor's storage with the base tensor's } functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_storage_impl() const { return static_cast(storage_.unsafeGetStorageImpl()); } void FunctionalTensorWrapper::commit_update() { auto storage_impl = functional_storage_impl(); storage_impl->add_update(value_, view_metas_); // As an optimization, we used to mark the tensor here as "up-to-date", // That way, code like: // x = torch.ones(1'000'000) // x[0].add_(1) // doesn't result in an unnecessary materialization of the base. // This optimization results in the slice temporarily haven't incorrect // stride/storage_offset though, and DCE should handle that optimization anyway. // generation_ = storage_impl->generation(); } bool FunctionalTensorWrapper::is_up_to_date() const { auto alias_generation = functional_storage_impl()->generation(); return generation_ == alias_generation; } // See Note [Functionalization Pass - Inplace View Ops] void FunctionalTensorWrapper::mutate_view_meta(const at::functionalization::ViewMeta& meta) { view_metas_.push_back(meta); // Manually track the fact that this tensor recieved a metadata mutation! has_metadata_mutation_ = true; // Mark this tensor as being symbolic if there are any symbolic inputs used by the view operation. maybe_mark_symbolic(meta); // Note [Functionalization Pass - Inplace View Ops] // So, these ops are special - they're mutation AND view ops. They get special codegen. // An example is transpose_, e.g. `a.transpose_()` // Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas. at::AutoDispatchSkipFunctionalize guard; value_ = meta.forward_fn(value_, meta.out_index); TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); } // Note [Functionalization: Mutation Removal] // Mutation removal is used to take a program like this: // // a.add_(b) // // and replace it with a slightly different program that has the same semantics: // // tmp = a.add(b) // a.replace_(tmp) // // Where the replace_() call is implemented directly in the functionalization pass, so it is transparent to the backend. // This is useful for backends that aren't able to handle certain types of mutations, like functorch. // // Why do we need to wrap every tensor in a FunctionalTensorWrapper? Consider this program: // // Before: // tensor.add_(batched_tensor) // // After: // tmp = tensor.add(batched_tensor) // tensor.replace_(tmp) // // In the above, tmp is a batched tensor (because adding a normal tensor to a batched tensor does broadcasting and creates a batched tensor). // But we can't just replace the underlying memory backing `tensor` with `tmp` - a batched tensor takes up more space! // Instead, every input, intermediate and output of the program is wrapped in a FunctionalTensorImpl, which wraps the underlying tensor. void FunctionalTensorWrapper::replace_(const Tensor& other, bool from_lazy_regenerate) { // TODO: going to need to change this if we want nested functionalize() transforms. TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(other)); value_ = other; TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); // out= ops are allowed to resize the output tensors, mutating both the data and metadata of the tensor. // We need to propagate that metadata mutation to the wrapper (new size). auto sizes_ = value_.sym_sizes(); auto strides_ = value_.sym_strides(); auto storage_offset_ = value_.sym_storage_offset(); set_sizes_and_strides(sizes_, strides_, storage_offset_); if (dtype() != value_.unsafeGetTensorImpl()->dtype() || layout() != value_.unsafeGetTensorImpl()->layout()) { // .to() should not re-entrantly go through functionalization. at::AutoDispatchSkipFunctionalize guard; // and we want _to_copy() to show up in the graph, not the composite .to() operator // (this can happen if autograd has already run by the time we enter this code) value_ = at::_to_copy(value_, c10::TensorOptions().dtype(dtype()).layout(layout())); TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); } // might not be until after the no_grad region is exited. // Therefore, replace_() is not unconditionally safe to check the current no_grad state. // If this is a lazy regeneration, then it is guaranteed that we have already // done the mutation for the storage alias (when we originally performed the mutation), // so no counter update may be needed. // Example: if a mutation happens to a view under a no_grad, // we won't call replace_() on the other alias until the alias is later used, which if (!from_lazy_regenerate) { mark_mutation(); if (!at::GradMode::is_enabled() || InferenceMode::is_enabled()) { // This mutation happened under no_grad or inference_mode mark_mutation_during_no_grad_or_inference_mode(); } } } bool FunctionalTensorWrapper::has_data_mutation() { // Current tensor's data was mutated if its storage saw any mutations. return functional_storage_impl()->generation() > 0; } void FunctionalTensorWrapper::set__impl(const FunctionalTensorWrapper* other) { // self.set_(src) will cause self to have all of the tensor properties of self. value_ = other->value_; generation_ = other->generation_; view_metas_ = other->view_metas_; is_symbolic_ = other->is_symbolic_; // FREEZE the old storage, preventing mutations to it. // this is a huge pain to handle properly in all cases, so we ban it. functional_storage_impl()->freeze(); // Unsafely swap out the storage with other's storage, // disconnecting `self` with its view chain storage_ = other->storage_; /// explicitly mark the tensor as having its storage changed from set_() // Otherwise, we don't actually have a 100% accurate way to check this. // (We could check if the updated value has a new storage than the original value, // but this won't also let us uniquely determine if the tensor **also** // experienced a data mutation). was_storage_changed_ = true; auto sizes_ = value_.sym_sizes(); auto strides_ = value_.sym_strides(); auto storage_offset_ = value_.sym_storage_offset(); set_sizes_and_strides(sizes_, strides_, storage_offset_); } void FunctionalTensorWrapper::storage_resize_(const c10::SymInt& new_size) { auto curr_storage_size = value_.unsafeGetTensorImpl()->unsafe_storage().unsafeGetStorageImpl()->sym_nbytes(); // storage resizing is severely limited: we only support resizing either to zero, or from zero bytes. TORCH_CHECK(new_size == 0 || curr_storage_size == 0, "new_size: ", new_size, ". curr_storage_size: ", curr_storage_size); // The "functionalization rule" for storage resizing is a giant no-op, mainly because we don't want // resize_() calls to actualy emit any ops in the functional graph. // How does it work? // Resizing up (old size == 0): // We do nothing in this case. // The expection is that for the user code to be valid, the next op that should run against the current tensor "x" // will be a x.copy_(y) (or similar), that will fully overwrite the data of x. // If there are any outstanding aliases of x, we expect them not to be used until after the copy_() call // (otherwise the eager code would be invalid), // and therefore functionalization will regenerate the aliases off of the result of `x.copy(y)`. // Resizing down (new size == 0): // We also do nothing in this case. The assumption is that after resizing a tensor down, // it is fully unused in the program (unless it is later resized back up first, has data copied in) // Although it might be saved for backward, which happens in FSDP. // The expected pattern is that the param will then be resized back up from zero in the backward. // Mark the tensor as having its storage resized. // This is so we can detect it for inputs in AOTAutograd and error / emit // an input mutation resize_() appropriately functional_storage_impl()->mark_inductor_storage_resize(new_size); } void FunctionalTensorWrapper::maybe_replace_storage(const Tensor& other) { // Note [resize_() in functionalization pass] // resize_() is a special operator in functionalization because it can reallocate its underlying storage. // This function is only ever called in the case that resize_() needs to reallocate its storage to a larger size. // // However, functionalization currently bans the following code: // a = torch.ones(2) // b = a.view(2) // b.resize_(4) # b is a view tensor, that we are trying to increase the storage size of // // Why is this code difficult to handle? // The functionalization pass currently keeps aliases in sync by making the following assumptions: // - The “base” tensor always refers to “all of the data” // - Whenever you have b = view_op(a), “b” should always refer to a subset of “a”s memory. // // The code above breaks that assumption b.resize_(4) actually needs to update "a" // to tell it that it is now actually some slice of a pre-existing larger storage. // We're also no longer re-generate "b" fully from "a" anymore, since "a" refers to a slice of "b"'s data. // // This is probably fixable in theory, but: // - the fix would likey complicated the functionalization logic quite a bit. // - the primary use case for resize_() today is resizing zero-sized tensors in out= variants of operators // - resize_() also can give you weird results today if you try to resize_() a weirdly strided tensor. // // Given all of the above, for now we're just banning the above usage. TORCH_CHECK(storage().use_count() == 1, "Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass"); TORCH_CHECK(view_metas_.empty(), "Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass"); // If this tensor is not a view (and has no outstanding views taken out on it), // Then it's safe to throw out the old storage and replace it with the new, larger one. storage_ = c10::Storage(c10::make_intrusive(other)); value_ = other; TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); generation_ = 0; // And update the metadata on the wrapper to reflect the new sizes and strides set_sizes_and_strides(value_.sizes(), value_.strides()); refresh_numel(); // (Technically we should be guaranteed that the tensor was already contiguous, // since it's guaranteed not to have been a view. Doesnt hurt to run though) refresh_contiguous(); // Swapping out the storage of a tensor (aka from a resize_() call) will update the sizes and strides of the tensor, // so we need to record the fact that metadata was mutated. has_metadata_mutation_ = true; } void FunctionalTensorWrapper::_unsafe_reset_storage() { // Reset the storage with the current value_ tensor as the base storage_ = c10::Storage(c10::make_intrusive(value_)); // Reset the generation so that it matches the new storage generation_ = 0; // Clear any pre-existing view metas so that base and value_ are semantically the same view_metas_.clear(); } void FunctionalTensorWrapper::sync_() { if (is_up_to_date()) { return; } apply_updates(); regenerate_from_base(); } Tensor FunctionalTensorWrapper::apply_view_metas(const Tensor& base) { auto t = base; // Reapply views to get the viewed tensor from the base in alias_ for (auto& view_meta: view_metas_) { t = view_meta.forward_fn(t, view_meta.out_index); } return t; } void FunctionalTensorWrapper::regenerate_from_base() { at::AutoDispatchSkipFunctionalize guard; auto storage_impl = functional_storage_impl(); auto t = storage_impl->base(); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); t = apply_view_metas(t); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); replace_(t, /*from_lazy_regenerate=*/true); generation_ = storage_impl->generation(); } bool FunctionalTensorWrapper::apply_updates() { // Apply all updates on alias_ auto storage_impl = functional_storage_impl(); return storage_impl->apply_updates(); } const char* FunctionalTensorWrapper::tensorimpl_type_name() const { return "FunctionalTensorWrapper"; } void FunctionalTensorWrapper::copy_tensor_metadata( const FunctionalTensorWrapper* src_impl, FunctionalTensorWrapper* dest_impl, const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) { TensorImpl::copy_tensor_metadata( src_impl, dest_impl, version_counter, allow_tensor_metadata_change); // FunctionalTensorWrapper-specific fields. dest_impl->value_ = src_impl->value_; dest_impl->level_ = src_impl->level_; dest_impl->has_metadata_mutation_ = src_impl->has_metadata_mutation_; dest_impl->is_multi_output_view_ = src_impl->is_multi_output_view_; dest_impl->was_storage_changed_ = src_impl->was_storage_changed_; dest_impl->is_symbolic_ = src_impl->is_symbolic_; dest_impl->generation_ = src_impl->generation_; dest_impl->view_metas_ = src_impl->view_metas_; } void FunctionalTensorWrapper::copy_tensor_metadata_and_refresh( const FunctionalTensorWrapper* src_impl, FunctionalTensorWrapper* dest_impl, const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) const { copy_tensor_metadata(src_impl, dest_impl, version_counter, allow_tensor_metadata_change); dest_impl->refresh_numel(); dest_impl->refresh_contiguous(); } template c10::intrusive_ptr FunctionalTensorWrapper::shallow_copy_and_detach_core( VariableVersion&& version_counter, bool allow_tensor_metadata_change) const { if (key_set_.has(DispatchKey::Python) && !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { auto r = pyobj_slot_.load_pyobj_interpreter()->detach(this); if (r) { r->set_version_counter(std::forward(version_counter)); r->set_allow_tensor_metadata_change(allow_tensor_metadata_change); return r; } } auto impl = c10::make_intrusive(value_); copy_tensor_metadata_and_refresh( /*src_impl=*/this, /*dest_impl=*/impl.get(), /*version_counter=*/std::forward(version_counter), /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); return impl; } c10::intrusive_ptr FunctionalTensorWrapper::shallow_copy_and_detach( const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) const { return shallow_copy_and_detach_core( version_counter, allow_tensor_metadata_change); } c10::intrusive_ptr FunctionalTensorWrapper::shallow_copy_and_detach( c10::VariableVersion&& version_counter, bool allow_tensor_metadata_change) const { return shallow_copy_and_detach_core( std::move(version_counter), allow_tensor_metadata_change); } void FunctionalTensorWrapper::shallow_copy_from(const c10::intrusive_ptr& impl) { AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set())); auto functional_impl = static_cast(impl.get()); copy_tensor_metadata_and_refresh( /*src_impl=*/functional_impl, /*dest_impl=*/this, /*version_counter=*/version_counter(), /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); } c10::Device FunctionalTensorWrapper::device_custom() const { return value_.unsafeGetTensorImpl()->device(); } at::IntArrayRef FunctionalTensorWrapper::sizes_custom() const { return value_.unsafeGetTensorImpl()->sizes(); } at::IntArrayRef FunctionalTensorWrapper::strides_custom() const { return value_.unsafeGetTensorImpl()->strides(); } int64_t FunctionalTensorWrapper::dim_custom() const { return value_.unsafeGetTensorImpl()->dim(); } int64_t FunctionalTensorWrapper::numel_custom() const { return value_.unsafeGetTensorImpl()->numel(); } bool FunctionalTensorWrapper::is_contiguous_custom(at::MemoryFormat memory_format) const { return value_.unsafeGetTensorImpl()->is_contiguous(memory_format); } c10::SymIntArrayRef FunctionalTensorWrapper::sym_sizes_custom() const { return value_.unsafeGetTensorImpl()->sym_sizes(); } c10::SymIntArrayRef FunctionalTensorWrapper::sym_strides_custom() const { return value_.unsafeGetTensorImpl()->sym_strides(); } c10::SymInt FunctionalTensorWrapper::sym_size_custom(int64_t d) const { return value_.unsafeGetTensorImpl()->sym_size(d); } c10::SymInt FunctionalTensorWrapper::sym_storage_offset_custom() const { return value_.unsafeGetTensorImpl()->sym_storage_offset(); } c10::Layout FunctionalTensorWrapper::layout_impl() const { return value_.unsafeGetTensorImpl()->layout(); } namespace functionalization { namespace impl { Tensor to_functional_tensor(const Tensor& tensor) { // Note [Wrapped Numbers <> Functionalization] if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) { return tensor; } TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isFunctionalTensor(tensor)); return at::detail::make_tensor(tensor); } std::optional to_functional_tensor(const std::optional& tensor) { if (tensor.has_value()) { return std::make_optional(to_functional_tensor(*tensor)); } return std::nullopt; } c10::List<::std::optional> to_functional_tensor(const c10::List<::std::optional>& t_list) { c10::List<::std::optional> outputs; outputs.reserve(t_list.size()); for (const auto i : c10::irange(t_list.size())) { outputs.push_back(to_functional_tensor(t_list[i])); } return outputs; } std::vector to_functional_tensor(ITensorListRef t_list) { std::vector outputs; outputs.reserve(t_list.size()); for (const auto& tensor : t_list) { outputs.push_back(to_functional_tensor(tensor)); } return outputs; } Tensor from_functional_tensor(const Tensor& tensor, bool assert_functional) { // Note [Wrapped Numbers <> Functionalization] if (!tensor.defined() || tensor.unsafeGetTensorImpl()->is_wrapped_number()) { return tensor; } if (isFunctionalTensor(tensor)) { auto impl = unsafeGetFunctionalWrapper(tensor); return impl->value(); } else { // If the current tensor is not functional, then raise an error // if assert_functional is true. Otherwise, return the input. TORCH_INTERNAL_ASSERT(!assert_functional) return tensor; } } std::optional from_functional_tensor(const std::optional& t, bool assert_functional) { if (t.has_value()) { return std::make_optional(from_functional_tensor(*t, assert_functional)); } return std::nullopt; } std::vector from_functional_tensor(ITensorListRef t_list) { std::vector outputs; outputs.reserve(t_list.size()); for (const auto& tensor : t_list) { // from_functional_tensor(Tensor) has asserts to make sure you don't accidentally call // it on a non-functional input, // but from_functional_tensor(TensorList) can recieve a list containing both // functional and non-functional tensors. // Example of when that can happen: torch.cat(function_input_tensor, global_state_tensor). // When that happens, we're okay with only unwrapping the functional tensors. outputs.push_back(from_functional_tensor(tensor, /*assert_functional=*/false)); } return outputs; } c10::List<::std::optional> from_functional_tensor(const c10::List<::std::optional>& t_list) { c10::List<::std::optional> outputs; outputs.reserve(t_list.size()); for (const auto i : c10::irange(t_list.size())) { outputs.push_back(from_functional_tensor(t_list[i], /*assert_functional=*/false)); } return outputs; } void sync(const Tensor& t) { if (t.unsafeGetTensorImpl()->is_wrapped_number()) { // Note [Wrapped Numbers <> Functionalization] // Unfortunately, we can't easily guarantee that wrapped numbers (scalar-tensors) // get wrapped up in a FunctionalTensorWrapper object, since they skip the dispatcher. // That shouldn't matter, since I don't think we're allowed to assign to wrapped numbers anyway. return; } // Not every tensor that hits a functionalization kernel is necessarily a functional tensor. // For example, xla_tensor.copy_(cpu_tensor) needs to hit the functionalization kernel // to sync xla_tensor, but not cpu_tensor. if (!at::functionalization::impl::isFunctionalTensor(t)) { return; } auto functional_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); functional_impl->sync_(); } void sync(const std::optional& t) { if (t.has_value()) { sync(*t); } } void sync(ITensorListRef t_list) { for (const auto& t : t_list) { sync(t); } } void sync(const c10::List<::std::optional>& t_list) { for (const auto i : c10::irange(t_list.size())) { sync(t_list[i]); } } void replace_(const Tensor& functional_tensor, const Tensor& other) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor)); unsafeGetFunctionalWrapper(functional_tensor)->replace_(other); } void replace_(const ITensorListRef functional_tensor, ITensorListRef other) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_tensor.size() == other.size()); auto functional_tensor_it = functional_tensor.begin(); auto other_it = other.begin(); for (C10_UNUSED const auto i : c10::irange(functional_tensor.size())) { replace_(*functional_tensor_it++, *other_it++); } } void propagate_xla_data(const Tensor& functional_tensor, const Tensor& other) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor)); if (functional_tensor.key_set().has(c10::DispatchKey::XLA)) { at::_propagate_xla_data(at::functionalization::impl::unsafeGetFunctionalWrapper(functional_tensor) ->value(), other); } } void propagate_xla_data(const ITensorListRef functional_tensor, ITensorListRef other) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_tensor.size() == other.size()); auto functional_tensor_it = functional_tensor.begin(); auto other_it = other.begin(); for (C10_UNUSED const auto i : c10::irange(functional_tensor.size())) { propagate_xla_data(*functional_tensor_it++, *other_it++); } } void propagate_xla_data_direct(const Tensor& tensor, const Tensor& other) { if (tensor.key_set().has(c10::DispatchKey::XLA)) { at::_propagate_xla_data(tensor, other); } } void propagate_xla_data_direct(const ITensorListRef tensor, ITensorListRef other) { auto tensor_it = tensor.begin(); auto other_it = other.begin(); for (C10_UNUSED const auto i : c10::irange(tensor.size())) { propagate_xla_data_direct(*tensor_it++, *other_it++); } } void commit_update(const Tensor& functional_tensor) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor)); unsafeGetFunctionalWrapper(functional_tensor)->commit_update(); } void commit_update(ITensorListRef functional_tensor) { for (const auto& t : functional_tensor) { commit_update(t); } } void unsafe_reset_storage(const Tensor& functional_tensor) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor)); unsafeGetFunctionalWrapper(functional_tensor)->_unsafe_reset_storage(); } void mark_mutation_hidden_from_autograd(const Tensor& functional_tensor) { TORCH_CHECK(isFunctionalTensor(functional_tensor)); unsafeGetFunctionalWrapper(functional_tensor)->mark_mutation_hidden_from_autograd(); } bool are_all_mutations_hidden_from_autograd(const Tensor& functional_tensor) { TORCH_CHECK(isFunctionalTensor(functional_tensor)); return unsafeGetFunctionalWrapper(functional_tensor)->are_all_mutations_hidden_from_autograd(); } bool are_all_mutations_under_no_grad_or_inference_mode(const Tensor& functional_tensor) { TORCH_CHECK(isFunctionalTensor(functional_tensor)); return unsafeGetFunctionalWrapper(functional_tensor)->are_all_mutations_under_no_grad_or_inference_mode(); } bool isFunctionalTensor(const at::Tensor& tensor) { return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize); } bool isBaseTensor(const at::Tensor& tensor) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(tensor)); return unsafeGetFunctionalWrapper(tensor)->isBaseTensor(); } bool isFunctionalTensor(const std::optional& t) { if (t.has_value()) { return isFunctionalTensor(*t); } else { return false; } } bool isFunctionalTensor(const c10::List<::std::optional>& t_list) { if (t_list.empty()) return false; auto functional_count = 0; for (const auto i : c10::irange(t_list.size())) { if (!t_list[i].has_value() || !t_list[i]->defined()) continue; if (isFunctionalTensor(t_list[i])) { ++functional_count; } } return functional_count > 0; } template bool isFunctionalTensorIListRef(c10::IListRef list) { if (list.size() == 0) return false; auto functional_count = 0; for (const auto& tensor : list) { if (!tensor.defined()) continue; if (isFunctionalTensor(tensor)) { ++functional_count; } } return functional_count > 0; } bool isFunctionalTensor(ITensorListRef list) { return isFunctionalTensorIListRef(list); } void freeze_functional_tensor(const Tensor& tensor) { TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(tensor)); auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor); functional_base_impl->freeze_storage(); } Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) { TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap)); TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base)); auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(base); if (out_idx != 0) { // Note [out_idx in ViewMeta] // When a view op outputs multiple tensors, each output needs its own separate ViewMeta. // Each ViewMeta also tracks the index of the particular output tensor, which is needed in the reverse function. meta = meta.to_out_idx(out_idx); } return at::detail::make_tensor(view_to_wrap, functional_base_impl, meta); } std::vector create_functional_tensor_with_view_meta(ITensorListRef view_to_wrap, const at::Tensor& base, const functionalization::ViewMeta& meta) { std::vector outputs(view_to_wrap.size()); int64_t i = 0; for (const auto& tensor : view_to_wrap) { outputs[i] = create_functional_tensor_with_view_meta(tensor, base, meta, i); i++; } return outputs; } void mutate_view_meta(const at::Tensor& self, const functionalization::ViewMeta& meta) { TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self)); auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self); self_impl->mutate_view_meta(meta); } // Note [Propagating strides in the functionalization pass] // In order to properly compute stride information, the functionalization pass // calls each {view} reference implementations with meta tensors. // The output meta tensor's stride info serves as a reference for what the correct strides should be. void set_sizes_strides_offset(const Tensor& out, const Tensor& reference_out) { out.unsafeGetTensorImpl()->set_sizes_and_strides(reference_out.sym_sizes(), reference_out.sym_strides(), reference_out.sym_storage_offset()); } void set_sizes_strides_offset(const std::vector& outs, const std::vector& reference_outs) { TORCH_INTERNAL_ASSERT(outs.size() == reference_outs.size()); for (const auto i : c10::irange(reference_outs.size())) { set_sizes_strides_offset(outs[i], reference_outs[i]); } } thread_local bool _functionalizationReapplyViews; bool getFunctionalizationReapplyViewsTLS() { return _functionalizationReapplyViews; } void setFunctionalizationReapplyViewsTLS(bool reapply_views) { _functionalizationReapplyViews = reapply_views; } } // namespace impl // Given an **out-of-place** op that might internally call view/inplace ops, // This function will "functionalize" it. // That is, it will call the operator, but removing any intermediate views/mutations // that are performed inside of it. // This is useful for LTC/XLA, which would like to re-use some of our composite kernels // from pytorch core but not have to worry about the view ops that they might call. // e.g. at::block_diag void functionalize_op_helper(const c10::OperatorHandle& op, torch::jit::Stack* stack) { const auto& schema = op.schema(); const auto num_arguments = schema.arguments().size(); const auto arguments_begin = stack->size() - num_arguments; auto arguments = torch::jit::last(stack, num_arguments); // Wrap all tensor-like inputs into FunctionalTensorWrappers. // When we re-invoke the dispatcher, this will automatically enable the functionalization pass. for (uint64_t idx = 0; idx < num_arguments; ++idx) { const auto& ivalue = arguments[idx]; if (ivalue.isTensor()) { const auto& t = ivalue.toTensor(); if (t.defined()) { TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t), "The composite op functionalization fallback expects its inputs all not to be functional tensors"); auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t)); (*stack)[arguments_begin + idx] = t_new; } } else if (ivalue.isTensorList()) { auto tensors = ivalue.toTensorList(); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(tensors), "The composite op functionalization fallback expects its inputs all not to be functional tensors"); auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(tensors)); (*stack)[arguments_begin + idx] = t_new; } else if (ivalue.isOptionalTensorList()) { auto opt_tensors = ivalue.toOptionalTensorList(); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(opt_tensors), "The composite op functionalization fallback expects its inputs all not to be functional tensors"); auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(opt_tensors)); (*stack)[arguments_begin + idx] = t_new; } } { // Today when you call at::empty(device=lazy), the lazy backend decides whether or not to wrap // the output in a functional tensor based on TLS. // In this code, we're re-entrantly entering functionalization in the same call-stack, // so we need to manually fix up TLS as if it hadn't already been called. auto curr_tls = c10::impl::tls_local_dispatch_key_set(); auto tls_reenable_functionalize = c10::impl::PODLocalDispatchKeySet(); tls_reenable_functionalize.set_included(curr_tls.included_); tls_reenable_functionalize.set_excluded(curr_tls.excluded_.remove(c10::DispatchKey::Functionalize)); c10::impl::ForceDispatchKeyGuard guard_(tls_reenable_functionalize); // So, we should probably provide a way to directly call a kernel registered to // the `CompositeExplicitAutograd` key. // We can't do that today, so this should be a reasonably good proxy // (It won't work in cases where an op has both a CompositeExplicitAutograd kernel // AND a dedicated meta kernel, but that probably shouldn't ever happen). op.redispatchBoxed(c10::DispatchKeySet(c10::DispatchKey::Meta), stack); } const auto num_returns = schema.returns().size(); const auto returns_begin = stack->size() - num_returns; auto returns = torch::jit::last(stack, num_returns); for (const auto idx : c10::irange(num_returns)) { const auto& ivalue = returns[idx]; if (ivalue.isTensor()) { const auto& t = ivalue.toTensor(); if (!t.defined()) continue; at::functionalization::impl::sync(t); auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t)); (*stack)[returns_begin + idx] = t_new; } else if (ivalue.isTensorList()) { auto tensors = ivalue.toTensorList(); at::functionalization::impl::sync(tensors); auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(tensors)); (*stack)[returns_begin + idx] = t_new; } else if (ivalue.isOptionalTensorList()) { auto opt_tensors = ivalue.toOptionalTensorList(); at::functionalization::impl::sync(opt_tensors); auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(opt_tensors)); (*stack)[returns_begin + idx] = t_new; } } } } // namespace functionalization } // namespace at