or how to effectively exploit orthogonality constraints
Tensor networks (TNs) and probabilistic circuits (PCs) are powerful tools for tractable probabilistic modeling. They can represent complex distributions while enabling efficient computation of marginals, conditionals, and other probabilistic queries. But there’s a problem when we try to make them more expressive.
In short, with both model families we can model probabilities as:
\[\newcommand{\dom}{\text{dom}} \newcommand{\bbC}{\mathbb{C}} \newcommand{\uscope}{ {\text{sc}}} \newcommand{\inscope}{ {\text{in}}} \newcommand{\vX}{X} \newcommand{\calO}{\mathcal{O}} p(x) = \frac{\lvert\psi(x)\rvert^2}{Z} = \frac{\psi(x) \cdot \psi^*(x)}{Z}\]where $\psi(x)$ is a complex-valued function, $\psi^*(x)$ is its conjugate, and $Z = \int \lvert\psi(x)\rvert^2 dx$ is the partition function. This “squared” formulation allows us to use real or complex parameters, making our models strictly more expressive than traditional PCs with only positive parameters, as Lorenzo showed in AAAI past year
The cost? Computing \(Z\) or any marginal \(p(y) = \int_z \lvert\ \psi(y,z)\rvert^2 dz\) requires time \(\calO(\lvert c \rvert^2)\) instead of \(\calO(\lvert c\rvert)\), where \(\lvert c\rvert\) is the circuit size. This quadratic overhead happens because we’re effectively evaluating \(\psi \cdot \psi^*\) as a product of two circuits.
In our latest ICLR article
Circuits are parameterized computational graphs over a set of variable \(\vX\) encoding a function $c\colon\dom(\vX)\to\bbC$ that hierarchically compose the following types of circuit units, each unit $n$ taken another set of inputs units denoted by $\inscope(n)$:
Tensor networks factorize high-dimensional tensors into products of smaller tensors. The most famous example is the Matrix Product State (MPS)
where \(R\) is the rank of the factorization and each \(\psi_j\) is a factor. Computing \(\psi(x)\) naively takes \(\calO(R^d)\) time, but with the right contraction order (left-to-right), it only takes \(\calO(R^2 d)\).
The computational graph of sums and products resulting from the TN contraction in a particular ordering turns out to be a circuit.
Efficient marginalization in PCs—that is, exact and in poly-time wrt. $\calO(\lvert c\rvert)$—is achieved by imposing two structural constraints over them:
Together, these ensure we can integrate out any variable in time $\calO(\lvert c\rvert)$.
To model \(p(x) = \lvert\psi(x)\rvert^2 / Z\), we need to:
The product \(\psi \cdot \psi^*\) can be represented as a decomposable circuit if \(\psi\) is structured-decomposable (all products with the same scope factorize it the same way). But this circuit has size \(\calO(\lvert c\rvert^2)\), making marginalization quadratically expensive!
Prior solution: Tensor network canonical forms use unitary matrices to simplify marginals. For example, in an MPS with left-canonical form:
\[\int \psi_i^{k_1}(x) \psi_i^{k_2}(x)^* dx = \delta_{k_1,k_2}\]This orthonormality makes certain marginals trivial to compute. But canonical forms are TN-specific and don’t extend to general circuits.
Our first key contribution is defining and identifying orthogonality as a generalization of determinism that enables efficient marginalization in squared circuits.
A sum unit \(n\) computing \(c_n(x) = \sum_i w_i c_i(x)\) is deterministic if its inputs have disjoint supports: \(\forall i \neq j: \text{supp}(c_i) \cap \text{supp}(c_j) = \emptyset\).
When squaring a deterministic sum:
\[\lvert c_n(x)\rvert^2 = \left|\sum_i w_i c_i(x)\right|^2 = \sum_i \sum_j w_i w_j^* c_i(x) c_j(x)^* = \sum_i \lvert w_i\rvert ^2 \lvert c_i(x)\rvert ^2\]The cross terms vanish! So determinism reduces squaring from \(\calO(n^2)\) to \(\calO(n)\) terms.
But: Deterministic squared circuits become equivalent to monotonic circuits (only \(\lvert w_i\rvert^2\) appears), losing the expressiveness of complex parameters.
Definition (Orthogonality): A sum unit \(n\) with scope \(Z\) is orthogonal if:
\[\forall i \neq j: \int c_i(z) c_j(z)^* dz = 0\]Unlike determinism, orthogonality allows inputs to have overlapping support: They just need to be orthogonal functions!
Example: Hermite functions or Fourier basis functions are orthogonal but have full support over \(\mathbb{R}\).
Theorem 1 (Simplified): If circuit \(c\) is smooth, decomposable, and orthogonal, then computing \(Z = \int \lvert c(x)\rvert^2 dx\) takes time \(\calO(\lvert c\rvert)\).
Proof sketch: When integrating \(\lvert c_n(x)\rvert^2 = \sum_i \sum_j w_i w_j^* c_i(x) c_j(x)^*\), orthogonality makes cross terms vanish after integration, leaving only \(\sum_i \lvert w_i\rvert ^2 \int \lvert c_i(z)\rvert ^2 dz\).
Crucially, orthogonality preserves complex parameters in the distribution, unlike determinism!
Orthogonality works even for non-structured-decomposable circuits—those encoding multiple variable partitionings. This is significant because:
How do we construct circuits that satisfy orthogonality? The paper introduces regular orthogonality.
Basis decomposability: A sum unit \(n\) is basis decomposable if there exists a variable \(\vX\) such that inputs depend on disjoint sets of input functions over \(\vX\).
Regular orthogonality: A circuit is regular orthogonal if:
Theorem 2: Regular orthogonality \(\Rightarrow\) Orthogonality.
This gives us a practical recipe: choose orthogonal input functions and ensure sum units split their dependencies appropriately.
Discrete variables:
Continuous variables:
The choice depends on your domain and the functions you want to approximate.
Regular orthogonality works well for circuits where each sum connects to different input functions. But modern circuits often have dense layers where multiple sums share inputs—these aren’t basis decomposable!
The solution: unitarity, which defines conditions at the layer level.
A tensorized circuit groups units into layers:
This is just a convenient representation that allows for efficient GPU usage and over-parametrization. In general, it is equivalent to a regular circuit.
A tensorized circuit is unitary if it satisfies:
(U1) Orthonormal inputs: Each input layer uses orthonormal functions, i.e.,
\[\int \ell(x) \otimes \ell(x)^* dx = I_K \;.\](U2) Separated dependencies: Each sum layer’s inputs depend on different input layers (for at least one variable).
(U3) Semi-unitary weights: Each sum layer’s weight matrix \(W\) satisfies \(WW^\dagger = I\).
Theorem 3: If \(c\) is unitary, then:
Proof sketch: By induction on layers, show \(\int \ell(z) \otimes \ell(z)^* dz = I_K\) for every layer. The semi-unitary property ensures: \(\int (W\ell) \otimes (W\ell)^* dz = (W \otimes W^*) I_K = WW^\dagger = I\).
Unitarity generalizes the upper-canonical form of tree tensor networks (TTNs). TTNs in upper-canonical form have:
When represented as a tensorized circuit, these become exactly conditions (U1) and (U3)!
But unitarity extends to non-TN structures—circuits that encode multiple variable partitionings simultaneously, which no TN can represent.
Even with unitarity ensuring \(Z = 1\), we still need to compute marginals \(p(y) = \int \lvert c(y,z)\rvert^2 dz\). The naive approach: materialize the squared circuit (size \(\calO(L^2 S_{\max}^2)\)) and marginalize (time \(\calO(L^2 S_{\max}^2)\)).
We present an algorithm that achieves:
\[\calO(\lvert\phi_Y \setminus \phi_Z\rvert S_{\max} + \lvert\phi_Y \cap \phi_Z\rvert S_{\max}^2)\]where \(\phi_Y\) = layers depending on \(Y\), \(\phi_Z\) = layers depending on \(Z\).
Key insights:
This requires strengthening (U2):
Best case: \(\calO(\lvert\phi_Y \setminus \phi_Z\rvert S_{\max})\) when only a few boundary layers depend on both \(Y\) and \(Z\).
Example: Marginalizing the left half of an image. Most layers depend entirely on either left or right pixels. Only layers near the output depend on both, so \(\lvert\phi_Y \cap \phi_Z\rvert = \calO(1)\) while \(\lvert\phi_Y\rvert = \calO(L)\).
def marginalize(layer, y, Z):
if scope(layer) ⊆ Z:
# Only marginalized variables
return I # From unitarity!
if scope(layer) ∩ Z == ∅:
# Only kept variables
r = evaluate(layer, y)
return r ⊗ r.T.conj()
# Both kept and marginalized
if layer is sum:
results = [marginalize(input, y, Z) for input in layer.inputs]
# Crucially: only diagonal terms survive (U4)
return sum(W[i] @ R[i] @ W[i].T.conj() for i, R in enumerate(results))
if layer is product:
r1 = marginalize(layer.input1, y, Z ∩ scope(layer.input1))
r2 = marginalize(layer.input2, y, Z ∩ scope(layer.input2))
return r1 ⊙ r2 # or permuted Kronecker
The experiments validate three key questions:
Setup: Compare \(\pm^2_C\) (baseline squared PCs) vs \(\perp^2_C\) (squared unitary PCs) with increasing model size.
Results:
Why? Not materializing the squared circuit saves both memory and compute.
Setup: Train on MNIST and FashionMNIST, measure bits-per-dimension (lower = better).
Results: \(\perp^2_C\) with Kronecker layers matches \(\pm^2_C\) with Hadamard layers across all model sizes.
Key: Adapted the LandingSGD optimizer
We called the final variant LandingPC.
Setup: Train a \(\perp^2_C\) circuit that encodes multiple variable partitionings (non-structured-decomposable).
Results: Competitive with structured circuits, especially at large scales, though harder to train.
Significance: First demonstration of tractably training non-structured squared PCs! These can be exponentially more expressive than structured ones.
This work bridges tensor networks and probabilistic circuits by showing that orthogonality unifies two seemingly different ideas:
Theoretical:
Practical: