-
Notifications
You must be signed in to change notification settings - Fork 479
Expand file tree
/
Copy pathIrisTest.java
More file actions
84 lines (70 loc) · 2.35 KB
/
IrisTest.java
File metadata and controls
84 lines (70 loc) · 2.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
package opt.test;
import func.nn.BackpropNetworkBuilder;
import func.nn.LayeredNetwork;
import func.nn.OptNetworkBuilder;
import func.nn.backprop.BackPropagationNetwork;
import func.nn.feedfwd.FeedForwardNetwork;
import shared.*;
import shared.filt.*;
import shared.normalizer.StandardMeanAndVariance;
import shared.reader.CSVDataSetReader;
import shared.reader.DataSetReader;
import java.io.File;
/**
* Iris nn example using network builder
*
* https://archive.ics.uci.edu/ml/datasets/Iris
*
* @author John Mansfield
* @version 1.0
*/
public class IrisTest {
private static int outputLayerSize;
private static DataSet train;
private static DataSet test;
private static void initializeData() throws Exception {
//import data
DataSetReader dsr = new CSVDataSetReader((new File("src/opt/test/iris.txt")).getAbsolutePath());
DataSet ds = dsr.read();
System.out.println(new DataSetDescription(ds));
//split last attribute for label
LabelSplitFilter lsf = new LabelSplitFilter();
lsf.filter(ds);
//encode label as one-hot array and get outputLayerSize
DiscreteToBinaryFilter dbf = new DiscreteToBinaryFilter();
dbf.filter(ds.getLabelDataSet());
outputLayerSize=dbf.getNewAttributeCount();
//test-train split
int percentTrain=75;
RandomOrderFilter randomOrderFilter = new RandomOrderFilter();
randomOrderFilter.filter(ds);
TestTrainSplitFilter testTrainSplit = new TestTrainSplitFilter(percentTrain);
testTrainSplit.filter(ds);
train=testTrainSplit.getTrainingSet();
test=testTrainSplit.getTestingSet();
//standardize data
StandardMeanAndVariance smv = new StandardMeanAndVariance();
smv.fit(train);
smv.transform(train);
smv.transform(test);
}
private static void runNetwork() {
//create backprop network using builder
BackPropagationNetwork network = new BackpropNetworkBuilder()
.withLayers(new int[] {25,10,outputLayerSize})
.withDataSet(train, test)
.withIterations(5000)
.train();
//create opt network using builder
FeedForwardNetwork optNetwork = new OptNetworkBuilder()
.withLayers(new int[] {25,10,outputLayerSize})
.withDataSet(train, test)
.withSA(100000, .975)
.withIterations(1000)
.train();
}
public static void main(String[] args) throws Exception {
initializeData();
runNetwork();
}
}