1# mypy: allow-untyped-defs 2""" 3This closely follows the implementation in NumPyro (https://github.com/pyro-ppl/numpyro). 4 5Original copyright notice: 6 7# Copyright: Contributors to the Pyro project. 8# SPDX-License-Identifier: Apache-2.0 9""" 10 11import math 12 13import torch 14from torch.distributions import Beta, constraints 15from torch.distributions.distribution import Distribution 16from torch.distributions.utils import broadcast_all 17 18 19__all__ = ["LKJCholesky"] 20 21 22class LKJCholesky(Distribution): 23 r""" 24 LKJ distribution for lower Cholesky factor of correlation matrices. 25 The distribution is controlled by ``concentration`` parameter :math:`\eta` 26 to make the probability of the correlation matrix :math:`M` generated from 27 a Cholesky factor proportional to :math:`\det(M)^{\eta - 1}`. Because of that, 28 when ``concentration == 1``, we have a uniform distribution over Cholesky 29 factors of correlation matrices:: 30 31 L ~ LKJCholesky(dim, concentration) 32 X = L @ L' ~ LKJCorr(dim, concentration) 33 34 Note that this distribution samples the 35 Cholesky factor of correlation matrices and not the correlation matrices 36 themselves and thereby differs slightly from the derivations in [1] for 37 the `LKJCorr` distribution. For sampling, this uses the Onion method from 38 [1] Section 3. 39 40 Example:: 41 42 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 43 >>> l = LKJCholesky(3, 0.5) 44 >>> l.sample() # l @ l.T is a sample of a correlation 3x3 matrix 45 tensor([[ 1.0000, 0.0000, 0.0000], 46 [ 0.3516, 0.9361, 0.0000], 47 [-0.1899, 0.4748, 0.8593]]) 48 49 Args: 50 dimension (dim): dimension of the matrices 51 concentration (float or Tensor): concentration/shape parameter of the 52 distribution (often referred to as eta) 53 54 **References** 55 56 [1] `Generating random correlation matrices based on vines and extended onion method` (2009), 57 Daniel Lewandowski, Dorota Kurowicka, Harry Joe. 58 Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008 59 """ 60 arg_constraints = {"concentration": constraints.positive} 61 support = constraints.corr_cholesky 62 63 def __init__(self, dim, concentration=1.0, validate_args=None): 64 if dim < 2: 65 raise ValueError( 66 f"Expected dim to be an integer greater than or equal to 2. Found dim={dim}." 67 ) 68 self.dim = dim 69 (self.concentration,) = broadcast_all(concentration) 70 batch_shape = self.concentration.size() 71 event_shape = torch.Size((dim, dim)) 72 # This is used to draw vectorized samples from the beta distribution in Sec. 3.2 of [1]. 73 marginal_conc = self.concentration + 0.5 * (self.dim - 2) 74 offset = torch.arange( 75 self.dim - 1, 76 dtype=self.concentration.dtype, 77 device=self.concentration.device, 78 ) 79 offset = torch.cat([offset.new_zeros((1,)), offset]) 80 beta_conc1 = offset + 0.5 81 beta_conc0 = marginal_conc.unsqueeze(-1) - 0.5 * offset 82 self._beta = Beta(beta_conc1, beta_conc0) 83 super().__init__(batch_shape, event_shape, validate_args) 84 85 def expand(self, batch_shape, _instance=None): 86 new = self._get_checked_instance(LKJCholesky, _instance) 87 batch_shape = torch.Size(batch_shape) 88 new.dim = self.dim 89 new.concentration = self.concentration.expand(batch_shape) 90 new._beta = self._beta.expand(batch_shape + (self.dim,)) 91 super(LKJCholesky, new).__init__( 92 batch_shape, self.event_shape, validate_args=False 93 ) 94 new._validate_args = self._validate_args 95 return new 96 97 def sample(self, sample_shape=torch.Size()): 98 # This uses the Onion method, but there are a few differences from [1] Sec. 3.2: 99 # - This vectorizes the for loop and also works for heterogeneous eta. 100 # - Same algorithm generalizes to n=1. 101 # - The procedure is simplified since we are sampling the cholesky factor of 102 # the correlation matrix instead of the correlation matrix itself. As such, 103 # we only need to generate `w`. 104 y = self._beta.sample(sample_shape).unsqueeze(-1) 105 u_normal = torch.randn( 106 self._extended_shape(sample_shape), dtype=y.dtype, device=y.device 107 ).tril(-1) 108 u_hypersphere = u_normal / u_normal.norm(dim=-1, keepdim=True) 109 # Replace NaNs in first row 110 u_hypersphere[..., 0, :].fill_(0.0) 111 w = torch.sqrt(y) * u_hypersphere 112 # Fill diagonal elements; clamp for numerical stability 113 eps = torch.finfo(w.dtype).tiny 114 diag_elems = torch.clamp(1 - torch.sum(w**2, dim=-1), min=eps).sqrt() 115 w += torch.diag_embed(diag_elems) 116 return w 117 118 def log_prob(self, value): 119 # See: https://mc-stan.org/docs/2_25/functions-reference/cholesky-lkj-correlation-distribution.html 120 # The probability of a correlation matrix is proportional to 121 # determinant ** (concentration - 1) = prod(L_ii ^ 2(concentration - 1)) 122 # Additionally, the Jacobian of the transformation from Cholesky factor to 123 # correlation matrix is: 124 # prod(L_ii ^ (D - i)) 125 # So the probability of a Cholesky factor is propotional to 126 # prod(L_ii ^ (2 * concentration - 2 + D - i)) = prod(L_ii ^ order_i) 127 # with order_i = 2 * concentration - 2 + D - i 128 if self._validate_args: 129 self._validate_sample(value) 130 diag_elems = value.diagonal(dim1=-1, dim2=-2)[..., 1:] 131 order = torch.arange(2, self.dim + 1, device=self.concentration.device) 132 order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order 133 unnormalized_log_pdf = torch.sum(order * diag_elems.log(), dim=-1) 134 # Compute normalization constant (page 1999 of [1]) 135 dm1 = self.dim - 1 136 alpha = self.concentration + 0.5 * dm1 137 denominator = torch.lgamma(alpha) * dm1 138 numerator = torch.mvlgamma(alpha - 0.5, dm1) 139 # pi_constant in [1] is D * (D - 1) / 4 * log(pi) 140 # pi_constant in multigammaln is (D - 1) * (D - 2) / 4 * log(pi) 141 # hence, we need to add a pi_constant = (D - 1) * log(pi) / 2 142 pi_constant = 0.5 * dm1 * math.log(math.pi) 143 normalize_term = pi_constant + numerator - denominator 144 return unnormalized_log_pdf - normalize_term 145