Java学习-Day37

前言

本文代码来自 CSDN文章: 日撸 Java 三百行(81-90天,CNN 卷积神经网络)

我将借用这部分代码对 CNN 进行一个更深层次的理解.

卷积神经网络 (代码篇)

一、数据集读取与存储

1. 数据集描述

简要描述一下我们需要读取的数据集.

1
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0

乍一看这不就是由 0 和 1组成的集合吗? 这个时候我们对这些数字想象成一个图片, 然后通过一些工具就可以呈现出下面的这样一副图片.

这张图片的大小就为 \(28 \times 28\), 那这堆数据最后不是多出了一个数字吗? 这个数字要表达什么意思呢? 这个时候仔细观察图片, 它是不是看起来像数字 '0'. 为了检验这个想法是否正确, 我们再找一行数据.

1
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3

虽然图中的数字写法不标准, 但是隐约中还是能判别为数字 '3', 然后多出的那个数字正好是 '3'. 由此得出结论, 数据集的每一行代表一张图片, 由 '0' '1' 表示其黑白像素点, 且该行最后一个数字表示图片中数字的值.

所以对于这个数据集数据的读取就是把图片的像素点以数组方式存储, 数组的大小就是图片的大小. 然后用一个单独的值存储图片中所表示的数字, 把这个就作为图片的标签.

2. 具体代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
package cnn;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
* Manage the dataset.
*
* @author Shi-Huai Wen Email: shihuaiwen@outlook.com.
*/
public class Dataset {
/**
* All instances organized by a list.
*/
private List<Instance> instances;

/**
* The label index.
*/
private int labelIndex;

/**
* The max label (label start from 0).
*/
private double maxLabel = -1;

/**
* **********************
* The first constructor.
* **********************
*/
public Dataset() {
labelIndex = -1;
instances = new ArrayList<>();
}// Of the first constructor

/**
* **********************
* The second constructor.
*
* @param paraFilename The filename.
* @param paraSplitSign Often comma.
* @param paraLabelIndex Often the last column.
* **********************
*/
public Dataset(String paraFilename, String paraSplitSign, int paraLabelIndex) {
instances = new ArrayList<>();
labelIndex = paraLabelIndex;

File tempFile = new File(paraFilename);
try {
BufferedReader tempReader = new BufferedReader(new FileReader(tempFile));
String tempLine;
while ((tempLine = tempReader.readLine()) != null) {
String[] tempDatum = tempLine.split(paraSplitSign);
if (tempDatum.length == 0) {
continue;
} // Of if

double[] tempData = new double[tempDatum.length];
for (int i = 0; i < tempDatum.length; i++)
tempData[i] = Double.parseDouble(tempDatum[i]);
Instance tempInstance = new Instance(tempData);
append(tempInstance);
} // Of while
tempReader.close();
} catch (IOException e) {
e.printStackTrace();
System.out.println("Unable to load " + paraFilename);
System.exit(0);
}//Of try
}// Of the second constructor

/**
* **********************
* Append an instance.
*
* @param paraInstance The given record.
* **********************
*/
public void append(Instance paraInstance) {
instances.add(paraInstance);
}// Of append

/**
* **********************
* Append an instance specified by double values.
* **********************
*/
public void append(double[] paraAttributes, Double paraLabel) {
instances.add(new Instance(paraAttributes, paraLabel));
}// Of append

/**
* **********************
* Getter.
* **********************
*/
public Instance getInstance(int paraIndex) {
return instances.get(paraIndex);
}// Of getInstance

/**
* **********************
* Getter.
* **********************
*/
public int size() {
return instances.size();
}// Of size

/**
* **********************
* Getter.
* **********************
*/
public double[] getAttributes(int paraIndex) {
return instances.get(paraIndex).getAttributes();
}// Of getAttrs

/**
* **********************
* Getter.
* **********************
*/
public Double getLabel(int paraIndex) {
return instances.get(paraIndex).getLabel();
}// Of getLabel

/**
* **********************
* Unit test.
* **********************
*/
public static void main(String[] args) {
Dataset tempData = new Dataset("D:/Work/Data/sampledata/train.format", ",", 784);
Instance tempInstance = tempData.getInstance(0);
System.out.println("The first instance is: " + tempInstance);
System.out.println("The first instance label is: " + tempInstance.label);

tempInstance = tempData.getInstance(1);
System.out.println("The second instance is: " + tempInstance);
System.out.println("The second instance label is: " + tempInstance.label);
}// Of main

/**
* **********************
* An instance.
* **********************
*/
public class Instance {
/**
* Conditional attributes.
*/
private double[] attributes;

/**
* Label.
*/
private Double label;

/**
* **********************
* The first constructor.
* **********************
*/
private Instance(double[] paraAttrs, Double paraLabel) {
attributes = paraAttrs;
label = paraLabel;
}//Of the first constructor

/**
* **********************
* The second constructor.
* **********************
*/
public Instance(double[] paraData) {
if (labelIndex == -1) {
// No label
attributes = paraData;
} else {
label = paraData[labelIndex];
if (label > maxLabel) {
// It is a new label
maxLabel = label;
} // Of if

if (labelIndex == 0) {
// The first column is the label
attributes = Arrays.copyOfRange(paraData, 1, paraData.length);
} else {
// The last column is the label
attributes = Arrays.copyOfRange(paraData, 0, paraData.length - 1);
} // Of if
} // Of if
}// Of the second constructor

/**
* **********************
* Getter.
* **********************
*/
public double[] getAttributes() {
return attributes;
}// Of getAttributes

/**
* **********************
* Getter.
* **********************
*/
public Double getLabel() {
if (labelIndex == -1)
return null;
return label;
}// Of getLabel

/**
* **********************
* toString.
* **********************
*/
public String toString() {
return Arrays.toString(attributes) + ", " + label;
}//Of toString
}// Of class Instance
} //Of class Dataset

3. 运行截图

二、卷积核大小的基本操作

1. 操作

对卷积核大小进行处理, 也就是对卷积核的长和宽进行处理.

一个方法是长和宽同时除以两个整数, 要是不能被整除就抛出错误. 例如:

1
2
(4, 12) / (2, 3) -> (2, 4)
(2, 2) / (4, 6) -> Error

另一个方法是长和宽同时减去两个整数, 然后再加上 1. 例如:

1
(4, 6) - (2, 2) + 1 -> (3,5)

2. 具体代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
package cnn;

