Tuesday, April 08, 2014

Drools - Bayesian Belief Network Integration Part 2

A while back I mentioned I was working on Bayesian Belief Network integration, and I outlined the work I was doing around Junction Tree building, and ensuring we had good unit testing.
http://blog.athico.com/2014/02/drools-bayesian-belief-network.html

Today I finally got everything working end to end, including the the addition of hard evidence. The next stage is to integrate this into our Pluggable Belief System. One of the things we hope to do is use Defeasible style superiority rules as a way to resolving conflicting evidence.

For those interested, here is the fruits of my labours, showing end to end unit testing of the Eathquake example, as covered here.



Graph<BayesVariable> graph = new BayesNetwork();

GraphNode<BayesVariable> burglaryNode    = graph.addNode();
GraphNode<BayesVariable> earthquakeNode = graph.addNode();
GraphNode<BayesVariable> alarmNode      = graph.addNode();
GraphNode<BayesVariable> johnCallsNode  = graph.addNode();
GraphNode<BayesVariable> maryCallsNode  = graph.addNode();

BayesVariable burglary    = new BayesVariable<String>("Burglary", burglaryNode.getId(), new String[]{"true", "false"}, new double[][]{{0.001, 0.999}});
BayesVariable earthquake = new BayesVariable<String>("Earthquake", earthquakeNode.getId(), new String[]{"true", "false"}, new double[][]{{0.002, 0.998}});
BayesVariable alarm      = new BayesVariable<String>("Alarm", alarmNode.getId(), new String[]{"true", "false"}, new double[][]{{0.95, 0.05}, {0.94, 0.06}, {0.29, 0.71}, {0.001, 0.999}});
BayesVariable johnCalls  = new BayesVariable<String>("JohnCalls", johnCallsNode.getId(), new String[]{"true", "false"}, new double[][]{{0.90, 0.1}, {0.05, 0.95}});
BayesVariable maryCalls  = new BayesVariable<String>("MaryCalls", maryCallsNode.getId(), new String[]{"true", "false"}, new double[][]{{0.7, 0.3}, {0.01, 0.99}});

JunctionTree jTree;

@Before
public void setUp() {
    connectParentToChildren( burglaryNode, alarmNode);
    connectParentToChildren( earthquakeNode, alarmNode);
    connectParentToChildren( alarmNode, johnCallsNode, maryCallsNode);

    burglaryNode.setContent(burglary);
    earthquakeNode.setContent(earthquake);
    alarmNode.setContent( alarm );
    johnCallsNode.setContent( johnCalls );
    maryCallsNode.setContent( maryCalls );

    JunctionTreeBuilder jtBuilder = new JunctionTreeBuilder( graph );
    jTree = jtBuilder.build();
    jTree.initialize();
}

@Test
public void testInitialize() {
    JunctionTreeNode jtNode = jTree.getRoot();

    // johnCalls
    assertArray(new double[]{0.90, 0.1, 0.05, 0.95}, scaleDouble( 3, jtNode.getPotentials() ));


    // burglary, earthquake, alarm
    jtNode = jTree.getRoot().getChildren().get(0).getChild();
    assertArray( new double[]{0.0000019, 0.0000001, 0.0009381, 0.0000599, 0.0005794, 0.0014186, 0.0009970, 0.9960050 },
                 scaleDouble( 7, jtNode.getPotentials() ));

    // maryCalls
    jtNode = jTree.getRoot().getChildren().get(1).getChild();
    assertArray( new double[]{ 0.7, 0.3, 0.01, 0.99 }, scaleDouble( 3, jtNode.getPotentials() ));
}

@Test
public void testNoEvidence() {
    NetworkUpdateEngine nue = new NetworkUpdateEngine(graph, jTree);
    nue.globalUpdate();

    JunctionTreeNode jtNode = jTree.getRoot();
    marginalize(johnCalls, jtNode);
    assertArray( new double[]{0.052139, 0.947861},  scaleDouble( 6, johnCalls.getDistribution() ) );

    jtNode = jTree.getRoot().getChildren().get(0).getChild();
    marginalize(burglary, jtNode);
    assertArray( new double[]{0.001, 0.999},  scaleDouble( 3, burglary.getDistribution() ) );

    marginalize(earthquake, jtNode);
    assertArray( new double[]{ 0.002, 0.998},  scaleDouble( 3, earthquake.getDistribution() ) );

    marginalize(alarm, jtNode);
    assertArray( new double[]{0.002516, 0.997484},  scaleDouble( 6, alarm.getDistribution() ) );

    jtNode = jTree.getRoot().getChildren().get(1).getChild();
    marginalize(maryCalls, jtNode);
    assertArray( new double[]{0.011736, 0.988264 },  scaleDouble( 6, maryCalls.getDistribution() ) );
}

