-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit be5a42c
Showing
12 changed files
with
2,222 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
target | ||
DigitRecognition.iml | ||
/.idea/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<project xmlns="http://maven.apache.org/POM/4.0.0" | ||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||
<modelVersion>4.0.0</modelVersion> | ||
|
||
<groupId>DigitRecognition</groupId> | ||
<artifactId>DigitRecognition</artifactId> | ||
<version>1.0-SNAPSHOT</version> | ||
|
||
|
||
</project> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
package neural; | ||
|
||
import static java.lang.Math.max; | ||
import static java.lang.Math.min; | ||
|
||
import java.awt.Color; | ||
|
||
import edu.princeton.cs.introcs.In; | ||
import static edu.princeton.cs.introcs.StdDraw.*; | ||
|
||
/** Handwritten digit recognizer with graphic user interface. */ | ||
public class DigitRecognizer { | ||
|
||
/** Half width of each square in the grid. */ | ||
private static final double SQUARE_RADIUS = 0.5 / 32; | ||
|
||
public static void main(String[] args) { | ||
new DigitRecognizer().run(); | ||
} | ||
|
||
/** The neural network. */ | ||
private Network net; | ||
|
||
/** The input values. */ | ||
private double[] pixels; | ||
|
||
private DigitRecognizer() { | ||
System.out.println("Reading data file..."); | ||
ClassLoader classLoader = getClass().getClassLoader(); | ||
In input = new In(classLoader.getResource("semeion.data.txt")); | ||
// The "magic numbers" below correspond to the size of the raw data | ||
double[][] inputs = new double[1593][256]; | ||
double[][] correct = new double[1593][10]; | ||
int i = 0; | ||
while (input.hasNextLine()) { | ||
String[] values = input.readLine().split(" "); | ||
for (int j = 0; j < 256; j++) { | ||
inputs[i][j] = Double.valueOf(values[j]); | ||
} | ||
for (int j = 0; j < 10; j++) { | ||
correct[i][j] = Double.valueOf(values[j + 256]); | ||
} | ||
i++; | ||
} | ||
System.out.println("Training neural network..."); | ||
net = new Network(256, 20, 10); | ||
for (int epoch = 0; epoch < 500; epoch++) { | ||
for (i = 0; i < inputs.length; i++) { | ||
net.train(inputs[i], correct[i]); | ||
} | ||
} | ||
pixels = new double[256]; | ||
net.run(pixels); | ||
} | ||
|
||
/** Draws the two buttons on the screen. */ | ||
private void drawControls() { | ||
text(0.25, 0.95, "Draw a digit in the grid at left,"); | ||
text(0.25, 0.9, "then click on Classify."); | ||
rectangle(0.2, 0.1, 0.1, 0.1); | ||
text(0.2, 0.1, "Clear"); | ||
rectangle(0.5, 0.1, 0.1, 0.1); | ||
text(0.5, 0.1, "Classify"); | ||
} | ||
|
||
/** Draws the grid where the user draws a digit. */ | ||
private void drawGrid() { | ||
for (int r = 0; r < 16; r++) { | ||
for (int c = 0; c < 16; c++) { | ||
if (pixels[r * 16 + c] > 0.5) { | ||
filledRectangle(c / 32.0, 0.75 - r / 32.0, | ||
SQUARE_RADIUS, SQUARE_RADIUS); | ||
} else { | ||
rectangle(c / 32.0, 0.75 - r / 32.0, SQUARE_RADIUS, | ||
SQUARE_RADIUS); | ||
} | ||
} | ||
} | ||
} | ||
|
||
/** Draws shaded circles indicating the network's output. */ | ||
private void drawOutput() { | ||
for (int i = 0; i < 10; i++) { | ||
text(0.85, 0.05 + 0.1 * i, "" + i); | ||
int brightness = (int) (256 * (1.0 - net.getNeuron(2, i) | ||
.getOutput())); | ||
setPenColor(new Color(brightness, brightness, brightness)); | ||
filledCircle(0.95, 0.05 + 0.1 * i, 0.05); | ||
setPenColor(); | ||
circle(0.95, 0.05 + 0.1 * i, 0.05); | ||
} | ||
} | ||
|
||
/** Respond to mouse actions. */ | ||
private void handleMouse() { | ||
if (mousePressed()) { | ||
double x = mouseX(); | ||
double y = mouseY(); | ||
if (0.1 < x && x < 0.3 && 0.0 < y && y < 0.2) { | ||
// Click on Clear | ||
pixels = new double[pixels.length]; | ||
while (mousePressed()) { | ||
// Wait for mouse release | ||
} | ||
} else if (0.4 < x && x < 0.6 && 0.0 < y && y < 0.2) { | ||
// Click on Classify | ||
net.run(pixels); | ||
while (mousePressed()) { | ||
// Wait for mouse release | ||
} | ||
} else if (0.0 - SQUARE_RADIUS < x | ||
&& x < 15.0 / 32 + SQUARE_RADIUS | ||
&& 0.75 - 15.0 / 32 - SQUARE_RADIUS < y | ||
&& y < 0.75 + SQUARE_RADIUS) { | ||
// Mouse down in grid | ||
int r = (int) ((0.75 - (y - SQUARE_RADIUS)) * 32); | ||
r = min((max(r, 0)), 15); | ||
int c = (int) ((x + SQUARE_RADIUS) * 32); | ||
c = min((max(c, 0)), 15); | ||
pixels[r * 16 + c] = 1.0; | ||
} | ||
} | ||
} | ||
|
||
/** Main interactive loop. */ | ||
private void run() { | ||
show(0); | ||
while (true) { | ||
clear(); | ||
drawGrid(); | ||
drawOutput(); | ||
drawControls(); | ||
handleMouse(); | ||
show(0); | ||
} | ||
} | ||
|
||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
package neural; | ||
|
||
/** A "neuron" that does no computation, for specifying inputs. Its output can be manually set. */ | ||
public class InputNeuron extends Neuron { | ||
|
||
protected InputNeuron() { | ||
super(new Neuron[0], null, 0.0); | ||
} | ||
|
||
/** Sets the output of this neuron. */ | ||
public void setOutput(double output) { | ||
super.setOutput(output); | ||
} | ||
|
||
@Override | ||
public double squash(double sum) { | ||
// Irrelevant | ||
return -1; | ||
} | ||
|
||
@Override | ||
public void update() { | ||
// Does nothing | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
package neural; | ||
|
||
/** A feed-forward network with input, hidden, and output layers. */ | ||
public class Network { | ||
|
||
/** The input neurons. */ | ||
private InputNeuron[] inputLayer; | ||
|
||
/** The hidden neurons. */ | ||
private SigmoidNeuron[] hiddenLayer; | ||
|
||
/** The output neurons. */ | ||
private SigmoidNeuron[] outputLayer; | ||
|
||
/** | ||
* @param in | ||
* Number of input neurons. | ||
* @param hid | ||
* Number of hidden neurons. | ||
* @param out | ||
* Number of output neurons. | ||
*/ | ||
protected Network(int in, int hid, int out) { | ||
inputLayer = new InputNeuron[in]; | ||
for (int i = 0; i < in; i++) | ||
inputLayer[i] = new InputNeuron(); | ||
|
||
hiddenLayer = new SigmoidNeuron[hid]; | ||
for (int i = 0; i < hid; i++) | ||
hiddenLayer[i] = new SigmoidNeuron(inputLayer); | ||
|
||
outputLayer = new SigmoidNeuron[out]; | ||
for (int i = 0; i < out; i++) | ||
outputLayer[i] = new SigmoidNeuron(hiddenLayer); | ||
} | ||
|
||
/** | ||
* Returns the specified neuron. The input layer is layer 0, hidden 1, | ||
* output 2. | ||
*/ | ||
public Neuron getNeuron(int layer, int index) { | ||
switch (layer) { | ||
case 0: return inputLayer[index]; | ||
case 1: return hiddenLayer[index]; | ||
case 2: return outputLayer[index]; | ||
default: return null; | ||
} | ||
} | ||
|
||
/** | ||
* Returns the sum, over a set of training examples and across all outputs, | ||
* of the square of the difference between actual and correct outputs. If | ||
* learning is working properly, this should decrease over the course of | ||
* training. | ||
*/ | ||
public double meanSquaredError(double[][] inputs, double[][] correct) { | ||
double sum = 0.0; | ||
for (int i = 0; i < inputs.length; i++) { | ||
run(inputs[i]); | ||
for (int j = 0; j < outputLayer.length; j++) { | ||
sum += Math.pow(correct[i][j] - outputLayer[j].getOutput(), 2); | ||
} | ||
} | ||
return sum / (inputs.length * outputLayer.length); | ||
} | ||
|
||
/** Feeds inputs through the network, updating the output of each neuron. */ | ||
public double[] run(double[] inputs) { | ||
for (int i = 0; i < inputs.length; i++) | ||
inputLayer[i].setOutput(inputs[i]); | ||
for (int i = 0; i < hiddenLayer.length; i++) | ||
hiddenLayer[i].update(); | ||
for (int i = 0; i < outputLayer.length; i++) | ||
outputLayer[i].update(); | ||
double[] result = new double[outputLayer.length]; | ||
for (int i = 0; i < outputLayer.length; i++) | ||
result[i] = outputLayer[i].getOutput(); | ||
return result; | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
String result = ""; | ||
result += "OUTPUT UNITS:\n"; | ||
for (int i = 0; i < outputLayer.length; i++) { | ||
result += i + ": " + outputLayer[i] + "\n"; | ||
} | ||
result += "HIDDEN UNITS:\n"; | ||
for (int i = 0; i < hiddenLayer.length; i++) { | ||
result += i + ": " + hiddenLayer[i] + "\n"; | ||
} | ||
return result + "(" + inputLayer.length + " INPUT UNITS)\n"; | ||
} | ||
|
||
/** | ||
* Slightly modifies this network's weights to cause it to response to | ||
* inputs with something closer to the correct outputs. | ||
*/ | ||
public void train(double[] inputs, double[] correct) { | ||
// This is a long method, with the following steps: | ||
|
||
// Feed the input forward through the network | ||
run(inputs); | ||
// Update deltas for output layer | ||
for (int i = 0; i < outputLayer.length; i++) { | ||
double error = outputLayer[i].getOutput() | ||
* (1 - outputLayer[i].getOutput()) | ||
* (correct[i] - outputLayer[i].getOutput()); | ||
outputLayer[i].setDelta(error); | ||
} | ||
// Update weights for output layer | ||
for (int i = 0; i < outputLayer.length; i++) { | ||
outputLayer[i].updateWeights(); | ||
} | ||
// Update deltas for hidden layer | ||
for (int k = 0; k < hiddenLayer.length; k++) { | ||
double g = 0.0; | ||
for (int i = 0; i < outputLayer.length; i++) { | ||
double w = (outputLayer[i].getWeights())[k]; | ||
g += outputLayer[i].getDelta() * w; | ||
} | ||
double error = hiddenLayer[k].getOutput() | ||
* (1 - hiddenLayer[k].getOutput()) * g; | ||
hiddenLayer[k].setDelta(error); | ||
} | ||
// Update weights for hidden layer | ||
for (int i = 0; i < hiddenLayer.length; i++) { | ||
hiddenLayer[i].updateWeights(); | ||
} | ||
} | ||
} |
Oops, something went wrong.