/**
* The size of a convolution core.
*
* @author Shi-Huai Wen Email: shihuaiwen@outlook.com.
*/
public class Size {
/**
* Cannot be changed after initialization.
*/
public final int width;

/**
* Cannot be changed after initialization.
*/
public final int height;

/**
* **********************
* The first constructor.
*
* @param paraWidth The given width.
* @param paraHeight The given height.
* **********************
*/
public Size(int paraWidth, int paraHeight) {
width = paraWidth;
height = paraHeight;
}// Of the first constructor

/**
* **********************
* Divide a scale with another one. For example (4, 12) / (2, 3) = (2, 4).
*
* @param paraScaleSize The given scale size.
* @return The new size.
* **********************
*/
public Size divide(Size paraScaleSize) {
int resultWidth = width / paraScaleSize.width;
int resultHeight = height / paraScaleSize.height;
if (resultWidth * paraScaleSize.width != width || resultHeight * paraScaleSize.height != height) {
throw new RuntimeException("Unable to divide " + this + " with " + paraScaleSize);
}
return new Size(resultWidth, resultHeight);
}// Of divide

/**
* **********************
* Subtract a scale with another one, and add a value. For example (4, 12) -
* (2, 3) + 1 = (3, 10).
*
* @param paraScaleSize The given scale size.
* @param paraAppend The appended size to both dimensions.
* @return The new size.
* **********************
*/
public Size subtract(Size paraScaleSize, int paraAppend) {
int resultWidth = width - paraScaleSize.width + paraAppend;
int resultHeight = height - paraScaleSize.height + paraAppend;
return new Size(resultWidth, resultHeight);
}// Of subtract


public String toString() {
String resultString = "(" + width + ", " + height + ")";
return resultString;
}// Of toString

/**
* **********************
* Unit test.
* **********************
*/
public static void main(String[] args) {
Size tempSize1 = new Size(4, 6);
Size tempSize2 = new Size(2, 2);
System.out.println("" + tempSize1 + " divide " + tempSize2 + " = " + tempSize1.divide(tempSize2));

try {
System.out.println("" + tempSize2 + " divide " + tempSize1 + " = " + tempSize2.divide(tempSize1));
} catch (Exception ee) {
System.out.println("Error is :" + ee);
} // Of try

System.out.println("" + tempSize1 + " - " + tempSize2 + " + 1 = " + tempSize1.subtract(tempSize2, 1));
}// Of main
} //Of class Size

3. 运行截图

三、数学工具类

1. 工具函数

定义了一个算子, 其主要目的是为了矩阵操作时对每个元素都做一遍. 有对单个矩阵进行运算, 例如用 1 减去矩阵中的值, 或者对矩阵中的值使用 \(Sigmoid\) 函数. 有对两个矩阵进行运算, 例如两个矩阵之间的加法还有减法.

矩阵旋转 180 度, 其实就是旋转两次 90 度. 旋转 90 度的公式为 \[ matrix[row][col] \overset{rotate}{=}matrix_{new}[col][n - row - 1] \]

convnValid 是卷积操作. convnFull 为其逆向操作.

scaleMatrix 是均值池化. kronecker 是池化的逆向操作.

2. 具体代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
package cnn;

import java.io.Serializable;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Random;
import java.util.Set;

/**
* Math operations. Adopted from cnn-master.
*
* @author Shi-Huai Wen Email: shihuaiwen@outlook.com.
*/
public class MathUtils {
/**
* An interface for different on-demand operators.
*/
public interface Operator extends Serializable {
double process(double value);
}// Of interface Operator

/**
* The one-minus-the-value operator.
*/
public static final Operator one_value = new Operator() {
private static final long serialVersionUID = 3752139491940330714L;

@Override
public double process(double value) {
return 1 - value;
}// Of process
};

/**
* The sigmoid operator.
*/
public static final Operator sigmoid = new Operator() {
private static final long serialVersionUID = -1952718905019847589L;

@Override
public double process(double value) {
return 1 / (1 + Math.pow(Math.E, -value));
}// Of process
};

/**
* An interface for operations with two operators.
*/
interface OperatorOnTwo extends Serializable {
double process(double a, double b);
}// Of interface OperatorOnTwo

/**
* Plus.
*/
public static final OperatorOnTwo plus = new OperatorOnTwo() {
private static final long serialVersionUID = -6298144029766839945L;

@Override
public double process(double a, double b) {
return a + b;
}// Of process
};

/**
* Multiply.
*/
public static OperatorOnTwo multiply = new OperatorOnTwo() {

private static final long serialVersionUID = -7053767821858820698L;

@Override
public double process(double a, double b) {
return a * b;
}// Of process
};

/**
* Minus.
*/
public static OperatorOnTwo minus = new OperatorOnTwo() {

private static final long serialVersionUID = 7346065545555093912L;

@Override
public double process(double a, double b) {
return a - b;
}// Of process
};

/**
* **********************
* Print a matrix
* **********************
*/
public static void printMatrix(double[][] matrix) {
for (double[] array : matrix) {
String line = Arrays.toString(array);
line = line.replaceAll(", ", "\t");
System.out.println(line);
} // Of for i
System.out.println();
}// Of printMatrix

/**
* **********************
* Clone a matrix. Do not use it reference directly.
* **********************
*/
public static double[][] cloneMatrix(final double[][] matrix) {
final int m = matrix.length;
int n = matrix[0].length;
final double[][] outMatrix = new double[m][n];

for (int i = 0; i < m; i++) {
System.arraycopy(matrix[i], 0, outMatrix[i], 0, n);
} // Of for i
return outMatrix;
}// Of cloneMatrix

/**
* **********************
* Rotate the matrix 180 degrees.
* **********************
*/
public static double[][] rot180(double[][] matrix) {
matrix = cloneMatrix(matrix);
int m = matrix.length;
int n = matrix[0].length;
for (int i = 0; i < m; i++) {
for (int j = 0; j < n / 2; j++) {
double tmp = matrix[i][j];
matrix[i][j] = matrix[i][n - 1 - j];
matrix[i][n - 1 - j] = tmp;
}
}
for (int j = 0; j < n; j++) {
for (int i = 0; i < m / 2; i++) {
double tmp = matrix[i][j];
matrix[i][j] = matrix[m - 1 - i][j];
matrix[m - 1 - i][j] = tmp;
}
}
return matrix;
}// Of rot180

private static final Random myRandom = new Random(2);

/**
* **********************
* Generate a random matrix with the given size. Each value takes value in
* [-0.005, 0.095].
* **********************
*/
public static double[][] randomMatrix(int x, int y, boolean b) {
double[][] matrix = new double[x][y];
// int tag = 1;
for (int i = 0; i < x; i++) {
for (int j = 0; j < y; j++) {
matrix[i][j] = (myRandom.nextDouble() - 0.05) / 10;
} // Of for j
} // Of for i
return matrix;
}// Of randomMatrix

/**
* **********************
* Generate a random array with the given length. Each value takes value in
* [-0.005, 0.095].
* **********************
*/
public static double[] randomArray(int len) {
double[] data = new double[len];
for (int i = 0; i < len; i++) {
//data[i] = myRandom.nextDouble() / 10 - 0.05;
data[i] = 0;
} // Of for i
return data;
}// Of randomArray

/**
* **********************
* Generate a random perm with the batch size.
* **********************
*/
public static int[] randomPerm(int size, int batchSize) {
Set<Integer> set = new HashSet<>();
while (set.size() < batchSize) {
set.add(myRandom.nextInt(size));
}
int[] randPerm = new int[batchSize];
int i = 0;
for (Integer value : set) {
randPerm[i++] = value;
}
return randPerm;
}// Of randomPerm

/**
* **********************
* Matrix operation with the given operator on single operand.
* **********************
*/
public static double[][] matrixOp(final double[][] ma, Operator operator) {
final int m = ma.length;
int n = ma[0].length;
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
ma[i][j] = operator.process(ma[i][j]);
} // Of for j
} // Of for i
return ma;
}// Of matrixOp

