Monday, December 21, 2015

8x3x8 Neural Network to learn the Identity function for 8-bit Inputs - Java

This neural network has8 input units (+1 unit for bias) 3 hidden layer units (+1 bias unit) and 8 output units. The code uses stochastic gradient descent/backpropagation. The learning rate is set to 0.3.

The training completes with correct weights in less than 5000 epochs most of the time. Weights are initialized randomly between -1 and +1.

Notice that the internal representation of 8 inputs by the 3 hidden layer units is similar to the binary encoding that can be used to encode the inputs.

The units use sigmoid activation function.

Source:

import java.util.Random;  

class Example{
    double input[];
    double output[];
    
    Example(double[] in){
        this.input = in;
        this.output = in;
    } 
}

class HiddenSigmoidUnit{
    double output, error;
    static double input[];
    double w[];  
    double n = 0.3;   
    
    static Random rand = new Random();
    HiddenSigmoidUnit(int nInputs){ 
        w = new double[nInputs];   
        int s = rand.nextBoolean()?1:-1;
        
        //bias weight is w[0]
        for(int i=0;i<w.length;++i)w[i] = s*rand.nextDouble();  
    } 
    
    public double computeOutput(){ 
        output = 0;
        for(int i=0;i<input.length;++i)output+=input[i]*w[i]; 
        output = 1/(1+Math.pow(Math.E, -output));    
        return output;
    }
         
    public void sumErrors(double w, double outputUnitError){
        error += outputUnitError*w;
    }
    
    public double computeError(){
        error = (output)*(1-output)*(error);  
        return error;        
    }
    
    public void fixError(){  
        for(int i=0;i<w.length;++i) w[i] += n*error*input[i]; 
    }
}
 
class OutputUnit{
    double t, output, error;
    static double input[];
    double w[];   
    static double n = 0.3; 

    static Random rand = new Random();
    
    OutputUnit(int hUnits){  
        w = new double[hUnits]; 
        int s = rand.nextBoolean()?1:-1; 
        for(int i=0;i<w.length;++i)w[i] = s*rand.nextDouble();
    } 
    
    public void computeOutput() { 
        output = 0;
        for(int i=0;i<input.length;++i)output += input[i]*w[i]; 
        output = 1/(1+Math.pow(Math.E, -output));   
    } 
    
    public void computeError(double t){ 
        this.t = t; 
        error = (output)*(1-output)*(t-output); 
    }
    
    public void fixError() {   
        for(int i=0;i<w.length;++i)w[i]+=n*error*input[i]; 
    }   
} 
 
public class EightInputNN {
    static Random rand = new Random();
    static OutputUnit[] outputUnits;
    static HiddenSigmoidUnit[] hlayer;
    
    static Example[] examples = {
        //first column is the bias
            new Example(new double[]{1, 1, 0, 0, 0, 0, 0, 0, 0}),
            new Example(new double[]{1, 0, 1, 0, 0, 0, 0, 0, 0}),
            new Example(new double[]{1, 0, 0, 1, 0, 0, 0, 0, 0}),
            new Example(new double[]{1, 0, 0, 0, 1, 0, 0, 0, 0}),
            new Example(new double[]{1, 0, 0, 0, 0, 1, 0, 0, 0}),
            new Example(new double[]{1, 0, 0, 0, 0, 0, 1, 0, 0}),
            new Example(new double[]{1, 0, 0, 0, 0, 0, 0, 1, 0}),
            new Example(new double[]{1, 0, 0, 0, 0, 0, 0, 0, 1}), 
    }; 
    
