Public API

This package provides one type, DiscreteMeasure, which describes a measure on a finite set for use in Sinkhorn's algorithm and the related functions. The first step to computing e.g. the Sinkhorn divergence (sinkhorn_divergence!) is to construct DiscreteMeasure's describing the quantities of interest.

UnbalancedOptimalTransport.DiscreteMeasureType
DiscreteMeasure(density, [log_density], set) -> DiscreteMeasure

Construct a DiscreteMeasure object for use in unbalanced_sinkhorn! and related functions.

  • density should be strictly positive; zero elements should instead be removed from set
  • log_density should be equal to log.(density) and can be omitted (in which case its computed automatically)
  • set is a collection so that density[i] is the probability of the element set[i] occurring (where i ∈ eachindex(density, set)).
source

Functions

This package provides three functions which act on DiscreteMeasure's to calculate quantities of interest:

UnbalancedOptimalTransport.OT!Function
function OT!(
    D::AbstractDivergence,
    C,
    a::DiscreteMeasure,
    b::DiscreteMeasure,
    ϵ = 1e-1;
    C = (x, y) -> norm(x - y),
    kwargs...,
) -> Number

Computes the optimal transport cost between a and b, using unbalanced_sinkhorn!; see that function for the meaning of the parameters and the keyword arguments. Implements Equation (15) of [SFVTP19].

source
UnbalancedOptimalTransport.sinkhorn_divergence!Function
sinkhorn_divergence!(
    D::AbstractDivergence,
    C,
    a::DiscreteMeasure,
    b::DiscreteMeasure,
    ϵ = 1e-1;
    kwargs...,
) -> Number

Computes the unbalanced sinkhorn divergence between a and b as defined in Def. 6 of [SFVTP19], using unbalanced_sinkhorn!; see that function for the meaning of the parameters and the keyword arguments. Sets the optimal dual_potential's of a and b.

source
UnbalancedOptimalTransport.optimal_coupling!Function
function optimal_coupling!(
    D::AbstractDivergence,
    C,
    a::DiscreteMeasure,
    b::DiscreteMeasure,
    ϵ = 1e-1;
    dual_potentials_populated::Bool = false,
    kwargs...) -> Matrix

Computes the optimal coupling between a and b using the optimal dual potentials, the regularization parameter ϵ, and the cost function C.

If dual_potentials_populated = false, unbalanced_sinkhorn! is called to populate the dual potentials of a and b, using the divergence D. If dual_potentials_populated = true, one of unbalanced_sinkhorn! or OT! or sinkhorn_divergence! must be called first to set the optimal dual potentials, with the same choice of ϵ and C. In this case, a and b are not mutated.

This function implements Prop. 6 of [SFVTP19].

source

Sinkhorn's algorithm

The above functions rely on unbalanced_sinkhorn! which uses Sinkhorn's algorithm to calculate the optimal Dual potentials.

UnbalancedOptimalTransport.unbalanced_sinkhorn!Function
function unbalanced_sinkhorn!(
    D::AbstractDivergence,
    C,
    a::DiscreteMeasure,
    b::DiscreteMeasure,
    ϵ = 1e-1;
    tol = 1e-5,
    max_iters = 10^5,
    warn::Bool = true,
) -> NamedTuple

Implements algorithm 1 of [SFVTP19]. The dual_potential fields of a and b are updated to hold the optimal dual potentials. The density, log_density, and set fields are not modified. The parameters are

  • D: the AbstractDivergence used for measuring the cost of creating or destroying mass
  • ϵ: the regularization parameter
  • C: either a function from a.set × b.set to real numbers; should satisfy C(x,y) = C(y,x) and C(x,x)=0 when applicable, or a precomputed cost matrix as generated by e.g. cost_matrix
  • tol: the convergence tolerance
  • max_iters: the maximum number of iterations to perform.
  • warn: whether or not to warn when the maximum number of iterations is reached.

Returns a NamedTuple of the number of iterations performed (iters), and the maximum residual (max_residual), which is the maximum infinity norm difference between consecutive iterates of the dual potentials, at the end of the process. If max_iters is not reached, iteration stops when the max_residual falls below tol.

source

Divergences

As described on the page Optimal transport, we use $\varphi$-divergences as penalty terms for mass creation and destruction. This package implements four such divergences, which are all described in Section 2.4 of [SFVTP19], and are listed below. To add your own divergences, see the AbstractDivergence interface section.

UnbalancedOptimalTransport.KLType
KL{ρ} <: AbstractDivergence

Represents the divergence ρ*KL(a|b), where KL is the Kullback-Leibler divergence. The parameter ρ is simply a scaling.

source
UnbalancedOptimalTransport.TVType
TV{ρ} <: AbstractDivergence

Represents the divergence ρ*TV(u,v) = ρ*norm(u-v,1), where TV is twice the total variation distance. The parameter ρ is simply a scaling.

source

AbstractDivergence interface

To add a divergence MyDivergence, create a struct

struct MyDivergence <: UnbalancedOptimalTransport.AbstractDivergence end

and implement a method for UnbalancedOptimalTransport.aprox and UnbalancedOptimalTransport.φstar. Optionally, one can also implement a method for UnbalancedOptimalTransport.initialize_dual_potential! and sinkhorn_divergence!, as a specialized implementation may obtain better performance.

Utilities

sinkhorn_divergence! uses the following function, which may be specialized to improve performance.

UnbalancedOptimalTransport.fdotFunction
fdot(f, u, v) -> Number

A generic, allocation-free implementation of dot(u, f.(v)). It may be faster to provide a specialized method to dispatch to BLAS or so forth.

source

The following function cost_matrix is used in unbalanced_sinkhorn!, OT!, and optimal_coupling to precompute the costs given a cost function.