/**
* **********************
* Matrix operation with the given operator on two operands.
* **********************
*/
public static double[][] matrixOp(final double[][] ma, final double[][] mb,
final Operator operatorA, final Operator operatorB, OperatorOnTwo operator) {
final int m = ma.length;
int n = ma[0].length;
if (m != mb.length || n != mb[0].length)
throw new RuntimeException("ma.length:" + ma.length + " mb.length:" + mb.length);

for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
double a = ma[i][j];
if (operatorA != null) {
a = operatorA.process(a);
}

double b = mb[i][j];
if (operatorB != null) {
b = operatorB.process(b);
}

mb[i][j] = operator.process(a, b);
} // Of for j
} // Of for i
return mb;
}// Of matrixOp

/**
* **********************
* Extend the matrix to a bigger one (a number of times).
* **********************
*/
public static double[][] kronecker(final double[][] matrix, final Size scale) {
final int m = matrix.length;
int n = matrix[0].length;
final double[][] outMatrix = new double[m * scale.width][n * scale.height];

for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
for (int ki = i * scale.width; ki < (i + 1) * scale.width; ki++) {
for (int kj = j * scale.height; kj < (j + 1) * scale.height; kj++) {
outMatrix[ki][kj] = matrix[i][j];
}
}
}
}
return outMatrix;
}// Of kronecker

/**
* **********************
* Scale the matrix.
* **********************
*/
public static double[][] scaleMatrix(final double[][] matrix, final Size scale) {
int m = matrix.length;
int n = matrix[0].length;
final int sm = m / scale.width;
final int sn = n / scale.height;
final double[][] outMatrix = new double[sm][sn];
if (sm * scale.width != m || sn * scale.height != n)
throw new RuntimeException("scale matrix");
final int size = scale.width * scale.height;
for (int i = 0; i < sm; i++) {
for (int j = 0; j < sn; j++) {
double sum = 0.0;
for (int si = i * scale.width; si < (i + 1) * scale.width; si++) {
for (int sj = j * scale.height; sj < (j + 1) * scale.height; sj++) {
sum += matrix[si][sj];
} // Of for sj
} // Of for si
outMatrix[i][j] = sum / size;
} // Of for j
} // Of for i
return outMatrix;
}// Of scaleMatrix

/**
* **********************
* Convolution full to obtain a bigger size. It is used in back-propagation.
* **********************
*/
public static double[][] convnFull(double[][] matrix, final double[][] kernel) {
int m = matrix.length;
int n = matrix[0].length;
final int km = kernel.length;
final int kn = kernel[0].length;
final double[][] extendMatrix = new double[m + 2 * (km - 1)][n + 2 * (kn - 1)];
for (int i = 0; i < m; i++) {
System.arraycopy(matrix[i], 0, extendMatrix[i + km - 1], kn - 1, n);
} // Of for i
return convnValid(extendMatrix, kernel);
}// Of convnFull

/**
* **********************
* Convolution operation, from a given matrix and a kernel, sliding and sum
* to obtain the result matrix. It is used in forward.
* **********************
*/
public static double[][] convnValid(final double[][] matrix, double[][] kernel) {
int m = matrix.length;
int n = matrix[0].length;
final int km = kernel.length;
final int kn = kernel[0].length;
int kns = n - kn + 1;
final int kms = m - km + 1;
final double[][] outMatrix = new double[kms][kns];

for (int i = 0; i < kms; i++) {
for (int j = 0; j < kns; j++) {
double sum = 0.0;
for (int ki = 0; ki < km; ki++) {
for (int kj = 0; kj < kn; kj++)
sum += matrix[i + ki][j + kj] * kernel[ki][kj];
}
outMatrix[i][j] = sum;

}
}
return outMatrix;
}// Of convnValid

/**
* **********************
* Convolution on a tensor.
* **********************
*/
public static double[][] convnValid(final double[][][][] matrix, int mapNoX, double[][][][] kernel, int mapNoY) {
int m = matrix.length;
int n = matrix[0][mapNoX].length;
int h = matrix[0][mapNoX][0].length;
int km = kernel.length;
int kn = kernel[0][mapNoY].length;
int kh = kernel[0][mapNoY][0].length;
int kms = m - km + 1;
int kns = n - kn + 1;
int khs = h - kh + 1;
if (matrix.length != kernel.length)
throw new RuntimeException("length");
final double[][][] outMatrix = new double[kms][kns][khs];
for (int i = 0; i < kms; i++) {
for (int j = 0; j < kns; j++)
for (int k = 0; k < khs; k++) {
double sum = 0.0;
for (int ki = 0; ki < km; ki++) {
for (int kj = 0; kj < kn; kj++)
for (int kk = 0; kk < kh; kk++) {
sum += matrix[i + ki][mapNoX][j + kj][k + kk] * kernel[ki][mapNoY][kj][kk];
}
}
outMatrix[i][j][k] = sum;
}
}
return outMatrix[0];
}// Of convnValid

/**
* **********************
* The sigmoid operation.
* **********************
*/
public static double sigmoid(double x) {
return 1 / (1 + Math.pow(Math.E, -x));
}// Of sigmoid

