diff --git a/.gitignore b/.gitignore index 865f99b..46bdc1b 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,7 @@ __pycache__/ *$py.class # C extensions -*.so +# *.so # Distribution / packaging .Python @@ -15,7 +15,7 @@ dist/ downloads/ eggs/ .eggs/ -lib/ +# lib/ lib64/ parts/ sdist/ diff --git a/data.py b/data.py index d8ae2ef..ac417a1 100644 --- a/data.py +++ b/data.py @@ -5,35 +5,31 @@ np.random.seed(42) -class DataLoader(): +def load_dataset(flatten=False): + (X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data() + + # normalize x + X_train = X_train / 255. + X_test = X_test / 255. - @staticmethod - def load_dataset(flatten=False): - (X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data() - - # normalize x - X_train = X_train.astype(float) / 255. - X_test = X_test.astype(float) / 255. + # we reserve the last 10000 training examples for validation + X_train, X_val = X_train[:-10000], X_train[-10000:] + y_train, y_val = y_train[:-10000], y_train[-10000:] - # we reserve the last 10000 training examples for validation - X_train, X_val = X_train[:-10000], X_train[-10000:] - y_train, y_val = y_train[:-10000], y_train[-10000:] + if flatten: + X_train = X_train.reshape([X_train.shape[0], -1]) + X_val = X_val.reshape([X_val.shape[0], -1]) + X_test = X_test.reshape([X_test.shape[0], -1]) - if flatten: - X_train = X_train.reshape([X_train.shape[0], -1]) - X_val = X_val.reshape([X_val.shape[0], -1]) - X_test = X_test.reshape([X_test.shape[0], -1]) + return X_train, y_train, X_val, y_val, X_test, y_test - return X_train, y_train, X_val, y_val, X_test, y_test - - @staticmethod - def iterate_minibatches(inputs, targets, batchsize, shuffle=False): - assert len(inputs) == len(targets) +def iterate_minibatches(inputs, targets, batchsize, shuffle=False): + assert len(inputs) == len(targets) + if shuffle: + indices = np.random.permutation(len(inputs)) + for start_idx in trange(0, len(inputs) - batchsize + 1, batchsize): if shuffle: - indices = np.random.permutation(len(inputs)) - for start_idx in trange(0, len(inputs) - batchsize + 1, batchsize): - if shuffle: - excerpt = indices[start_idx:start_idx + batchsize] - else: - excerpt = slice(start_idx, start_idx + batchsize) - yield inputs[excerpt], targets[excerpt] \ No newline at end of file + excerpt = indices[start_idx:start_idx + batchsize] + else: + excerpt = slice(start_idx, start_idx + batchsize) + yield inputs[excerpt], targets[excerpt] \ No newline at end of file diff --git a/functions.py b/functions.py new file mode 100644 index 0000000..5aee5cd --- /dev/null +++ b/functions.py @@ -0,0 +1,91 @@ +""" FUNCTIONAL API: functions.py +* Purpose: Core functional API for performing computation on CPU/GPU. +* @author Prabhsimran Singh +* @version 2.0 17/10/18 +""" +import numpy as np + +from ops.cpu_ops import ( + cpu_matmul, + cpu_matsum, + cpu_matprod, + cpu_sum, + cpu_prod, + cpu_maximum +) +from ops.numba_ops import ( + numba_matmul, + numba_matsum, + numba_matprod, + numba_sum, + numba_prod, + numba_maximum +) + +NUM_THREADS = 32 + +def get_cuda_execution_config(m, n): + gridBlock = (NUM_THREADS, NUM_THREADS) + gridDim = ((n // gridBlock[0]) + 1, (m // gridBlock[1]) + 1) + return gridDim, gridBlock + +def matmul(a, b, method='cpu'): + # fall back to cpu if dim inconsistency (numpy handle) + if method == 'cpu' or len(a.shape) != len(b.shape) or len(a.shape) == 1 or len(b.shape) == 1: + return cpu_matmul(a, b) + elif method == 'gpu': + m, n, k = a.shape[0], a.shape[1], b.shape[1] + c = np.zeros(shape=(m, k)) + gridDim, gridBlock = get_cuda_execution_config(m, k) + numba_matmul[gridDim, gridBlock](a, b, c, m, n, k) + return c + +def matsum(a, b, method='cpu'): + if method == 'cpu' or len(a.shape) != len(b.shape) or len(a.shape) == 1 or len(b.shape) == 1: + return cpu_matsum(a, b) + if method == 'gpu': + m, n = a.shape[0], a.shape[1] + c = np.zeros(shape=(m, n)) + gridDim, gridBlock = get_cuda_execution_config(m, n) + numba_matsum[gridDim, gridBlock](a, b, c, m, n) + return c.reshape((m, n)) + +def matprod(a, b, method='cpu'): + if method == 'cpu' or len(a.shape) != len(b.shape) or len(a.shape) == 1 or len(b.shape) == 1: + return cpu_matprod(a, b) + if method == 'gpu': + m, n = a.shape[0], a.shape[1] + c = np.zeros(shape=(m, n)) + gridDim, gridBlock = get_cuda_execution_config(m, n) + numba_matprod[gridDim, gridBlock](a, b, c, m, n) + return c.reshape((m, n)) + +def add(a, value, method='cpu'): + if method == 'cpu' or len(a.shape) == 1: + return cpu_sum(a, value) + if method == 'gpu': + m, n = a.shape[0], a.shape[1] + c = np.zeros(shape=(m, n)) + gridDim, gridBlock = get_cuda_execution_config(m, n) + numba_sum[gridDim, gridBlock](a, value, c, m, n) + return c.reshape((m, n)) + +def prod(a, value, method='cpu'): + if method == 'cpu' or len(a.shape) == 1: + return cpu_prod(a, value) + if method == 'gpu': + m, n = a.shape[0], a.shape[1] + c = np.zeros(shape=(m, n)) + gridDim, gridBlock = get_cuda_execution_config(m, n) + numba_prod[gridDim, gridBlock](a, value, c, m, n) + return c.reshape((m, n)) + +def maximum(a, value, method='cpu'): + if method == 'cpu' or len(a.shape) == 1: + return cpu_maximum(a, value) + if method == 'gpu': + m, n = a.shape[0], a.shape[1] + c = np.zeros(shape=(m, n)) + gridDim, gridBlock = get_cuda_execution_config(m, n) + numba_maximum[gridDim, gridBlock](a, value, c, m, n) + return c \ No newline at end of file diff --git a/layers.py b/layers.py index fe0b84c..675e6c9 100644 --- a/layers.py +++ b/layers.py @@ -1,6 +1,8 @@ from __future__ import print_function import numpy as np +import functions as F + np.random.seed(42) class Layer: @@ -34,23 +36,26 @@ def backward(self, inputs, gradients, **kwargs): """ pass -class Dense(Layer): +class Dense(): """ Dense layer. A dense layer is a layer which performs a learned affine transformation: f(x) = + b + input shape: [batch, input_units] + output shape: [batch, output units] """ - - def __init__(self, input_units, output_units): + + def __init__(self, input_units, output_units, method='cpu'): self.type = 'dense' + self.method = method - # initialize weights with glorot/xavier uniform initialization - self.weights = np.random.randn(input_units, output_units) * np.sqrt(6. / (input_units + output_units)) + # initialize weights with small random numbers. We use xavier initialization + self.weights = F.prod(np.random.randn(input_units, output_units), np.sqrt(2. / (input_units + output_units)), method=self.method) self.biases = np.zeros(output_units) def _init_g2(self): self.g2_weights = np.zeros_like(self.weights) self.g2_biases = np.zeros_like(self.biases) - + def forward(self, inputs): """ Forward pass of the dense layer. Perform an affine transformation: @@ -59,59 +64,68 @@ def forward(self, inputs): input shape: [batch, input_units] output shape: [batch, output units] """ - return np.dot(inputs, self.weights) + self.biases - + Wx = F.matmul(inputs, self.weights, method=self.method) + Z = F.matsum(Wx, self.biases, method=self.method) + return Z + def backward(self, inputs, gradients, **kwargs): + """ Backward pass of the layer. + Performs a backpropagation step through the layer, with respect to the given input. + To compute loss gradients w.r.t input, you need to apply chain rule (backprop): + dL / dx = (dL / dZ) * (dZ / dx) + """ lr = kwargs.get('lr', 0.001) gamma = kwargs.get('gamma', 0.9) epsilon = kwargs.get('epsilon', 1e-7) optim = kwargs.get('optim', 'rmsprop') # dL / dx = dL / dZ * dZ / dx = gradients * W - grad_input = np.dot(gradients, self.weights.T) + grad_input = F.matmul(gradients, self.weights.T, method=self.method) # m -> batch size m = inputs.shape[0] # compute gradient w.r.t. weights and biases # dL / dW = dL / dZ * dZ / dW = gradients * inputs - grad_weights = np.dot(inputs.T, gradients) / m + grad_weights = F.prod(F.matmul(inputs.T, gradients, method=self.method), 1. / m, method=self.method) # dL / db = dL / dZ * dZ / db = gradients * 1 - grad_biases = gradients.sum(axis=0) / m - + grad_biases = F.prod(gradients.sum(axis=0), 1. / m, method=self.method) assert grad_weights.shape == self.weights.shape and grad_biases.shape == self.biases.shape - update_weights = lr * grad_weights - update_biases = lr * grad_biases + update_weights = F.prod(grad_weights, lr, method=self.method) + update_biases = F.prod(grad_biases, lr, method=self.method) if optim == 'rmsprop': if not hasattr(self, 'g2_weights'): self._init_g2() - self.g2_weights = (self.g2_weights * gamma) + np.square(grad_weights) * (1 - gamma) - self.g2_biases = (self.g2_biases * gamma) + np.square(grad_biases) * (1 - gamma) - - self.weights -= update_weights / (np.sqrt(self.g2_weights) + epsilon) - self.biases -= update_biases / (np.sqrt(self.g2_biases) + epsilon) + self.g2_weights = F.matsum(F.prod(self.g2_weights, gamma, method=self.method), F.prod(np.square(grad_weights), (1 - gamma), method=self.method), method=self.method) + self.g2_biases = F.matsum(F.prod(self.g2_biases, gamma, method=self.method), F.prod(np.square(grad_biases), (1 - gamma), method=self.method), method=self.method) + self.weights = F.matsum(self.weights, -F.matprod(update_weights, 1. / np.sqrt(F.add(self.g2_weights, epsilon, method=self.method)), method=self.method), method=self.method) + self.biases = F.matsum(self.biases, -F.matprod(update_biases, 1. / np.sqrt(F.add(self.g2_biases, epsilon, method=self.method)), method=self.method), method=self.method) elif optim == 'gd': - self.weights -= update_weights - self.biases -= update_biases + self.weights = F.matsum(self.weights, -update_weights, method=self.method) + self.biases = F.matsum(self.biases, -update_biases, method=self.method) # propagate back the gradients of Loss wrt to layer inputs # dL / dx return grad_input - -class ReLU(Layer): - """ReLU layer. - Simply applies elementwise rectified linear unit to all inputs. + +class ReLU(): + """ ReLU layer. + + Applies elementwise rectified linear unit to all inputs: + f(x) = max(0, x) + + input shape: [batch, input_units] + output shape: [batch, input_units] """ - def __init__(self): + def __init__(self, method='cpu'): self.type = 'relu' + self.method = method def forward(self, inputs): - """Apply elementwise ReLU to [batch, input_units] matrix""" - return np.maximum(0, inputs) + return F.maximum(inputs, 0., method=self.method) def backward(self, inputs, gradients, **kwargs): - """Compute gradient of loss w.r.t. ReLU input""" - grad_relu = inputs > 0 - return gradients * grad_relu \ No newline at end of file + grad_relu = inputs > 0. + return F.matprod(gradients, grad_relu, method=self.method) diff --git a/lib/cuda_c.cu b/lib/cuda_c.cu new file mode 100644 index 0000000..22fe5f4 --- /dev/null +++ b/lib/cuda_c.cu @@ -0,0 +1,338 @@ +/** + * CUDA PARALLEL PROGRAMMING: cuda_c.cu + * Purpose: Matrix Operations using CUDA C/C++ + * @author Prabhsimran Singh + * @version 2.4 17/10/18 + * + * Build using: nvcc -Xcompiler -fPIC -shared -o lib/cuda_c.so lib/cuda_c.cu --gpu-architecture=compute_61 --gpu-code=sm_61,compute_61 + */ +#include +#include "utils/devices.cu" +#include "utils/utils.cpp" + +#define NUM_THREADS 32 + +/** +* Computes dot-product of two matrices (using parallel threads on CUDA capable device) +* +* @param a the double pointer to first input array +* @param b the double pointer to second input array +* @param c the double pointer to output array +* @param m the no. rows in a(m x n) and c(m x k) +* @param n the no. cols in a(m x n) and rows in b(n x k) +* @param k the no. cols in b(n x k) and c(m x k) +* @return void +*/ +__global__ void matmul(double *a, double *b, double *c, int m, int n, int k) { + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + // strides are unnecessary if you're not using variable sized blocks + // however I'm going to leave it here since the focus of this is not + // optimality or performance but readability and overall coherence. + int stride_row = gridDim.y * blockDim.y; + int stride_col = gridDim.x * blockDim.x; + + for (; row < m && col < k; row += stride_row, col += stride_col) { + double sum = 0; + #pragma unroll // unrolls the for loop (optimization - cuts on exec time) + for (int i = 0; i < n; i++) { + sum += a[row * n + i] * b[i * k + col]; + } + c[row * k + col] = sum; + } +} + +/** +* Calculates element-wise sum of two matrices (using parallel threads on CUDA capable device) +* +* @param a the double pointer to first input array +* @param b the double pointer to second input array +* @param c the double pointer to output array +* @param m the no. of rows in the arrays +* @param n the no. of cols in the arrays +* @return void +*/ +__global__ void matsum(double *a, double *b, double *c, int m, int n) { + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + int stride_row = gridDim.y * blockDim.y; + int stride_col = gridDim.x * blockDim.x; + + for (; row < m && col < n; row += stride_row, col += stride_col) { + c[row * n + col] = a[row * n + col] + b[row * n + col]; + } +} + +/** +* Calculates element-wise product of two matrices (using parallel threads on CUDA capable device) +* +* @param a the double pointer to first input array +* @param b the double pointer to second input array +* @param c the double pointer to output array +* @param m the no. of rows in the arrays +* @param n the no. of cols in the arrays +* @return void +*/ +__global__ void matprod(double *a, double *b, double *c, int m, int n) { + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + int stride_row = gridDim.y * blockDim.y; + int stride_col = gridDim.x * blockDim.x; + + for (; row < m && col < n; row += stride_row, col += stride_col) { + c[row * n + col] = a[row * n + col] * b[row * n + col]; + } +} + +/** +* Adds a value element-wise to the matrix (using parallel threads on CUDA capable device) +* +* @param a the double pointer to first input array +* @param b the double value to add to the array +* @param c the double pointer to output array +* @param m the no. of rows in the arrays +* @param n the no. of cols in the arrays +* @return void +*/ +__global__ void sum(double *a, double b, double *c, int m, int n) { + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + int stride_row = gridDim.y * blockDim.y; + int stride_col = gridDim.x * blockDim.x; + + for (; row < m && col < n; row += stride_row, col += stride_col) { + c[row * n + col] = a[row * n + col] + b; + } +} + +/** +* Multiplies a value element-wise to matrix (using parallel threads on CUDA capable device) +* +* @param a the double pointer to first input array +* @param b the double value to multiply the array with +* @param c the double pointer to output array +* @param m the no. of rows in the arrays +* @param n the no. of cols in the arrays +* @return void +*/ +__global__ void prod(double *a, double b, double *c, int m, int n) { + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + int stride_row = gridDim.y * blockDim.y; + int stride_col = gridDim.x * blockDim.x; + + for (; row < m && col < n; row += stride_row, col += stride_col) { + c[row * n + col] = a[row * n + col] * b; + } +} + +/** +* Computes the element-wise maximum ofs a matrix and a value (using parallel threads on CUDA capable device) +* +* @param a the double pointer to first input array +* @param b the double value to check maximum against +* @param c the double pointer to output array +* @param m the no. of rows in the arrays +* @param n the no. of cols in the arrays +* @return void +*/ +__global__ void maximum(double *a, double b, double *c, int m, int n) { + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + int stride_row = gridDim.y * blockDim.y; + int stride_col = gridDim.x * blockDim.x; + + for (; row < m && col < n; row += stride_row, col += stride_col) { + c[row * n + col] = (a[row * n + col] > b) ? a[row * n + col] : b; + } +} + +extern "C" { + + void cuda_device_info() { + getCudaDeviceInfo(); + } + + void cuda_matmul(double *a, double *b, double *c, int m, int n, int k) { + double *d_a, *d_b, *d_c; + + cudaMallocManaged(&d_a, (m * n) * sizeof(double)); + cudaMallocManaged(&d_b, (n * k) * sizeof(double)); + cudaMallocManaged(&d_c, (m * k) * sizeof(double)); + + cudaMemcpy(d_a, a, (m * n) * sizeof(double), cudaMemcpyHostToDevice); + cudaMemcpy(d_b, b, (n * k) * sizeof(double), cudaMemcpyHostToDevice); + + dim3 dimBlock(NUM_THREADS, NUM_THREADS, 1); + dim3 dimGrid((k / dimBlock.x) + 1, (m / dimBlock.y) + 1, 1); + + cudaError_t syncErr, asyncErr; + matmul<<>>(d_a, d_b, d_c, m, n, k); + + syncErr = cudaGetLastError(); + asyncErr = cudaDeviceSynchronize(); + + if (syncErr != cudaSuccess) + cout << "CUDA Error: " << cudaGetErrorString(syncErr) << endl; + if (asyncErr != cudaSuccess) + cout << "CUDA Error: " << cudaGetErrorString(asyncErr) << endl; + + cudaMemcpy(c, d_c, (m * k) * sizeof(double), cudaMemcpyDeviceToHost); + + cudaFree(d_a); + cudaFree(d_b); + cudaFree(d_c); + } + + void cuda_matsum(double *a, double *b, double *c, int m, int n) { + double *d_a, *d_b, *d_c; + + cudaMallocManaged(&d_a, (m * n) * sizeof(double)); + cudaMallocManaged(&d_b, (m * n) * sizeof(double)); + cudaMallocManaged(&d_c, (m * n) * sizeof(double)); + + cudaMemcpy(d_a, a, (m * n) * sizeof(double), cudaMemcpyHostToDevice); + cudaMemcpy(d_b, b, (m * n) * sizeof(double), cudaMemcpyHostToDevice); + + dim3 dimBlock(NUM_THREADS, NUM_THREADS, 1); + dim3 dimGrid((n / dimBlock.x) + 1, (m / dimBlock.y) + 1, 1); + + cudaError_t syncErr, asyncErr; + matsum<<>>(d_a, d_b, d_c, m, n); + + syncErr = cudaGetLastError(); + asyncErr = cudaDeviceSynchronize(); + + if (syncErr != cudaSuccess) + cout << "CUDA Error: " << cudaGetErrorString(syncErr) << endl; + if (asyncErr != cudaSuccess) + cout << "CUDA Error: " << cudaGetErrorString(asyncErr) << endl; + + cudaMemcpy(c, d_c, (m * n) * sizeof(double), cudaMemcpyDeviceToHost); + + cudaFree(d_a); + cudaFree(d_b); + cudaFree(d_c); + } + + void cuda_matprod(double *a, double *b, double *c, int m, int n) { + double *d_a, *d_b, *d_c; + + cudaMallocManaged(&d_a, (m * n) * sizeof(double)); + cudaMallocManaged(&d_b, (m * n) * sizeof(double)); + cudaMallocManaged(&d_c, (m * n) * sizeof(double)); + + cudaMemcpy(d_a, a, (m * n) * sizeof(double), cudaMemcpyHostToDevice); + cudaMemcpy(d_b, b, (m * n) * sizeof(double), cudaMemcpyHostToDevice); + + dim3 dimBlock(NUM_THREADS, NUM_THREADS, 1); + dim3 dimGrid((n / dimBlock.x) + 1, (m / dimBlock.y) + 1, 1); + + cudaError_t syncErr, asyncErr; + matprod<<>>(d_a, d_b, d_c, m, n); + + syncErr = cudaGetLastError(); + asyncErr = cudaDeviceSynchronize(); + + if (syncErr != cudaSuccess) + cout << "CUDA Error: " << cudaGetErrorString(syncErr) << endl; + if (asyncErr != cudaSuccess) + cout << "CUDA Error: " << cudaGetErrorString(asyncErr) << endl; + + cudaMemcpy(c, d_c, (m * n) * sizeof(double), cudaMemcpyDeviceToHost); + + cudaFree(d_a); + cudaFree(d_b); + cudaFree(d_c); + } + + void cuda_sum(double *a, double b, double *c, int m, int n) { + double *d_a, *d_c; + + cudaMallocManaged(&d_a, (m * n) * sizeof(double)); + cudaMallocManaged(&d_c, (m * n) * sizeof(double)); + + cudaMemcpy(d_a, a, (m * n) * sizeof(double), cudaMemcpyHostToDevice); + + dim3 dimBlock(NUM_THREADS, NUM_THREADS, 1); + dim3 dimGrid((n / dimBlock.x) + 1, (m / dimBlock.y) + 1, 1); + + cudaError_t syncErr, asyncErr; + sum<<>>(d_a, b, d_c, m, n); + + syncErr = cudaGetLastError(); + asyncErr = cudaDeviceSynchronize(); + + if (syncErr != cudaSuccess) + cout << "CUDA Error: " << cudaGetErrorString(syncErr) << endl; + if (asyncErr != cudaSuccess) + cout << "CUDA Error: " << cudaGetErrorString(asyncErr) << endl; + + cudaMemcpy(c, d_c, (m * n) * sizeof(double), cudaMemcpyDeviceToHost); + + cudaFree(d_a); + cudaFree(d_c); + } + + void cuda_prod(double *a, double b, double *c, int m, int n) { + double *d_a, *d_c; + + cudaMallocManaged(&d_a, (m * n) * sizeof(double)); + cudaMallocManaged(&d_c, (m * n) * sizeof(double)); + + cudaMemcpy(d_a, a, (m * n) * sizeof(double), cudaMemcpyHostToDevice); + + dim3 dimBlock(NUM_THREADS, NUM_THREADS, 1); + dim3 dimGrid((n / dimBlock.x) + 1, (m / dimBlock.y) + 1, 1); + + cudaError_t syncErr, asyncErr; + prod<<>>(d_a, b, d_c, m, n); + + syncErr = cudaGetLastError(); + asyncErr = cudaDeviceSynchronize(); + + if (syncErr != cudaSuccess) + cout << "CUDA Error: " << cudaGetErrorString(syncErr) << endl; + if (asyncErr != cudaSuccess) + cout << "CUDA Error: " << cudaGetErrorString(asyncErr) << endl; + + cudaMemcpy(c, d_c, (m * n) * sizeof(double), cudaMemcpyDeviceToHost); + + cudaFree(d_a); + cudaFree(d_c); + } + + void cuda_maximum(double *a, double b, double *c, int m, int n) { + double *d_a, *d_c; + + cudaMallocManaged(&d_a, (m * n) * sizeof(double)); + cudaMallocManaged(&d_c, (m * n) * sizeof(double)); + + cudaMemcpy(d_a, a, (m * n) * sizeof(double), cudaMemcpyHostToDevice); + + dim3 dimBlock(NUM_THREADS, NUM_THREADS, 1); + dim3 dimGrid((n / dimBlock.x) + 1, (m / dimBlock.y) + 1, 1); + + cudaError_t syncErr, asyncErr; + maximum<<>>(d_a, b, d_c, m, n); + + syncErr = cudaGetLastError(); + asyncErr = cudaDeviceSynchronize(); + + if (syncErr != cudaSuccess) + cout << "CUDA Error: " << cudaGetErrorString(syncErr) << endl; + if (asyncErr != cudaSuccess) + cout << "CUDA Error: " << cudaGetErrorString(asyncErr) << endl; + + cudaMemcpy(c, d_c, (m * n) * sizeof(double), cudaMemcpyDeviceToHost); + + cudaFree(d_a); + cudaFree(d_c); + } +} \ No newline at end of file diff --git a/lib/cuda_c.so b/lib/cuda_c.so new file mode 100755 index 0000000..51d3ef2 Binary files /dev/null and b/lib/cuda_c.so differ diff --git a/lib/utils/devices.cu b/lib/utils/devices.cu new file mode 100644 index 0000000..29e7261 --- /dev/null +++ b/lib/utils/devices.cu @@ -0,0 +1,20 @@ +#include +using namespace std; + +void getCudaDeviceInfo() { + int nDevices; + cudaGetDeviceCount(&nDevices); + for (int i = 0; i < nDevices; i++) { + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, i); + cout << "GPU Device Id: " << i << endl; + cout << "Device name: " << prop.name << endl; + cout << "Memory Clock Rate (KHz): " << + prop.memoryClockRate << endl; + cout << "Memory Bus Width (bits): " << + prop.memoryBusWidth << endl; + cout << "Peak Memory Bandwidth (GB/s): " << + 2.0 * prop.memoryClockRate * (prop.memoryBusWidth / 8) / 1.0e6 << endl; + cout << endl; + } +} \ No newline at end of file diff --git a/lib/utils/utils.cpp b/lib/utils/utils.cpp new file mode 100644 index 0000000..0ecbafe --- /dev/null +++ b/lib/utils/utils.cpp @@ -0,0 +1,9 @@ +#include + +void printMatrix(double *matrix, int m, int n) { + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) + std::cout << matrix[i * m + j] << " "; + std::cout << std::endl; + } +} \ No newline at end of file diff --git a/main.py b/main.py index 781bb3b..d2b96d1 100644 --- a/main.py +++ b/main.py @@ -1,56 +1,53 @@ from __future__ import print_function -import matplotlib.pyplot as plt -import numpy as np import argparse -from data import DataLoader +import numpy as np + from train import Trainer +from data import iterate_minibatches, load_dataset +from ops.cuda_c_ops import cuda_device_info np.random.seed(42) def parse_args(): parser = argparse.ArgumentParser(description='Training Configuration') - parser.add_argument('--epochs', type=int, default=10, dest='epochs', + parser.add_argument('--epochs', type=int, default=20, dest='epochs', help='Number of iterations for training') - parser.add_argument('--batch-size', type=int, default=64, dest='batch_size', + parser.add_argument('--batch-size', type=int, default=128, dest='batch_size', help='Batch size for one epoch in training') - parser.add_argument('--lr', type=float, default=0.001, dest='lr', + parser.add_argument('--lr', type=float, default=0.005, dest='lr', help='Initial learning rate') - parser.add_argument('--plot', type=bool, default=False, dest='plot', - help='Flag that indicates whether plot the accuracy during training') + parser.add_argument('--backend', type=str, default='cpu', dest='backend', + help='Type of computation backend to use [CPU/GPU]') return parser.parse_args() def main(): args = parse_args() + if args.backend.lower() == 'gpu': + cuda_device_info() - X_train, y_train, X_val, y_val, X_test, y_test = DataLoader.load_dataset(flatten=True) + X_train, y_train, X_val, y_val, X_test, y_test = load_dataset(flatten=True) input_dim = X_train.shape[1] num_classes = 10 - dims = [input_dim, 100, 200, 200, num_classes] + dims = [input_dim, 1024, 1024, 256, num_classes] - trainer = Trainer(dims=dims) + trainer = Trainer(dims=dims, backend=args.backend.lower()) train_log = [] val_log = [] for epoch in range(1, args.epochs + 1): - for x_batch, y_batch in DataLoader.iterate_minibatches(X_train, y_train, batchsize=args.batch_size, shuffle=True): + for x_batch, y_batch in iterate_minibatches(X_train, y_train, batchsize=args.batch_size, shuffle=True): trainer.fit(x_batch, y_batch, lr=args.lr) train_log.append(np.mean(trainer.predict(X_train) == y_train)) val_log.append(np.mean(trainer.predict(X_val) == y_val)) print("Epoch[{}/{}] train acc: {:.4f} - val acc: {:.4f}".format(epoch, args.epochs, train_log[-1], val_log[-1])) - if args.plot: - plt.plot(train_log,label='train accuracy') - plt.plot(val_log,label='val accuracy') - plt.legend(loc='best') - plt.grid() - plt.show() print('\nTesting on {} samples'.format(len(X_test))) - accuracy = np.mean(trainer.predict(X_test) == y_test) * 100 + accuracy = np.mean(trainer.predict(X_test) == y_test) print('test acc: {:.4f}'.format(accuracy)) if __name__ == '__main__': diff --git a/ops/cpu_ops.py b/ops/cpu_ops.py new file mode 100644 index 0000000..95fbec5 --- /dev/null +++ b/ops/cpu_ops.py @@ -0,0 +1,24 @@ +""" NUMPY CPU API: cpu_ops.py +* Purpose: NumPy API for performing computation on the CPU. +* @author Prabhsimran Singh +* @version 1.0 17/10/18 +""" +import numpy as np + +def cpu_matmul(a, b): + return np.dot(a, b) + +def cpu_matsum(a, b): + return np.add(a, b) + +def cpu_matprod(a, b): + return np.multiply(a, b) + +def cpu_sum(a, value): + return cpu_matsum(a, value) + +def cpu_prod(a, value): + return cpu_matprod(a, value) + +def cpu_maximum(a, value): + return np.maximum(a, value) \ No newline at end of file diff --git a/ops/cuda_c_ops.py b/ops/cuda_c_ops.py new file mode 100644 index 0000000..636bce8 --- /dev/null +++ b/ops/cuda_c_ops.py @@ -0,0 +1,148 @@ +""" CUDA GPU API: cuda_c_ops.py +* Purpose: Python interface exposing CUDA C/C++ API for performing computation on the GPU . +* @author Prabhsimran Singh +* @version 2.2 17/10/18 +* Build shared object library using: + nvcc -Xcompiler -fPIC -shared -o lib/cuda_c.so lib/cuda_c.cu +""" +import ctypes +import numpy as np + +from ctypes import POINTER, c_double, c_int + +# extract cuda function pointers in the shared object cuda_c.so +dll = ctypes.CDLL('./lib/cuda_c.so', mode=ctypes.RTLD_GLOBAL) + +# get the required functions exposed by CUDA C/C++ API +def get_cuda_device_info(dll): + func = dll.cuda_device_info + return func + +def get_cuda_matmul(dll): + func = dll.cuda_matmul + func.argtypes = [POINTER(c_double), POINTER(c_double), POINTER(c_double), c_int, c_int, c_int] + return func + +def get_cuda_matsum(dll): + func = dll.cuda_matsum + func.argtypes = [POINTER(c_double), POINTER(c_double), POINTER(c_double), c_int, c_int] + return func + +def get_cuda_matprod(dll): + func = dll.cuda_matprod + func.argtypes = [POINTER(c_double), POINTER(c_double), POINTER(c_double), c_int, c_int] + return func + +def get_cuda_sum(dll): + func = dll.cuda_sum + func.argtypes = [POINTER(c_double), c_double, POINTER(c_double), c_int, c_int] + return func + +def get_cuda_prod(dll): + func = dll.cuda_prod + func.argtypes = [POINTER(c_double), c_double, POINTER(c_double), c_int, c_int] + return func + +def get_cuda_maximum(dll): + func = dll.cuda_maximum + func.argtypes = [POINTER(c_double), c_double, POINTER(c_double), c_int, c_int] + return func + +__cuda_device_info = get_cuda_device_info(dll) +__cuda_matmul = get_cuda_matmul(dll) +__cuda_matsum = get_cuda_matsum(dll) +__cuda_matprod = get_cuda_matprod(dll) +__cuda_sum = get_cuda_sum(dll) +__cuda_prod = get_cuda_prod(dll) +__cuda_maximum = get_cuda_maximum(dll) + +# convenient python wrappers for cuda functions +def cuda_device_info(): + __cuda_device_info() + +def cuda_matmul(a, b, c, m, n, k): + a_p = a.ctypes.data_as(POINTER(c_double)) + b_p = b.ctypes.data_as(POINTER(c_double)) + c_p = c.ctypes.data_as(POINTER(c_double)) + __cuda_matmul(a_p, b_p, c_p, m, n, k) + +def cuda_matsum(a, b, c, m, n): + a_p = a.ctypes.data_as(POINTER(c_double)) + b_p = b.ctypes.data_as(POINTER(c_double)) + c_p = c.ctypes.data_as(POINTER(c_double)) + __cuda_matsum(a_p, b_p, c_p, m, n) + +def cuda_matprod(a, b, c, m, n): + a_p = a.ctypes.data_as(POINTER(c_double)) + b_p = b.ctypes.data_as(POINTER(c_double)) + c_p = c.ctypes.data_as(POINTER(c_double)) + __cuda_matprod(a_p, b_p, c_p, m, n) + +def cuda_sum(a, b, c, m, n): + a_p = a.ctypes.data_as(POINTER(c_double)) + b_f = ctypes.c_double(b) + c_p = c.ctypes.data_as(POINTER(c_double)) + __cuda_sum(a_p, b_f, c_p, m, n) + +def cuda_prod(a, b, c, m, n): + a_p = a.ctypes.data_as(POINTER(c_double)) + b_f = ctypes.c_double(b) + c_p = c.ctypes.data_as(POINTER(c_double)) + __cuda_prod(a_p, b_f, c_p, m, n) + +def cuda_maximum(a, b, c, m, n): + a_p = a.ctypes.data_as(POINTER(c_double)) + b_f = ctypes.c_double(b) + c_p = c.ctypes.data_as(POINTER(c_double)) + __cuda_maximum(a_p, b_f, c_p, m, n) + +def get_test_params(): + size = int(16) + a = np.array([3.0] * (size * size)) + b = np.array([3.0] * (size * size)) + c = np.zeros(shape=(size * size)) + return a, b, c, size + +if __name__ == '__main__': + cuda_device_info() + + a, b, c, size = get_test_params() + # basic checks for all ops + cuda_matmul(a, b, c, size, size, size) + assert np.all(c==144.0), "Matrix dot-product operation is buggy" + cuda_matsum(a, b, c, size, size) + assert np.all(c==6.0), "Matrix sum operation is buggy" + cuda_matprod(a, b, c, size, size) + assert np.all(c==9.0), "Matrix product operation is buggy" + cuda_sum(a, 5.0, c, size, size) + assert np.all(c==8.0), "Element-wise sum operation is buggy" + cuda_prod(a, 2.5, c, size, size) + assert np.all(c==7.5), "Element-wise product operation is buggy" + cuda_maximum(a, 4.0, c, size, size) + assert np.all(c==4.0), "Element-wise max operation is buggy" + + # robust check for matmul + a = np.random.randn(205, 510) + b = np.random.randn(510, 340) + c = np.zeros(205 * 340) + cuda_matmul(a.flatten(), b.flatten(), c, 205, 510, 340) + actual_dot = np.dot(a, b) + c = c.reshape(205, 340) + assert np.allclose(actual_dot, c), "Matrix dot-product operation is buggy" + + # robust checks for other ops + a = np.random.randn(100 * 200) + b = np.random.randn(100 * 200) + c = np.zeros_like(a) + cuda_matsum(a, b, c, 100, 200) + assert np.all(a + b == c), "Matrix sum operation is buggy" + cuda_matprod(a, b, c, 100, 200) + assert np.all(a * b == c), "Matrix product operation is buggy" + cuda_sum(a, 5.3, c, 100, 200) + assert np.all(a + 5.3 == c), "Element-wise sum operation is buggy" + cuda_prod(a, 6, c, 100, 200) + assert np.all(a * 6 == c), "Element-wise product operation is buggy" + cuda_maximum(a, 0, c, 100, 200) + assert np.all(np.maximum(0, a) == c), "Element-wise max operation is buggy" + + print('Passed all tests!') \ No newline at end of file diff --git a/ops/numba_ops.py b/ops/numba_ops.py new file mode 100644 index 0000000..2583281 --- /dev/null +++ b/ops/numba_ops.py @@ -0,0 +1,52 @@ +""" NUMBA GPU OPS: numba_ops.py +* Purpose: Numba API for performing computation on the GPU. +* @author Prabhsimran Singh +* @version 1.0 17/10/18 +""" +import numpy as np +from numba import cuda + +@cuda.jit +def numba_matmul(a, b, c, m, n, k): + row = cuda.blockIdx.y * cuda.blockDim.y + cuda.threadIdx.y + col = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x + if row < m and col < k: + summ = 0 + for i in range(n): + summ += a[row, i] * b[i, col] + c[row, col] = summ + +@cuda.jit +def numba_matsum(a, b, c, m, n): + row = cuda.blockIdx.y * cuda.blockDim.y + cuda.threadIdx.y + col = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x + if row < m and col < n: + c[row, col] = a[row, col] + b[row, col] + +@cuda.jit +def numba_matprod(a, b, c, m, n): + row = cuda.blockIdx.y * cuda.blockDim.y + cuda.threadIdx.y + col = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x + if row < m and col < n: + c[row, col] = a[row, col] * b[row, col] + +@cuda.jit +def numba_sum(a, value, c, m, n): + row = cuda.blockIdx.y * cuda.blockDim.y + cuda.threadIdx.y + col = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x + if row < m and col < n: + c[row, col] = a[row, col] + value + +@cuda.jit +def numba_prod(a, value, c, m, n): + row = cuda.blockIdx.y * cuda.blockDim.y + cuda.threadIdx.y + col = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x + if row < m and col < n: + c[row, col] = a[row, col] * value + +@cuda.jit +def numba_maximum(a, value, c, m, n): + row = cuda.blockIdx.y * cuda.blockDim.y + cuda.threadIdx.y + col = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x + if row < m and col < n: + c[row, col] = a[row, col] if a[row, col] > value else value \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index cddc231..3e4deae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,8 @@ -autograd==1.2 numpy==1.15.2 keras==2.2.2 matplotlib==2.2.3 tqdm==4.25.0 dill==0.2.8.2 -nose==1.3.7 \ No newline at end of file +nose==1.3.7 +numba==0.40.1 +autograd==1.2 \ No newline at end of file diff --git a/train.py b/train.py index a00d71d..88cfbb8 100644 --- a/train.py +++ b/train.py @@ -3,7 +3,6 @@ import dill import numpy as np -from autograd import elementwise_grad as grad from loss import ( softmax_crossentropy_with_logits, @@ -15,21 +14,24 @@ class Trainer(): - def __init__(self, dims=None): + def __init__(self, dims=None, backend='cpu'): if dims is None: - raise UserWarning('Model dims should not be none') - self._create(dims) + raise UserWarning('Model dims should not be none.') + if backend.lower() in ['cpu', 'gpu']: + self._create(dims, backend.lower()) + else: + raise UserWarning('Unknown computation backend. Should be one of [CPU, GPU]') - def _create(self, dims): + def _create(self, dims, backend): model = [] input_shape = dims[0] num_classes = dims[-1] - model.append(Dense(input_shape, dims[1])) - model.append(ReLU()) + model.append(Dense(input_shape, dims[1], method=backend)) + model.append(ReLU(method=backend)) for i in range(2, len(dims) - 1): - model.append(Dense(dims[i - 1], dims[i])) - model.append(ReLU()) - model.append(Dense(dims[-2], num_classes)) + model.append(Dense(dims[i - 1], dims[i], method=backend)) + model.append(ReLU(method=backend)) + model.append(Dense(dims[-2], num_classes, method=backend)) self._network = model def _forward(self, X): @@ -39,11 +41,9 @@ def _forward(self, X): """ activations = [] A = X - for layer in self._network: activations.append(layer.forward(A)) A = activations[-1] - assert len(activations) == len(self._network) return activations @@ -58,7 +58,6 @@ def fit(self, X, y, **kwargs): """ Train your network on a given batch of X and y. You first need to run forward to get all layer activations. Then you can run layer.backward going from last to first layer. - After you called backward for all layers, all Dense layers have already made one gradient step. """