Squaring tensor networks and circuits without squaring them

or how to effectively exploit orthogonality constraints

The quadratic bottleneck

Tensor networks (TNs) and probabilistic circuits (PCs) are powerful tools for tractable probabilistic modeling. They can represent complex distributions while enabling efficient computation of marginals, conditionals, and other probabilistic queries. But there’s a problem when we try to make them more expressive.

In short, with both model families we can model probabilities as:

\[\newcommand{\dom}{\text{dom}} \newcommand{\bbC}{\mathbb{C}} \newcommand{\uscope}{ {\text{sc}}} \newcommand{\inscope}{ {\text{in}}} \newcommand{\vX}{X} \newcommand{\calO}{\mathcal{O}} p(x) = \frac{\lvert\psi(x)\rvert^2}{Z} = \frac{\psi(x) \cdot \psi^*(x)}{Z}\]

where $\psi(x)$ is a complex-valued function, $\psi^*(x)$ is its conjugate, and $Z = \int \lvert\psi(x)\rvert^2 dx$ is the partition function. This “squared” formulation allows us to use real or complex parameters, making our models strictly more expressive than traditional PCs with only positive parameters, as Lorenzo showed in AAAI past year .

The cost? Computing \(Z\) or any marginal \(p(y) = \int_z \lvert\ \psi(y,z)\rvert^2 dz\) requires time \(\calO(\lvert c \rvert^2)\) instead of \(\calO(\lvert c\rvert)\), where \(\lvert c\rvert\) is the circuit size. This quadratic overhead happens because we’re effectively evaluating \(\psi \cdot \psi^*\) as a product of two circuits.

In our latest ICLR article we show how to escape this quadratic bottleneck in squared PCs. The key insight: borrow ideas from both TN canonical forms (orthogonality) and PCs (determinism) to enable linear-time marginalization!

Background

Circuits and tensor networks

Circuits are parameterized computational graphs over a set of variable \(\vX\) encoding a function $c\colon\dom(\vX)\to\bbC$ that hierarchically compose the following types of circuit units, each unit $n$ taken another set of inputs units denoted by $\inscope(n)$:

Tensor networks factorize high-dimensional tensors into products of smaller tensors. The most famous example is the Matrix Product State (MPS) :

\[\psi(x) = \sum_{i_1=1}^R \sum_{i_2=1}^R \cdots \sum_{i_{d-1}=1}^R \psi_1^{i_1}(x_1) \psi_2^{i_1,i_2}(x_2) \cdots \psi_d^{i_{d-1}}(x_d)\]

where \(R\) is the rank of the factorization and each \(\psi_j\) is a factor. Computing \(\psi(x)\) naively takes \(\calO(R^d)\) time, but with the right contraction order (left-to-right), it only takes \(\calO(R^2 d)\).

The computational graph of sums and products resulting from the TN contraction in a particular ordering turns out to be a circuit.

An MPS tensor network (lower right, Penrose notation) and its representation as a circuit (left). The circuit encodes the left-to-right contraction order.

Efficient marginalization in PCs—that is, exact and in poly-time wrt. $\calO(\lvert c\rvert)$—is achieved by imposing two structural constraints over them:

Together, these ensure we can integrate out any variable in time $\calO(\lvert c\rvert)$.

The Born rule and squaring

To model \(p(x) = \lvert\psi(x)\rvert^2 / Z\), we need to:

  1. Represent \(\psi(x) \cdot \psi^*(x)\) as a circuit.
  2. Compute \(Z = \int \lvert\psi(x)\rvert^2 dx\).

The product \(\psi \cdot \psi^*\) can be represented as a decomposable circuit if \(\psi\) is structured-decomposable (all products with the same scope factorize it the same way). But this circuit has size \(\calO(\lvert c\rvert^2)\), making marginalization quadratically expensive!

Prior solution: Tensor network canonical forms use unitary matrices to simplify marginals. For example, in an MPS with left-canonical form:

\[\int \psi_i^{k_1}(x) \psi_i^{k_2}(x)^* dx = \delta_{k_1,k_2}\]

This orthonormality makes certain marginals trivial to compute. But canonical forms are TN-specific and don’t extend to general circuits.

From determinism to orthogonality

Our first key contribution is defining and identifying orthogonality as a generalization of determinism that enables efficient marginalization in squared circuits.

Determinism enables simplification