/**
* **********************
* Sum all values of a matrix.
* **********************
*/
public static double sum(double[][] error) {
int n = error[0].length;
double sum = 0.0;
for (double[] array : error) {
for (int i = 0; i < n; i++) {
sum += array[i];
}
}
return sum;
}// Of sum

/**
* **********************
* Ad hoc sum.
* **********************
*/
public static double[][] sum(double[][][][] errors, int j) {
int m = errors[0][j].length;
int n = errors[0][j][0].length;
double[][] result = new double[m][n];
for (int mi = 0; mi < m; mi++) {
for (int nj = 0; nj < n; nj++) {
double sum = 0;
for (double[][][] error : errors) {
sum += error[j][mi][nj];
}
result[mi][nj] = sum;
}
}
return result;
}// Of sum

/**
* **********************
* Get the index of the maximal value for the final classification.
* **********************
*/
public static int getMaxIndex(double[] out) {
double max = out[0];
int index = 0;
for (int i = 1; i < out.length; i++)
if (out[i] > max) {
max = out[i];
index = i;
}
return index;
}// Of getMaxIndex
} //Of class MathUtils

这里定义了一个枚举类用来标识每一层的属性, 比如输入层, 卷积层等.

1
2
3
4
5
6
7
8
9
10
package cnn;

/**
* Enumerate all layer types.
*
* @author Shi-Huai Wen Email: shihuaiwen@outlook.com.
*/
public enum LayerTypeEnum {
INPUT, CONVOLUTION, SAMPLING, OUTPUT;
} //Of enum LayerTypeEnum

四、网络结构与参数

对单层设置一些工具类的函数, 然后就是通过上面的枚举类型 LayerTypeEnum 来区别神经网络中不同的层, 例如输入层、卷积层和池化层.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
package cnn;

/**
* One layer, support all four layer types. The code mainly initializes, gets,
* and sets variables. Essentially no algorithm is implemented.
*
* @author Shi-Huai Wen Email: shihuaiwen@outlook.com.
*/
public class CnnLayer {
/**
* The type of the layer.
*/
LayerTypeEnum type;

/**
* The number of out map.
*/
int outMapNum;

/**
* The map size.
*/
Size mapSize;

/**
* The kernel size.
*/
Size kernelSize;

/**
* The scale size.
*/
Size scaleSize;

/**
* The index of the class (label) attribute.
*/
int classNum = -1;

/**
* Kernel. Dimensions: [front map][out map][width][height].
*/
private double[][][][] kernel;

/**
* Bias. The length is outMapNum.
*/
private double[] bias;

/**
* Out maps. Dimensions:
* [batchSize][outMapNum][mapSize.width][mapSize.height].
*/
private double[][][][] outMaps;

/**
* Errors.
*/
private double[][][][] errors;

/**
* For batch processing.
*/
private static int recordInBatch = 0;

/**
* **********************
* The first constructor.
*
* @param paraType Describe which Layer
* @param paraNum When the type is CONVOLUTION, it is the out map number. when
* the type is OUTPUT, it is the class number.
* @param paraSize When the type is INPUT, it is the map size; when the type is
* CONVOLUTION, it is the kernel size; when the type is SAMPLING,
* it is the scale size.
* **********************
*/
public CnnLayer(LayerTypeEnum paraType, int paraNum, Size paraSize) {
type = paraType;
switch (type) {
case INPUT:
outMapNum = 1;
mapSize = paraSize; // No deep copy.
break;
case CONVOLUTION:
outMapNum = paraNum;
kernelSize = paraSize;
break;
case SAMPLING:
scaleSize = paraSize;
break;
case OUTPUT:
classNum = paraNum;
mapSize = new Size(1, 1);
outMapNum = classNum;
break;
default:
System.out.println("Internal error occurred in AbstractLayer.java constructor.");
}// Of switch
}// Of the first constructor

/**
* **********************
* Initialize the kernel.
*
* @param paraFrontMapNum When the type is CONVOLUTION, it is the out map number. when
* **********************
*/
public void initKernel(int paraFrontMapNum) {
kernel = new double[paraFrontMapNum][outMapNum][][];
for (int i = 0; i < paraFrontMapNum; i++) {
for (int j = 0; j < outMapNum; j++) {
kernel[i][j] = MathUtils.randomMatrix(kernelSize.width, kernelSize.height, true);
} // Of for j
} // Of for i
}// Of initKernel

/**
* **********************
* Initialize the output kernel. The code is revised to invoke initKernel(int).
* **********************
*/
public void initOutputKernel(int paraFrontMapNum, Size paraSize) {
kernelSize = paraSize;
initKernel(paraFrontMapNum);
}// Of initOutputKernel

/**
* **********************
* Initialize the bias. No parameter. "int frontMapNum" is claimed however not used.
* **********************
*/
public void initBias() {
bias = MathUtils.randomArray(outMapNum);
}// Of initBias

/**
* **********************
* Initialize the errors.
*
* @param paraBatchSize The batch size.
* **********************
*/
public void initErrors(int paraBatchSize) {
errors = new double[paraBatchSize][outMapNum][mapSize.width][mapSize.height];
}// Of initErrors

/**
* **********************
* Initialize out maps.
*
* @param paraBatchSize The batch size.
* **********************
*/
public void initOutMaps(int paraBatchSize) {
outMaps = new double[paraBatchSize][outMapNum][mapSize.width][mapSize.height];
}// Of initOutMaps

/**
* **********************
* Prepare for a new batch.
* **********************
*/
public static void prepareForNewBatch() {
recordInBatch = 0;
}// Of prepareForNewBatch

/**
* **********************
* Prepare for a new record.
* **********************
*/
public static void prepareForNewRecord() {
recordInBatch++;
}// Of prepareForNewRecord

/**
* **********************
* Set one value of outMaps.
* **********************
*/
public void setMapValue(int paraMapNo, int paraX, int paraY, double paraValue) {
outMaps[recordInBatch][paraMapNo][paraX][paraY] = paraValue;
}// Of setMapValue

/**
* **********************
* Set values of the whole map.
* **********************
*/
public void setMapValue(int paraMapNo, double[][] paraOutMatrix) {
outMaps[recordInBatch][paraMapNo] = paraOutMatrix;
}// Of setMapValue

/**
* **********************
* Getter.
* **********************
*/
public Size getMapSize() {
return mapSize;
}// Of getMapSize

/**
* **********************
* Setter.
* **********************
*/
public void setMapSize(Size paraMapSize) {
mapSize = paraMapSize;
}// Of setMapSize

/**
* **********************
* Getter.
* **********************
*/
public LayerTypeEnum getType() {
return type;
}// Of getType

/**
* **********************
* Getter.
* **********************
*/
public int getOutMapNum() {
return outMapNum;
}// Of getOutMapNum

/**
* **********************
* Setter.
* **********************
*/
public void setOutMapNum(int paraOutMapNum) {
outMapNum = paraOutMapNum;
}// Of setOutMapNum

/**
* **********************
* Getter.
* **********************
*/
public Size getKernelSize() {
return kernelSize;
}// Of getKernelSize

/**
* **********************
* Getter.
* **********************
*/
public Size getScaleSize() {
return scaleSize;
}// Of getScaleSize

/**
* **********************
* Getter.
* **********************
*/
public double[][] getMap(int paraIndex) {
return outMaps[recordInBatch][paraIndex];
}// Of getMap

/**
* **********************
* Getter.
* **********************
*/
public double[][] getKernel(int paraFrontMap, int paraOutMap) {
return kernel[paraFrontMap][paraOutMap];
}// Of getKernel

/**
* **********************
* Setter. Set one error.
* **********************
*/
public void setError(int paraMapNo, int paraMapX, int paraMapY, double paraValue) {
errors[recordInBatch][paraMapNo][paraMapX][paraMapY] = paraValue;
}// Of setError

/**
* **********************
* Setter. Set one error matrix.
* **********************
*/
public void setError(int paraMapNo, double[][] paraMatrix) {
errors[recordInBatch][paraMapNo] = paraMatrix;
}// Of setError

/**
* **********************
* Getter. Get one error matrix.
* **********************
*/
public double[][] getError(int paraMapNo) {
return errors[recordInBatch][paraMapNo];
}// Of getError

/**
* **********************
* Getter. Get the whole error tensor.
* **********************
*/
public double[][][][] getErrors() {
return errors;
}// Of getErrors

/**
* **********************
* Setter. Set one kernel.
* **********************
*/
public void setKernel(int paraLastMapNo, int paraMapNo, double[][] paraKernel) {
kernel[paraLastMapNo][paraMapNo] = paraKernel;
}// Of setKernel

/**
* **********************
* Getter.
* **********************
*/
public double getBias(int paraMapNo) {
return bias[paraMapNo];
}// Of getBias

/**
* **********************
* Setter.
* **********************
*/
public void setBias(int paraMapNo, double paraValue) {
bias[paraMapNo] = paraValue;
}// Of setBias

/**
* **********************
* Getter.
* **********************
*/
public double[][][][] getMaps() {
return outMaps;
}// Of getMaps

/**
* **********************
* Getter.
* **********************
*/
public double[][] getError(int paraRecordId, int paraMapNo) {
return errors[paraRecordId][paraMapNo];
}// Of getError

/**
* **********************
* Getter.
* **********************
*/
public double[][] getMap(int paraRecordId, int paraMapNo) {
return outMaps[paraRecordId][paraMapNo];
}// Of getMap

/**
* **********************
* Getter.
* **********************
*/
public int getClassNum() {
return classNum;
}// Of getClassNum

/**
* **********************
* Getter. Get the whole kernel tensor.
* **********************
*/
public double[][][][] getKernel() {
return kernel;
} // Of getKernel
}//Of class CnnLayer