@Test
public void testAlarmEvidence() {
    NetworkUpdateEngine nue = new NetworkUpdateEngine(graph, jTree);

    JunctionTreeNode jtNode = jTree.getJunctionTreeNodes( )[alarm.getFamily()];
    nue.setLikelyhood( new BayesLikelyhood( graph, jtNode,  alarmNode, new double[] { 1.0, 0.0 }) );

    nue.globalUpdate();

    jtNode = jTree.getRoot();
    marginalize(johnCalls, jtNode);
    assertArray( new double[]{0.9, 0.1},  scaleDouble( 6, johnCalls.getDistribution() ) );

    jtNode = jTree.getRoot().getChildren().get(0).getChild();
    marginalize(burglary, jtNode);
    assertArray( new double[]{.374, 0.626},  scaleDouble( 3, burglary.getDistribution() ) );

    marginalize(earthquake, jtNode);
    assertArray( new double[]{ 0.231, 0.769},  scaleDouble( 3, earthquake.getDistribution() ) );

    marginalize(alarm, jtNode);
    assertArray( new double[]{1.0, 0.0},  scaleDouble( 6, alarm.getDistribution() ) );

    jtNode = jTree.getRoot().getChildren().get(1).getChild();
    marginalize(maryCalls, jtNode);
    assertArray( new double[]{0.7, 0.3 },  scaleDouble( 6, maryCalls.getDistribution() ) );
}

@Test
public void testEathQuakeEvidence() {
    NetworkUpdateEngine nue = new NetworkUpdateEngine(graph, jTree);

    JunctionTreeNode jtNode = jTree.getJunctionTreeNodes( )[earthquake.getFamily()];
    nue.setLikelyhood( new BayesLikelyhood( graph, jtNode,  earthquakeNode, new double[] { 1.0, 0.0 }) );
    nue.globalUpdate();

    jtNode = jTree.getRoot();
    marginalize(johnCalls, jtNode);
    assertArray( new double[]{0.297, 0.703},  scaleDouble( 3, johnCalls.getDistribution() ) );

    jtNode = jTree.getRoot().getChildren().get(0).getChild();
    marginalize(burglary, jtNode);
    assertArray( new double[]{.001, 0.999},  scaleDouble( 3, burglary.getDistribution() ) );

    marginalize(earthquake, jtNode);
    assertArray( new double[]{ 1.0, 0.0},  scaleDouble( 3, earthquake.getDistribution() ) );

    marginalize(alarm, jtNode);
    assertArray( new double[]{0.291, 0.709},  scaleDouble( 3, alarm.getDistribution() ) );

    jtNode = jTree.getRoot().getChildren().get(1).getChild();
    marginalize(maryCalls, jtNode);
    assertArray( new double[]{0.211, 0.789 },  scaleDouble( 3, maryCalls.getDistribution() ) );
}

@Test
public void testJoinCallsEvidence() {
    NetworkUpdateEngine nue = new NetworkUpdateEngine(graph, jTree);

    JunctionTreeNode jtNode = jTree.getJunctionTreeNodes( )[johnCalls.getFamily()];
    nue.setLikelyhood( new BayesLikelyhood( graph, jtNode,  johnCallsNode, new double[] { 1.0, 0.0 }) );
    nue.globalUpdate();

    jtNode = jTree.getRoot();
    marginalize(johnCalls, jtNode);
    assertArray( new double[]{1.0, 0.0},  scaleDouble( 2, johnCalls.getDistribution() ) );

    jtNode = jTree.getRoot().getChildren().get(0).getChild();
    marginalize(burglary, jtNode);
    assertArray( new double[]{0.016, 0.984},  scaleDouble( 3, burglary.getDistribution() ) );

    marginalize(earthquake, jtNode);
    assertArray( new double[]{ 0.011, 0.989},  scaleDouble( 3, earthquake.getDistribution() ) );

    marginalize(alarm, jtNode);
    assertArray( new double[]{0.043, 0.957},  scaleDouble( 3, alarm.getDistribution() ) );

    jtNode = jTree.getRoot().getChildren().get(1).getChild();
    marginalize(maryCalls, jtNode);
    assertArray( new double[]{0.04, 0.96 },  scaleDouble( 3, maryCalls.getDistribution() ) );
}

@Test
public void testEathquakeAndJohnCallsEvidence() {
    JunctionTreeBuilder jtBuilder = new JunctionTreeBuilder( graph );
    JunctionTree jTree = jtBuilder.build();
    jTree.initialize();

    NetworkUpdateEngine nue = new NetworkUpdateEngine(graph, jTree);

    JunctionTreeNode jtNode = jTree.getJunctionTreeNodes( )[johnCalls.getFamily()];
    nue.setLikelyhood( new BayesLikelyhood( graph, jtNode,  johnCallsNode, new double[] { 1.0, 0.0 }) );

    jtNode = jTree.getJunctionTreeNodes( )[earthquake.getFamily()];
    nue.setLikelyhood( new BayesLikelyhood( graph, jtNode,  earthquakeNode, new double[] { 1.0, 0.0 }) );
    nue.globalUpdate();

    jtNode = jTree.getRoot();
    marginalize(johnCalls, jtNode);
    assertArray( new double[]{1.0, 0.0},  scaleDouble( 2, johnCalls.getDistribution() ) );

    jtNode = jTree.getRoot().getChildren().get(0).getChild();
    marginalize(burglary, jtNode);
    assertArray( new double[]{0.003, 0.997},  scaleDouble( 3, burglary.getDistribution() ) );

    marginalize(earthquake, jtNode);
    assertArray( new double[]{ 1.0, 0.0},  scaleDouble( 3, earthquake.getDistribution() ) );

    marginalize(alarm, jtNode);
    assertArray( new double[]{0.881, 0.119},  scaleDouble( 3, alarm.getDistribution() ) );

    jtNode = jTree.getRoot().getChildren().get(1).getChild();
    marginalize(maryCalls, jtNode);
    assertArray( new double[]{0.618, 0.382 },  scaleDouble( 3, maryCalls.getDistribution() ) );
}


1 comment: