RRN

Paper: Recurrent Relational Networks

Motivation

  • Previous models fall short when faced with problems that require basic relational reasoning.

    • MLP, CNN take the entire Sudoku as input and output the entire solution in a single forward pass, ignoring the followings:

      • The inductive bias that object exists in the world.

      • Objects affect each other in a consistent manner.

The Relational Network is a first-step attempt towards a simple module for reasoning.

Relational Networks

Paper: Relational Networks

The Relational Network was proposed to do relational reasoning, the capacity to compute relations is baked into RN without needing to be learned.

RN(O)=fϕ(i,jgθ(oi,oj))\operatorname { RN } ( O ) = f _ { \phi } \left( \sum _ { i , j } g _ { \theta } \left( o _ { i } , o _ { j } \right) \right)

Strengths

  • RN learn to infer relations

    • No prior about what object relations exist is required.

    • Only consider some object pairs is also possible.

  • RNs are data efficient: batch

  • RNs operate on a set of objects

    • RN is invariant to the order in the input

However, it is limited to performing a single relational operation.

Experiments

CLEVR dataset

CLEVR is a dataset for visual QA, the questions in CLEVR test various aspects of visual reasoning including:

  1. attribute identification

  2. counting

  3. comparison

  4. spatial relations

  5. logical operations

Example image:

Question:

  • Are there an equal number of large things and metal spheres.

There are two versions of CLEVR:

  1. pixel version.

  2. state description version

    • 3D coordinates, color, shape, …

For more examples, see CLEVR.

The feature of CLEVR is that many questions ar explicitly relational in nature, traditional powerful QA architectures are unable to solve CLEVR.

  • For compare attribute and count questions(involving relations across objects), previous model(CNN+LSTM) performed little better then baseline.

Sort-of-CLEVR

Simplified CLEVR

BABI

Text reasoning.

Model

  • Dealing with pixels(CLEVR): each of the d2d^2 kk-dimensional cells in the final feature map was treated as an object for RN.

  • Conditioning on Queries(CLEVR): Use LSTM to process the question words, and take its final state as the question embedding qq; Then RN was modified to: a=fϕ(i,jgθ(oi,oj,q))a = f _ { \phi } \left( \sum _ { i , j } g _ { \theta } \left( o _ { i } , o _ { j } , q \right) \right) to take query into account.

  • Dealing with Natural Language(BABI): Tag support sentences with relative positions, then process each sentence word-by-word with an LSTM. We regard each sentence as an object.

Results

  • On CLEVR:

  • On BABI: 18/20.

Recurrent Relational Network

Message Passing

At each step tt, each node has a hidden state vector hith_i^t, and each node sends a message to each of its neighboring nodes:

mijt=f(hit1,hjt1)m _ { i j } ^ { t } = f \left( h _ { i } ^ { t - 1 } , h _ { j } ^ { t - 1 } \right)

where ff is a multi-layer perceptron. A node needs to consider all the incoming messages, sum them with:

mjt=iN(j)mijtm _ { j } ^ { t } = \sum _ { i \in N ( j ) } m _ { i j } ^ { t }

where N(j)N(j) are adjacent nodes of jj. The hidden state is then updated via a gather function gg:

hjt=g(hjt1,xj,mjt)h _ { j } ^ { t } = g \left( h _ { j } ^ { t - 1 } , x _ { j } , m _ { j } ^ { t } \right)

The process is similar to Universal Transformer(however, this paper comes earlier to Universal Transformer).

Training

The output distribution of node ii at time tt is given by:

oit=r(hit)o _ { i } ^ { t } = r \left( h _ { i } ^ { t } \right)

where rr is a MLP followed by softmax. The loss function adopted is cross-entropy:

lt=i=1Ilogoit[yi]l ^ { t } = - \sum _ { i = 1 } ^ { I } \log o _ { i } ^ { t } \left[ y _ { i } \right]

where yiy_i is the target digit at position ii.

Convergent Message Passing

This being said we minimize cross entropy between output and target at every step, and during test phase we only consider the output probability at the last step.

Experiment

Solving Sodukus

  • Solving soduku requires many steps of methodical deduction, intermediate results, and possibly trying several nepartial solutions before the right one is found.

  • The authors trained a RRN to solve Sudokus by considering each cell an object, and each cell affects other cell in the same row, column and box.

  • The network learned a strategy which solves 96.6%96.6\% of the hardest Sudoku(only 17 numbers given)

Details

Feature definition(dj=0d_j=0 if not given):

xj=MLP(concat(embed(dj),embed(rowj),embed(columnj)))x _ { j } = \operatorname { MLP } \left( \operatorname { concat } \left( \operatorname { embed } \left( d _ { j } \right) , \operatorname { embed } \left( \mathrm { row } _ { j } \right) , \mathrm { embed } \left( \operatorname { column } _ { j } \right) \right) \right)

Message function and reduce function:

hjt,sjt=LSTMG(MLP(concat(xj,mjt)),sjt1)h _ { j } ^ { t } , s _ { j } ^ { t } = \mathrm { LSTM } _ { G } \left( \mathrm { MLP } \left( \operatorname { concat } \left( x _ { j } , m _ { j } ^ { t } \right) \right) , s _ { j } ^ { t - 1 } \right)

Run the network for 3232 steps.

Natural Language Inference

  • Use the same experiment setting as Relational Networks.

  • RRN solves 20/20 tasks in 13 out of 15 runs.

  • bAbI does not require multiple step reasoning, single step is enough.

Details

Feature definition:

xi=MLP(concat(last(LSTMS(si)),last(LSTMQ(q)), onehot (pi+o)))x _ { i } = \operatorname { MLP } \left( \operatorname { concat } \left( \operatorname { last } \left( \operatorname { LSTM } _ { S } \left( s _ { i } \right) \right) , \operatorname { last } \left( \operatorname { LSTM } _ { Q } ( q ) \right) , \text { onehot } \left( p _ { i } + o \right) \right) \right)

Message function and reduce function:

hjt,sjt=LSTMG(MLP(concat(xj,mjt)),sjt1)h _ { j } ^ { t } , s _ { j } ^ { t } = \operatorname { LSTM } _ { G } \left( \mathrm { MLP } \left( \operatorname { concat } \left( x _ { j } , m _ { j } ^ { t } \right) \right) , s _ { j } ^ { t - 1 } \right)

Then we need a graph level output conditioned on all node hidden states:

ot=MLP(ihit)o ^ { t } = \operatorname { MLP } \left( \sum _ { i } h _ { i } ^ { t } \right)

Pretty-CLEVR

The dataset is an extension to Sort-or-CLEVR but also has questions requiring varing degrees of relational reasoning.

Question form:

  • Starting at object X which object is N jumps away?

Jumps are defined as moving to the closest object, without going to an object already visited.

Last updated