A sum unit \(n\) computing \(c_n(x) = \sum_i w_i c_i(x)\) is deterministic if its inputs have disjoint supports: \(\forall i \neq j: \text{supp}(c_i) \cap \text{supp}(c_j) = \emptyset\).

When squaring a deterministic sum:

\[\lvert c_n(x)\rvert^2 = \left|\sum_i w_i c_i(x)\right|^2 = \sum_i \sum_j w_i w_j^* c_i(x) c_j(x)^* = \sum_i \lvert w_i\rvert ^2 \lvert c_i(x)\rvert ^2\]

The cross terms vanish! So determinism reduces squaring from \(\calO(n^2)\) to \(\calO(n)\) terms.

But: Deterministic squared circuits become equivalent to monotonic circuits (only \(\lvert w_i\rvert^2\) appears), losing the expressiveness of complex parameters.

Orthogonality: The key generalization

Definition (Orthogonality): A sum unit \(n\) with scope \(Z\) is orthogonal if:

\[\forall i \neq j: \int c_i(z) c_j(z)^* dz = 0\]

Unlike determinism, orthogonality allows inputs to have overlapping support: They just need to be orthogonal functions!

Example: Hermite functions or Fourier basis functions are orthogonal but have full support over \(\mathbb{R}\).

Determinism requires disjoint supports (left). Orthogonality allows overlapping supports as long as functions are orthogonal (right).

Theorem 1 (Simplified): If circuit \(c\) is smooth, decomposable, and orthogonal, then computing \(Z = \int \lvert c(x)\rvert^2 dx\) takes time \(\calO(\lvert c\rvert)\).

Proof sketch: When integrating \(\lvert c_n(x)\rvert^2 = \sum_i \sum_j w_i w_j^* c_i(x) c_j(x)^*\), orthogonality makes cross terms vanish after integration, leaving only \(\sum_i \lvert w_i\rvert ^2 \int \lvert c_i(z)\rvert ^2 dz\).

Crucially, orthogonality preserves complex parameters in the distribution, unlike determinism!

Unlocking non-structured-decomposable circuits

Orthogonality works even for non-structured-decomposable circuits—those encoding multiple variable partitionings. This is significant because:

Building orthogonal circuits

How do we construct circuits that satisfy orthogonality? The paper introduces regular orthogonality.

Regular orthogonality

Basis decomposability: A sum unit \(n\) is basis decomposable if there exists a variable \(\vX\) such that inputs depend on disjoint sets of input functions over \(\vX\).

Regular orthogonality: A circuit is regular orthogonal if:

  1. It’s basis decomposable.
  2. All input functions over the same variable are orthogonal.

Theorem 2: Regular orthogonality \(\Rightarrow\) Orthogonality.

This gives us a practical recipe: choose orthogonal input functions and ensure sum units split their dependencies appropriately.

Choices for orthogonal functions

Discrete variables:

Continuous variables:

The choice depends on your domain and the functions you want to approximate.

Scaling up: Unitarity for tensorized circuits

Regular orthogonality works well for circuits where each sum connects to different input functions. But modern circuits often have dense layers where multiple sums share inputs—these aren’t basis decomposable!

Basis decomposable: each sum connects to different inputs.
NOT basis decomposable: sums share input connections

The solution: unitarity, which defines conditions at the layer level.

Tensorized circuits

A tensorized circuit groups units into layers:

This is just a convenient representation that allows for efficient GPU usage and over-parametrization. In general, it is equivalent to a regular circuit.

Unitarity conditions

A tensorized circuit is unitary if it satisfies:

(U1) Orthonormal inputs: Each input layer uses orthonormal functions, i.e.,

\[\int \ell(x) \otimes \ell(x)^* dx = I_K \;.\]

(U2) Separated dependencies: Each sum layer’s inputs depend on different input layers (for at least one variable).

(U3) Semi-unitary weights: Each sum layer’s weight matrix \(W\) satisfies \(WW^\dagger = I\).

Theorem 3: If \(c\) is unitary, then:

  1. \(c\) is orthogonal.
  2. \(Z = \int \lvert c(x)\rvert^2 dx = 1\) (already normalized!).

Proof sketch: By induction on layers, show \(\int \ell(z) \otimes \ell(z)^* dz = I_K\) for every layer. The semi-unitary property ensures: \(\int (W\ell) \otimes (W\ell)^* dz = (W \otimes W^*) I_K = WW^\dagger = I\).

Connection to TNs

Unitarity generalizes the upper-canonical form of tree tensor networks (TTNs). TTNs in upper-canonical form have:

