컴퓨터공학/CNN

CNN 자바로만 구현하기

airoot 2024. 8. 18. 15:13

CNN을 구현해 보고자 한다.

파이선의 텐서플로우를 사용하면 쉽게 구현할 수 있겠지만 공부를 위해서 또 향후 좀더 발전된 알고리즘을 개발하기 위해 Java로 구현하고자 한다. 물론 Java에도 Deeplearning4j라는 라이브러리가 있지만 이 역시 사용하지 않고 순수 코드로만 작성한다.

구현은 고전적인 숫자학습으로 하고자 한다.

 

입력 데이터 셋 : https://yann.lecun.com/exdb/mnist/

인식율은 90~92% 정도 되는 것 같다.

 

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.*;

public class CNN_Final {
	public static float[][][] trainImages = null;
	public static float[][] trainLabels = null;
	public static float[][][] testImages = null;
	public static float[][] testLabels = null;

	public static float[][][] filters = null;
	public static float[][] fcWeights = null;
    public static float[][] fcBiases = null;

    /**
     * creates 3X3 convolution filters with random initial weights.
     * @param size number of 3X3 filters to be randomly initialized
     * @return a [size] X [3] X [3] 3d array with size filters
     */
    public static float[][][] init_filters(int size) {
        float[][][] result = new float[size][3][3];
        for (int k = 0; k < size; k++) {
			for (int i = 0; i < 3; i++) {
				for (int j = 0; j < 3; j++) {
					result[k][i][j] = (float) Math.random();
				}
			}
        }
        return result;
    }
	public static void ini_SoftMax(int input, int output) {
		fcWeights = new float[input][output];
        for (int i = 0; i < input; i++) {
            for (int j = 0; j < output; j++) {
                fcWeights[i][j] = ((float) Math.random()) * (1.0f / input); //scale down
            }
        }
		fcBiases = new float[1][10];
        zerosVector(fcBiases[0]);
    }

	public static float[][][] loadImages(String filename, int batchSize, int imageSize) throws IOException {
        FileInputStream fis = new FileInputStream(filename);
        byte[] buffer = new byte[4];
        fis.read(buffer, 0, 4); // Magic number
        fis.read(buffer, 0, 4); // Number of images
        int numImages = ByteBuffer.wrap(buffer).order(ByteOrder.BIG_ENDIAN).getInt();
		System.out.println("numImages is "+numImages);
        fis.read(buffer, 0, 4); // Number of rows
        fis.read(buffer, 0, 4); // Number of columns
        float[][][] images = new float[batchSize][imageSize][imageSize];
        for (int i = 0; i < batchSize; i++) {
            for (int j = 0; j < imageSize; j++) {
                for (int k = 0; k < imageSize; k++) {
                    images[i][j][k] = (float)fis.read() / (float)255.0;
                }
            }
			if(i==0) {
				printImage(images[i],imageSize,imageSize);
				//break;
			}
			if(i % 100 == 99){
				System.out.print("."+(i+1));
			} else {
				System.out.print(".");
			}
        }
        fis.close();
		System.out.println("");
		System.out.println("loadImages completed.");
        return images;
    }

    public static float[][] loadLabels(String filename, int batchSize, int numClasses) throws IOException {
        FileInputStream fis = new FileInputStream(filename);
        byte[] buffer = new byte[4];
        fis.read(buffer, 0, 4); // Magic number
        fis.read(buffer, 0, 4); // Number of labels
        float[][] labels = new float[batchSize][numClasses];
        for (int i = 0; i < batchSize; i++) {
            int label = fis.read();
            labels[i][label] = (float)1.0;
			if(i==0) {
				System.out.println("label is "+label);
				//break;
			}
			/*if(i % 100 == 99){
				System.out.print("."+(i+1));
			} else {
				System.out.print(".");
			}*/
        }
        fis.close();
		System.out.println("");
		System.out.println("loadLabels completed.");
        return labels;
    }