    public static void main(String[] args){   
        int nhidden = 3; 
        
        //+1 unit for bias to output units
        hlayer = new HiddenSigmoidUnit[nhidden+1]; 
           
        int correct=0;
        int nOutput = 8;
        
        //Train 100 times to check for optimal convergence with different random weight values
        for(int check=0;check<10;++check){
            
            //Initializations
            for(int i=0;i<hlayer.length;++i)hlayer[i] = new HiddenSigmoidUnit(8+1);
            outputUnits = new OutputUnit[nOutput];  
            for(int i=0;i<outputUnits.length;++i)outputUnits[i] = new OutputUnit(hlayer.length);
            
            //epochs
            for(int iteration = 0;;++iteration){  
                //When all the examples are classified correctly, count==8, we stop training
                int count=0; 
                 
                //For all examples
                for(int i=0;i<examples.length;++i){
                    
                    Example example = examples[i]; 
                    
                    //This array stores outputs by hidden layer units
                    double[] outputsHLayer = new double[hlayer.length]; 
                     
                    //Set inputs for hidden sigmoid units
                    HiddenSigmoidUnit.input = example.input; 
                    
                    //The bias value is always +1. We train bias weights
                    outputsHLayer[0] = 1;  
                    
                    //Compute each hidden layer unit's output
                    for(int j=1;j<outputsHLayer.length;++j) outputsHLayer[j] = hlayer[j].computeOutput(); 

                    //Set output layer unit's inputs to the array just computed.
                    OutputUnit.input = outputsHLayer; 
                    
                    //Compute output units' outputs and errors
                    for(int j=0;j<outputUnits.length;++j){
                        outputUnits[j].computeOutput();
                        outputUnits[j].computeError(example.output[j+1]); //+1 as first unit is bias
                    } 
                    
                    //reset error values, compute them again
                    for(int j=0;j<hlayer.length;++j) hlayer[j].error = 0;
                    for(int j=0;j<hlayer.length;++j){
                        for(int k=0;k<outputUnits.length;++k){
                            hlayer[j].sumErrors(outputUnits[k].w[j], outputUnits[k].error); 
                        }
                        hlayer[j].computeError();
                    }
                     
                    //fix errors by changing weights correctly
                    for(int j=0;j<outputUnits.length;++j)outputUnits[j].fixError(); 
                    for(int j=0;j<hlayer.length;++j) hlayer[j].fixError();  
                    
                    //check all inputs give correct output values or not
                    boolean f = true;
                    for(int j=0;j<outputUnits.length;++j){
                        int o = outputUnits[j].output<0.5?0:1;
                        if(o!=example.output[j+1]){
                            f = false;
                        }
                    } 
                    
                    if(f)count++;
                } //end-for each example
                
                if(iteration>90000){
                    //Probably not going to converge to optimal values
                    System.out.println("\nIterations > 90000, stop...");
                    displayOutputs(); 
                    break;
                }
                
                
                if(count==examples.length){
                    //All inputs give correct outputs, stop training
                    System.out.println("\nTraining complete. No of iterations = "+iteration);
                    displayOutputs();
                    correct++;
                    break;
                }
            } //end for-each 
        }//end for-each re-initialization
        
        System.out.println("\n\nCorrectly converged "+(correct)+"/10 times");
    }//end main
    
    //Apply the learned weights to all the examples
    public static void displayOutputs(){
        System.out.println("Displaying outputs for all examples... ");
        double[] outputsHLayer = new double[hlayer.length]; 
        outputsHLayer[0] = 1;
        
        System.out.print("\nHidden values\t\t\tOutput Values\n");
        for(int e=0;e<examples.length;++e){
            HiddenSigmoidUnit.input = examples[e].input;
            for(int j=1;j<hlayer.length;++j) {
                outputsHLayer[j] = hlayer[j].computeOutput();
            } 
            OutputUnit.input = outputsHLayer; 
            for(int j=0;j<outputUnits.length;++j){
                outputUnits[j].computeOutput(); 
            } 
             
            for(int j=1;j<hlayer.length;++j)System.out.printf("%.2f ", hlayer[j].output);
            for(int j=1;j<hlayer.length;++j)System.out.printf(" %d", hlayer[j].output<0.5?0:1);
//            System.out.print("\nOutput values: ");
            System.out.print("\t\t");
            for(int j=0;j<outputUnits.length;++j)System.out.print((outputUnits[j].output<0.5?0:1)+" "); 
            
            System.out.println();
        }
    }
}

Output:

Training complete. No of iterations = 765
Displaying outputs for all examples... 

Hidden values   Output Values
0.87 0.06 0.63  1 0 1  1 0 0 0 0 0 0 0 
0.08 0.97 0.96  0 1 1  0 1 0 0 0 0 0 0 
0.09 0.04 0.29  0 0 0  0 0 1 0 0 0 0 0 
0.08 0.38 0.03  0 0 0  0 0 0 1 0 0 0 0 
0.11 0.96 0.15  0 1 0  0 0 0 0 1 0 0 0 
0.89 0.43 0.03  1 0 0  0 0 0 0 0 1 0 0 
0.93 0.77 0.95  1 1 1  0 0 0 0 0 0 1 0 
0.08 0.11 0.97  0 0 1  0 0 0 0 0 0 0 1 

Training complete. No of iterations = 2247
Displaying outputs for all examples... 

Hidden values   Output Values
0.28 0.98 0.02  0 1 0  1 0 0 0 0 0 0 0 
0.93 0.99 0.96  1 1 1  0 1 0 0 0 0 0 0 
0.99 0.16 0.91  1 0 1  0 0 1 0 0 0 0 0 
0.06 0.58 0.98  0 1 1  0 0 0 1 0 0 0 0 
0.94 0.84 0.13  1 1 0  0 0 0 0 1 0 0 0 
0.01 0.44 0.06  0 0 0  0 0 0 0 0 1 0 0 
0.24 0.01 0.45  0 0 0  0 0 0 0 0 0 1 0 
0.94 0.16 0.01  1 0 0  0 0 0 0 0 0 0 1 

Training complete. No of iterations = 1661
Displaying outputs for all examples... 

Hidden values   Output Values
0.13 0.37 0.01  0 0 0  1 0 0 0 0 0 0 0 
0.01 0.93 0.54  0 1 1  0 1 0 0 0 0 0 0 
0.92 0.99 0.96  1 1 1  0 0 1 0 0 0 0 0 
0.99 0.24 0.12  1 0 0  0 0 0 1 0 0 0 0 
0.25 0.02 0.33  0 0 0  0 0 0 0 1 0 0 0 
0.71 0.96 0.04  1 1 0  0 0 0 0 0 1 0 0 
0.95 0.07 0.98  1 0 1  0 0 0 0 0 0 1 0 
0.03 0.28 0.96  0 0 1  0 0 0 0 0 0 0 1 

Training complete. No of iterations = 1232
Displaying outputs for all examples... 

Hidden values   Output Values
0.03 0.66 0.97  0 1 1  1 0 0 0 0 0 0 0 
0.98 0.33 0.05  1 0 0  0 1 0 0 0 0 0 0 
0.97 0.98 0.94  1 1 1  0 0 1 0 0 0 0 0 
0.02 0.82 0.16  0 1 0  0 0 0 1 0 0 0 0 
0.56 0.02 0.13  1 0 0  0 0 0 0 1 0 0 0 
0.83 0.05 0.98  1 0 1  0 0 0 0 0 1 0 0 
0.75 0.96 0.02  1 1 0  0 0 0 0 0 0 1 0 
0.03 0.03 0.39  0 0 0  0 0 0 0 0 0 0 1 

Training complete. No of iterations = 1539
Displaying outputs for all examples... 

Hidden values   Output Values
0.97 0.35 0.95  1 0 1  1 0 0 0 0 0 0 0 
0.69 0.95 0.02  1 1 0  0 1 0 0 0 0 0 0 
0.92 0.06 0.03  1 0 0  0 0 1 0 0 0 0 0 
0.95 0.98 0.92  1 1 1  0 0 0 1 0 0 0 0 
0.60 0.01 0.92  1 0 1  0 0 0 0 1 0 0 0 
0.03 0.98 0.70  0 1 1  0 0 0 0 0 1 0 0 
0.02 0.23 0.97  0 0 1  0 0 0 0 0 0 1 0 
0.02 0.21 0.04  0 0 0  0 0 0 0 0 0 0 1 

