function outstruct=rp_xor(eta,momentum,target_sse,rseed)
% backprop net for xor problem

USAGE='USAGE: function oustruct=rp_xor(eta,momentum,target_sse,rseed)';

% check input

if ~exist('eta'),
  fprintf('%s\n',USAGE);
  outstruct=[];
  return;
end;

if ~exist('momentum'),
  momentum=1;
end;

if exist('rseed'),
  rand('state',rseed);
end;
rseed=rand('state');

if ~exist('target_sse'),
    target_sse=0.01;  % stop when sse gets to this level
elseif isempty(target_sse)|~isnumeric(target_sse),
    target_sse=0.01; 
end;

% set the number of units in each layer

n_in=2;
n_out=1;
n_hid=3;

% set patterns for xor problem
patterns=[0 0; 1 0; 0 1; 1 1];

desired_output=[0 1 1 0]';

% set weights randomly w/ range [-0.25, 0.25]
weight_scale=0.1;
w_ih=(rand(n_in,n_hid)-0.5)*weight_scale;
w_ho=(rand(n_hid,n_out)-0.5)*weight_scale;

dw_ih_last=zeros(size(w_ih));
dw_ho_last=zeros(size(w_ho));

epoch=1; % set epoch counter to 1

sse_rec=[];
sse_last=100; % set large for first epoch

% the momentum term smooths out the 
% descent in error space 
alpha=momentum; % momentum term

% We need to decide that training has ended
% by looking for convergence to a set of
% weights
% first, we look to see that the weights
% have stopped changing by much
convergence_point=.01; 

% sometimes this doesn't happen,
% e.g., weights are oscillating
% in this case, bail after 
% this number of epochs
convergence_failure=2500;
min_epochs=1000;

% save output layer activity
output_saved=zeros(4,100);

% now, train the network
converged=0;

while ~converged,
     
      % pass input pattern into weights
      act_to_hidden=patterns*w_ih;

      % use sigmoid activation function to 
      % determine hidden activation
      hidden_act= sigmoid(act_to_hidden);

      % pass activation from hidden to output
      act_to_output=hidden_act*w_ho;

      % determine output layer activation
      output_act=sigmoid(act_to_output);
      output_saved(:,epoch)=output_act;

      output_error=desired_output-output_act;
      
      % now, change the weights
      deltas_out=output_error .* output_act .* (1-output_act);
      deltas_hid = deltas_out*w_ho' .* hidden_act .* (1-hidden_act);

      % The key backprop step, in matrix form  
      % the new weights are determined by two factors:
      % - the delta for the current step (times the learning rate)
      % - the delta for the previous step (times the momentum factor)
     
      dw_ih = eta * patterns' * deltas_hid + alpha * dw_ih_last;

      dw_ho = eta * hidden_act' * deltas_out + alpha * dw_ho_last;

      % Weight update 
      w_ih = w_ih + dw_ih; 
      w_ho = w_ho + dw_ho;        
      % Update momentum records
      dw_ih_last = dw_ih; dw_ho_last = dw_ho;      
  
      % compute SSE using matrix operation
      sse = trace(output_error'*output_error);	
      sse_rec = [sse_rec sse];          % Record keeping

      % print out a report occasionally

      if mod(epoch,10)==0,
        fprintf('epoch %d:\t%0.4f\n',epoch,sse);
      end;

      epoch=epoch+1;

      % check for convergence

      if (sse_last - sse)<convergence_point & sse<=target_sse & epoch>min_epochs,
         fprintf('converged after %d epochs-exiting!\n',epoch);
         converged=1;
      elseif (sse_last - sse)<convergence_point & epoch>min_epochs,
         fprintf('WARNING: Model may have converged on non-optimal solution\n');
         fprintf('converged after %d epochs-exiting!\n',epoch);
         converged=1;
      elseif epoch>convergence_failure,
        fprintf('failed to converge after %d epochs-exiting!\n',epoch);
        converged=1;
      end;
      sse_last=sse;

end; % while sse

% print a brief report
fprintf('desired: \t');
fprintf('%0.3f\t',desired_output);
fprintf('\n');
fprintf('actual: \t');
fprintf('%0.3f\t',output_act);
fprintf('\n');

% create the output structure

outstruct = struct(...
  'epochs', epoch,...
  'sse',    sse_rec,...
  'hidden_act', hidden_act,...
  'output_act', output_act,...
  'output_saved',output_saved,...
  'desired_output',desired_output,...
  'w_ho',   w_ho,...
  'w_ih',   w_ih,...
  'patterns',patterns,...
  'rseed',rseed);