	private static void printImage(float[][] image, int rows, int cols) {
		System.out.println("printImage rows is "+rows);
		System.out.println("printImage cols is "+cols);
        for (int i = 0; i < rows; i++) {
            for (int j = 0; j < cols; j++) {
                //int pixel = image[i][j] & 0xFF; // 바이트를 무부호 정수로 변환
				float pixel = image[i][j];
		//System.out.println("printImage pixel is "+pixel);
                if (pixel > (float)0.5) {
                    System.out.print("#"); // 픽셀이 밝으면 '#'
                } else {
                    System.out.print(" "); // 픽셀이 어두우면 ' '
                }
            }
            System.out.println();
        }
    }
    
    /**
     * performs both the forward and back-propagation passes of the CNN.
     * @param training_size the number of images used for training the CNN.
     * @throws IOException if image cannot be found.
     */
    public static void train(int batchSize) throws IOException {
        float ce_loss=0;
        int accuracy=0;
        float acc_sum=0.0f;
        float learn_rate=0.005f;

        float[][] out_l = new float[1][10];    
        for (int i = 0; i < batchSize; i++) {            
            //FORWARD PROPAGATION
            
            //convert to pixel array
            float[][] pxl = trainImages[i];
            // perform convolution 28*28 --> 8x26x26
            float[][][] out = convolution(pxl, filters);

            // perform maximum pooling  8x26x26 --> 8x13x13
            //out = pool.forward(out);
			out = maxPooling(out);
            
            // perform softmax operation  8*13*13 --> 10
            //out_l = softmax.forward(out); 
			//flattens the input to [8] X [13] X [13] to a [1] X [8*13*13] vector.
			float[][] flatOutput = flatten(out);  //1X1342
			float[][] fcOutput = fullyConnected(flatOutput, fcWeights, fcBiases);
			out_l = softmax(fcOutput);
            
            // compute cross-entropy loss
			int correct_label = maxIndex(trainLabels[i]);
            ce_loss += (float) -Math.log(out_l[0][correct_label]);
            accuracy += correct_label == maxIndex(out_l[0]) ? 1 : 0;
            
            //BACKWARD PROPAGATION --- STOCHASTIC GRADIENT DESCENT
            //gradient of the cross entropy loss
            float[][] gradient= new float[1][10];
			zerosVector(gradient[0]);
            gradient[0][correct_label]=-1/out_l[0][correct_label];
            float[][][] sm_gradient=fullyConnected_backprop(gradient,learn_rate); //softmax.backprop(gradient,learn_rate);
            float[][][] mp_gradient=maxpool_backprop(sm_gradient);
            convolution_backprop(mp_gradient, learn_rate);
            if(i % 100 == 99){
                System.out.println(" step: "+ i+ " loss: "+ce_loss/100.0+" accuracy: "+accuracy);
                ce_loss=0;
                acc_sum+=accuracy;
                accuracy=0;
            }
        }
        //System.out.println("average accuracy:- "+acc_sum/training_size+"%");
    }

    public static void test(int batchSize) throws IOException {
        int label_counter = 0;
        int accuracy=0;
        float acc_sum=0.0f;

        float[][] out_l = new float[1][10];    
        for (int i = 0; i < batchSize; i++) {
            
            //FORWARD PROPAGATION
            
            //convert to pixel array
            float[][] pxl = testImages[i];
            // perform convolution 28*28 --> 8x26x26
            float[][][] out = convolution(pxl, filters);

            // perform maximum pooling  8x26x26 --> 8x13x13
            out = maxPooling(out);
            
            // perform softmax operation  8*13*13 --> 10
            //out_l = softmax.forward(out); 
			float[][] flatOutput = flatten(out);  //1X1342
			float[][] fcOutput = fullyConnected(flatOutput, fcWeights, fcBiases);
			out_l = softmax(fcOutput);
            
            // compute cross-entropy loss
			int correct_label = maxIndex(testLabels[i]);
            float ce_loss = (float) -Math.log(out_l[0][correct_label]);
			int detected_label = maxIndex(out_l[0]);
            accuracy += correct_label == detected_label ? 1 : 0;
			if(correct_label != detected_label) {
				//printImage(testImages[i],28,28);
				//System.out.println(" correct_label: "+ correct_label+" detected_label: "+detected_label);
			}
            
            if(i % 100 == 99){
                System.out.println(" step: "+ i+ " loss: "+ce_loss/100.0+" accuracy: "+accuracy);
                ce_loss=0;
                acc_sum+=accuracy;
                accuracy=0;
            }
        }
        System.out.println("average accuracy: "+((acc_sum*100)/batchSize)+"%");
    }
      
