RRN
Paper: Recurrent Relational Networks
Motivation
Previous models fall short when faced with problems that require basic relational reasoning.
MLP, CNN take the entire Sudoku as input and output the entire solution in a single forward pass, ignoring the followings:
The inductive bias that object exists in the world.
Objects affect each other in a consistent manner.
The Relational Network is a first-step attempt towards a simple module for reasoning.
Relational Networks
Paper: Relational Networks
The Relational Network was proposed to do relational reasoning, the capacity to compute relations is baked into RN without needing to be learned.
Strengths
RN learn to infer relations
No prior about what object relations exist is required.
Only consider some object pairs is also possible.
RNs are data efficient: batch
RNs operate on a set of objects
RN is invariant to the order in the input
However, it is limited to performing a single relational operation.
Experiments
CLEVR dataset
CLEVR is a dataset for visual QA, the questions in CLEVR test various aspects of visual reasoning including:
attribute identification
counting
comparison
spatial relations
logical operations
Example image:
Question:
Are there an equal number of large things and metal spheres.
There are two versions of CLEVR:
pixel version.
state description version
3D coordinates, color, shape, …
For more examples, see CLEVR.
The feature of CLEVR is that many questions ar explicitly relational in nature, traditional powerful QA architectures are unable to solve CLEVR.
For
compare attribute
andcount
questions(involving relations across objects), previous model(CNN+LSTM) performed little better then baseline.
Sort-of-CLEVR
Simplified CLEVR
BABI
Text reasoning.
Model
Dealing with pixels(CLEVR): each of the -dimensional cells in the final feature map was treated as an object for RN.
Conditioning on Queries(CLEVR): Use LSTM to process the question words, and take its final state as the question embedding ; Then RN was modified to: to take query into account.
Dealing with Natural Language(BABI): Tag support sentences with relative positions, then process each sentence word-by-word with an LSTM. We regard each sentence as an object.
Results
On CLEVR:
On BABI: 18/20.
Recurrent Relational Network
Message Passing
At each step , each node has a hidden state vector , and each node sends a message to each of its neighboring nodes:
where is a multi-layer perceptron. A node needs to consider all the incoming messages, sum them with:
where are adjacent nodes of . The hidden state is then updated via a gather function :
The process is similar to Universal Transformer(however, this paper comes earlier to Universal Transformer).
Training
The output distribution of node at time is given by:
where is a MLP followed by softmax. The loss function adopted is cross-entropy:
where is the target digit at position .
Convergent Message Passing
This being said we minimize cross entropy between output and target at every step, and during test phase we only consider the output probability at the last step.
Experiment
Solving Sodukus
Solving soduku requires many steps of methodical deduction, intermediate results, and possibly trying several nepartial solutions before the right one is found.
The authors trained a RRN to solve Sudokus by considering each cell an object, and each cell affects other cell in the same row, column and box.
The network learned a strategy which solves of the hardest Sudoku(only 17 numbers given)
Details
Feature definition( if not given):
Message function and reduce function:
Run the network for steps.
Natural Language Inference
Use the same experiment setting as Relational Networks.
RRN solves 20/20 tasks in 13 out of 15 runs.
bAbI does not require multiple step reasoning, single step is enough.
Details
Feature definition:
Message function and reduce function:
Then we need a graph level output conditioned on all node hidden states:
Pretty-CLEVR
The dataset is an extension to Sort-or-CLEVR but also has questions requiring varing degrees of relational reasoning.
Question form:
Starting at object X which object is N jumps away?
Jumps are defined as moving to the closest object, without going to an object already visited.
Last updated