Public API
This package provides one type, DiscreteMeasure
, which describes a measure on a finite set for use in Sinkhorn's algorithm and the related functions. The first step to computing e.g. the Sinkhorn divergence (sinkhorn_divergence!
) is to construct DiscreteMeasure
's describing the quantities of interest.
UnbalancedOptimalTransport.DiscreteMeasure
— TypeDiscreteMeasure(density, [log_density], set) -> DiscreteMeasure
Construct a DiscreteMeasure
object for use in unbalanced_sinkhorn!
and related functions.
density
should be strictly positive; zero elements should instead be removed fromset
log_density
should be equal tolog.(density)
and can be omitted (in which case its computed automatically)set
is a collection so thatdensity[i]
is the probability of the elementset[i]
occurring (wherei ∈ eachindex(density, set)
).
Functions
This package provides three functions which act on DiscreteMeasure
's to calculate quantities of interest:
UnbalancedOptimalTransport.OT!
— Functionfunction OT!(
D::AbstractDivergence,
C,
a::DiscreteMeasure,
b::DiscreteMeasure,
ϵ = 1e-1;
C = (x, y) -> norm(x - y),
kwargs...,
) -> Number
Computes the optimal transport cost between a
and b
, using unbalanced_sinkhorn!
; see that function for the meaning of the parameters and the keyword arguments. Implements Equation (15) of [SFVTP19].
UnbalancedOptimalTransport.sinkhorn_divergence!
— Functionsinkhorn_divergence!(
D::AbstractDivergence,
C,
a::DiscreteMeasure,
b::DiscreteMeasure,
ϵ = 1e-1;
kwargs...,
) -> Number
Computes the unbalanced sinkhorn divergence between a
and b
as defined in Def. 6 of [SFVTP19], using unbalanced_sinkhorn!
; see that function for the meaning of the parameters and the keyword arguments. Sets the optimal dual_potential
's of a
and b
.
UnbalancedOptimalTransport.optimal_coupling!
— Functionfunction optimal_coupling!(
D::AbstractDivergence,
C,
a::DiscreteMeasure,
b::DiscreteMeasure,
ϵ = 1e-1;
dual_potentials_populated::Bool = false,
kwargs...) -> Matrix
Computes the optimal coupling between a
and b
using the optimal dual potentials, the regularization parameter ϵ
, and the cost function C
.
If dual_potentials_populated = false
, unbalanced_sinkhorn!
is called to populate the dual potentials of a
and b
, using the divergence D
. If dual_potentials_populated = true
, one of unbalanced_sinkhorn!
or OT!
or sinkhorn_divergence!
must be called first to set the optimal dual potentials, with the same choice of ϵ
and C
. In this case, a
and b
are not mutated.
This function implements Prop. 6 of [SFVTP19].
Sinkhorn's algorithm
The above functions rely on unbalanced_sinkhorn!
which uses Sinkhorn's algorithm to calculate the optimal Dual potentials.
UnbalancedOptimalTransport.unbalanced_sinkhorn!
— Functionfunction unbalanced_sinkhorn!(
D::AbstractDivergence,
C,
a::DiscreteMeasure,
b::DiscreteMeasure,
ϵ = 1e-1;
tol = 1e-5,
max_iters = 10^5,
warn::Bool = true,
) -> NamedTuple
Implements algorithm 1 of [SFVTP19]. The dual_potential
fields of a
and b
are updated to hold the optimal dual potentials. The density
, log_density
, and set
fields are not modified. The parameters are
D
: theAbstractDivergence
used for measuring the cost of creating or destroying massϵ
: the regularization parameterC
: either a function froma.set
×b.set
to real numbers; should satisfyC(x,y) = C(y,x)
andC(x,x)=0
when applicable, or a precomputed cost matrix as generated by e.g.cost_matrix
tol
: the convergence tolerancemax_iters
: the maximum number of iterations to perform.warn
: whether or not to warn when the maximum number of iterations is reached.
Returns a NamedTuple with f
and g
, the dual potentials of a
and b
respectively, the number of iterations performed (iters
), and the maximum residual (max_residual
), which is the maximum infinity norm difference between consecutive iterates of the dual potentials, at the end of the process. If max_iters
is not reached, iteration stops when the max_residual
falls below tol
.
If a
and b
alias, then one must use the return values of f
and g
rather than a.dual_potential
and b.dual_potential
.
Divergences
As described on the page Optimal transport, we use $\varphi$-divergences as penalty terms for mass creation and destruction. This package implements four such divergences, which are all described in Section 2.4 of [SFVTP19], and are listed below. To add your own divergences, see the AbstractDivergence
interface section.
UnbalancedOptimalTransport.KL
— TypeKL{ρ} <: AbstractDivergence
Represents the divergence ρ*KL(a|b)
, where KL
is the Kullback-Leibler divergence. The parameter ρ
is simply a scaling.
UnbalancedOptimalTransport.TV
— TypeTV{ρ} <: AbstractDivergence
Represents the divergence ρ*TV(u,v) = ρ*norm(u-v,1)
, where TV
is twice the total variation distance. The parameter ρ
is simply a scaling.
UnbalancedOptimalTransport.Balanced
— TypeBalanced <: AbstractDivergence
Represents the divergence Dᵩ(a|b)
which is zero if a==b
and infinite otherwise. Generalized by RG
.
UnbalancedOptimalTransport.RG
— TypeRG{l,u} <: AbstractDivergence
Represents the divergence Dᵩ(a|b)
which is zero if l*b .<= a .<= u*b
and infinite otherwise. Equivalent to Balanced
when l == u
.
AbstractDivergence
interface
To add a divergence MyDivergence
, create a struct
struct MyDivergence <: UnbalancedOptimalTransport.AbstractDivergence end
and implement a method for UnbalancedOptimalTransport.aprox
and UnbalancedOptimalTransport.φstar
. Optionally, one can also implement a method for UnbalancedOptimalTransport.initialize_dual_potential!
and sinkhorn_divergence!
, as a specialized implementation may obtain better performance.
UnbalancedOptimalTransport.AbstractDivergence
— Typeabstract type AbstractDivergence
An abstract type representing Csiszár φ-divergences. Subtypes should implement φstar
and aprox
, and optionally can implement initialize_dual_potential!
and/or sinkhorn_divergence!
.
UnbalancedOptimalTransport.aprox
— Functionaprox(::AbstractDivergence, ϵ::Number, x::Number) -> Number
The anisotropic proximity operator defined in Def. 2 of [SFVTP19].
UnbalancedOptimalTransport.φstar
— Functionφstar(::AbstractDivergence, q::Number) -> Number
The Legendre conjugate of the function φ
associated to the divergence.
UnbalancedOptimalTransport.initialize_dual_potential!
— Functioninitialize_dual_potential!(::AbstractDivergence, a::DiscreteMeasure) -> Nothing
Apply an initialization for the dual potential, for use in unbalanced_sinkhorn!
; falls back to zeroing out the dual potential. Specialized implementations can improve performance, but should not affect correctness.
Utilities
sinkhorn_divergence!
uses the following function, which may be specialized to improve performance.
UnbalancedOptimalTransport.fdot
— Functionfdot(f, u, v) -> Number
A generic, allocation-free implementation of dot(u, f.(v))
. It may be faster to provide a specialized method to dispatch to BLAS or so forth.
The following function cost_matrix
is used in unbalanced_sinkhorn!
, OT!
, and optimal_coupling
to precompute the costs given a cost function.
UnbalancedOptimalTransport.cost_matrix
— Functioncost_matrix([C,] a, b) -> Matrix
Precompute the cost matrix given a cost function C
. If no function C
is provided, the default is C(x,y) = norm(x-y)
.