What?

Learning to do combinatorial optimisation (mixed-integer programming) via embedding the discrete solution space into a continuous one, optimising in it and decoding the solution back.

Why?

Combinatorial optimisation is hard. There are two main approaches here. A traditional one is to derive a general algorithm or a heuristic that is able to solve everything. However, if approximate solutions are acceptable, we can leverage the data and use ML to learn an approximate search algorithm.

How?

Main idea, source: https://openreview.net/pdf?id=P3FX9pUev-

Main idea, source: https://openreview.net/pdf?id=P3FX9pUev-

Turn an optimisation problem over discrete $x$ into another one involving minimisation of a function of a continuous variable $z$ and recovering the solution from a decoder $g(z)$.

<aside> 💡 Another way to think about this is that we want to learn a solution improvement operator similar to the gradient that is not available in discrete problems.

</aside>

There are two functions to train here: surrogate $s(.)$ and the decoder $p(.|z)$. The authors use gradient-based meta-learning to learn the parameters:

z = Normal(0, I) # sample a random latent solution
for k in range(K):
  # do k grad descent steps on the surrogate s(z)
  z = z - eta * grad(s_theta(z))

All of this is trained with the following loss:

$$ \mathbb{E}{p(\tau)}\sum{k=1}^K\mathbb{E}{x\sim p{\tau, \theta}(.|z^{(k)})}\big[\gamma(x) + \alpha||\max\{0, Ax-b\}||^2 + \beta\max\{0, s_{\tau, \theta}(z^*) - s_{\tau, \theta}(z^{k})\} \big], $$

where $\tau=\{c, A,B\}$ is the optimisation problem. The three terms in the square brackets above are the objective, feasibility and latent supervision. The objective part is a re-scaled original objective $c^Tx \in [0,1]$. Feasibility, obviously, comes into play when we violate the constraint (i.e. $Ax>b$). The last term (latent supervision) exploits our knowledge about optimal solution during training to make the surrogate value of the optimal point embedding less than all of the other points in the trajectory.

The surrogate and the decoder are GNNs operating on a bipartite representation of MIPs (variables and constraints are nodes of the graph and edges mean that a variable participates in a constraint, we also did this in Graph-Q-SAT where edges encoded participation of variables in clauses). The surrogate value is computed as a linear function from the global component output (check out Battaglia et al. for more details) with variable output features used as logits for the decoder.

In addition to above, another loss for training decoder is employed (not sure how they are combined):

$$ -\sum_{k=1}^K{\log{p_{\theta}(x^{(k)}|g_{\theta}(x^{(k)}))}} - \log{p_{\theta}(x^|g_{\theta}(x^))}, g_{\theta}(x)\approx \argmax_z{\log{p_{\theta}(x|z)}} $$

To evaluate, the approach above is used as a primal heuristic in a SCIP solver. This heuristic prunes suboptimal branches in the branch-and-bound search to avoid spending computational budget in futile regions of the search space.