Home
last modified time | relevance | path

Searched defs:_mask_mod_signature (Results 1 – 1 of 1) sorted by relevance

/aosp_15_r20/external/pytorch/torch/nn/attention/
H A Dflex_attention.py39 _mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] variable
273 mask_mod: _mask_mod_signature,
307 mask_mod: Optional[_mask_mod_signature] = None,
641 def or_masks(*mask_mods: _mask_mod_signature) -> _mask_mod_signature:
655 def and_masks(*mask_mods: _mask_mod_signature) -> _mask_mod_signature:
708 mod_fn: Union[_score_mod_signature, _mask_mod_signature], argument
786 mask_mod: _mask_mod_signature,