Download Dataset from here: https://drive.google.com/drive/folders/1SVZC3eoU_Cbgweov2M_Dts-DzZpq-gIb?usp=sharing
Welcome to this tutorial – Training A Simple Deep Neural Network to Classify Digits In this tutorial, we are going to learn:
1. How to load, inspect and prepare dataset
2. Then we will see how to construct the network layers
3. After that we will learn how to specify the training options
4. Then we will see how to train the network
5. After the training, we will see how to test the network and
6. Finally, we will learn how to evaluate the network using accuracy and confusion matrix
---------------------------------------------------------
MATLAB Code to Train the Network:
---------------------------------------------------------
dataset = imageDatastore("Dataset", 'IncludeSubfolders',true,'LabelSource','foldernames');
figure;
random = randperm(10000,20);
for i = 1:20
subplot(4,5,i);
imshow(dataset.Files{random(i)});
end
image_label = countEachLabel(dataset)
img = readimage(dataset,1);
size(img)
spliting_ratio = 0.75;
[training,validation] = splitEachLabel(dataset,spliting_ratio,'randomize');
layers = [
imageInputLayer([28 28 1])
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
options = trainingOptions('sgdm', ...
'InitialLearnRate',0.01, ...
'MaxEpochs',4, ...
'Shuffle','every-epoch', ...
'ValidationData',validation, ...
'ValidationFrequency',30, ...
'Verbose',false, ...
'Plots','training-progress');
net = trainNetwork(training,layers,options);
---------------------------------------------------------
MATLAB Code to Test the network:
---------------------------------------------------------
prediction = classify(net,validation);
actual = validation.Labels;
correct = sum(prediction==actual)
total = numel(actual)
accuracy = (correct/total)*100
figure;
plotconfusion(actual, prediction)
figure;
cm = confusionchart(actual, prediction)
cm.RowSummary = 'row-normalized';
cm.ColumnSummary = 'column-normalized';
0 Comments