在 CnnLayer 类上再封装一层, 用于更加简便地创建神经网络中的各层.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
package cnn;

import java.util.ArrayList;
import java.util.List;

/**
* CnnLayer builder.
*
* @author Shi-Huai Wen Email: shihuaiwen@outlook.com.
*/
public class LayerBuilder {
/**
* Layers.
*/
private List<CnnLayer> layers;

/**
* **********************
* The first constructor.
* **********************
*/
public LayerBuilder() {
layers = new ArrayList<>();
}// Of the first constructor

/**
* **********************
* The second constructor.
* **********************
*/
public LayerBuilder(CnnLayer paraLayer) {
this();
layers.add(paraLayer);
}// Of the second constructor

/**
* **********************
* Add a layer.
*
* @param paraLayer The new layer.
* **********************
*/
public void addLayer(CnnLayer paraLayer) {
layers.add(paraLayer);
}// Of addLayer

/**
* **********************
* Get the specified layer.
*
* @param paraIndex The index of the layer.
* **********************
*/
public CnnLayer getLayer(int paraIndex) throws RuntimeException {
if (paraIndex >= layers.size()) {
throw new RuntimeException("CnnLayer " + paraIndex + " is out of range: " + layers.size() + ".");
}//Of if

return layers.get(paraIndex);
}//Of getLayer

/**
* **********************
* Get the output layer.
* **********************
*/
public CnnLayer getOutputLayer() {
return layers.get(layers.size() - 1);
}//Of getOutputLayer

/**
* **********************
* Get the number of layers.
* **********************
*/
public int getNumLayers() {
return layers.size();
}//Of getNumLayers
} //Of class LayerBuilder

五、神经网络的搭建

1. 正向传播

正向传播的基本内容在之前已经提到了, 这里简述一下.

一张图片通过卷积核得到特征图, 然后特征图通过自己选择的池化层进行池化, 最后使用激活函数对池化层进行激活, 并把激活后的输出做为下一个卷积层的输入.

在重复卷积、池化、激活后进入全连接层. 全连接层中也有一个卷积过程, 他是把 \(m \times n\) 的特征图转换为 \(1 \times n\) 的向量, 然后这个向量通过 \(Softmax\) 函数进行处理并归一化. 这时候这个向量中最大值的下标就表示是最有可能的类别的下标.

2. 反向传播

反向传播这是一个老生常谈的问题了, 因为开始的卷积核是随机的, 所以就需要利用损失函数找到最佳的卷积核.

反向传播最开始更新的是全连接层, 它的反向传播和 ANN 网络中类似, 就是更新其中的权值.

然后就是池化层, 池化层的权值更新是最简单的. 以最大池化举例, 假设池化后的值为 6, 反向传播得到的误差为 +1, 反向传播回去得到池化前的值就是 \(6 + 1 = 7\).

最麻烦的就是卷积层, 其中的公式推导我还是没有太弄清楚. 大致理解就是从二维出发得到了一个二维的公式, 然后将二维推广到神经网络中的多维.

知乎文章: 卷积神经网络(CNN)反向传播算法推导 有详细的推导和解释.

4. 具体代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
package cnn;

import java.util.Arrays;

import cnn.Dataset.Instance;
import cnn.MathUtils.Operator;

