From 1baac6dad54f2966f1f515574938b03ac5fd0900 Mon Sep 17 00:00:00 2001 From: YCAyca Date: Thu, 9 May 2019 13:58:03 +0300 Subject: [PATCH] Update network.py I am working on your online book and projects using Python 3.5. Since I used a different and newer python version, I came across several problems and solved them by making some changes in network.py and mnist_loader files. Now the code is compatible with Python 3.5 The biggest problem is zip variables haven't len() attribute in 3.5. So I converted training_data, test_data to the list variables. I had to make a deep copy of original training and test data because when they are converted to list variables, they can't be reused. --- src/network.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/network.py b/src/network.py index f66c362d4..652777b83 100644 --- a/src/network.py +++ b/src/network.py @@ -51,20 +51,20 @@ def SGD(self, training_data, epochs, mini_batch_size, eta, network will be evaluated against the test data after each epoch, and partial progress printed out. This is useful for tracking progress, but slows things down substantially.""" - if test_data: n_test = len(test_data) - n = len(training_data) - for j in xrange(epochs): - random.shuffle(training_data) - mini_batches = [ - training_data[k:k+mini_batch_size] - for k in xrange(0, n, mini_batch_size)] + temp_test_data = copy.deepcopy(test_data) + test_data_list = list(temp_test_data) + n_test= len(test_data_list) + + temp_training_data = copy.deepcopy(training_data) + training_data_list = list(temp_training_data) + n_training = len(training_data_list) + + for j in range(epochs): + random.shuffle(training_data_list) + mini_batches = [training_data_list[k:k+mini_batch_size] for k in range(0, n_training, mini_batch_size)] for mini_batch in mini_batches: - self.update_mini_batch(mini_batch, eta) - if test_data: - print "Epoch {0}: {1} / {2}".format( - j, self.evaluate(test_data), n_test) - else: - print "Epoch {0} complete".format(j) + self.update_mini_batch(mini_batch, mini_batch_size,eta) + print ("Epoch {0}: {1} / {2}".format(j, self.evaluate(test_data_list), n_test)) def update_mini_batch(self, mini_batch, eta): """Update the network's weights and biases by applying