我们需要做的第⼀件事情是获取 MNIST 数据。如果你是⼀个 git ⽤⼾,那么你能够通过克隆这本书的代码仓库获得数据,实现我们的⽹络来分类数字
git clone https://github.com/mnielsen/neural-networks-and-deep-learning.git
class Network(object):def __init__(self, sizes):self.num_layers = len(sizes)self.sizes = sizesself.biases = [np.random.randn(y, 1) for y in sizes[1:]]self.weights = [np.random.randn(y, x)for x, y in zip(sizes[:-1], sizes[1:])]
在这段代码中,列表 sizes 包含各层神经元的数量。例如,如果我们想创建⼀个在第⼀层有2 个神经元,第⼆层有 3 个神经元,最后层有 1 个神经元的 Network 对象,我们应这样写代码:
net = Network([2, 3, 1])
Network 对象中的偏置和权重都是被随机初始化的,使⽤ Numpy 的 np.random.randn 函数来⽣成均值为 0,标准差为 1 的⾼斯分布。这样的随机初始化给了我们的随机梯度下降算法⼀个起点。在后⾯的章节中我们将会发现更好的初始化权重和偏置的⽅法,但是⽬前随机地将其初始化。注意 Network 初始化代码假设第⼀层神经元是⼀个输⼊层,并对这些神经元不设置任何偏置,因为偏置仅在后⾯的层中⽤于计算输出。有了这些,很容易写出从⼀个 Network 实例计算输出的代码。我们从定义 S 型函数开始:
def sigmoid(z):return 1.0/(1.0+np.exp(-z))
注意,当输⼊ z 是⼀个向量或者 Numpy 数组时,Numpy ⾃动地按元素应⽤ sigmoid 函数,即以向量形式。
我们然后对 Network 类添加⼀个 feedforward ⽅法,对于⽹络给定⼀个输⼊ a,返回对应的输出 6 。这个⽅法所做的是对每⼀层应⽤⽅程 (22):
def feedforward(self, a):"""Return the output of the network if "a" is input."""for b, w in zip(self.biases, self.weights):a = sigmoid(np.dot(w, a)+b)return a
当然,我们想要 Network 对象做的主要事情是学习。为此我们给它们⼀个实现随即梯度下降算法的 SGD ⽅法。代码如下。其中⼀些地⽅看似有⼀点神秘,我会在代码后⾯逐个分析
def SGD(self, training_data, epochs, mini_batch_size, eta,test_data=None):"""Train the neural network using mini-batch stochasticgradient descent. The "training_data" is a list of tuples"(x, y)" representing the training inputs and the desiredoutputs. The other non-optional parameters areself-explanatory. If "test_data" is provided then thenetwork will be evaluated against the test data after eachepoch, and partial progress printed out. This is useful fortracking progress, but slows things down substantially."""if test_data: n_test = len(test_data)n = len(training_data)for j in xrange(epochs):random.shuffle(training_data)mini_batches = [training_data[k:k+mini_batch_size]for k in xrange(0, n, mini_batch_size)]for mini_batch in mini_batches:self.update_mini_batch(mini_batch, eta)if test_data:print "Epoch {0}: {1} / {2}".format(j, self.evaluate(test_data), n_test)else:print "Epoch {0} complete".format(j)
新闻热点
疑难解答