Department of Biomedical Engineering and Computational Science

MLP network in a 3-class classification problem, 'demo_3class'

In the demonstration program an MLP network is used with a Bayesian learning for classification problem of 3 classes. This example demonstrates also the benefits of ARD prior, when some of the inputs are irrelevant.

The data used in the demonstration program is the same used by Radford M. Neal in his three-way classification example in Software for Flexible Bayesian Modeling (http://www.cs.toronto.edu/~radford/fbm.software.html). The data consists of 1000 4-D vectors which are classified into three classes. The data is generated by drawing the components of vector, x1, x2, x3 and x4, uniformly form (0,1). The class of each vector is selected according to the first two components of the vector, x1 and x2. After this a Gaussian noise with standard deviation of 0.1 has been added to every component of the vector.

The data is divided into training and test parts. After the posterior parameters for network has been sampled only ~13% of test units are misclassified.

The effect of ARD prior can be seen by studying the posterior densities of weights prior a_k. The weights for an input to hidden connection are sampled from w_kj ~ N(0,a_k), where w_kj is a weight from k:th input to j:th hidden unit. In the figure below it can be seen that the distribution of a_3 and a_4 (corresponding the irrelevant inputs x_3 and x_4) is consentrated near zero, whereas the distributions of a_1 and a_2 are more spread.

Figure 1

Figure 1.

The code of demonstration program is shown below.

function demo_3class

% Load the data
x=load('demos/cdata');
y=repmat(0,size(x,1),3);
y(x(:,5)==0,1) = 1;
y(x(:,5)==1,2) = 1;
y(x(:,5)==2,3) = 1;
x(:,end)=[];

% Divide the data into training and test parts.
xt = x(401:end,:);
x=x(1:400,:);
yt=y(401:end,:);
y=y(1:400,:);

nin=size(x,2);
nhid=8;
nout=size(y,2);
% create MLP with logistic output function ('mlp2b')
net = mlp2('mlp2c', nin, nhid, nout);

%Create a Gaussian multivariate hierarchical prior with ARD
net=mlp2normp(net, {{repmat(1,1,net.nin) 0.5 0.05 -0.05 1}... % input-hidden weigth
 {1 0.5 0.05} ...                           % bias-hidden
 {1 -0.5 0.05} ...                          % hidden-output weigth
    {1}})                                       % bias-output

% Intialize weights to zero and set the optimization parameters...
w=randn(size(mlp2pak(net)))*0.01;

fe=str2fun('mlp2c_e');
fg=str2fun('mlp2c_g');
n=length(y);
itr=1:floor(0.5*n);     % training set of data for early stop
its=floor(0.5*n)+1:n;   % test set of data for early stop
optes=scges_opt;
optes.display=1;
optes.tolfun=1e-1;
optes.tolx=1e-1;

% ... Start scaled conjugate gradient optimization with early stopping.
[w,fs,vs]=scges(fe, w, optes, fg, net, x(itr,:),y(itr,:), net,x(its,:),y(its,:));
net=mlp2unpak(net,w);

% First we initialize random seed for Monte
% Carlo sampling and set the sampling options to default.
hmc2('state', sum(100*clock));
opt=mlp2c_mcopt;
opt.sample_inputs=0; % do not use RJMCMC for input variables

%  Here we do the sampling without persistence.
opt.repeat=70;
opt.hmc_opt.steps=10;
opt.hmc_opt.stepadj=0.2;
opt.gibbs=1;
opt.nsamples=1;
[r,net,rstate]=mlp2c_mc(opt, net, x, y);

% Now that the starting values are found we start the main sampling.
opt.hmc_opt.stepadj=0.5;
opt.hmc_opt.persistence=1;
opt.hmc_opt.steps=40;
opt.hmc_opt.decay=0.95;
opt.repeat=50;
opt.nsamples=250;
opt.hmc_opt.window=5;

[r,net,rstate]=mlp2c_mc(opt, net, x, y, [], [], r, rstate);


% Thin the sample chain.
r1=thin(r1,50,6)

% Lets test how well the network works for test data.
forw = mlp2fwds(r1,xt);
for i=1:size(forw,3)
tga(:,:,i) = softmax(forw(:,:,i)')';
end
tga = mean(tga,3);
tt = tga==repmat(max(tga,[],2),1,size(tga,2));

% lets calculate the percentage of misclassified points
missed = (sum(sum(abs(tt-yt)))/2)/size(yt,1)

% The effect of ARD to the input to hidden unit weights
% can be checked by looking the values of first layer weights.
% Below are shown the means of thinned posterior chain of
% the weights. It can be seen that the weights from 3rd and
% 4th input unit are much smaller in magnitude than the ones
% from 1st and 2nd (every column represents weights from
% one input unit).

weigthsMean = reshape(mean(r.inputWeights),8,4)