    /**
     * Test method.
     * @param args
     * @throws IOException
     */
    public static void main(String[] args) throws IOException {      
		int imageSize = 28;
        int numClasses = 10;
        int batchSize = 60000;  // Full training set size
        int testBatchSize = 10000;  // Full test set size
        
		System.out.println("Load training data <=================");
        // Load MNIST training data
        String trainImagesFile = "train-images.idx3-ubyte";
        String trainLabelsFile = "train-labels.idx1-ubyte";
        trainImages = loadImages(trainImagesFile, batchSize, imageSize);
        trainLabels = loadLabels(trainLabelsFile, batchSize, numClasses);
        
		System.out.println("Load test data <=================");
		// Load MNIST test data
        String testImagesFile = "t10k-images.idx3-ubyte";
        String testLabelsFile = "t10k-labels.idx1-ubyte";
        testImages = loadImages(testImagesFile, testBatchSize, imageSize);
        testLabels = loadLabels(testLabelsFile, testBatchSize, numClasses);

        filters = init_filters(8);
		ini_SoftMax(13*13*8,10);

        train(batchSize);
		test(testBatchSize);
    }

	/**
     * caches the input data (the image) for use in the back-propagation phase.
     */
        public static float[][] inputConvolution; // shape --> [28] X [28]
    //

    /**
     * caches filters that were used in the convolution phase for use in the 
     * back-propagation phase.
     */
        public static float[][][] filtersConvolution; // shape --> [3] X [8] X [8]

	/**
     * Convolves the image with respect to a 3X3 filter
     * @param image the image matrix with shape [28] X [28]
     * @param filter a 3X3 filter used in the convolution process.
     * @return a 2D matrix with shape [26] X [26].
     */
    public static float[][] convolve3x3(float[][] image, float[][] filter) {
        inputConvolution=image;
        float[][] result = new float[image.length - 2][image[0].length - 2];
        //loop through
        for (int i = 1; i < image.length - 2; i++) {
            for (int j = 1; j < image[0].length - 2; j++) {
                float[][] conv_region = subMatrix(image, i - 1, i + 1, j - 1, j + 1);
                result[i][j] = elsumxMatrix(conv_region, filter);
            }
        }
        return result;
    }

	/**
     * the forward convolution pass that convolves the image w.r.t. each filter
     * in the filter array. No padding has been used in this case, so output matrix
     * shape decreases by 2 w.r.t row width and column height.
     * @param image the input image matrix. [28] X [28]
     * @param filter a 3D matrix containing an array of 3X3 filters ([8]X[3]X[3])
     * @return a 3D array containing an array of the convolved images w.r.t.
     * each filter. [8] X [26] X [26]
     */
	public static float[][][] convolution(float[][] image, float[][][] filter) {
        filtersConvolution=filter; // 8 X 3 X 3
        float[][][] result = new float[8][26][26];
        for (int k = 0; k < filters.length; k++) {
            float[][] res = convolve3x3(image, filters[k]);
			//res = relu(res); //Added by Kimbs
            result[k] = res;
        }
        return result;
    }