/**
* CNN.
*
* @author Shi-Huai Wen Email: shihuaiwen@outlook.com.
*/
public class FullCnn {
/**
* The value changes.
*/
private static double ALPHA = 0.85;

/**
* A constant.
*/
public static double LAMBDA = 0;

/**
* Manage layers.
*/
private static LayerBuilder layerBuilder;

/**
* Train using a number of instances simultaneously.
*/
private int batchSize;

/**
* Divide the batch size with the given value.
*/
private Operator divideBatchSize;

/**
* Multiply alpha with the given value.
*/
private Operator multiplyAlpha;

/**
* Multiply lambda and alpha with the given value.
*/
private Operator multiplyLambda;

/**
* **********************
* The first constructor.
* **********************
*/
public FullCnn(LayerBuilder paraLayerBuilder, int paraBatchSize) {
layerBuilder = paraLayerBuilder;
batchSize = paraBatchSize;
setup();
initOperators();
}// Of the first constructor

/**
* **********************
* Initialize operators using temporary classes.
* **********************
*/
private void initOperators() {
divideBatchSize = new Operator() {
private static final long serialVersionUID = 7424011281732651055L;

@Override
public double process(double value) {
return value / batchSize;
}// Of process
};

multiplyAlpha = new Operator() {
private static final long serialVersionUID = 5761368499808006552L;

@Override
public double process(double value) {
return value * ALPHA;
}// Of process
};

multiplyLambda = new Operator() {
private static final long serialVersionUID = 4499087728362870577L;

@Override
public double process(double value) {
return value * (1 - LAMBDA * ALPHA);
}// Of process
};
}// Of initOperators

/**
* **********************
* Setup according to the layer builder.
* **********************
*/
public void setup() {
CnnLayer tempInputLayer = layerBuilder.getLayer(0);
tempInputLayer.initOutMaps(batchSize);

for (int i = 1; i < layerBuilder.getNumLayers(); i++) {
CnnLayer tempLayer = layerBuilder.getLayer(i);
CnnLayer tempFrontLayer = layerBuilder.getLayer(i - 1);
int tempFrontMapNum = tempFrontLayer.getOutMapNum();
switch (tempLayer.getType()) {
case INPUT:
// Should not be input. Maybe an error should be thrown out.
break;
case CONVOLUTION:
tempLayer.setMapSize(
tempFrontLayer.getMapSize().subtract(tempLayer.getKernelSize(), 1));
tempLayer.initKernel(tempFrontMapNum);
tempLayer.initBias();
tempLayer.initErrors(batchSize);
tempLayer.initOutMaps(batchSize);
break;
case SAMPLING:
tempLayer.setOutMapNum(tempFrontMapNum);
tempLayer.setMapSize(tempFrontLayer.getMapSize().divide(tempLayer.getScaleSize()));
tempLayer.initErrors(batchSize);
tempLayer.initOutMaps(batchSize);
break;
case OUTPUT:
tempLayer.initOutputKernel(tempFrontMapNum, tempFrontLayer.getMapSize());
tempLayer.initBias();
tempLayer.initErrors(batchSize);
tempLayer.initOutMaps(batchSize);
break;
}// Of switch
} // Of for i
}// Of setup

/**
* **********************
* Forward computing.
* **********************
*/
private void forward(Instance instance) {
setInputLayerOutput(instance);
for (int l = 1; l < layerBuilder.getNumLayers(); l++) {
CnnLayer tempCurrentLayer = layerBuilder.getLayer(l);
CnnLayer tempLastLayer = layerBuilder.getLayer(l - 1);
switch (tempCurrentLayer.getType()) {
case CONVOLUTION:
case OUTPUT:
setConvolutionOutput(tempCurrentLayer, tempLastLayer);
break;
case SAMPLING:
setSampOutput(tempCurrentLayer, tempLastLayer);
break;
default:
break;
}// Of switch
} // Of for l
}// Of forward

/**
* **********************
* Set the in layer output. Given a record, copy its values to the input map.
* **********************
*/
private void setInputLayerOutput(Instance paraRecord) {
CnnLayer tempInputLayer = layerBuilder.getLayer(0);
Size tempMapSize = tempInputLayer.getMapSize();
double[] tempAttributes = paraRecord.getAttributes();
if (tempAttributes.length != tempMapSize.width * tempMapSize.height)
throw new RuntimeException("input record does not match the map size.");

for (int i = 0; i < tempMapSize.width; i++) {
for (int j = 0; j < tempMapSize.height; j++) {
tempInputLayer.setMapValue(0, i, j, tempAttributes[tempMapSize.height * i + j]);
} // Of for j
} // Of for i
}// Of setInputLayerOutput

/**
* **********************
* Compute the convolution output according to the output of the last layer.
*
* @param paraLastLayer the last layer.
* @param paraLayer the current layer.
* **********************
*/
private void setConvolutionOutput(final CnnLayer paraLayer, final CnnLayer paraLastLayer) {

final int lastMapNum = paraLastLayer.getOutMapNum();

// Attention: paraLayer.getOutMapNum() may not be right.
for (int j = 0; j < paraLayer.getOutMapNum(); j++) {
double[][] tempSumMatrix = null;
for (int i = 0; i < lastMapNum; i++) {
double[][] lastMap = paraLastLayer.getMap(i);
double[][] kernel = paraLayer.getKernel(i, j);
if (tempSumMatrix == null) {
// On the first map.
tempSumMatrix = MathUtils.convnValid(lastMap, kernel);
} else {
// Sum up convolution maps
tempSumMatrix = MathUtils.matrixOp(MathUtils.convnValid(lastMap, kernel),
tempSumMatrix, null, null, MathUtils.plus);
} // Of if
} // Of for i

// Activation.
final double bias = paraLayer.getBias(j);
tempSumMatrix = MathUtils.matrixOp(tempSumMatrix, new Operator() {
private static final long serialVersionUID = 2469461972825890810L;

@Override
public double process(double value) {
return MathUtils.sigmoid(value + bias);
}

});

paraLayer.setMapValue(j, tempSumMatrix);
} // Of for j
}// Of setConvolutionOutput

/**
* **********************
* Compute the convolution output according to the output of the last layer.
*
* @param paraLastLayer the last layer.
* @param paraLayer the current layer.
* **********************
*/
private void setSampOutput(final CnnLayer paraLayer, final CnnLayer paraLastLayer) {
// int tempLastMapNum = paraLastLayer.getOutMapNum();

// Attention: paraLayer.outMapNum may not be right.
for (int i = 0; i < paraLayer.outMapNum; i++) {
double[][] lastMap = paraLastLayer.getMap(i);
Size scaleSize = paraLayer.getScaleSize();
double[][] sampMatrix = MathUtils.scaleMatrix(lastMap, scaleSize);
paraLayer.setMapValue(i, sampMatrix);
} // Of for i
}// Of setSampOutput

/**
* **********************
* Train the cnn.
* **********************
*/
public void train(Dataset paraDataset, int paraRounds) {
for (int t = 0; t < paraRounds; t++) {
System.out.println("Iteration: " + t);
int tempNumEpochs = paraDataset.size() / batchSize;
if (paraDataset.size() % batchSize != 0)
tempNumEpochs++;

double tempNumCorrect = 0;
int tempCount = 0;
for (int i = 0; i < tempNumEpochs; i++) {
int[] tempRandomPerm = MathUtils.randomPerm(paraDataset.size(), batchSize);
CnnLayer.prepareForNewBatch();

for (int index : tempRandomPerm) {
boolean isRight = train(paraDataset.getInstance(index));
if (isRight)
tempNumCorrect++;
tempCount++;
CnnLayer.prepareForNewRecord();
} // Of for index

updateParameters();
if (i % 50 == 0) {
System.out.print("..");
if (i + 50 > tempNumEpochs)
System.out.println();
}
}
double p = 1.0 * tempNumCorrect / tempCount;
if (t % 10 == 1 && p > 0.96) {
ALPHA = 0.001 + ALPHA * 0.9;
// logger.info("设置 alpha = {}", ALPHA);
} // Of iff
System.out.println("Training precision: " + p);
// logger.info("计算精度: {}/{}={}.", right, count, p);
} // Of for i
}// Of train

/**
* **********************
* Train the cnn with only one record.
*
* @param paraRecord The given record.
* **********************
*/
private boolean train(Instance paraRecord) {
forward(paraRecord);
boolean result = backPropagation(paraRecord);
return result;
}// Of train

/**
* **********************
* Back-propagation.
*
* @param paraRecord The given record.
* **********************
*/
private boolean backPropagation(Instance paraRecord) {
boolean result = setOutputLayerErrors(paraRecord);
setHiddenLayerErrors();
return result;
}// Of backPropagation

/**
* **********************
* Update parameters.
* **********************
*/
private void updateParameters() {
for (int l = 1; l < layerBuilder.getNumLayers(); l++) {
CnnLayer layer = layerBuilder.getLayer(l);
CnnLayer lastLayer = layerBuilder.getLayer(l - 1);
switch (layer.getType()) {
case CONVOLUTION:
case OUTPUT:
updateKernels(layer, lastLayer);
updateBias(layer, lastLayer);
break;
default:
break;
}// Of switch
} // Of for l
}// Of updateParameters

/**
* **********************
* Update bias.
* **********************
*/
private void updateBias(final CnnLayer paraLayer, CnnLayer paraLastLayer) {
final double[][][][] errors = paraLayer.getErrors();
// int mapNum = paraLayer.getOutMapNum();

// Attention: getOutMapNum() may not be correct.
for (int j = 0; j < paraLayer.getOutMapNum(); j++) {
double[][] error = MathUtils.sum(errors, j);
double deltaBias = MathUtils.sum(error) / batchSize;
double bias = paraLayer.getBias(j) + ALPHA * deltaBias;
paraLayer.setBias(j, bias);
} // Of for j
}// Of updateBias

/**
* **********************
* Update kernels.
* **********************
*/
private void updateKernels(final CnnLayer paraLayer, final CnnLayer paraLastLayer) {
// int mapNum = paraLayer.getOutMapNum();
int tempLastMapNum = paraLastLayer.getOutMapNum();

// Attention: getOutMapNum() may not be right
for (int j = 0; j < paraLayer.getOutMapNum(); j++) {
for (int i = 0; i < tempLastMapNum; i++) {
double[][] tempDeltaKernel = null;
for (int r = 0; r < batchSize; r++) {
double[][] error = paraLayer.getError(r, j);
if (tempDeltaKernel == null)
tempDeltaKernel = MathUtils.convnValid(paraLastLayer.getMap(r, i), error);
else {
tempDeltaKernel = MathUtils.matrixOp(
MathUtils.convnValid(paraLastLayer.getMap(r, i), error),
tempDeltaKernel, null, null, MathUtils.plus);
} // Of if
} // Of for r

tempDeltaKernel = MathUtils.matrixOp(tempDeltaKernel, divideBatchSize);

double[][] kernel = paraLayer.getKernel(i, j);
tempDeltaKernel = MathUtils.matrixOp(kernel, tempDeltaKernel, multiplyLambda, multiplyAlpha, MathUtils.plus);
paraLayer.setKernel(i, j, tempDeltaKernel);
} // Of for i
} // Of for j
}// Of updateKernels

/**
* **********************
* Set errors of all hidden layers.
* **********************
*/
private void setHiddenLayerErrors() {
// System.out.println("setHiddenLayerErrors");
for (int l = layerBuilder.getNumLayers() - 2; l > 0; l--) {
CnnLayer layer = layerBuilder.getLayer(l);
CnnLayer nextLayer = layerBuilder.getLayer(l + 1);

switch (layer.getType()) {
case SAMPLING:
setSamplingErrors(layer, nextLayer);
break;
case CONVOLUTION:
setConvolutionErrors(layer, nextLayer);
break;
default:
break;
}// Of switch
} // Of for l
}// Of setHiddenLayerErrors

/**
* **********************
* Set errors of a sampling layer.
* **********************
*/
private void setSamplingErrors(final CnnLayer paraLayer, final CnnLayer paraNextLayer) {
// int mapNum = layer.getOutMapNum();
int tempNextMapNum = paraNextLayer.getOutMapNum();
// Attention: getOutMapNum() may not be correct
for (int i = 0; i < paraLayer.getOutMapNum(); i++) {
double[][] sum = null;
for (int j = 0; j < tempNextMapNum; j++) {
double[][] nextError = paraNextLayer.getError(j);
double[][] kernel = paraNextLayer.getKernel(i, j);
if (sum == null) {
sum = MathUtils.convnFull(nextError, MathUtils.rot180(kernel));
} else {
sum = MathUtils.matrixOp(
MathUtils.convnFull(nextError, MathUtils.rot180(kernel)), sum, null,
null, MathUtils.plus);
} // Of if
} // Of for j
paraLayer.setError(i, sum);
} // Of for i
}// Of setSamplingErrors

/**
* **********************
* Set errors of a sampling layer.
* **********************
*/
private void setConvolutionErrors(final CnnLayer paraLayer, final CnnLayer paraNextLayer) {
// System.out.println("setConvErrors");
for (int m = 0; m < paraLayer.getOutMapNum(); m++) {
Size tempScale = paraNextLayer.getScaleSize();
double[][] tempNextLayerErrors = paraNextLayer.getError(m);
double[][] tempMap = paraLayer.getMap(m);
double[][] tempOutMatrix = MathUtils.matrixOp(tempMap, MathUtils.cloneMatrix(tempMap),
null, MathUtils.one_value, MathUtils.multiply);
tempOutMatrix = MathUtils.matrixOp(tempOutMatrix,
MathUtils.kronecker(tempNextLayerErrors, tempScale), null, null,
MathUtils.multiply);
paraLayer.setError(m, tempOutMatrix);
} // Of for m
}// Of setConvolutionErrors

/**
* **********************
* Set errors of a sampling layer.
* **********************
*/
private boolean setOutputLayerErrors(Instance paraRecord) {
CnnLayer tempOutputLayer = layerBuilder.getOutputLayer();
int tempMapNum = tempOutputLayer.getOutMapNum();

double[] tempTarget = new double[tempMapNum];
double[] tempOutMaps = new double[tempMapNum];
for (int m = 0; m < tempMapNum; m++) {
double[][] outmap = tempOutputLayer.getMap(m);
tempOutMaps[m] = outmap[0][0];
} // Of for m

int tempLabel = paraRecord.getLabel().intValue();
tempTarget[tempLabel] = 1;

for (int m = 0; m < tempMapNum; m++) {
tempOutputLayer.setError(m, 0, 0,
tempOutMaps[m] * (1 - tempOutMaps[m]) * (tempTarget[m] - tempOutMaps[m]));
} // Of for m

return tempLabel == MathUtils.getMaxIndex(tempOutMaps);
}// Of setOutputLayerErrors

/**
* **********************
* Setup the network.
* **********************
*/
public void setup(int paraBatchSize) {
CnnLayer tempInputLayer = layerBuilder.getLayer(0);
tempInputLayer.initOutMaps(paraBatchSize);

for (int i = 1; i < layerBuilder.getNumLayers(); i++) {
CnnLayer tempLayer = layerBuilder.getLayer(i);
CnnLayer tempLastLayer = layerBuilder.getLayer(i - 1);
int tempLastMapNum = tempLastLayer.getOutMapNum();
switch (tempLayer.getType()) {
case INPUT:
break;
case CONVOLUTION:
tempLayer.setMapSize(
tempLastLayer.getMapSize().subtract(tempLayer.getKernelSize(), 1));
tempLayer.initKernel(tempLastMapNum);
tempLayer.initBias();
tempLayer.initErrors(paraBatchSize);
tempLayer.initOutMaps(paraBatchSize);
break;
case SAMPLING:
tempLayer.setOutMapNum(tempLastMapNum);
tempLayer.setMapSize(tempLastLayer.getMapSize().divide(tempLayer.getScaleSize()));
tempLayer.initErrors(paraBatchSize);
tempLayer.initOutMaps(paraBatchSize);
break;
case OUTPUT:
tempLayer.initOutputKernel(tempLastMapNum, tempLastLayer.getMapSize());
tempLayer.initBias();
tempLayer.initErrors(paraBatchSize);
tempLayer.initOutMaps(paraBatchSize);
break;
}// Of switch
} // Of for i
}// Of setup

/**
* **********************
* Predict for the dataset.
* **********************
*/
public int[] predict(Dataset paraDataset) {
System.out.println("Predicting ... ");
CnnLayer.prepareForNewBatch();

int[] resultPredictions = new int[paraDataset.size()];
double tempCorrect = 0.0;

Instance tempRecord;
for (int i = 0; i < paraDataset.size(); i++) {
tempRecord = paraDataset.getInstance(i);
forward(tempRecord);
CnnLayer outputLayer = layerBuilder.getOutputLayer();

int tempMapNum = outputLayer.getOutMapNum();
double[] tempOut = new double[tempMapNum];
for (int m = 0; m < tempMapNum; m++) {
double[][] outmap = outputLayer.getMap(m);
tempOut[m] = outmap[0][0];
} // Of for m

resultPredictions[i] = MathUtils.getMaxIndex(tempOut);
if (resultPredictions[i] == tempRecord.getLabel().intValue()) {
tempCorrect++;
} // Of if
} // Of for

System.out.println("Accuracy: " + tempCorrect / paraDataset.size());
return resultPredictions;
}// Of predict

/**
* **********************
* The main entrance.
* **********************
*/
public static void main(String[] args) {
LayerBuilder builder = new LayerBuilder();
// Input layer, the maps are 28*28
builder.addLayer(new CnnLayer(LayerTypeEnum.INPUT, -1, new Size(28, 28)));
// Convolution output has size 24*24, 24=28+1-5
builder.addLayer(new CnnLayer(LayerTypeEnum.CONVOLUTION, 6, new Size(5, 5)));
// Sampling output has size 12*12,12=24/2
builder.addLayer(new CnnLayer(LayerTypeEnum.SAMPLING, -1, new Size(2, 2)));
// Convolution output has size 8*8, 8=12+1-5
builder.addLayer(new CnnLayer(LayerTypeEnum.CONVOLUTION, 12, new Size(5, 5)));
// Sampling output has size4×4,4=8/2
builder.addLayer(new CnnLayer(LayerTypeEnum.SAMPLING, -1, new Size(2, 2)));
// output layer, digits 0 - 9.
builder.addLayer(new CnnLayer(LayerTypeEnum.OUTPUT, 10, null));
// Construct the full CNN.
FullCnn tempCnn = new FullCnn(builder, 10);

Dataset tempTrainingSet = new Dataset("D:/Work/Data/sampledata/train.format", ",", 784);

// Train the model.
tempCnn.train(tempTrainingSet, 10);
// tempCnn.predict(tempTrainingSet);
}// Of main
}//Of class FullCnn

5. 运行截图

总结

卷积神经网络理解起来容易, 但是实际编写一个框架对我来说就是非常痛苦且困难的事情.

首先是反向传播时数学公式的推导, 知道梯度下降和矩阵求导, 这些也仅仅是在单一的练习题中完成, 当实际运用时就找不到门路.

再者是代码的编写, 不管是数学的工具类还是矩阵的工具类, 尤其是矩阵旋转那部分刚开始看完全不理解.

要是我们只需要利用公式去编写代码, 想必这些问题就会迎刃而解. 偏偏天不遂人愿, 这些都是需要我真真实实去感受、去推导、去实现的.

路漫漫其修远兮, 吾将上下而求索.