A neural network built from scratch in Java.

View the Project on GitHub

Neural Network

A simple neural network library built from scratch in Java.


  1. Clone the repository: $ git clone
  2. Go into the source: $ cd simple_nn_in_java/src
  3. Pick out the libraries you want to use and copy the file into your code base

Example code for using the neural network:

public static void main(String[] args) { 
    NeuralNetwork nn = new NeuralNetwork();
    // Add new layers
    nn.add(2, 7, "sigmoid");
    nn.add(7, 1, "sigmoid");
    // Do a feedforward pass through the network using a random matrix 
    nn.feedforward(Matrix.random(new Shape(1, 2))).show();

Example code for using the matrix library:

public static void main(String[] args) {
// Initialization
    // Simply give it a 2d array
    Matrix m = new Matrix([[1,2,3],[4,5,6],[7,8,9]]);
    // Populate a matrix with zeros, ones, tens, or any arbitrary value, given a shape
    Matrix m2 = Matrix.zeros(new Shape(1,2));
    Matrix m3 = Matrix.ones(new Shape(1,2));
    Matrix m4 = Matrix.tens(new Shape(1,2));
    Matrix m5 = Matrix.fillShapeWithValue(new Shape(1,2), 4444); 
    Matrix m6 = Matrix.random(new Shape(1,2)); // give it random values between 0 and 1 for each element
    // Scalar matrix operations
    m.add(4); // adds 4 to each element
    m.div(4); // divides each element by 4
    m.mul(4); // multiplies each element by 4
    m.sub(4); // subtracts 4 from each element
    // Element wise multiplication, addition, subtraction, and division of two matrices
    // Matrix transposition
    Matrix itGotTransposed = Matrix.transpose(m);
    // Dot products, m2);
    Matrix.vectorDotProduct([1, 2, 3, 4], [5, 6, 7, 8]);
    // Pretty printing to the console;