	/**
     * 
     * @param d_L_d_out the input gradient matrix retrieved from the back-propagation
     *  phase of the maximum pooling stage. shape = [8] X [26] X [26]
     * @param learning_rate the learning rate factor used in the neural network.
     */
    public static void convolution_backprop(float[][][] d_L_d_out,float learning_rate){
        //the output gradient which is dL/dfilter= (dL/dout)*(dout/dfilter)
        float[][][] d_L_d_filters= new float[filtersConvolution.length][filtersConvolution[0].length][filtersConvolution[0][0].length];
        //reverses the convolution phase by creating a 3X3 gradient filter 
        //and assigning its elements with the input gradient values scaled by
        //the corresponding pixels of the image.
        for(int i=1;i<inputConvolution.length-2;i++){
            for(int j=1;j<inputConvolution[0].length-2;j++){
                for(int k=0;k<filtersConvolution.length;k++){
                    //get a 3X3 region of the matrix
                    float[][] region=subMatrix(inputConvolution,  i - 1, i + 1, j - 1, j + 1);
                    //for each 3X3 region in the input image i,j
                    // d_L_d_filter(kth filter) = d_L_d_filter(kth filter)+ d_L_d_out(k,i,j)* sub_image(3,3)i,j
                    //       [3] X [3]          =       [3] X [3]         +     gradient    *      [3] X [3]
                    //see article as to how this gradient is computed.
                    d_L_d_filters[k]=addMatrix(d_L_d_filters[k], scaleMatrix(region,d_L_d_out[k][i-1][j-1]));
                }
            }
        }
        
        //update the filter matrix with the gradient matrix obtained above.
        for(int m=0;m<filtersConvolution.length;m++){
          // [3] X [3]  =   [3] X [3] + -lr * [3] X [3]   
            filtersConvolution[m]= addMatrix(filtersConvolution[m], scaleMatrix(d_L_d_filters[m],-learning_rate));
        }  
    }

	 /**
     * caches the input data (the image) for use in the back-propagation phase.
     */
     public static float[][][] inputMaxpool;  // [8] X [26] X [26]
    

    /**
     * caches the output data (the image) for use in the back-propagation phase.
     */
    public static float[][][] outputMaxpool;

	/**
     * performs a 2X2 maximum pooling operation which computes the maximum value
     * contained in each 2X2 sub-region of the input matrix, consequently reducing the
     * size of the input array by half.
     * for example, if A is the max value, then
     * | X Y | ---> | A |
     * | Z A |  
     * @param img the input image matrix. [26] X [26]
     * @return a [13] X [13] 2D array. 
     */
    public static float[][] max_pool(float[][] img) {
        //final array shape is half of the original input shape
        float[][] pool = new float[img.length / 2][img[0].length / 2];
        for (int i = 0; i < pool.length - 1; i++) {
            for (int j = 0; j < pool[0].length - 1; j++) {
                //get the maximum value from the (i,j)th 2X2 sub-region of the input. 
                pool[i][j] = maxMatrix(subMatrix(img, i * 2, i * 2 + 1, j * 2, j * 2 + 1));
            }
        }
        return pool;
    }

	/**
     * performs max pooling for each convolved images (8 in this cases)
     * @param dta the array of convolved images [8] X [26] X [26]
     * @return a [8] X [13] X [13] array
     */
	public static float[][][] maxPooling(float[][][] dta) {
        inputMaxpool = dta;
        float[][][] result = new float[dta.length][dta[0].length][dta[0][0].length];
        for (int k = 0; k < dta.length; k++) {
            float[][] res = max_pool(dta[k]);
            result[k] = res;
        }
        outputMaxpool = result;
        return result;
    }

	public static float[][][] maxpool_backprop(float[][][] d_L_d_out) {
        float[][][] d_L_d_input = new float[inputMaxpool.length][inputMaxpool[0].length][inputMaxpool[0][0].length];
        for (int i = 0; i < outputMaxpool.length; i++) { // filter index 0 - 12 [13 values]
            for (int j = 0; j < outputMaxpool[0].length; j++) { //pool row index 0 -12 [13 values]
                for (int k = 0; k < outputMaxpool[0][0].length; k++) { //pool column index
                    //get 2X2 image region.
                    float[][] region = subMatrix(inputMaxpool[i], j * 2, j * 2 + 1, k * 2, k * 2 + 1);
                    //loop through image region to get row,column index of the maximum value.
                    for (int m = 0; m < region.length; m++) {
                        for (int n = 0; n < region[0].length; n++) {
                            //if the current value in the 2X2 region is the maximum 
                            //then assign the output gradient value to this index on the
                            // [8]X[26]X[26] output gradient matrix.
                            if (Math.abs(outputMaxpool[i][j][k] - region[m][n]) < 0.00000001) {
                                //the index should be translated from local 3X3 to global
                                //i.e. from [m][n] of the 3X3 matrix to [i*2+m][j*2+n] of the grad matrix
                                d_L_d_input[i][j * 2 + m][k * 2 + n] = d_L_d_out[i][j][k];
                            }
                        }
                    }
                }
            }
        }
        return d_L_d_input;
    }

