What?
Persistent Message Passing (PMP): persist node states in a smart way instead of overwriting it.
Why?
GNNs work well with Markovian dynamics, i.e. rewriting entities' states and quering the last. The non-Markovian case is important and worth further investigations.
How?
Source: original paper.
The setup assumes a set of entities $\mathcal{E}$ (these are not edges!) which are related. Every time step, we do some operations to change some of the entities states and emit an output. An operation might as well be done on a set of past states.
More formally, we have a sequence of pairs $(\mathcal{E}^{(t)}, s^{(t)})$, where $\mathcal{E}^{(t)}=(e_1^t, ..., e_n^t)$ are feature vectors, and $s^t\in\{1,...,t-1\}$ are snapshot indices. A problem the authors consider is to predict operation outputs $y^{(t)}$, where we have persistency: $(\mathcal{E}^{(t)}
, s_t) = (\mathcal{E}^{(t')}
, s_t') \implies y^{(t)} = y^{(t')}\; \forall t'<t.$
A naive way of doing the above would be memory costly. We can do better.
The picture above gives a intuitive explanation of how PMP works ( I highly recommend looking at the figure before reading all the descriptions below/in the paper):
- Encode-process-decode architecture is used. There are two encoders, one for relevance, and one for operation. Message passing (processing) is done on the latent features and relevance latents separately as well.
- After a message-passing step, a per-state relevance binary mask is computed from the relevance latents to decide which hidden states to select for further computation of response $y^t$and persistency.
- A binary per-state persistency mask is predicted from latent features as well. Those who get chosen are copied to the set of hidden state for the next round.
- After that, the adjacency matrix is updated so that the old adjacency is preserved and the newly added nodes have the same relations as their predecessors, if $a$→$b$→$c$, $\hat{b}$ will have an incoming edge from $a$ and outgoing edge to $c$.
- Finally, the readout is done to output a prediction $y^t$, which is a function of aggregated latents and relevance latents.
- Linear encoder/decoder/masking/query are used. Max is used for edge-to-vertex aggregation and for the readout (how the output graph is aggregated to the prediction).
The whole thing is trained with cross-entropy of relevance and persistency masks against the ground truth.
The motivation for all of the above becomes more clear if we look at the evaluation task. The authors consider range minimum query (RMQ), where we have an array of K integers and have two operations: i) set an element of an array to a new value; ii) query for the minimum value over a particular contiguous range, at a previous point in time.
And?
- The authors imitate segment trees
- because i) it's an important data structure and ii) this is a good use case for the persistent computation mechanism.
- To imitate, teacher-forcing is used. The model is trained on arrays of size 5 (9 initial nodes, 26 nodes after 5 updates) and out-of-distribution tested on $K=10$ and $n=19$ → 63 nodes after 10 steps.
- PMP matches oracle MPNN, that is using a correct snapshot ($2\times$ better than an MPNN), showing that the right states are remembered by the network.
- The introduction is a nicely condensed summary of the emerging subfield of algorithmic reasoning.
- Using $\mathcal{E}$ for entities is a bit confusing. I always think of edges in a graph, when I see $\mathcal{E}$. Calling features operations is confusing to me as well.
- It's not clear to me if a processor network is shared for latents and relevance latents. If yes, what's the motivation for that? I find it interesting how the paper uses two sets of nodes/edges working separately.
- PMP allows to do computations on a really tiny portion of the input space. This is a cool idea.
- This paper is somewhat related to Working Memory Graphs, which has a circular list of nodes, where the states persist over time. However, in WMG, this memos are an aggregated version of the graph rather than a state of a single node.