Why are the gradients not backpropagating into the encoder in this custom loop?
I am building a convolutional autoencoder using a custom training loop. When I attempt to reconstruct the images, the network’s output degenerates to guessing the same incorrect value for all inputs. However, training the autoencoder in a single stack with the trainnet function works fine, indicating that the gradient updates are unable to bridge the bottleneck layer in the custom training loop. Unfortunately, I need to use the custom training loop for a different task and am prohibited from using TensorFlow or PyTorch.
What is the syntax to ensure that the encoder is able to update based on the decoder’s reconstruction performance?
%% Functional ‘trainnet’ loop
clear
close all
clc
% Get handwritten digit data
xTrain = digitTrain4DArrayData;
xTest = digitTest4DArrayData;
% Check that all pixel values are min-max scaled
assert(max(xTrain(:)) == 1); assert(min(xTrain(:)) == 0);
assert(max(xTest(:)) == 1); assert(min(xTest(:)) == 0);
imageSize = [28 28 1];
%% Layer definitions
% Latent projection
projectionSize = [7 7 64];
numInputChannels = imageSize(3);
% Decoder
aeLayers = [
imageInputLayer(imageSize)
convolution2dLayer(3,32,Padding="same",Stride=2)
reluLayer
convolution2dLayer(3,64,Padding="same",Stride=2)
reluLayer
transposedConv2dLayer(3,64,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,32,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,numInputChannels,Cropping="same")
sigmoidLayer(Name=’Output’)
];
autoencoder = dlnetwork(aeLayers);
%% Training Parameters
numEpochs = 150;
miniBatchSize = 25;
learnRate = 1e-3;
options = trainingOptions("adam", …
InitialLearnRate=learnRate,…
MaxEpochs=30, …
Plots="training-progress", …
TargetDataFormats="SSCB", …
InputDataFormats="SSCB", …
MiniBatchSize=miniBatchSize, …
OutputNetwork="last-iteration", …
Shuffle="every-epoch");
autoencoder = trainnet(dlarray(xTrain, ‘SSCB’),dlarray(xTrain, ‘SSCB’), …
autoencoder, ‘mse’, options);
%% Testing
YTest = predict(autoencoder, dlarray(xTest, ‘SSCB’));
indices = randi(size(xTest, 4), [1, size(xTest, 4)]); % Shuffle YTest & xTest
xTest = xTest(:,:,:,indices); YTest = YTest(:,:,:,indices);
% Display test images
numImages = 64;
figure
subplot(1,2,1)
preds = extractdata(YTest(:,:,:,1:numImages));
I = imtile(preds);
imshow(I)
title("Reconstructed Images")
subplot(1,2,2)
orgs = xTest(:,:,:,1:numImages);
I = imtile(orgs);
imshow(I)
title("Original Images")
%% Nonfunctional Custom Training Loop
clear
close all
clc
% Get handwritten digit data
xTrain = digitTrain4DArrayData;
xTest = digitTest4DArrayData;
% Check that all pixel values are min-max scaled
assert(max(xTrain(:)) == 1); assert(min(xTrain(:)) == 0);
assert(max(xTest(:)) == 1); assert(min(xTest(:)) == 0);
imageSize = [28 28 1];
%% Layer definitions
% Encoder
layersE = [
imageInputLayer(imageSize)
convolution2dLayer(3,32,Padding="same",Stride=2)
reluLayer
convolution2dLayer(3,64,Padding="same",Stride=2)
reluLayer];
% Latent projection
projectionSize = [7 7 64];
numInputChannels = imageSize(3);
% Decoder
layersD = [
imageInputLayer(projectionSize)
transposedConv2dLayer(3,64,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,32,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,numInputChannels,Cropping="same")
sigmoidLayer(Name=’Output’)
];
netE = dlnetwork(layersE);
netD = dlnetwork(layersD);
%% Training Parameters
numEpochs = 150;
miniBatchSize = 25;
learnRate = 1e-3;
% Create training minibatchqueue
dsTrain = arrayDatastore(xTrain,IterationDimension=4);
numOutputs = 1;
mbq = minibatchqueue(dsTrain,numOutputs, …
MiniBatchSize = miniBatchSize, …
MiniBatchFormat="SSCB", …
MiniBatchFcn=@preprocessMiniBatch,…
PartialMiniBatch="return");
%Initialize the parameters for the Adam solver.
trailingAvgE = [];
trailingAvgSqE = [];
trailingAvgD = [];
trailingAvgSqD = [];
%Calculate the total number of iterations for the training progress monitor
numIterationsPerEpoch = ceil(size(xTrain, 4) / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
epoch = 0;
iteration = 0;
%Initialize the training progress monitor.
monitor = trainingProgressMonitor( …
Metrics="TrainingLoss", …
Info=["Epoch", "LearningRate"], …
XLabel="Iteration");
%% Training
while epoch < numEpochs && ~monitor.Stop
epoch = epoch + 1;
% Shuffle data.
shuffle(mbq);
% Loop over mini-batches.
while hasdata(mbq) && ~monitor.Stop
% Assess validation criterion
iteration = iteration + 1;
% Read mini-batch of data.
X = next(mbq);
% Evaluate loss and gradients.
[loss,gradientsE,gradientsD] = dlfeval(@modelLoss,netE,netD,X);
% Update learnable parameters.
[netE,trailingAvgE,trailingAvgSqE] = adamupdate(netE, …
gradientsE,trailingAvgE,trailingAvgSqE,iteration,learnRate);
[netD, trailingAvgD, trailingAvgSqD] = adamupdate(netD, …
gradientsD,trailingAvgD,trailingAvgSqD,iteration,learnRate);
updateInfo(monitor, …
LearningRate=learnRate, …
Epoch=string(epoch) + " of " + string(numEpochs));
recordMetrics(monitor,iteration, …
TrainingLoss=loss);
monitor.Progress = 100*iteration/numIterations;
end
end
%% Testing
dsTest = arrayDatastore(xTest,IterationDimension=4);
numOutputs = 1;
ntest = size(xTest, 4);
indices = randi(ntest,[1,ntest]);
xTest = xTest(:,:,:,indices);% Shuffle test data
mbqTest = minibatchqueue(dsTest,numOutputs, …
MiniBatchSize = miniBatchSize, …
MiniBatchFcn=@preprocessMiniBatch, …
MiniBatchFormat="SSCB");
YTest = modelPredictions(netE,netD,mbqTest);
% Display test images
numImages = 64;
figure
subplot(1,2,1)
preds = YTest(:,:,:,1:numImages);
I = imtile(preds);
imshow(I)
title("Reconstructed Images")
subplot(1,2,2)
orgs = xTest(:,:,:,1:numImages);
I = imtile(orgs);
imshow(I)
title("Original Images")
%% Functions
function [loss,gradientsE,gradientsD] = modelLoss(netE,netD,X)
% Forward through encoder.
Z = forward(netE,X);
% Forward through decoder.
Xrecon = forward(netD,Z);
% Calculate loss and gradients.
loss = regularizedLoss(Xrecon,X);
[gradientsE,gradientsD] = dlgradient(loss,netE.Learnables,netD.Learnables);
end
function loss = regularizedLoss(Xrecon,X)
% Image Reconstruction loss.
reconstructionLoss = l2loss(Xrecon, X, ‘NormalizationFactor’,’all-elements’);
% Combined loss.
loss = reconstructionLoss;
end
function Xrecon = modelPredictions(netE,netD,mbq)
Xrecon = [];
shuffle(mbq)
% Loop over mini-batches.
while hasdata(mbq)
X = next(mbq);
% Pass through encoder
Z = predict(netE,X);
% Pass through decoder to get reconstructed images
XGenerated = predict(netD,Z);
% Extract and concatenate predictions.
Xrecon = cat(4,Xrecon,extractdata(XGenerated));
end
end
function X = preprocessMiniBatch(Xcell)
% Concatenate.
X = cat(4,Xcell{:});
endI am building a convolutional autoencoder using a custom training loop. When I attempt to reconstruct the images, the network’s output degenerates to guessing the same incorrect value for all inputs. However, training the autoencoder in a single stack with the trainnet function works fine, indicating that the gradient updates are unable to bridge the bottleneck layer in the custom training loop. Unfortunately, I need to use the custom training loop for a different task and am prohibited from using TensorFlow or PyTorch.
What is the syntax to ensure that the encoder is able to update based on the decoder’s reconstruction performance?
%% Functional ‘trainnet’ loop
clear
close all
clc
% Get handwritten digit data
xTrain = digitTrain4DArrayData;
xTest = digitTest4DArrayData;
% Check that all pixel values are min-max scaled
assert(max(xTrain(:)) == 1); assert(min(xTrain(:)) == 0);
assert(max(xTest(:)) == 1); assert(min(xTest(:)) == 0);
imageSize = [28 28 1];
%% Layer definitions
% Latent projection
projectionSize = [7 7 64];
numInputChannels = imageSize(3);
% Decoder
aeLayers = [
imageInputLayer(imageSize)
convolution2dLayer(3,32,Padding="same",Stride=2)
reluLayer
convolution2dLayer(3,64,Padding="same",Stride=2)
reluLayer
transposedConv2dLayer(3,64,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,32,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,numInputChannels,Cropping="same")
sigmoidLayer(Name=’Output’)
];
autoencoder = dlnetwork(aeLayers);
%% Training Parameters
numEpochs = 150;
miniBatchSize = 25;
learnRate = 1e-3;
options = trainingOptions("adam", …
InitialLearnRate=learnRate,…
MaxEpochs=30, …
Plots="training-progress", …
TargetDataFormats="SSCB", …
InputDataFormats="SSCB", …
MiniBatchSize=miniBatchSize, …
OutputNetwork="last-iteration", …
Shuffle="every-epoch");
autoencoder = trainnet(dlarray(xTrain, ‘SSCB’),dlarray(xTrain, ‘SSCB’), …
autoencoder, ‘mse’, options);
%% Testing
YTest = predict(autoencoder, dlarray(xTest, ‘SSCB’));
indices = randi(size(xTest, 4), [1, size(xTest, 4)]); % Shuffle YTest & xTest
xTest = xTest(:,:,:,indices); YTest = YTest(:,:,:,indices);
% Display test images
numImages = 64;
figure
subplot(1,2,1)
preds = extractdata(YTest(:,:,:,1:numImages));
I = imtile(preds);
imshow(I)
title("Reconstructed Images")
subplot(1,2,2)
orgs = xTest(:,:,:,1:numImages);
I = imtile(orgs);
imshow(I)
title("Original Images")
%% Nonfunctional Custom Training Loop
clear
close all
clc
% Get handwritten digit data
xTrain = digitTrain4DArrayData;
xTest = digitTest4DArrayData;
% Check that all pixel values are min-max scaled
assert(max(xTrain(:)) == 1); assert(min(xTrain(:)) == 0);
assert(max(xTest(:)) == 1); assert(min(xTest(:)) == 0);
imageSize = [28 28 1];
%% Layer definitions
% Encoder
layersE = [
imageInputLayer(imageSize)
convolution2dLayer(3,32,Padding="same",Stride=2)
reluLayer
convolution2dLayer(3,64,Padding="same",Stride=2)
reluLayer];
% Latent projection
projectionSize = [7 7 64];
numInputChannels = imageSize(3);
% Decoder
layersD = [
imageInputLayer(projectionSize)
transposedConv2dLayer(3,64,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,32,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,numInputChannels,Cropping="same")
sigmoidLayer(Name=’Output’)
];
netE = dlnetwork(layersE);
netD = dlnetwork(layersD);
%% Training Parameters
numEpochs = 150;
miniBatchSize = 25;
learnRate = 1e-3;
% Create training minibatchqueue
dsTrain = arrayDatastore(xTrain,IterationDimension=4);
numOutputs = 1;
mbq = minibatchqueue(dsTrain,numOutputs, …
MiniBatchSize = miniBatchSize, …
MiniBatchFormat="SSCB", …
MiniBatchFcn=@preprocessMiniBatch,…
PartialMiniBatch="return");
%Initialize the parameters for the Adam solver.
trailingAvgE = [];
trailingAvgSqE = [];
trailingAvgD = [];
trailingAvgSqD = [];
%Calculate the total number of iterations for the training progress monitor
numIterationsPerEpoch = ceil(size(xTrain, 4) / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
epoch = 0;
iteration = 0;
%Initialize the training progress monitor.
monitor = trainingProgressMonitor( …
Metrics="TrainingLoss", …
Info=["Epoch", "LearningRate"], …
XLabel="Iteration");
%% Training
while epoch < numEpochs && ~monitor.Stop
epoch = epoch + 1;
% Shuffle data.
shuffle(mbq);
% Loop over mini-batches.
while hasdata(mbq) && ~monitor.Stop
% Assess validation criterion
iteration = iteration + 1;
% Read mini-batch of data.
X = next(mbq);
% Evaluate loss and gradients.
[loss,gradientsE,gradientsD] = dlfeval(@modelLoss,netE,netD,X);
% Update learnable parameters.
[netE,trailingAvgE,trailingAvgSqE] = adamupdate(netE, …
gradientsE,trailingAvgE,trailingAvgSqE,iteration,learnRate);
[netD, trailingAvgD, trailingAvgSqD] = adamupdate(netD, …
gradientsD,trailingAvgD,trailingAvgSqD,iteration,learnRate);
updateInfo(monitor, …
LearningRate=learnRate, …
Epoch=string(epoch) + " of " + string(numEpochs));
recordMetrics(monitor,iteration, …
TrainingLoss=loss);
monitor.Progress = 100*iteration/numIterations;
end
end
%% Testing
dsTest = arrayDatastore(xTest,IterationDimension=4);
numOutputs = 1;
ntest = size(xTest, 4);
indices = randi(ntest,[1,ntest]);
xTest = xTest(:,:,:,indices);% Shuffle test data
mbqTest = minibatchqueue(dsTest,numOutputs, …
MiniBatchSize = miniBatchSize, …
MiniBatchFcn=@preprocessMiniBatch, …
MiniBatchFormat="SSCB");
YTest = modelPredictions(netE,netD,mbqTest);
% Display test images
numImages = 64;
figure
subplot(1,2,1)
preds = YTest(:,:,:,1:numImages);
I = imtile(preds);
imshow(I)
title("Reconstructed Images")
subplot(1,2,2)
orgs = xTest(:,:,:,1:numImages);
I = imtile(orgs);
imshow(I)
title("Original Images")
%% Functions
function [loss,gradientsE,gradientsD] = modelLoss(netE,netD,X)
% Forward through encoder.
Z = forward(netE,X);
% Forward through decoder.
Xrecon = forward(netD,Z);
% Calculate loss and gradients.
loss = regularizedLoss(Xrecon,X);
[gradientsE,gradientsD] = dlgradient(loss,netE.Learnables,netD.Learnables);
end
function loss = regularizedLoss(Xrecon,X)
% Image Reconstruction loss.
reconstructionLoss = l2loss(Xrecon, X, ‘NormalizationFactor’,’all-elements’);
% Combined loss.
loss = reconstructionLoss;
end
function Xrecon = modelPredictions(netE,netD,mbq)
Xrecon = [];
shuffle(mbq)
% Loop over mini-batches.
while hasdata(mbq)
X = next(mbq);
% Pass through encoder
Z = predict(netE,X);
% Pass through decoder to get reconstructed images
XGenerated = predict(netD,Z);
% Extract and concatenate predictions.
Xrecon = cat(4,Xrecon,extractdata(XGenerated));
end
end
function X = preprocessMiniBatch(Xcell)
% Concatenate.
X = cat(4,Xcell{:});
end I am building a convolutional autoencoder using a custom training loop. When I attempt to reconstruct the images, the network’s output degenerates to guessing the same incorrect value for all inputs. However, training the autoencoder in a single stack with the trainnet function works fine, indicating that the gradient updates are unable to bridge the bottleneck layer in the custom training loop. Unfortunately, I need to use the custom training loop for a different task and am prohibited from using TensorFlow or PyTorch.
What is the syntax to ensure that the encoder is able to update based on the decoder’s reconstruction performance?
%% Functional ‘trainnet’ loop
clear
close all
clc
% Get handwritten digit data
xTrain = digitTrain4DArrayData;
xTest = digitTest4DArrayData;
% Check that all pixel values are min-max scaled
assert(max(xTrain(:)) == 1); assert(min(xTrain(:)) == 0);
assert(max(xTest(:)) == 1); assert(min(xTest(:)) == 0);
imageSize = [28 28 1];
%% Layer definitions
% Latent projection
projectionSize = [7 7 64];
numInputChannels = imageSize(3);
% Decoder
aeLayers = [
imageInputLayer(imageSize)
convolution2dLayer(3,32,Padding="same",Stride=2)
reluLayer
convolution2dLayer(3,64,Padding="same",Stride=2)
reluLayer
transposedConv2dLayer(3,64,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,32,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,numInputChannels,Cropping="same")
sigmoidLayer(Name=’Output’)
];
autoencoder = dlnetwork(aeLayers);
%% Training Parameters
numEpochs = 150;
miniBatchSize = 25;
learnRate = 1e-3;
options = trainingOptions("adam", …
InitialLearnRate=learnRate,…
MaxEpochs=30, …
Plots="training-progress", …
TargetDataFormats="SSCB", …
InputDataFormats="SSCB", …
MiniBatchSize=miniBatchSize, …
OutputNetwork="last-iteration", …
Shuffle="every-epoch");
autoencoder = trainnet(dlarray(xTrain, ‘SSCB’),dlarray(xTrain, ‘SSCB’), …
autoencoder, ‘mse’, options);
%% Testing
YTest = predict(autoencoder, dlarray(xTest, ‘SSCB’));
indices = randi(size(xTest, 4), [1, size(xTest, 4)]); % Shuffle YTest & xTest
xTest = xTest(:,:,:,indices); YTest = YTest(:,:,:,indices);
% Display test images
numImages = 64;
figure
subplot(1,2,1)
preds = extractdata(YTest(:,:,:,1:numImages));
I = imtile(preds);
imshow(I)
title("Reconstructed Images")
subplot(1,2,2)
orgs = xTest(:,:,:,1:numImages);
I = imtile(orgs);
imshow(I)
title("Original Images")
%% Nonfunctional Custom Training Loop
clear
close all
clc
% Get handwritten digit data
xTrain = digitTrain4DArrayData;
xTest = digitTest4DArrayData;
% Check that all pixel values are min-max scaled
assert(max(xTrain(:)) == 1); assert(min(xTrain(:)) == 0);
assert(max(xTest(:)) == 1); assert(min(xTest(:)) == 0);
imageSize = [28 28 1];
%% Layer definitions
% Encoder
layersE = [
imageInputLayer(imageSize)
convolution2dLayer(3,32,Padding="same",Stride=2)
reluLayer
convolution2dLayer(3,64,Padding="same",Stride=2)
reluLayer];
% Latent projection
projectionSize = [7 7 64];
numInputChannels = imageSize(3);
% Decoder
layersD = [
imageInputLayer(projectionSize)
transposedConv2dLayer(3,64,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,32,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,numInputChannels,Cropping="same")
sigmoidLayer(Name=’Output’)
];
netE = dlnetwork(layersE);
netD = dlnetwork(layersD);
%% Training Parameters
numEpochs = 150;
miniBatchSize = 25;
learnRate = 1e-3;
% Create training minibatchqueue
dsTrain = arrayDatastore(xTrain,IterationDimension=4);
numOutputs = 1;
mbq = minibatchqueue(dsTrain,numOutputs, …
MiniBatchSize = miniBatchSize, …
MiniBatchFormat="SSCB", …
MiniBatchFcn=@preprocessMiniBatch,…
PartialMiniBatch="return");
%Initialize the parameters for the Adam solver.
trailingAvgE = [];
trailingAvgSqE = [];
trailingAvgD = [];
trailingAvgSqD = [];
%Calculate the total number of iterations for the training progress monitor
numIterationsPerEpoch = ceil(size(xTrain, 4) / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
epoch = 0;
iteration = 0;
%Initialize the training progress monitor.
monitor = trainingProgressMonitor( …
Metrics="TrainingLoss", …
Info=["Epoch", "LearningRate"], …
XLabel="Iteration");
%% Training
while epoch < numEpochs && ~monitor.Stop
epoch = epoch + 1;
% Shuffle data.
shuffle(mbq);
% Loop over mini-batches.
while hasdata(mbq) && ~monitor.Stop
% Assess validation criterion
iteration = iteration + 1;
% Read mini-batch of data.
X = next(mbq);
% Evaluate loss and gradients.
[loss,gradientsE,gradientsD] = dlfeval(@modelLoss,netE,netD,X);
% Update learnable parameters.
[netE,trailingAvgE,trailingAvgSqE] = adamupdate(netE, …
gradientsE,trailingAvgE,trailingAvgSqE,iteration,learnRate);
[netD, trailingAvgD, trailingAvgSqD] = adamupdate(netD, …
gradientsD,trailingAvgD,trailingAvgSqD,iteration,learnRate);
updateInfo(monitor, …
LearningRate=learnRate, …
Epoch=string(epoch) + " of " + string(numEpochs));
recordMetrics(monitor,iteration, …
TrainingLoss=loss);
monitor.Progress = 100*iteration/numIterations;
end
end
%% Testing
dsTest = arrayDatastore(xTest,IterationDimension=4);
numOutputs = 1;
ntest = size(xTest, 4);
indices = randi(ntest,[1,ntest]);
xTest = xTest(:,:,:,indices);% Shuffle test data
mbqTest = minibatchqueue(dsTest,numOutputs, …
MiniBatchSize = miniBatchSize, …
MiniBatchFcn=@preprocessMiniBatch, …
MiniBatchFormat="SSCB");
YTest = modelPredictions(netE,netD,mbqTest);
% Display test images
numImages = 64;
figure
subplot(1,2,1)
preds = YTest(:,:,:,1:numImages);
I = imtile(preds);
imshow(I)
title("Reconstructed Images")
subplot(1,2,2)
orgs = xTest(:,:,:,1:numImages);
I = imtile(orgs);
imshow(I)
title("Original Images")
%% Functions
function [loss,gradientsE,gradientsD] = modelLoss(netE,netD,X)
% Forward through encoder.
Z = forward(netE,X);
% Forward through decoder.
Xrecon = forward(netD,Z);
% Calculate loss and gradients.
loss = regularizedLoss(Xrecon,X);
[gradientsE,gradientsD] = dlgradient(loss,netE.Learnables,netD.Learnables);
end
function loss = regularizedLoss(Xrecon,X)
% Image Reconstruction loss.
reconstructionLoss = l2loss(Xrecon, X, ‘NormalizationFactor’,’all-elements’);
% Combined loss.
loss = reconstructionLoss;
end
function Xrecon = modelPredictions(netE,netD,mbq)
Xrecon = [];
shuffle(mbq)
% Loop over mini-batches.
while hasdata(mbq)
X = next(mbq);
% Pass through encoder
Z = predict(netE,X);
% Pass through decoder to get reconstructed images
XGenerated = predict(netD,Z);
% Extract and concatenate predictions.
Xrecon = cat(4,Xrecon,extractdata(XGenerated));
end
end
function X = preprocessMiniBatch(Xcell)
% Concatenate.
X = cat(4,Xcell{:});
end deep learning, autoencoder, regularization, initialization, custom loops MATLAB Answers — New Questions