Training complete. No of iterations = 998
Displaying outputs for all examples... 

Hidden values   Output Values
0.52 0.05 0.98  1 0 1  1 0 0 0 0 0 0 0 
0.39 0.04 0.03  0 0 0  0 1 0 0 0 0 0 0 
0.03 0.93 0.93  0 1 1  0 0 1 0 0 0 0 0 
0.97 0.04 0.45  1 0 0  0 0 0 1 0 0 0 0 
0.81 0.91 0.02  1 1 0  0 0 0 0 1 0 0 0 
0.06 0.09 0.53  0 0 1  0 0 0 0 0 1 0 0 
0.96 0.97 0.95  1 1 1  0 0 0 0 0 0 1 0 
0.01 0.71 0.06  0 1 0  0 0 0 0 0 0 0 1 

Training complete. No of iterations = 1288
Displaying outputs for all examples... 

Hidden values   Output Values
0.07 0.93 0.04  0 1 0  1 0 0 0 0 0 0 0 
0.97 0.77 0.02  1 1 0  0 1 0 0 0 0 0 0 
0.70 0.02 0.17  1 0 0  0 0 1 0 0 0 0 0 
0.99 0.25 0.96  1 0 1  0 0 0 1 0 0 0 0 
0.04 0.92 0.98  0 1 1  0 0 0 0 1 0 0 0 
0.28 0.03 0.97  0 0 1  0 0 0 0 0 1 0 0 
0.02 0.10 0.23  0 0 0  0 0 0 0 0 0 1 0 
0.90 0.96 0.70  1 1 1  0 0 0 0 0 0 0 1 

Training complete. No of iterations = 2741
Displaying outputs for all examples... 

Hidden values   Output Values
0.99 0.17 0.90  1 0 1  1 0 0 0 0 0 0 0 
0.20 0.03 0.97  0 0 1  0 1 0 0 0 0 0 0 
0.54 0.02 0.37  1 0 0  0 0 1 0 0 0 0 0 
0.03 0.98 0.13  0 1 0  0 0 0 1 0 0 0 0 
0.01 0.23 0.21  0 0 0  0 0 0 0 1 0 0 0 
0.97 0.99 0.42  1 1 0  0 0 0 0 0 1 0 0 
0.67 0.51 0.01  1 1 0  0 0 0 0 0 0 1 0 
0.08 0.96 0.99  0 1 1  0 0 0 0 0 0 0 1 

Training complete. No of iterations = 916
Displaying outputs for all examples... 

Hidden values   Output Values
0.70 0.93 0.05  1 1 0  1 0 0 0 0 0 0 0 
0.06 0.03 0.30  0 0 0  0 1 0 0 0 0 0 0 
0.95 0.20 0.05  1 0 0  0 0 1 0 0 0 0 0 
0.02 0.59 0.05  0 1 0  0 0 0 1 0 0 0 0 
0.87 0.02 0.85  1 0 1  0 0 0 0 1 0 0 0 
0.15 0.95 0.55  0 1 1  0 0 0 0 0 1 0 0 
0.96 0.86 0.94  1 1 1  0 0 0 0 0 0 1 0 
0.04 0.44 0.98  0 0 1  0 0 0 0 0 0 0 1 

Training complete. No of iterations = 2130
Displaying outputs for all examples... 

Hidden values   Output Values
0.98 0.97 0.26  1 1 0  1 0 0 0 0 0 0 0 
0.92 0.02 0.21  1 0 0  0 1 0 0 0 0 0 0 
0.02 0.97 0.33  0 1 0  0 0 1 0 0 0 0 0 
0.03 0.07 0.99  0 0 1  0 0 0 1 0 0 0 0 
0.45 0.98 0.99  0 1 1  0 0 0 0 1 0 0 0 
0.97 0.05 0.99  1 0 1  0 0 0 0 0 1 0 0 
0.37 0.39 0.02  0 0 0  0 0 0 0 0 0 1 0 
0.01 0.03 0.32  0 0 0  0 0 0 0 0 0 0 1 


Correctly converged 10/10 times