Simple Deep Neural Network to Classify Digits

Simple Deep Neural Network to Classify Digits

Download Dataset from here: https://drive.google.com/drive/folders/1SVZC3eoU_Cbgweov2M_Dts-DzZpq-gIb?usp=sharing

Welcome to this tutorial – Training A Simple Deep Neural Network to Classify Digits In this tutorial, we are going to learn:
1. How to load, inspect and prepare dataset
2. Then we will see how to construct the network layers
3. After that we will learn how to specify the training options
4. Then we will see how to train the network
5. After the training, we will see how to test the network and
6. Finally, we will learn how to evaluate the network using accuracy and confusion matrix

---------------------------------------------------------
MATLAB Code to Train the Network:
---------------------------------------------------------

dataset = imageDatastore("Dataset", 'IncludeSubfolders',true,'LabelSource','foldernames');
figure;
random = randperm(10000,20);
for i = 1:20
subplot(4,5,i);
imshow(dataset.Files{random(i)});
end

image_label = countEachLabel(dataset)
img = readimage(dataset,1);
size(img)

spliting_ratio = 0.75;
[training,validation] = splitEachLabel(dataset,spliting_ratio,'randomize');

layers = [
imageInputLayer([28 28 1])
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)

convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer

maxPooling2dLayer(2,'Stride',2)

convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer

fullyConnectedLayer(10)
softmaxLayer
classificationLayer];

options = trainingOptions('sgdm', ...
'InitialLearnRate',0.01, ...
'MaxEpochs',4, ...
'Shuffle','every-epoch', ...
'ValidationData',validation, ...
'ValidationFrequency',30, ...
'Verbose',false, ...
'Plots','training-progress');

net = trainNetwork(training,layers,options);

---------------------------------------------------------
MATLAB Code to Test the network:
---------------------------------------------------------

prediction = classify(net,validation);
actual = validation.Labels;

correct = sum(prediction==actual)
total = numel(actual)

accuracy = (correct/total)*100

figure;
plotconfusion(actual, prediction)

figure;
cm = confusionchart(actual, prediction)
cm.RowSummary = 'row-normalized';
cm.ColumnSummary = 'column-normalized';

training a neural networkdeep neural networkhow to train a neural network

Post a Comment

0 Comments