xref: /aosp_15_r20/external/pytorch/torch/distributions/lkj_cholesky.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3This closely follows the implementation in NumPyro (https://github.com/pyro-ppl/numpyro).
4
5Original copyright notice:
6
7# Copyright: Contributors to the Pyro project.
8# SPDX-License-Identifier: Apache-2.0
9"""
10
11import math
12
13import torch
14from torch.distributions import Beta, constraints
15from torch.distributions.distribution import Distribution
16from torch.distributions.utils import broadcast_all
17
18
19__all__ = ["LKJCholesky"]
20
21
22class LKJCholesky(Distribution):
23    r"""
24    LKJ distribution for lower Cholesky factor of correlation matrices.
25    The distribution is controlled by ``concentration`` parameter :math:`\eta`
26    to make the probability of the correlation matrix :math:`M` generated from
27    a Cholesky factor proportional to :math:`\det(M)^{\eta - 1}`. Because of that,
28    when ``concentration == 1``, we have a uniform distribution over Cholesky
29    factors of correlation matrices::
30
31        L ~ LKJCholesky(dim, concentration)
32        X = L @ L' ~ LKJCorr(dim, concentration)
33
34    Note that this distribution samples the
35    Cholesky factor of correlation matrices and not the correlation matrices
36    themselves and thereby differs slightly from the derivations in [1] for
37    the `LKJCorr` distribution. For sampling, this uses the Onion method from
38    [1] Section 3.
39
40    Example::
41
42        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
43        >>> l = LKJCholesky(3, 0.5)
44        >>> l.sample()  # l @ l.T is a sample of a correlation 3x3 matrix
45        tensor([[ 1.0000,  0.0000,  0.0000],
46                [ 0.3516,  0.9361,  0.0000],
47                [-0.1899,  0.4748,  0.8593]])
48
49    Args:
50        dimension (dim): dimension of the matrices
51        concentration (float or Tensor): concentration/shape parameter of the
52            distribution (often referred to as eta)
53
54    **References**
55
56    [1] `Generating random correlation matrices based on vines and extended onion method` (2009),
57    Daniel Lewandowski, Dorota Kurowicka, Harry Joe.
58    Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008
59    """
60    arg_constraints = {"concentration": constraints.positive}
61    support = constraints.corr_cholesky
62
63    def __init__(self, dim, concentration=1.0, validate_args=None):
64        if dim < 2:
65            raise ValueError(
66                f"Expected dim to be an integer greater than or equal to 2. Found dim={dim}."
67            )
68        self.dim = dim
69        (self.concentration,) = broadcast_all(concentration)
70        batch_shape = self.concentration.size()
71        event_shape = torch.Size((dim, dim))
72        # This is used to draw vectorized samples from the beta distribution in Sec. 3.2 of [1].
73        marginal_conc = self.concentration + 0.5 * (self.dim - 2)
74        offset = torch.arange(
75            self.dim - 1,
76            dtype=self.concentration.dtype,
77            device=self.concentration.device,
78        )
79        offset = torch.cat([offset.new_zeros((1,)), offset])
80        beta_conc1 = offset + 0.5
81        beta_conc0 = marginal_conc.unsqueeze(-1) - 0.5 * offset
82        self._beta = Beta(beta_conc1, beta_conc0)
83        super().__init__(batch_shape, event_shape, validate_args)
84
85    def expand(self, batch_shape, _instance=None):
86        new = self._get_checked_instance(LKJCholesky, _instance)
87        batch_shape = torch.Size(batch_shape)
88        new.dim = self.dim
89        new.concentration = self.concentration.expand(batch_shape)
90        new._beta = self._beta.expand(batch_shape + (self.dim,))
91        super(LKJCholesky, new).__init__(
92            batch_shape, self.event_shape, validate_args=False
93        )
94        new._validate_args = self._validate_args
95        return new
96
97    def sample(self, sample_shape=torch.Size()):
98        # This uses the Onion method, but there are a few differences from [1] Sec. 3.2:
99        # - This vectorizes the for loop and also works for heterogeneous eta.
100        # - Same algorithm generalizes to n=1.
101        # - The procedure is simplified since we are sampling the cholesky factor of
102        #   the correlation matrix instead of the correlation matrix itself. As such,
103        #   we only need to generate `w`.
104        y = self._beta.sample(sample_shape).unsqueeze(-1)
105        u_normal = torch.randn(
106            self._extended_shape(sample_shape), dtype=y.dtype, device=y.device
107        ).tril(-1)
108        u_hypersphere = u_normal / u_normal.norm(dim=-1, keepdim=True)
109        # Replace NaNs in first row
110        u_hypersphere[..., 0, :].fill_(0.0)
111        w = torch.sqrt(y) * u_hypersphere
112        # Fill diagonal elements; clamp for numerical stability
113        eps = torch.finfo(w.dtype).tiny
114        diag_elems = torch.clamp(1 - torch.sum(w**2, dim=-1), min=eps).sqrt()
115        w += torch.diag_embed(diag_elems)
116        return w
117
118    def log_prob(self, value):
119        # See: https://mc-stan.org/docs/2_25/functions-reference/cholesky-lkj-correlation-distribution.html
120        # The probability of a correlation matrix is proportional to
121        #   determinant ** (concentration - 1) = prod(L_ii ^ 2(concentration - 1))
122        # Additionally, the Jacobian of the transformation from Cholesky factor to
123        # correlation matrix is:
124        #   prod(L_ii ^ (D - i))
125        # So the probability of a Cholesky factor is propotional to
126        #   prod(L_ii ^ (2 * concentration - 2 + D - i)) = prod(L_ii ^ order_i)
127        # with order_i = 2 * concentration - 2 + D - i
128        if self._validate_args:
129            self._validate_sample(value)
130        diag_elems = value.diagonal(dim1=-1, dim2=-2)[..., 1:]
131        order = torch.arange(2, self.dim + 1, device=self.concentration.device)
132        order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order
133        unnormalized_log_pdf = torch.sum(order * diag_elems.log(), dim=-1)
134        # Compute normalization constant (page 1999 of [1])
135        dm1 = self.dim - 1
136        alpha = self.concentration + 0.5 * dm1
137        denominator = torch.lgamma(alpha) * dm1
138        numerator = torch.mvlgamma(alpha - 0.5, dm1)
139        # pi_constant in [1] is D * (D - 1) / 4 * log(pi)
140        # pi_constant in multigammaln is (D - 1) * (D - 2) / 4 * log(pi)
141        # hence, we need to add a pi_constant = (D - 1) * log(pi) / 2
142        pi_constant = 0.5 * dm1 * math.log(math.pi)
143        normalize_term = pi_constant + numerator - denominator
144        return unnormalized_log_pdf - normalize_term
145