Public API
This package provides one type, DiscreteMeasure, which describes a measure on a finite set for use in Sinkhorn's algorithm and the related functions. The first step to computing e.g. the Sinkhorn divergence (sinkhorn_divergence!) is to construct DiscreteMeasure's describing the quantities of interest.
UnbalancedOptimalTransport.DiscreteMeasure — TypeDiscreteMeasure(density, [log_density], set) -> DiscreteMeasureConstruct a DiscreteMeasure object for use in unbalanced_sinkhorn! and related functions.
densityshould be strictly positive; zero elements should instead be removed fromsetlog_densityshould be equal tolog.(density)and can be omitted (in which case its computed automatically)setis a collection so thatdensity[i]is the probability of the elementset[i]occurring (wherei ∈ eachindex(density, set)).
Functions
This package provides three functions which act on DiscreteMeasure's to calculate quantities of interest:
UnbalancedOptimalTransport.OT! — Functionfunction OT!(
D::AbstractDivergence,
C,
a::DiscreteMeasure,
b::DiscreteMeasure,
ϵ = 1e-1;
C = (x, y) -> norm(x - y),
kwargs...,
) -> NumberComputes the optimal transport cost between a and b, using unbalanced_sinkhorn!; see that function for the meaning of the parameters and the keyword arguments. Implements Equation (15) of [SFVTP19].
UnbalancedOptimalTransport.sinkhorn_divergence! — Functionsinkhorn_divergence!(
D::AbstractDivergence,
C,
a::DiscreteMeasure,
b::DiscreteMeasure,
ϵ = 1e-1;
kwargs...,
) -> NumberComputes the unbalanced sinkhorn divergence between a and b as defined in Def. 6 of [SFVTP19], using unbalanced_sinkhorn!; see that function for the meaning of the parameters and the keyword arguments. Sets the optimal dual_potential's of a and b.
UnbalancedOptimalTransport.optimal_coupling! — Functionfunction optimal_coupling!(
D::AbstractDivergence,
C,
a::DiscreteMeasure,
b::DiscreteMeasure,
ϵ = 1e-1;
dual_potentials_populated::Bool = false,
kwargs...) -> MatrixComputes the optimal coupling between a and b using the optimal dual potentials, the regularization parameter ϵ, and the cost function C.
If dual_potentials_populated = false, unbalanced_sinkhorn! is called to populate the dual potentials of a and b, using the divergence D. If dual_potentials_populated = true, one of unbalanced_sinkhorn! or OT! or sinkhorn_divergence! must be called first to set the optimal dual potentials, with the same choice of ϵ and C. In this case, a and b are not mutated.
This function implements Prop. 6 of [SFVTP19].
Sinkhorn's algorithm
The above functions rely on unbalanced_sinkhorn! which uses Sinkhorn's algorithm to calculate the optimal Dual potentials.
UnbalancedOptimalTransport.unbalanced_sinkhorn! — Functionfunction unbalanced_sinkhorn!(
D::AbstractDivergence,
C,
a::DiscreteMeasure,
b::DiscreteMeasure,
ϵ = 1e-1;
tol = 1e-5,
max_iters = 10^5,
warn::Bool = true,
) -> NamedTupleImplements algorithm 1 of [SFVTP19]. The dual_potential fields of a and b are updated to hold the optimal dual potentials. The density, log_density, and set fields are not modified. The parameters are
D: theAbstractDivergenceused for measuring the cost of creating or destroying massϵ: the regularization parameterC: either a function froma.set×b.setto real numbers; should satisfyC(x,y) = C(y,x)andC(x,x)=0when applicable, or a precomputed cost matrix as generated by e.g.cost_matrixtol: the convergence tolerancemax_iters: the maximum number of iterations to perform.warn: whether or not to warn when the maximum number of iterations is reached.
Returns a NamedTuple of the number of iterations performed (iters), and the maximum residual (max_residual), which is the maximum infinity norm difference between consecutive iterates of the dual potentials, at the end of the process. If max_iters is not reached, iteration stops when the max_residual falls below tol.
Divergences
As described on the page Optimal transport, we use $\varphi$-divergences as penalty terms for mass creation and destruction. This package implements four such divergences, which are all described in Section 2.4 of [SFVTP19], and are listed below. To add your own divergences, see the AbstractDivergence interface section.
UnbalancedOptimalTransport.KL — TypeKL{ρ} <: AbstractDivergenceRepresents the divergence ρ*KL(a|b), where KL is the Kullback-Leibler divergence. The parameter ρ is simply a scaling.
UnbalancedOptimalTransport.TV — TypeTV{ρ} <: AbstractDivergenceRepresents the divergence ρ*TV(u,v) = ρ*norm(u-v,1), where TV is twice the total variation distance. The parameter ρ is simply a scaling.
UnbalancedOptimalTransport.Balanced — TypeBalanced <: AbstractDivergenceRepresents the divergence Dᵩ(a|b) which is zero if a==b and infinite otherwise. Generalized by RG.
UnbalancedOptimalTransport.RG — TypeRG{l,u} <: AbstractDivergenceRepresents the divergence Dᵩ(a|b) which is zero if l*b .<= a .<= u*b and infinite otherwise. Equivalent to Balanced when l == u.
AbstractDivergence interface
To add a divergence MyDivergence, create a struct
struct MyDivergence <: UnbalancedOptimalTransport.AbstractDivergence endand implement a method for UnbalancedOptimalTransport.aprox and UnbalancedOptimalTransport.φstar. Optionally, one can also implement a method for UnbalancedOptimalTransport.initialize_dual_potential! and sinkhorn_divergence!, as a specialized implementation may obtain better performance.
UnbalancedOptimalTransport.AbstractDivergence — Typeabstract type AbstractDivergenceAn abstract type representing Csiszár φ-divergences. Subtypes should implement φstar and aprox, and optionally can implement initialize_dual_potential! and/or sinkhorn_divergence!.
UnbalancedOptimalTransport.aprox — Functionaprox(::AbstractDivergence, ϵ::Number, x::Number) -> NumberThe anisotropic proximity operator defined in Def. 2 of [SFVTP19].
UnbalancedOptimalTransport.φstar — Functionφstar(::AbstractDivergence, q::Number) -> NumberThe Legendre conjugate of the function φ associated to the divergence.
UnbalancedOptimalTransport.initialize_dual_potential! — Functioninitialize_dual_potential!(::AbstractDivergence, a::DiscreteMeasure) -> NothingApply an initialization for the dual potential, for use in unbalanced_sinkhorn!; falls back to zeroing out the dual potential. Specialized implementations can improve performance, but should not affect correctness.
Utilities
sinkhorn_divergence! uses the following function, which may be specialized to improve performance.
UnbalancedOptimalTransport.fdot — Functionfdot(f, u, v) -> NumberA generic, allocation-free implementation of dot(u, f.(v)). It may be faster to provide a specialized method to dispatch to BLAS or so forth.
The following function cost_matrix is used in unbalanced_sinkhorn!, OT!, and optimal_coupling to precompute the costs given a cost function.
UnbalancedOptimalTransport.cost_matrix — Functioncost_matrix([C,] a, b) -> MatrixPrecompute the cost matrix given a cost function C. If no function C is provided, the default is C(x,y) = norm(x-y).