1.加载数据集,并划分训练数据集和验证数据集

clc;close all;clear;
imds=imageDatastore('C:\\Users\\Administrator\\Desktop\\class08\\train_small\\train_small',...
    'IncludeSubfolders',true,'LabelSource','foldernames');
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');

numTrainImages= numel(imdsTrain.Labels);
idx =randperm(numTrainImages,16);
figure
for i=1:16
    subplot(4,4,i)
    I=readimage(imdsTrain,idx(i));
    imshow(I)
end

2.加载预训网络

net=squeezenet;
analyzeNetwork(net);
inputSize=net.Layers(1).InputSize;

3.分类层制定网络的输出累。将分类层替换过为没有类标签的新分类层。trainNetwork会训练时自动设置层的输出类

if isa(net,'SeriesNetwork')
    lgraph=layerGraph(net.Layers);
else
   lgraph=layerGraph(net);
end
 [learnableLayer,classLayer]=findLayersToReplace(lgraph)   ;
 
 numClasses = numel(categories(imdsTrain.Labels));
 
 if isa(learnableLayer,'nnet.cnn.layer.FullyConnetctedLayer')
     newLearnableLayer = fullyConnectedLayer(numClasses,...
         'Name','new_fc',...
         'WeightLearnRateFactor',10,...
         'BiasLearnRateFactor',10);
 elseif isa(learnableLayer,'nnet.cnn.layer.Convolution2DLayer')
     newLearnableLayer = convolution2dLayer(1,numClasses,...
          'Name','new_conv',...
         'WeightLearnRateFactor',10,...
         'BiasLearnRateFactor',10);
 end
 lgraph =replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer);

newClassLayer=classificationLayer('Name','new_classoutput');
 lgraph=replaceLayer(lgraph,classLayer.Name,newClassLayer);
 analyzeNetwork(lgraph);

4.冻结初始层-提取层次图和连接,并选择要冻结的层

	layers=lgraph.Layers;
 connections=lgraph.Connections;
 
 layers(1:43)=freezeWeights(layers(1:43));
 lgraph=createLgraphUsingConnections(layers,connections);

5.调整输入层图片的尺寸重构成227x227像素

augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain);
 augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);

6.训练网络

miniBatchSize=10;
 valFrequency = floor(numel(augimdsTrain.Files)/miniBatchSize);
 options = trainingOptions('sgdm',...
     'MiniBatchSize',miniBatchSize,...
     'MaxEpochs',6,...
     'InitialLearnRate',3e-4,...
     'Shuffle','every-epoch',...
     'ValidationData',augimdsValidation,...
     'ValidationFrequency',valFrequency,...
     'Verbose',false,...
     'Plots','training-progress');
 net=trainNetwork(augimdsTrain,lgraph,options);

7.对验证图像进行分类

[YPred,probs] = classify(net,augimdsValidation);
 accuracy = mean(YPred==imdsValidation.Labels)
 idx =randperm(numel(imdsValidation.Files),16);
 figure
 for i =1:16
    subplot(4,4,i)
    I=readimage(imdsValidation,idx(i));
    imshow(I)
    label=YPred(idx(i));
    title(string(label)+","+num2str(100*max(probs(idx(i),:)),3)+"%")
 end