	//cache
	public static float[][] inputFullyConnected;
	public static float[][] outputFullyConnected;	

	public static float[][] flatten(float[][][] mat) {
        float[][] v = new float[1][mat.length * mat[0].length * mat[0][0].length];
        int l = 0; //vector iterator
        for (int i = 0; i < mat.length; i++) {
            for (int j = 0; j < mat[0].length; j++) {
                for (int k = 0; k < mat[0][0].length; k++) {
                    v[0][l] = mat[i][j][k];
                    l++;
                }
            }
        }
        return v;
    }

	public static float[][] fullyConnected(float[][] input, float[][] weights, float[][] biases) {
		// evaluate the total activation value --> t=[i][w]+[b] and cache the totals for backprop
		// [1] X [10] =  [1] X [1342]  * [1342] X [10] + [1] X [10]
        outputFullyConnected = addMatrix(multMatrix(input, weights), biases);
		//cache input
		inputFullyConnected = input;
		return outputFullyConnected;
    }

	public static float[][] softmax(float[][] input) {
        //compute softmax probabilities.
        float[][] totals = expVector(input);
        float inv_activation_sum = 1 / sumVector(totals);
        return scaleVector(totals, inv_activation_sum);
    }

	/**
     * performs the back-propagation phase of the softmax layer. 
     * @param d_L_d_out the gradient vector obtained from the cross-entropy loss vector.
     * @param learning_rate the learning rate of the neural network.
     * @return a gradient matrix with the shape [8] X [13] X [13] to be fed to the
     * maxpooling layer.
     */
    public static float[][][] fullyConnected_backprop(float[][] d_L_d_out, float learning_rate) {
        //gradient of loss w.r.t. the total probabilites of the softmax layer.
        float[][] d_L_d_t = new float[1][d_L_d_out[0].length];
        //repeat softmax probability computations (caching can be used to avoid this.)
        float[][] t_exp = expVector(outputFullyConnected);
        float S = sumVector(t_exp);
        float[][] d_L_d_inputs=null;
        
        for (int i = 0; i < d_L_d_out[0].length; i++) {
            float grad = d_L_d_out[0][i];
            if (grad == 0) {
                continue;
            }
            //gradient of the output layer w.r.t. the totals [1] X [10]
            float[][] d_out_d_t = scaleVector(t_exp, -t_exp[0][i] / (S * S));
            d_out_d_t[0][i] = t_exp[0][i] * (S - t_exp[0][i]) / (S * S);
            
            d_L_d_t = scaleMatrix(d_out_d_t, grad); 
            //gradient of totals w.r.t weights -- [1342] X [1]
            float[][] d_t_d_weight = transposeMatrix(inputFullyConnected);
            //gradient of totals w.r.t inputs -- [1342] X [10] 
            float[][] d_t_d_inputs = fcWeights;
            //gradient of Loss w.r.t. weights ---> chain rule 
            //        [1342] X [10] = [1342] X [1] * [1] X [10]
            float[][] d_L_d_w = multMatrix(d_t_d_weight, d_L_d_t);
            //gradient of Loss w.r.t. inputs ---> chain rule
            // [1342] X [1]      [1342] X [10]    *   [10] X [1](transposed)
            d_L_d_inputs = multMatrix(d_t_d_inputs, transposeMatrix(d_L_d_t));
            //gradient of loss w.r.t. bias
            float[][] d_L_d_b = d_L_d_t;
            //update the weight and bias matrices.
            fcWeights = addMatrix(scaleMatrix(d_L_d_w, -learning_rate), fcWeights);
            fcBiases = addMatrix(scaleMatrix(d_L_d_b, -learning_rate), fcBiases);
        }
        // reshape the final gradient matrix to the input shape of the maxpooling layer.
        // [1] X [1342](transposed) ----> [8] X [13] X [13]
        return reshapeMatrix(transposeMatrix(d_L_d_inputs),8,13,13);
    }	