When represented as a tensorized circuit, these become exactly conditions (U1) and (U3)!

But unitarity extends to non-TN structures—circuits that encode multiple variable partitionings simultaneously, which no TN can represent.

Tree tensor network (TTN) represented as a tensorized circuit.
Tensorized circuit with no corresponding TTN.

Efficient marginalization

Even with unitarity ensuring \(Z = 1\), we still need to compute marginals \(p(y) = \int \lvert c(y,z)\rvert^2 dz\). The naive approach: materialize the squared circuit (size \(\calO(L^2 S_{\max}^2)\)) and marginalize (time \(\calO(L^2 S_{\max}^2)\)).

We present an algorithm that achieves:

\[\calO(\lvert\phi_Y \setminus \phi_Z\rvert S_{\max} + \lvert\phi_Y \cap \phi_Z\rvert S_{\max}^2)\]

where \(\phi_Y\) = layers depending on \(Y\), \(\phi_Z\) = layers depending on \(Z\).

Key insights:

  1. Layers depending only on \(Z\) integrate to identity matrices (from unitarity) — don’t evaluate them!
  2. Layers depending only on \(Y\) don’t need squaring — evaluate once: \(\ell(y) \otimes \ell(y)^*\)
  3. Only layers depending on both \(Y\) and \(Z\) need full squaring

This requires strengthening (U2):

Best case: \(\calO(\lvert\phi_Y \setminus \phi_Z\rvert S_{\max})\) when only a few boundary layers depend on both \(Y\) and \(Z\).

Example: Marginalizing the left half of an image. Most layers depend entirely on either left or right pixels. Only layers near the output depend on both, so \(\lvert\phi_Y \cap \phi_Z\rvert = \calO(1)\) while \(\lvert\phi_Y\rvert = \calO(L)\).

def marginalize(layer, y, Z):
    if scope(layer)  Z:
        # Only marginalized variables
        return I  # From unitarity!

    if scope(layer)  Z == :
        # Only kept variables
        r = evaluate(layer, y)
        return r  r.T.conj()

    # Both kept and marginalized
    if layer is sum:
        results = [marginalize(input, y, Z) for input in layer.inputs]
        # Crucially: only diagonal terms survive (U4)
        return sum(W[i] @ R[i] @ W[i].T.conj() for i, R in enumerate(results))

    if layer is product:
        r1 = marginalize(layer.input1, y, Z  scope(layer.input1))
        r2 = marginalize(layer.input2, y, Z  scope(layer.input2))
        return r1  r2  # or permuted Kronecker

Experimental results

The experiments validate three key questions:

RQ1: Do unitary circuits scale better?

Time and memory per training iteration vs. parameter count. Unitary circuits ($\perp_\bbC^2$) are consistently faster and more memory-efficient.

Setup: Compare \(\pm^2_C\) (baseline squared PCs) vs \(\perp^2_C\) (squared unitary PCs) with increasing model size.

Results:

Why? Not materializing the squared circuit saves both memory and compute.

RQ2: Can we train unitary circuits without losing performance?

Test bits-per-dimension on MNIST and FashionMNIST. Unitary circuits (⊥²C) match baseline performance across scales.

Setup: Train on MNIST and FashionMNIST, measure bits-per-dimension (lower = better).

Results: \(\perp^2_C\) with Kronecker layers matches \(\pm^2_C\) with Hadamard layers across all model sizes.

Key: Adapted the LandingSGD optimizer (for orthogonal constraints) by:

We called the final variant LandingPC.

RQ3: Do non-structured circuits work?

Setup: Train a \(\perp^2_C\) circuit that encodes multiple variable partitionings (non-structured-decomposable).

Results: Competitive with structured circuits, especially at large scales, though harder to train.

Significance: First demonstration of tractably training non-structured squared PCs! These can be exponentially more expressive than structured ones.

Conclusion

This work bridges tensor networks and probabilistic circuits by showing that orthogonality unifies two seemingly different ideas:

Key Takeaways

Theoretical:

  1. Orthogonality enables \(\calO(\lvert c\rvert)\) marginalization in squared circuits.
  2. Unitarity provides practical layer-level conditions.
  3. Works for non-structured circuits (beyond what TNs can represent).

Practical:

  1. Unitary circuits are faster and use less memory.
  2. No performance loss with proper optimization (LandingPC).
  3. Enables new architectures (Kronecker layers, non-structured).