Kapitel 13: "Methoden der Künstlichen Intelligenz"
"KMeans.java"
import java.io.*;
import java.util.Random;
import java.util.StringTokenizer;
public class KMeans {
public double[][] trainingsdaten;
public int[] label;
private int[] myprototype;
private double prototypen[][];
private int protolabel[];
private int k;
public KMeans(int anzprototypen) {
k = anzprototypen;
trainingsdaten = new double[1000][192];
label = new int[1000];
myprototype = new int[1000];
prototypen = new double[k][192];
protolabel = new int[k];
}
public boolean readData(String filename) {
String filenameIn = filename, sLine;
StringTokenizer st;
try {
FileInputStream fis = new FileInputStream(filenameIn);
InputStreamReader isr = new InputStreamReader(fis);
BufferedReader bur = new BufferedReader(isr);
int zaehler = 0;
while ((sLine = bur.readLine()) != null) {
int count = 0;
for (int i = 0; i < 16; i++) {
st = new StringTokenizer(sLine, " ");
for (int j = 0; j < 12; j++)
trainingsdaten[zaehler][count++] = Double
.parseDouble(st.nextToken());
sLine = bur.readLine();
}
int lab = -1;
st = new StringTokenizer(sLine, " ");
for (int i = 0; i < 10; i++) {
if (Double.parseDouble(st.nextToken()) == 1.0) {
lab = i;
break;
}
}
label[zaehler] = lab;
sLine = bur.readLine();
zaehler++;
}
} catch (ArrayIndexOutOfBoundsException eAIOOB) {
System.out.println("Es gab einen Indexfehler.");
} catch (IOException eIO) {
System.out.println("Konnte Datei " + filenameIn + " nicht öffnen!");
}
return true;
}
public void zentrierung(){
int m[] = new int[192];
for (int j=0; j < 192; j++){
for (int i=0; i < 1000; i++)
m[j]+=trainingsdaten[i][j];
m[j]/=1000;
}
for (int i=0; i < 1000; i++)
for (int j=0; j < 192; j++)
trainingsdaten[i][j]-=m[j];
}
private double distance(double[] v1, double[] v2) {
double dist = 0, diff = 0;
for (int i = 0; i < v1.length; i++) {
diff = v1[i] - v2[i];
dist += diff*diff;
}
return Math.sqrt(dist);
}
public void train(int abbruchkrit, int anzIter, double epsilon) {
double d = Double.MAX_VALUE;
int nearestprototype = 0;
Random random = new Random();
double sum[][] = new double[k][192];
for (int i = 0; i < k; i++) {
int rand = random.nextInt(trainingsdaten.length);
prototypen[i] = trainingsdaten[rand];
protolabel[i] = label[rand];
}
while ( (0 < anzIter & 1 == abbruchkrit) |
(d > epsilon & 2 == abbruchkrit) ) {
for (int i = 0; i < 1000; i++) {
d = Double.MAX_VALUE;
for (int j = 0; j < k; j++) {
if (d > distance(trainingsdaten[i], prototypen[j])) {
nearestprototype = j;
d = distance(trainingsdaten[i], prototypen[j]);
}
}
myprototype[i] = nearestprototype;
}
int xicount[] = new int[k];
double[][] prototypenref = new double[k][192];
for (int i = 0; i < k; i++)
for (int j = 0; j < 192; j++)
sum[i][j] = 0;
for (int i = 0; i < 1000; i++) {
xicount[myprototype[i]]++;
for (int j = 0; j < 192; j++)
sum[myprototype[i]][j] += trainingsdaten[i][j];
}
for (int i = 0; i < k; i++) {
for (int j = 0; j < 192; j++) {
prototypenref[i][j] = prototypen[i][j];
prototypen[i][j] = (prototypen[i][j] + sum[i][j]) / (xicount[i] + 1) ;
}
}
d = 0;
for(int i=0; i < k; i++)
d += distance(prototypen[i], prototypenref[i]);
anzIter--;
}
for (int i=0; i < k; i++){
int ziffern[] = new int[10];
for (int j=0; j < 1000; j++)
if (myprototype[j]==i)
if (label[j]!=-1)
ziffern[label[j]]++;
int max=0;
for (int l=1; l < 10; l++)
if (ziffern[l] > ziffern[max])
max=l;
protolabel[i]=max;
}
}
public int classify(double[] y) {
int returnLabel;
double d;
returnLabel = -1;
d = Double.MAX_VALUE;
for (int i = 0; i < k; i++) {
if(d > distance(prototypen[i], y)) {
returnLabel = protolabel[i];
d = distance(prototypen[i], y);
}
}
return returnLabel;
}
public static void main(String[] args) {
int abbruchKriterium = 2;
double epsilon = 0.000001;
int anzPrototypen = 20;
int anzIterationen = 3;
KMeans classifier = new KMeans(anzPrototypen);
classifier.readData("digits-training.txt");
classifier.zentrierung();
classifier.train(abbruchKriterium, anzIterationen, epsilon);
double er = 0;
for (int i = 0; i < 1000; i++)
if (classifier.label[i] == classifier.classify(
classifier.trainingsdaten[i]))
er++;
er /= 1000;
System.out.println("Erkennungsrate = " + er);
}
}