	public static void zerosVector(float[] input) {
        for (int i = 0; i < input.length; i++) {
            input[i] = 0.0f;
        }
    }

	public static float sumVector(float[][] v) {
        float sum = 0;
        for (int i = 0; i < v[0].length; i++) {
            sum += v[0][i];
        }
        return sum;
    }

	public static float[][] expVector(float[][] v) {
        float[][] exp = new float[1][v[0].length];
        for (int i = 0; i < v[0].length; i++) {
            exp[0][i] = (float) Math.exp(v[0][i]);
        }
        return exp;
    }

	public static float[][] scaleVector(float[][] v, float scale) {
        float[][] scl = new float[1][v[0].length];
        for (int i = 0; i < v[0].length; i++) {
            scl[0][i] = (float) v[0][i] * scale;
        }
        return scl;
    }

	public static int maxIndex(float[] array) {
        int maxIndex = 0;
        for (int i = 1; i < array.length; i++) {
            if (array[i] > array[maxIndex]) {
                maxIndex = i;
            }
        }
        return maxIndex;
    }

	public static float[][] addMatrix(float[][] m1, float[][] m2) {
        float[][] result = new float[m1.length][m1[0].length];
        for (int i = 0; i < m1.length; i++) {
            for(int j = 0; j < m1[0].length; j++){
            result[i][j] = m1[i][j] + m2[i][j];
        }
        }
        return result;
    }

	public static float[][] subMatrix(float[][] mat, int r_s, int r_e, int c_s, int c_e) {
        float[][] sub = new float[r_e - r_s + 1][c_e - c_s + 1];
        for (int i = 0; i < sub.length; i++) {
            for (int j = 0; j < sub[0].length; j++) {
                sub[i][j] = mat[r_s + i][c_s + j];
            }
        }
        return sub;
    }

	public static float[][] multMatrix(float[][] m1, float[][] m2) {
        float[][] result = new float[m1.length][m2[0].length];
        for (int i = 0; i < m1.length; i++) {//row index
            for (int j = 0; j < m2[0].length; j++) {//column index
                for (int k = 0; k < m1[0].length; k++) {
                    result[i][j] += m1[i][k] * m2[k][j];
                }
            }
        }
        return result;
    }

	public static float maxMatrix(float[][] mat) {
        float max = mat[0][0];
        for (int i = 0; i < mat.length; i++) {
            for (int j = 0; j < mat[0].length; j++) {
                max = max < mat[i][j] ? mat[i][j] : max;
            }
        }
        return max;
    }

	public static float elsumxMatrix(float[][] mat1, float[][] mat2) {
        float sum = 0;
        for (int i = 0; i < mat1.length; i++) {
            for (int j = 0; j < mat2[0].length; j++) {
                sum += mat1[i][j] * mat2[i][j];
            }
        }
        return sum;
    }

	public static float[][] scaleMatrix(float[][] mat, float scale) {
        float[][] scl = new float[mat.length][mat[0].length];
        for (int i = 0; i < mat.length; i++) {
            for (int j = 0; j < mat[0].length; j++) {
                scl[i][j] = (float) mat[i][j] * scale;
            }
        }
        return scl;
    }

	public static float[][] transposeMatrix(float[][] mat) {
        float[][] transpose = new float[mat[0].length][mat.length];
        for (int i = 0; i < mat.length; i++) {
            for (int j = 0; j < mat[0].length; j++) {
                transpose[j][i] = mat[i][j];
            }
        }
        return transpose;
    }

	public static float[][][] reshapeMatrix(float[][] input, int d, int h, int w){
        //input --> [1Xn]  output --> [d][h][w]
        float[][][] output=new float[d][h][w];
        int input_index=0;
        for(int i=0;i<d;i++){
            for(int j=0;j<h;j++){
                for(int k=0;k<w;k++){
                    output[i][j][k]=input[0][input_index];
                    input_index++;
                }
            }
        }
        return output;
    }
}

'컴퓨터공학 > CNN' 카테고리의 다른 글

Convolutional Layer(합성곱층)  (0) 2024.08.14
CNN(Convolutional Neural Network)이란  (0) 2024.08.05