#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch::autograd { // Returns a ViewFunc with a corresponding view that matches the shape, // stride, and storage offset of the given tensor. // NB: On mobile, the as_strided() op and thus the generated AsStridedViewFunc // may not be available. static std::unique_ptr create_view_func_matching(const Variable& t) { #ifdef AS_STRIDED_VIEW_FUNC_AVAILABLE return std::make_unique( t.sym_sizes(), t.sym_strides(), t.sym_storage_offset()); #else return std::make_unique("as_strided() not available"); #endif } DifferentiableViewMeta::DifferentiableViewMeta( at::TensorImpl* self_impl, std::optional backward_info, std::optional forward_info, bool shared_view_info, CreationMeta creation_meta) : AutogradMeta(self_impl), backward_info_(std::move(backward_info)), forward_info_(std::move(forward_info)), shared_view_info_(shared_view_info), creation_meta_(creation_meta) { is_view_ = true; if (backward_info_.has_value()) { self_impl->set_version_counter( impl::version_counter(backward_info_.value().base_)); attr_version_ = self_impl->version_counter().current_version(); TORCH_INTERNAL_ASSERT( backward_info_.value().base_.unsafeGetTensorImpl() != self_impl); } if (shared_view_info_) { TORCH_INTERNAL_ASSERT( backward_info_.has_value(), "Shared view info require a backward view info."); TORCH_INTERNAL_ASSERT( !forward_info_.has_value(), "Shared view info require forward view info to be empty") } } // Chain this view info with the new view op between base and tensor ViewInfo ViewInfo::chain( const Variable& base, const Variable& tensor, std::unique_ptr view_func, std::function rev_view_func) const { // Set `view_func` using the root base as input. // `view_func` is used to recover views in backward when either as_strided is // not supported or the view function changes the metadata which is not // recorded by as_strided See Note [View + Inplace update on base tensor] and // [View + Inplace update on view tensor] for more details how we use this // function in backward. if (view_func) { // both current_view and it's parent have a view_func if (view_fn_) { view_func = std::make_unique( view_fn_->clone_and_set(), std::move(view_func)); // assume view_fn_ / rev_view_fn_ always exist together or neither are set auto prev_rev_fn = rev_view_fn_; rev_view_func = [=](const at::Tensor& root_view) { auto temp = rev_view_func(root_view); return prev_rev_fn(temp); }; } else { // current_view has a view_func and but it's parent doesn't have one if (base.unsafeGetTensorImpl()->support_as_strided()) { auto match_base_view_func = create_view_func_matching(base); view_func = std::make_unique( std::move(match_base_view_func), std::move(view_func)); // assume view_fn_ / rev_view_fn_ always exist together or neither are // set const auto& root_base = base._base(); auto root_base_size = root_base.sym_sizes().vec(); auto root_base_stride = root_base.sym_strides().vec(); auto root_base_storage_offset = root_base.sym_storage_offset(); rev_view_func = [=](const at::Tensor& root_view) { auto temp = rev_view_func(root_view); return temp.as_strided_symint( root_base_size, root_base_stride, root_base_storage_offset); }; } else { // This case should be relatively rare: parent view doesn't have a // view_func() AND as_strided() isn't supported; there's no obvious way // to chain the two views. auto error_msg = ("Attempted to chain views when the parent view has no view_func() and " "does not support as_strided(). This is not supported."); view_func = std::make_unique(error_msg); rev_view_func = [=](const at::Tensor& root_view) { TORCH_CHECK(false, error_msg); return root_view; }; } } } else if (view_fn_) { // if current_view doesn't have a view_func but it's parent has one auto match_tensor_view_func = create_view_func_matching(tensor); view_func = std::make_unique( view_fn_->clone_and_set(), std::move(match_tensor_view_func)); // assume view_fn_ / rev_view_fn_ always exist together or neither are set auto prev_rev_view_fn = rev_view_fn_; auto base_size = base.sym_sizes().vec(); auto base_stride = base.sym_strides().vec(); auto base_storage_offset = base.sym_storage_offset(); rev_view_func = [=](const at::Tensor& root_view) { auto temp = root_view.as_strided_symint( base_size, base_stride, base_storage_offset); return prev_rev_view_fn(temp); }; } return ViewInfo(base_, std::move(view_func), std::move(rev_view_func)); } namespace { at::Tensor singleton_undefined_tensor; struct ConcreteAutogradMetaFactory : public c10::impl::AutogradMetaFactory { std::unique_ptr make() const override { return std::make_unique(); } const at::Tensor& undefined_tensor() const override { return singleton_undefined_tensor; } }; ConcreteAutogradMetaFactory meta_factory; static c10::impl::AutogradMetaFactoryRegisterer meta_factory_registerer( &meta_factory); } // namespace namespace impl { AutogradMeta* materialize_autograd_meta(const at::TensorBase& self) { TORCH_CHECK( self.defined(), "cannot call materialize_autograd_meta() on undefined tensor"); auto p = self.unsafeGetTensorImpl(); if (!p->autograd_meta()) { p->set_autograd_meta(std::make_unique()); } return get_autograd_meta(self); } static void update_tensor_hooks_on_new_gradfn( const at::TensorBase& self, const std::shared_ptr& old_fn, const std::shared_ptr& new_fn) { // This function is called whenever the grad_fn of the tensor is // changed. We assume here that new_fn does not yet have hooks of // its own. // // This function does two things: // (1) reset the list when grad_fn is updated, so new hooks don't // get erroneously registered to the old grad_fn. // Note that the old cpp_hooks_list_ is still kept alive by the // old grad_fn so hooks registered to the older version of the tensor // will continue to be active. // (2) If there is a retains_grad hook registered, move that from the // old cpp_hooks_list_ to the new one const auto& meta = impl::get_autograd_meta(self); TORCH_INTERNAL_ASSERT(meta); TORCH_INTERNAL_ASSERT(new_fn); meta->cpp_hooks_list_ = nullptr; const c10::impl::PyInterpreter* interp = self.unsafeGetTensorImpl()->pyobj_slot()->pyobj_interpreter(); if (interp) { (*interp)->reset_backward_hooks(self.unsafeGetTensorImpl()); } if (self.retains_grad()) { TORCH_INTERNAL_ASSERT(old_fn); auto out = old_fn->pop_retains_grad_hook(self.output_nr()); TORCH_INTERNAL_ASSERT(out != nullptr); new_fn->add_retains_grad_hook(std::move(out), self.output_nr()); } } void rebase_history(const Variable& self, Edge gradient_edge) { TORCH_INTERNAL_ASSERT(gradient_edge.function != nullptr); const auto& meta = impl::get_autograd_meta(self); auto old_fn = meta != nullptr ? meta->grad_fn_ : nullptr; auto diff_view_meta = get_view_autograd_meta(self); if (diff_view_meta && diff_view_meta->has_bw_view()) { // See NOTE [ View + Inplace detection ] auto creation_meta = diff_view_meta->get_creation_meta(); // Do not use handle_view_on_rebase here as check_inplace should have been // called before this and either throw an error TORCH_INTERNAL_ASSERT(creation_meta == CreationMeta::DEFAULT); TORCH_INTERNAL_ASSERT(gradient_edge.input_nr == 0); TORCH_INTERNAL_ASSERT(gradient_edge.function); TORCH_CHECK( gradient_edge.function->num_inputs() == 1, "Functions which modify views in-place must return a single Variable"); const auto& view_info = diff_view_meta->get_backward_view(); diff_view_meta->output_nr_ = gradient_edge.input_nr; auto copy_slices = std::make_shared( view_info.base_, at::TensorGeometry(self), view_info.has_view_fn() ? view_info.view_fn().clone_and_set() : nullptr, std::move(gradient_edge.function)); if (self.requires_grad()) { // If self did not previously require grad, there are no hooks to move torch::autograd::impl::update_tensor_hooks_on_new_gradfn( view_info.base_, view_info.base_.grad_fn(), copy_slices); } set_gradient_edge(view_info.base_, {std::move(copy_slices), 0}); self.grad_fn(); // trigger an update to the view's grad_fn return; } set_gradient_edge(self, std::move(gradient_edge)); // Pass both self and its grad_fn to avoid calling into grad_fn reentrantly torch::autograd::impl::update_tensor_hooks_on_new_gradfn( self, old_fn, self.grad_fn()); } void create_cpp_hook(const at::TensorBase& self, bool is_retains_grad_hook) { const auto& fn = self.grad_fn(); std::shared_ptr& list = materialize_autograd_meta(self)->cpp_hooks_list_; list = std::make_shared(); auto hook_ptr = std::make_unique(list, self.output_nr()); // NB: we could potentially only update hooks_ if !fn, but it shouldn't // matter // and this was the way before, so we keep it like this for now. clear_hooks(self); add_hook(self, std::make_unique(list, 0)); if (fn) { fn->add_tensor_pre_hook(std::move(hook_ptr)); } } void set_grad_accumulator( const Variable& self, std::weak_ptr grad_accumulator) { materialize_autograd_meta(self)->grad_accumulator_ = std::move(grad_accumulator); } std::shared_ptr try_get_grad_accumulator(const Variable& self) { if (get_autograd_meta(self)) { return get_autograd_meta(self)->grad_accumulator_.lock(); } else { return nullptr; } } std::shared_ptr grad_accumulator(const Variable& self) { auto autograd_meta = get_autograd_meta(self); if (!autograd_meta) { return nullptr; } if (autograd_meta->grad_fn_) { throw std::logic_error( "grad_accumulator() should be only called on leaf Variables"); } if (!autograd_meta->requires_grad_) { return nullptr; } std::lock_guard lock(autograd_meta->mutex_); auto result = autograd_meta->grad_accumulator_.lock(); if (result) return result; c10::raw::intrusive_ptr::incref(self.unsafeGetTensorImpl()); auto intrusive_from_this = c10::intrusive_ptr::reclaim(self.unsafeGetTensorImpl()); result = std::make_shared( Variable(std::move(intrusive_from_this))); autograd_meta->grad_accumulator_ = result; return result; } Edge gradient_edge(const Variable& self) { // If grad_fn is null (as is the case for a leaf node), we instead // interpret the gradient function to be a gradient accumulator, which will // accumulate its inputs into the grad property of the variable. These // nodes get suppressed in some situations, see "suppress gradient // accumulation" below. Note that only variables which have `requires_grad = // True` can have gradient accumulators. if (const auto& gradient = self.grad_fn()) { return Edge(gradient, self.output_nr()); } else { return Edge(grad_accumulator(self), 0); } } void set_gradient_edge(const Variable& self, Edge edge) { auto* meta = materialize_autograd_meta(self); meta->grad_fn_ = std::move(edge.function); meta->output_nr_ = edge.input_nr; // For views, make sure this new grad_fn_ is not overwritten unless it is // necessary in the VariableHooks::grad_fn below. This logic is only relevant // for custom autograd Functions for which multiple operations can happen on a // given Tensor before its gradient edge is set when exiting the custom // Function. auto diff_view_meta = get_view_autograd_meta(self); if (diff_view_meta && diff_view_meta->has_bw_view()) { diff_view_meta->set_attr_version(self._version()); } } Node* grad_fn_unsafe(const Variable& self) { if (get_autograd_meta(self)) { return get_autograd_meta(self)->grad_fn_.get(); } else { return nullptr; } } // Versions //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ void set_version_counter( const Variable& self, const c10::VariableVersion& version_counter) { TORCH_CHECK( self.defined(), "cannot call set_version_counter() on undefined tensor"); self.unsafeGetTensorImpl()->set_version_counter(version_counter); } void bump_version(const Variable& self) { TORCH_CHECK(self.defined(), "cannot call bump_version() on undefined tensor"); self.unsafeGetTensorImpl()->bump_version(); } const c10::VariableVersion& version_counter(const Variable& self) { TORCH_CHECK( self.defined(), "cannot call version_counter() on undefined tensor"); return self.unsafeGetTensorImpl()->version_counter(); } // Hooks //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ void add_hook( const at::TensorBase& self, std::unique_ptr hook) { AutogradMeta* meta = materialize_autograd_meta(self); TORCH_INTERNAL_ASSERT(meta->hooks_.empty()); meta->hooks_.push_back(std::move(hook)); } std::vector>& hooks(const Variable& self) { TORCH_INTERNAL_ASSERT(get_autograd_meta(self)); return get_autograd_meta(self)->hooks_; } void clear_hooks(const at::TensorBase& self) { // This is a little goofy, but usually this should be a no oop materialize_autograd_meta(self)->hooks_.clear(); } void set_post_acc_grad_hooks( const at::TensorBase& self, std::unique_ptr dict) { AutogradMeta* meta = materialize_autograd_meta(self); meta->post_acc_grad_hooks_ = std::move(dict); } std::unique_ptr& post_acc_grad_hooks( const Variable& self) { TORCH_INTERNAL_ASSERT(get_autograd_meta(self)); return get_autograd_meta(self)->post_acc_grad_hooks_; } void set_name(const Variable& self, const std::string& name) { materialize_autograd_meta(self)->name_ = name; } // Miscellaneous //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AutogradMeta* get_autograd_meta(const at::TensorBase& self) { // NB: could return nullptr TORCH_CHECK( self.defined(), "cannot call get_autograd_meta() on undefined tensor"); return static_cast( self.unsafeGetTensorImpl()->autograd_meta()); } DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase& self) { // NB: return nullptr if self is not a view AutogradMeta* meta = get_autograd_meta(self); if (meta && meta->is_view_) { return static_cast(meta); } else { return nullptr; } } } // namespace impl using at::Tensor; VariableHooks variableHooks; at::impl::VariableHooksRegisterer registerVariableHooks(&variableHooks); at::TensorBase VariableHooks::variable_data(const at::TensorBase& self) const { TORCH_CHECK( self.defined(), "cannot call variable_data() on undefined tensor"); auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach( /*version_counter=*/0, /*allow_tensor_metadata_change=*/false); self_impl_copy->set_autograd_meta(nullptr); return at::Tensor(self_impl_copy); } at::TensorBase VariableHooks::tensor_data(const at::TensorBase& self) const { TORCH_CHECK(self.defined(), "cannot call tensor_data() on undefined tensor"); auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach( /*version_counter=*/self.unsafeGetTensorImpl()->version_counter(), /*allow_tensor_metadata_change=*/ self.unsafeGetTensorImpl()->allow_tensor_metadata_change()); return at::Tensor(self_impl_copy); } bool VariableHooks::is_leaf(const at::TensorBase& self) const { if (impl::get_autograd_meta(self)) { return impl::get_autograd_meta(self)->grad_fn_ == nullptr; } else { return true; } } int64_t VariableHooks::output_nr(const at::TensorBase& self) const { if (impl::get_autograd_meta(self)) { return impl::get_autograd_meta(self)->output_nr_; } else { return 0; } } void VariableHooks::set_data( const at::TensorBase& self_base, const at::TensorBase& new_data_base) const { at::OptionalTensorRef self_ref(self_base); const Tensor& self = *self_ref; at::OptionalTensorRef new_data_ref(new_data_base); const Tensor& new_data = *new_data_ref; // `var.set_data(new_data)` shallow-copies all non-autograd TensorImpl fields // from `new_data` to `var`. It requires that `new_data` and `var` have // compatible tensor type. TORCH_CHECK( _has_compatible_shallow_copy_type(self, new_data), "Attempted to call `variable.set_data(tensor)`, but `variable` and `tensor` have incompatible tensor type."); TORCH_CHECK( !self.requires_grad() || isDifferentiableType(at::typeMetaToScalarType(new_data.dtype())), "data set to a tensor that requires gradients must be floating point or complex dtype"); // Resets gradient accumulator if metadata is out of date AutogradMeta* autograd_meta = impl::get_autograd_meta(self); if (autograd_meta) { std::lock_guard lock(autograd_meta->mutex_); auto prior_accumulator = autograd_meta->grad_accumulator_.lock(); if (prior_accumulator) { const auto prior_device = prior_accumulator->input_metadata(0).device(); const auto new_device = new_data.device(); if (!new_data.options().type_equal(self.options()) || prior_device != new_device) { autograd_meta->grad_accumulator_.reset(); } } } // Version counter is not shared when we replace a `Variable`'s tensor data // by calling `set_data(...)`. The original version of the `Variable` is // always preserved. See NOTE [ Version Counter Sharing ] for details. // // `var.set_data(new_data)` always ignores `var`'s // `allow_tensor_metadata_change_`, because users need this API as an escape // hatch for changing a tensor's metadata regardless of its // `allow_tensor_metadata_change_` value, and the users are responsible for // ensuring this is the behavior they want. self.unsafeGetTensorImpl()->shallow_copy_from(new_data.getIntrusivePtr()); } at::TensorBase VariableHooks::data(const at::TensorBase& self) const { return self.variable_data(); } int64_t VariableHooks::_version(const at::TensorBase& self) const { return self.unsafeGetTensorImpl()->version_counter().current_version(); } void VariableHooks::retain_grad(const at::TensorBase& self) const { TORCH_CHECK( self.requires_grad(), "can't retain_grad on Tensor that has requires_grad=False"); // temporary hack to improve functorch UX. const auto& functorch_tls = at::functorch::functorchTLSAccessor(); if (functorch_tls) { functorch_tls->checkSupportsRetainGrad(); } if (self.is_leaf()) { // no-op for leaves return; } if (impl::get_autograd_meta(self)->retains_grad_) { return; } c10::weak_intrusive_ptr weak_self(self.getIntrusivePtr()); auto retain_grad_hook = [weak_self](const at::TensorBase& grad_base) { at::Tensor grad{grad_base}; if (!weak_self.expired() && grad.defined()) { auto var = weak_self.lock(); if (!var->grad().defined()) { if (grad.is_sparse()) { var->mutable_grad() = grad.clone(); } else { var->mutable_grad() = grad.clone(at::MemoryFormat::Contiguous); } } else { var->mutable_grad() = var->grad() + grad; } } return at::TensorBase{}; }; const auto& fn = self.grad_fn(); fn->add_retains_grad_hook( std::make_unique( std::move(retain_grad_hook), self.output_nr()), self.output_nr()); impl::get_autograd_meta(self)->retains_grad_ = true; } bool VariableHooks::retains_grad(const at::TensorBase& self) const { if (impl::get_autograd_meta(self)) { return impl::get_autograd_meta(self)->retains_grad_; } else { return false; } } void VariableHooks::_backward( const Tensor& self, at::TensorList inputs, const std::optional& gradient, std::optional keep_graph, bool create_graph) const { // TODO torch::autograd::backward should take the std::optional // gradient directly instead of us having to unwrap it to Tensor _gradient // here. Tensor _gradient = gradient.has_value() ? *gradient : Tensor(); std::vector input_vars( inputs.begin(), inputs.end()); torch::autograd::backward( {self}, {std::move(_gradient)}, keep_graph, create_graph, input_vars); } void VariableHooks::requires_grad_( const at::TensorBase& self, bool _requires_grad) const { if (!self.is_leaf() && !_requires_grad) { throw std::runtime_error( autograd::utils::requires_grad_leaf_error(_requires_grad)); } self.set_requires_grad(_requires_grad); } // Backward View Variables //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ bool VariableHooks::is_view(const at::TensorBase& self) const { auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self); if (diff_view_meta) { return diff_view_meta->has_bw_view(); } else { return false; } } const at::TensorBase& VariableHooks::base(const at::TensorBase& self) const { auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self); if (diff_view_meta) { TORCH_CHECK( diff_view_meta->has_bw_view(), "Can't get base of non-backward view Tensor"); return diff_view_meta->get_backward_view().base_; } else { throw std::runtime_error("Can't get base of non-view Tensor"); } } namespace { std::string singleton_string; } const std::string& VariableHooks::name(const at::TensorBase& self) const { TORCH_CHECK( self.defined(), "cannot call variable_data() on undefined tensor"); if (torch::autograd::impl::get_autograd_meta(self)) { return torch::autograd::impl::get_autograd_meta(self)->name_; } else { return singleton_string; } } namespace { std::shared_ptr singleton_shared_ptr; } const std::shared_ptr& VariableHooks::grad_fn( const at::TensorBase& self) const { auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self); if (diff_view_meta && diff_view_meta->has_bw_view()) { // See NOTE [ View + Inplace detection ] std::lock_guard lock(diff_view_meta->mutex_); auto& view_info = diff_view_meta->get_backward_view(); if (!diff_view_meta->grad_fn_ && !view_info.base_.requires_grad()) { return diff_view_meta->grad_fn_; } auto current_version = self._version(); auto old_fn = diff_view_meta->grad_fn_; if (diff_view_meta->get_attr_version() != current_version) { // This is an indirect rebase_history due to another view or the base // being modified inplace handle_view_on_rebase(diff_view_meta, /* indirect */ true); TORCH_INTERNAL_ASSERT(diff_view_meta->output_nr_ == 0); // Note [View + Inplace update for view tensor] // An inplace update happened on Tensor `self` (which is a view). // For example: // view_1 = view_op_1(diff_view_meta->base_) // view_2 = view_op_2(view_1) // ... // self = view_op_n(view_n-1) // self = inplace_op(self) // // For CPU/CUDA backends, we employ one AsStridedBackward0 Node to // represent the chain of view backward ops for efficiency. // // However in XLA backend we don't have full support of // AsStridedBackward0, we instead run a full forward pass with a tensor // that requires gradient to get proper grad_fn setup, then save it to // DifferentiableViewMeta for future use. This is fairly cheap for XLA // lazy tensor approach (but would be really expensive for CPU/CUDA). XLA // Tensor only run through VariableType dispatch and lower the forward // pass to a XLA HLO graph, then we take grad_fn and never materialize the // tensor content. So we only construct the graph but not execute it, // which is a fairly cheap operation to do. // // See Note [View + Inplace update for base tensor] for what we do to base // tensor when an in-place operation happens. // // TODO: Potentially the following logic can be replaced by special logic // in VariableType_x.cpp // that would provide a way to recreate the grad_fn chain. if (view_info.has_view_fn()) { auto& view_fn = view_info.view_fn(); Tensor diff_view; { // We can reach this path with grad_mode disabled, e.g. engine AutoGradMode grad_mode(true); diff_view = view_fn(view_info.base_); } diff_view_meta->grad_fn_ = diff_view.grad_fn(); } else { auto fn = std::make_shared(); fn->self_geometry = at::TensorGeometry(view_info.base_); fn->size = self.sym_sizes().vec(); fn->stride = self.sym_strides().vec(); fn->storage_offset = self.sym_storage_offset(); fn->set_next_edges( torch::autograd::collect_next_edges(view_info.base_)); fn->add_input_metadata( view_info.base_.options(), self.sym_sizes(), // Note: sizes(), not base_.sizes(), is // intentional self.unsafeGetTensorImpl()->is_python_dispatch(), self.is_nested()); diff_view_meta->grad_fn_ = std::move(fn); } diff_view_meta->set_attr_version(current_version); torch::autograd::impl::update_tensor_hooks_on_new_gradfn( self, old_fn, diff_view_meta->grad_fn_); } return diff_view_meta->grad_fn_; } if (torch::autograd::impl::get_autograd_meta(self)) { return torch::autograd::impl::get_autograd_meta(self)->grad_fn_; } else { return singleton_shared_ptr; } } void VariableHooks::remove_hook(const at::TensorBase& self, unsigned pos) const { auto& list = torch::autograd::impl::materialize_autograd_meta(self)->cpp_hooks_list_; TORCH_CHECK( list && pos < list->size(), "Invalid index, no hook at position ", pos); // Hook will be ignored (*list)[pos] = nullptr; } unsigned VariableHooks::_register_hook( const at::TensorBase& self, std::function hook) const { TORCH_CHECK( self.requires_grad(), "cannot register a hook on a variable that " "doesn't require gradient"); // NB: materialize_autograd_meta unnecessary due to requires grad check auto& list = torch::autograd::impl::get_autograd_meta(self)->cpp_hooks_list_; if (!list) { torch::autograd::impl::create_cpp_hook( self, /*is_retains_grad_hooks=*/false); } unsigned idx = list->size(); list->push_back(hook); return idx; } void handle_view_on_rebase( DifferentiableViewMeta* diff_view_meta, bool indirect) { /// See NOTE [ View + Inplace detection ] for justification of the logic below auto creation_meta = diff_view_meta->get_creation_meta(); if (creation_meta != CreationMeta::DEFAULT) { auto grad_fn = diff_view_meta->grad_fn_.get(); std::string msg; std::string modified_obj; // Create the header for the error message. if (indirect) { modified_obj = "its base or another view of its base has been"; } else { modified_obj = "is being"; } if (creation_meta == CreationMeta::INFERENCE_MODE || creation_meta == CreationMeta::NO_GRAD_MODE || !grad_fn) { std::string prefix; if (grad_fn) { prefix = c10::str( "Output ", diff_view_meta->output_nr_, " of ", grad_fn->name(), " is a view of a view which was created in"); } else { prefix = "A view was created in"; } if (creation_meta == CreationMeta::INFERENCE_MODE) { msg = c10::str( prefix, " inference mode and ", modified_obj, " modified inplace in normal mode."); } else { // create_meta is not necessarily CreationMeta::NO_GRAD_MODE // e.g. CreationMeta::IN_CUSTOM_FUNCTION is possible, but we know that // if there is no grad_fn, that means that the view was performed in // no-grad mode msg = c10::str( prefix, " no_grad mode and ", modified_obj, " modified inplace with grad mode enabled."); } } else { msg = c10::str( "Output ", diff_view_meta->output_nr_, " of ", grad_fn->name(), " is a view and ", modified_obj, " modified inplace."); } if (creation_meta == CreationMeta::MULTI_OUTPUT_NODE) { msg = c10::str( msg, " This view is the output of a function that returns multiple views. Such functions do not" " allow the output views to be modified inplace. You should replace the inplace operation by an" " out-of-place one."); } else if (creation_meta == CreationMeta::NO_GRAD_MODE) { msg = c10::str( msg, " Given that this use case is ambiguous and error-prone, it is forbidden." " You can clarify your code by moving both the view and the inplace either both" " inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want" " the inplace to be tracked)."); } else if (creation_meta == CreationMeta::INFERENCE_MODE) { msg = c10::str( msg, " Given that this use case is ambiguous and error-prone, it is forbidden." " You can clarify your code by moving both the view and the inplace either both" " inside the inference_mode block (if you don't want the inplace to be tracked) or both outside (if you want" " the inplace to be tracked)."); } else if (creation_meta == CreationMeta::IN_CUSTOM_FUNCTION) { msg = c10::str( msg, " This view was created inside a custom Function (or because an input was returned as-is) and the" " autograd logic to handle view+inplace would override the custom backward associated with the custom" " Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by" " cloning the output of the custom Function."); } else { TORCH_INTERNAL_ASSERT(false, "Invalid CreationMeta state"); } TORCH_CHECK(false, msg); } } std::vector ChainedViewFunc::get_symints() const { auto symints = first->get_symints(); auto second_symints = second->get_symints(); symints.reserve(symints.size() + second_symints.size()); symints.insert( symints.end(), std::make_move_iterator(second_symints.begin()), std::make_move_iterator(second_symints.end())); return symints; } std::vector ChainedViewFunc::get_tensors() const { auto tensors = first->get_tensors(); auto second_tensors = second->get_tensors(); tensors.reserve(tensors.size() + second_tensors.size()); tensors.insert( tensors.end(), std::make_move_iterator(second_tensors.begin()), std::make_move_iterator(second_tensors.end())); return tensors; } at::Tensor ChainedViewFunc::operator()(const at::Tensor& input_base) const { return (*second)((*first)(input_base)); } std::unique_ptr ChainedViewFunc::clone_and_set( std::optional> symints, std::optional> tensors) const { std::optional> first_symints; std::optional> second_symints; if (symints.has_value()) { TORCH_INTERNAL_ASSERT(symints->size() == num_symints()); first_symints = std::vector( symints->begin(), symints->begin() + first->num_symints()); second_symints = std::vector( symints->begin() + first->num_symints(), symints->end()); } std::optional> first_tensors; std::optional> second_tensors; if (tensors.has_value()) { TORCH_INTERNAL_ASSERT(tensors->size() == num_tensors()); first_tensors = std::vector( tensors->begin(), tensors->begin() + first->num_tensors()); second_tensors = std::vector( tensors->begin() + first->num_tensors(), tensors->end()); } return std::make_unique( first->clone_and_set(first_symints, first_tensors), second->clone_and_set(second_symints, second_tensors)); } } // namespace torch::autograd