How can I make custom RNN training as efficient / fast as standard LSTM / GRU training?
Hi everyone!
Context: I’m working on a research project to perform dynamical system identification using a custom Recurrent Neural Network (RNN), specifically a Minimal Gated Unit (MGU) as described in Minimal Gated Unit for Recurrent Neural Networks.
I’ve implemented my MGU layer and a custom training loop following MATLAB’s tutorials (e.g.,Define Custom Recurrent Deep Learning Layer, Train Network Using Custom Training Loop, …).
Problem: I’m encountering significantly slower training times with my custom MGU layer compared to using MATLAB’s built-in `lstmLayer` or `gruLayer`. As noted in discussion Why peepholeLSTMLayer implemented in a tutorial is much slower than built-in lstmlayer?, custom RNN implementations can be slow, and my experience aligns with this. My training has much slower training if used a MGU layer rather than for instance a built-in `gruLayer` or `lstmLayer`.
I’ve thoroughly analyzed the `lstmLayer` and `gruLayer` implementations within the Deep Learning Toolbox. It appears their performance advantage stems from highly optimized, likely C-based, built-in functions that are not accessible for custom layer development.
I’ve attempted to use acceleration techniques within MATLAB, but these haven’t yielded significant improvements for the custom layer’s speed bottleneck. I briefly explored creating custom C/MEX scripts but lack the expertise to implement this effectively (i.e., did not even compile or work in any manner).
Question: Is the observed performance disparity an inherent limitation when implementing custom RNN layers in MATLAB, or are there advanced strategies or techniques to significantly accelerate the training of custom recurrent layers (like MGU) to achieve performance closer to that of built-in `lstmLayer` or `gruLayer`? I’m aiming to at least reach comparable training velocity.
Additional Context
I am using MATLAB R2024b
I use a GPU for training (but even switching to CPU does no change)
For comparison, I am also implementing custom LSTM and GRU layers to benchmark their relative speeds against my custom MGU and the built-in layers.
Code
Below are the code snippets for my MGU layer and the custom training loop, which might be helpful for understanding the context:
classdef mguLayer < nnet.layer.Layer & nnet.layer.Formattable
%MGULAYER Minimal Gated Unit Layer
properties
% Layer properties.
NumHiddenUnits
OutputMode
end
properties (Learnable)
% Layer learnable parameters.
InputWeights
RecurrentWeights
Bias
end
properties (State)
% Layer state parameters.
HiddenState
end
methods
function layer = mguLayer(numHiddenUnits,args)
%MGULAYER Minimal Gated Unit Layer
% layer = MGULayer(numHiddenUnits)
% creates a MGU layer with the specified number of
% hidden units.
%
% layer = MGULayer(numHiddenUnits,Name=Value)
% creates a MGU layer and specifies additional
% options using one or more name-value arguments:
%
% Name – Name of the layer, specified as a string.
% The default is "".
%
% OutputMode – Output mode, specified as one of the
% following:
% "sequence" – Output the entire sequence
% of data.
%
% "last" – Output the last time step
% of the data.
% The default is "sequence".
% Parse input arguments.
arguments
numHiddenUnits
args.Name = "";
args.OutputMode = "sequence";
end
layer.NumHiddenUnits = numHiddenUnits;
layer.Name = args.Name;
layer.OutputMode = args.OutputMode;
% Set layer description.
layer.Description = "MGU with " + numHiddenUnits + " hidden units";
end
function layer = initialize(layer,layout)
% layer = initialize(layer,layout) initializes the layer
% learnable and state parameters.
%
% Inputs:
% layer – Layer to initialize.
% layout – Data layout, specified as a
% networkDataLayout object.
%
% Outputs:
% layer – Initialized layer.
numHiddenUnits = layer.NumHiddenUnits;
% Find number of channels.
idx = finddim(layout,"C");
numChannels = layout.Size(idx);
% Initialize input weights.
if isempty(layer.InputWeights)
sz = [2*numHiddenUnits numChannels]; % Only 2 gates
numOut = 2*numHiddenUnits;
numIn = numChannels;
layer.InputWeights = initializeGlorot(sz,numOut,numIn);
end
% Initialize recurrent weights.
if isempty(layer.RecurrentWeights)
sz = [2*numHiddenUnits numHiddenUnits]; % Only 2 gates
layer.RecurrentWeights = initializeOrthogonal(sz);
end
% Initialize bias.
if isempty(layer.Bias)
layer.Bias = initializeBias(numHiddenUnits);
end
% Initialize hidden state.
if isempty(layer.HiddenState)
layer.HiddenState = zeros(numHiddenUnits,1);
end
end
function [Y,hiddenState] = predict(layer,X)
%PREDICT MGU predict function
% [Y,hiddenState] = predict(layer,X) forward
% propagates the data X through the layer and returns the
% layer output Y and the updated hidden state. X
% is a dlarray with format "CBT" and Y is a dlarray with
% format "CB" or "CBT", depending on the layer OutputMode
% property.
% Initialize sequence output.
numHiddenUnits = layer.NumHiddenUnits;
miniBatchSize = size(X,finddim(X,"B"));
numTimeSteps = size(X,finddim(X,"T"));
if layer.OutputMode == "sequence"
Y = zeros(numHiddenUnits,miniBatchSize,numTimeSteps,like=X);
Y = dlarray(Y,"CBT");
end
% Calculate WX + b.
X = stripdims(X);
WX = pagemtimes(layer.InputWeights,X) + layer.Bias; % Input weight * input + bias
% Indices of concatenated weight arrays.
idxf = 1:numHiddenUnits; % Forget gate
idxh = 1+numHiddenUnits:2*numHiddenUnits; % Candidate hidden state
% Initial states.
hiddenState = layer.HiddenState;
% Loop over time steps.
for t = 1:numTimeSteps
% Forget Gate computation
% f_k = sigma( (W_f * u_k + b_f) + R_f * h_{k-1} )
recurrent_term_f = layer.RecurrentWeights(idxf,:) * hiddenState;
ft = sigmoid(WX(idxf,:,t) + recurrent_term_f);
% Candidate hidden state computation
% h_tilde_k = tanh( (W_h_tilde * u_k + b_h_tilde) + R_h_tilde * (f_k .* h_{k-1}) )
modulated_hidden_state_for_candidate = ft .* hiddenState;
recurrent_term_h_tilde = layer.RecurrentWeights(idxh,:) * modulated_hidden_state_for_candidate;
htildet = tanh(WX(idxh,:,t) + recurrent_term_h_tilde);
% Update hidden state
% h_k = (1 – f_k) .* h_{k-1} + f_k .* h_tilde_k
hiddenState = (1 – ft) .* hiddenState + ft .* htildet;
% Update sequence output
if layer.OutputMode == "sequence"
Y(:,:,t) = hiddenState;
end
end
% Last time step output.
if layer.OutputMode == "last"
Y = dlarray(hiddenState,"CB");
end
end
function layer = resetState(layer)
%RESETSTATE Reset layer state
% layer = resetState(layer) resets the state properties of the
% layer.
numHiddenUnits = layer.NumHiddenUnits;
layer.HiddenState = zeros(numHiddenUnits,1);
end
end
end
function weights = initializeGlorot(sz,numOut,numIn,className)
arguments
sz
numOut
numIn
className = ‘single’
end
Z = 2*rand(sz,className) – 1;
bound = sqrt(6 / (numIn + numOut));
weights = bound * Z;
weights = dlarray(weights);
end
function parameter = initializeOrthogonal(sz)
Z = randn(sz,’single’);
[Q,R] = qr(Z,0);
D = diag(R);
Q = Q * diag(D ./ abs(D));
parameter = dlarray(Q);
end
function bias = initializeBias(numHiddenUnits)
bias = zeros(2*numHiddenUnits,1,’single’);
idx = 1:numHiddenUnits;
bias(idx) = 1;
bias = dlarray(bias);
end
% Trains a GRNN network, either LSTM or GRU. Provides a network with specified characteristics, at the minimum validation RMSE.
function [net,info,monitor,net_name] = train_loop(train_dataset, valid_dataset, train_options)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% UNPACK OPTIONS
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
hidden_units = train_options.hidden_units;
dropout_rate = train_options.dropout_rate;
learn_rate = train_options.learn_rate;
max_epochs = train_options.max_epochs;
decay_rate = train_options.decay_rate;
decay_points = train_options.decay_points;
mini_batch_size = train_options.mini_batch_size;
is_visible = train_options.is_visible;
is_verbose = train_options.is_verbose;
early_stop_weights = train_options.early_stop_weights;
toy_system = train_options.toy_system;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% SETTING PARAMETERS
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% General parameters to set, CAN BE CHANGED
validation_frequency = max(5e-3, 1 / max_epochs); % Relative validation checks
window_RMSE = 5; % Smoothing window for RMSE reporting
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% DATA SCRAPING & PREALLOCATIONS
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Scrape and initialize data
num_layers = numel(hidden_units);
num_features = size(train_dataset.inputs_norm, 3);
num_responses = size(train_dataset.outputs_norm, 3);
network_layers = net_maker(hidden_units, num_features, num_responses, dropout_rate);
net_name = generate_net_name(num_layers, toy_system, hidden_units, learn_rate);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% PRE-INITIALIZE LOOP
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Turn dataset as in cells (legacy transformation)
x_train = tensor_to_cell(train_dataset.inputs_norm);
y_train = tensor_to_cell(train_dataset.outputs_norm);
x_valid = tensor_to_cell(valid_dataset.inputs_norm);
y_valid = tensor_to_cell(valid_dataset.outputs_norm);
n_train_seq = size(x_train,2);
% Create network from layers
net = dlnetwork(network_layers);
if canUseGPU, net = dlupdate(@gpuArray, net); end
% To load another net in case for parameter initialization
% loaded_data = load("net_resultstestmgu_tank_2L_6_4_1e-02");
% net_data = loaded_data.net_data;
% net = net_data.net;
% Get validation data in dlarray
if ~canUseGPU
x_valid = cellfun(@dlarray, x_valid, ‘UniformOutput’, false);
y_valid = cellfun(@dlarray, y_valid, ‘UniformOutput’, false);
else
x_valid = cellfun(@gpuArray, x_valid, ‘UniformOutput’, false);
y_valid = cellfun(@gpuArray, y_valid, ‘UniformOutput’, false);
end
% Select custom loss function
custom_loss = @clf;
% Initialize monitor data
num_iterations_per_epoch = n_train_seq / mini_batch_size;
num_iterations = num_iterations_per_epoch * max_epochs;
validation_steps = ceil(validation_frequency * num_iterations);
monitor = generate_monitor(is_visible);
monitor.Progress = 0;
% Initialize iteration and epoch counters
iteration = 0;
epoch = 0;
window_data = zeros(window_RMSE,1);
% Initialize data for solver
average_grad = [];
average_sqgrad = [];
% Select best execution environment, nice if GPU available
executionEnvironment = "auto";
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
updateInfo(monitor, ExecutionEnvironment="GPU");
else
updateInfo(monitor, ExecutionEnvironment="CPU");
end
% Preallocate arrays to store metrics and iterations
max_iterations = max_epochs * num_iterations_per_epoch;
rmse_train = zeros(max_iterations, 1);
rmse_train_smooth = zeros(max_iterations, 1);
rmse_validation = zeros(max_iterations, 1);
iterations_store = zeros(max_iterations, 1);
% Initialize min validation RMSE and corresponding network, to later save
% min val network in the end
min_early_stop_condition = inf;
min_cond_net = [];
min_info = struct(‘training_rmse’, [], ‘validation_rmse’, []);
min_monitor_data = struct(…
‘rmse_train’, [], …
‘rmse_train_smooth’, [], …
‘rmse_validation’, [], …
‘iterations_store’, []);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% CUSTOM LOOP
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
monitor_data_idx = 0; % Index for storing data in monitor_data
while epoch < max_epochs && ~monitor.Stop
epoch = epoch + 1;
if is_verbose, fprintf(‘%d. ‘, epoch); end
% Check if the current epoch matches any defined decay points
if ismember(epoch, ceil(decay_points * max_epochs))
learn_rate = learn_rate * decay_rate;
end
% Shuffle training sequences
sequence_shuffle_indices = randperm(n_train_seq);
x_train_sequences_shuffled_this_epoch = x_train(sequence_shuffle_indices);
y_train_sequences_shuffled_this_epoch = y_train(sequence_shuffle_indices);
% Create mini batches
[x_batches_cells_epoch, y_batches_cells_epoch] = create_mini_batches(x_train_sequences_shuffled_this_epoch, y_train_sequences_shuffled_this_epoch, mini_batch_size);
[dlx_batched_this_epoch, dly_batched_this_epoch] = preprocess_mini_batches(x_batches_cells_epoch, y_batches_cells_epoch);
% Determine the number of batches created in this epoch.
% This should ideally match num_iterations_per_epoch if batching is consistent.
current_num_batches_in_epoch = size(dlx_batched_this_epoch, 1);
% Or use num_iterations_per_epoch if it’s reliably fixed
% Shuffle the order of the batches
batch_order_in_epoch_indices = randperm(current_num_batches_in_epoch);
x_train_batch = dlx_batched_this_epoch(batch_order_in_epoch_indices);
y_train_batch = dly_batched_this_epoch(batch_order_in_epoch_indices);
for batch = 1:num_iterations_per_epoch
if monitor.Stop, break; end
iteration = iteration + 1;
monitor_data_idx = monitor_data_idx + 1;
% Evaluate loss function
[MSE, gradients, ~] = dlfeval(custom_loss, net, x_train_batch{batch}, y_train_batch{batch}, is_verbose);
% Update network parameters based on loss
[net, average_grad, average_sqgrad] = adamupdate(net, gradients,average_grad, average_sqgrad, iteration, learn_rate);
current_rmse_train = double(sqrt(MSE));
rmse_train(monitor_data_idx) = current_rmse_train;
if rem(iteration, window_RMSE) == 0 || iteration == 1
if iteration == 1
smooth_rmse = current_rmse_train;
else
smooth_rmse = double(mean(rmse_train(iteration-window_RMSE+1:iteration)));
end
recordMetrics(monitor, iteration, TrainingRMSE_smooth=smooth_rmse);
rmse_train_smooth(monitor_data_idx) = smooth_rmse;
end
if rem(iteration, validation_steps) == 0 || iteration == 1
validation_rmse = rmse_validator(net, x_valid, y_valid);
recordMetrics(monitor, iteration, ValidationRMSE=double(validation_rmse));
rmse_validation(monitor_data_idx) = double(validation_rmse);
% Save min_val_net if new early stop condition is minimal
early_stop_condition = validation_rmse * early_stop_weights(1) + …
MSE * early_stop_weights(2);
if early_stop_condition < min_early_stop_condition
min_early_stop_condition = early_stop_condition;
min_cond_net = net;
% Save the relevant information for min_val_net
min_info.training_rmse = sqrt(MSE);
min_info.validation_rmse = validation_rmse;
min_monitor_data.rmse_train = rmse_train(1:monitor_data_idx);
min_monitor_data.rmse_train_smooth = rmse_train_smooth(1:monitor_data_idx);
min_monitor_data.rmse_validation = rmse_validation(1:monitor_data_idx);
min_monitor_data.iterations_store = iterations_store(1:monitor_data_idx);
% Save the iteration when min_val_net is triggered
recorded_monitor.min_val_iteration = iteration;
end
end
iterations_store(monitor_data_idx) = iteration;
recordMetrics(monitor, iteration, TrainingRMSE=current_rmse_train);
updateInfo(monitor, Epoch=[num2str(epoch) ‘ of ‘ num2str(max_epochs)], Iteration=[num2str(iteration) ‘ of ‘ num2str(num_iterations)], LearnRate=learn_rate);
monitor.Progress = 100 * iteration/num_iterations;
end
end
% Save all data in monitor
recorded_monitor.rmse_train = rmse_train; % Save RMSE for training
recorded_monitor.rmse_train_smooth = rmse_train_smooth; % Save smoothed training RMSE
recorded_monitor.rmse_validation = rmse_validation; % Save validation RMSE
recorded_monitor.iterations_store = iterations_store; % Save the iterations
net = min_cond_net;
info = min_info;
monitor = recorded_monitor;
endHi everyone!
Context: I’m working on a research project to perform dynamical system identification using a custom Recurrent Neural Network (RNN), specifically a Minimal Gated Unit (MGU) as described in Minimal Gated Unit for Recurrent Neural Networks.
I’ve implemented my MGU layer and a custom training loop following MATLAB’s tutorials (e.g.,Define Custom Recurrent Deep Learning Layer, Train Network Using Custom Training Loop, …).
Problem: I’m encountering significantly slower training times with my custom MGU layer compared to using MATLAB’s built-in `lstmLayer` or `gruLayer`. As noted in discussion Why peepholeLSTMLayer implemented in a tutorial is much slower than built-in lstmlayer?, custom RNN implementations can be slow, and my experience aligns with this. My training has much slower training if used a MGU layer rather than for instance a built-in `gruLayer` or `lstmLayer`.
I’ve thoroughly analyzed the `lstmLayer` and `gruLayer` implementations within the Deep Learning Toolbox. It appears their performance advantage stems from highly optimized, likely C-based, built-in functions that are not accessible for custom layer development.
I’ve attempted to use acceleration techniques within MATLAB, but these haven’t yielded significant improvements for the custom layer’s speed bottleneck. I briefly explored creating custom C/MEX scripts but lack the expertise to implement this effectively (i.e., did not even compile or work in any manner).
Question: Is the observed performance disparity an inherent limitation when implementing custom RNN layers in MATLAB, or are there advanced strategies or techniques to significantly accelerate the training of custom recurrent layers (like MGU) to achieve performance closer to that of built-in `lstmLayer` or `gruLayer`? I’m aiming to at least reach comparable training velocity.
Additional Context
I am using MATLAB R2024b
I use a GPU for training (but even switching to CPU does no change)
For comparison, I am also implementing custom LSTM and GRU layers to benchmark their relative speeds against my custom MGU and the built-in layers.
Code
Below are the code snippets for my MGU layer and the custom training loop, which might be helpful for understanding the context:
classdef mguLayer < nnet.layer.Layer & nnet.layer.Formattable
%MGULAYER Minimal Gated Unit Layer
properties
% Layer properties.
NumHiddenUnits
OutputMode
end
properties (Learnable)
% Layer learnable parameters.
InputWeights
RecurrentWeights
Bias
end
properties (State)
% Layer state parameters.
HiddenState
end
methods
function layer = mguLayer(numHiddenUnits,args)
%MGULAYER Minimal Gated Unit Layer
% layer = MGULayer(numHiddenUnits)
% creates a MGU layer with the specified number of
% hidden units.
%
% layer = MGULayer(numHiddenUnits,Name=Value)
% creates a MGU layer and specifies additional
% options using one or more name-value arguments:
%
% Name – Name of the layer, specified as a string.
% The default is "".
%
% OutputMode – Output mode, specified as one of the
% following:
% "sequence" – Output the entire sequence
% of data.
%
% "last" – Output the last time step
% of the data.
% The default is "sequence".
% Parse input arguments.
arguments
numHiddenUnits
args.Name = "";
args.OutputMode = "sequence";
end
layer.NumHiddenUnits = numHiddenUnits;
layer.Name = args.Name;
layer.OutputMode = args.OutputMode;
% Set layer description.
layer.Description = "MGU with " + numHiddenUnits + " hidden units";
end
function layer = initialize(layer,layout)
% layer = initialize(layer,layout) initializes the layer
% learnable and state parameters.
%
% Inputs:
% layer – Layer to initialize.
% layout – Data layout, specified as a
% networkDataLayout object.
%
% Outputs:
% layer – Initialized layer.
numHiddenUnits = layer.NumHiddenUnits;
% Find number of channels.
idx = finddim(layout,"C");
numChannels = layout.Size(idx);
% Initialize input weights.
if isempty(layer.InputWeights)
sz = [2*numHiddenUnits numChannels]; % Only 2 gates
numOut = 2*numHiddenUnits;
numIn = numChannels;
layer.InputWeights = initializeGlorot(sz,numOut,numIn);
end
% Initialize recurrent weights.
if isempty(layer.RecurrentWeights)
sz = [2*numHiddenUnits numHiddenUnits]; % Only 2 gates
layer.RecurrentWeights = initializeOrthogonal(sz);
end
% Initialize bias.
if isempty(layer.Bias)
layer.Bias = initializeBias(numHiddenUnits);
end
% Initialize hidden state.
if isempty(layer.HiddenState)
layer.HiddenState = zeros(numHiddenUnits,1);
end
end
function [Y,hiddenState] = predict(layer,X)
%PREDICT MGU predict function
% [Y,hiddenState] = predict(layer,X) forward
% propagates the data X through the layer and returns the
% layer output Y and the updated hidden state. X
% is a dlarray with format "CBT" and Y is a dlarray with
% format "CB" or "CBT", depending on the layer OutputMode
% property.
% Initialize sequence output.
numHiddenUnits = layer.NumHiddenUnits;
miniBatchSize = size(X,finddim(X,"B"));
numTimeSteps = size(X,finddim(X,"T"));
if layer.OutputMode == "sequence"
Y = zeros(numHiddenUnits,miniBatchSize,numTimeSteps,like=X);
Y = dlarray(Y,"CBT");
end
% Calculate WX + b.
X = stripdims(X);
WX = pagemtimes(layer.InputWeights,X) + layer.Bias; % Input weight * input + bias
% Indices of concatenated weight arrays.
idxf = 1:numHiddenUnits; % Forget gate
idxh = 1+numHiddenUnits:2*numHiddenUnits; % Candidate hidden state
% Initial states.
hiddenState = layer.HiddenState;
% Loop over time steps.
for t = 1:numTimeSteps
% Forget Gate computation
% f_k = sigma( (W_f * u_k + b_f) + R_f * h_{k-1} )
recurrent_term_f = layer.RecurrentWeights(idxf,:) * hiddenState;
ft = sigmoid(WX(idxf,:,t) + recurrent_term_f);
% Candidate hidden state computation
% h_tilde_k = tanh( (W_h_tilde * u_k + b_h_tilde) + R_h_tilde * (f_k .* h_{k-1}) )
modulated_hidden_state_for_candidate = ft .* hiddenState;
recurrent_term_h_tilde = layer.RecurrentWeights(idxh,:) * modulated_hidden_state_for_candidate;
htildet = tanh(WX(idxh,:,t) + recurrent_term_h_tilde);
% Update hidden state
% h_k = (1 – f_k) .* h_{k-1} + f_k .* h_tilde_k
hiddenState = (1 – ft) .* hiddenState + ft .* htildet;
% Update sequence output
if layer.OutputMode == "sequence"
Y(:,:,t) = hiddenState;
end
end
% Last time step output.
if layer.OutputMode == "last"
Y = dlarray(hiddenState,"CB");
end
end
function layer = resetState(layer)
%RESETSTATE Reset layer state
% layer = resetState(layer) resets the state properties of the
% layer.
numHiddenUnits = layer.NumHiddenUnits;
layer.HiddenState = zeros(numHiddenUnits,1);
end
end
end
function weights = initializeGlorot(sz,numOut,numIn,className)
arguments
sz
numOut
numIn
className = ‘single’
end
Z = 2*rand(sz,className) – 1;
bound = sqrt(6 / (numIn + numOut));
weights = bound * Z;
weights = dlarray(weights);
end
function parameter = initializeOrthogonal(sz)
Z = randn(sz,’single’);
[Q,R] = qr(Z,0);
D = diag(R);
Q = Q * diag(D ./ abs(D));
parameter = dlarray(Q);
end
function bias = initializeBias(numHiddenUnits)
bias = zeros(2*numHiddenUnits,1,’single’);
idx = 1:numHiddenUnits;
bias(idx) = 1;
bias = dlarray(bias);
end
% Trains a GRNN network, either LSTM or GRU. Provides a network with specified characteristics, at the minimum validation RMSE.
function [net,info,monitor,net_name] = train_loop(train_dataset, valid_dataset, train_options)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% UNPACK OPTIONS
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
hidden_units = train_options.hidden_units;
dropout_rate = train_options.dropout_rate;
learn_rate = train_options.learn_rate;
max_epochs = train_options.max_epochs;
decay_rate = train_options.decay_rate;
decay_points = train_options.decay_points;
mini_batch_size = train_options.mini_batch_size;
is_visible = train_options.is_visible;
is_verbose = train_options.is_verbose;
early_stop_weights = train_options.early_stop_weights;
toy_system = train_options.toy_system;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% SETTING PARAMETERS
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% General parameters to set, CAN BE CHANGED
validation_frequency = max(5e-3, 1 / max_epochs); % Relative validation checks
window_RMSE = 5; % Smoothing window for RMSE reporting
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% DATA SCRAPING & PREALLOCATIONS
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Scrape and initialize data
num_layers = numel(hidden_units);
num_features = size(train_dataset.inputs_norm, 3);
num_responses = size(train_dataset.outputs_norm, 3);
network_layers = net_maker(hidden_units, num_features, num_responses, dropout_rate);
net_name = generate_net_name(num_layers, toy_system, hidden_units, learn_rate);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% PRE-INITIALIZE LOOP
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Turn dataset as in cells (legacy transformation)
x_train = tensor_to_cell(train_dataset.inputs_norm);
y_train = tensor_to_cell(train_dataset.outputs_norm);
x_valid = tensor_to_cell(valid_dataset.inputs_norm);
y_valid = tensor_to_cell(valid_dataset.outputs_norm);
n_train_seq = size(x_train,2);
% Create network from layers
net = dlnetwork(network_layers);
if canUseGPU, net = dlupdate(@gpuArray, net); end
% To load another net in case for parameter initialization
% loaded_data = load("net_resultstestmgu_tank_2L_6_4_1e-02");
% net_data = loaded_data.net_data;
% net = net_data.net;
% Get validation data in dlarray
if ~canUseGPU
x_valid = cellfun(@dlarray, x_valid, ‘UniformOutput’, false);
y_valid = cellfun(@dlarray, y_valid, ‘UniformOutput’, false);
else
x_valid = cellfun(@gpuArray, x_valid, ‘UniformOutput’, false);
y_valid = cellfun(@gpuArray, y_valid, ‘UniformOutput’, false);
end
% Select custom loss function
custom_loss = @clf;
% Initialize monitor data
num_iterations_per_epoch = n_train_seq / mini_batch_size;
num_iterations = num_iterations_per_epoch * max_epochs;
validation_steps = ceil(validation_frequency * num_iterations);
monitor = generate_monitor(is_visible);
monitor.Progress = 0;
% Initialize iteration and epoch counters
iteration = 0;
epoch = 0;
window_data = zeros(window_RMSE,1);
% Initialize data for solver
average_grad = [];
average_sqgrad = [];
% Select best execution environment, nice if GPU available
executionEnvironment = "auto";
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
updateInfo(monitor, ExecutionEnvironment="GPU");
else
updateInfo(monitor, ExecutionEnvironment="CPU");
end
% Preallocate arrays to store metrics and iterations
max_iterations = max_epochs * num_iterations_per_epoch;
rmse_train = zeros(max_iterations, 1);
rmse_train_smooth = zeros(max_iterations, 1);
rmse_validation = zeros(max_iterations, 1);
iterations_store = zeros(max_iterations, 1);
% Initialize min validation RMSE and corresponding network, to later save
% min val network in the end
min_early_stop_condition = inf;
min_cond_net = [];
min_info = struct(‘training_rmse’, [], ‘validation_rmse’, []);
min_monitor_data = struct(…
‘rmse_train’, [], …
‘rmse_train_smooth’, [], …
‘rmse_validation’, [], …
‘iterations_store’, []);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% CUSTOM LOOP
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
monitor_data_idx = 0; % Index for storing data in monitor_data
while epoch < max_epochs && ~monitor.Stop
epoch = epoch + 1;
if is_verbose, fprintf(‘%d. ‘, epoch); end
% Check if the current epoch matches any defined decay points
if ismember(epoch, ceil(decay_points * max_epochs))
learn_rate = learn_rate * decay_rate;
end
% Shuffle training sequences
sequence_shuffle_indices = randperm(n_train_seq);
x_train_sequences_shuffled_this_epoch = x_train(sequence_shuffle_indices);
y_train_sequences_shuffled_this_epoch = y_train(sequence_shuffle_indices);
% Create mini batches
[x_batches_cells_epoch, y_batches_cells_epoch] = create_mini_batches(x_train_sequences_shuffled_this_epoch, y_train_sequences_shuffled_this_epoch, mini_batch_size);
[dlx_batched_this_epoch, dly_batched_this_epoch] = preprocess_mini_batches(x_batches_cells_epoch, y_batches_cells_epoch);
% Determine the number of batches created in this epoch.
% This should ideally match num_iterations_per_epoch if batching is consistent.
current_num_batches_in_epoch = size(dlx_batched_this_epoch, 1);
% Or use num_iterations_per_epoch if it’s reliably fixed
% Shuffle the order of the batches
batch_order_in_epoch_indices = randperm(current_num_batches_in_epoch);
x_train_batch = dlx_batched_this_epoch(batch_order_in_epoch_indices);
y_train_batch = dly_batched_this_epoch(batch_order_in_epoch_indices);
for batch = 1:num_iterations_per_epoch
if monitor.Stop, break; end
iteration = iteration + 1;
monitor_data_idx = monitor_data_idx + 1;
% Evaluate loss function
[MSE, gradients, ~] = dlfeval(custom_loss, net, x_train_batch{batch}, y_train_batch{batch}, is_verbose);
% Update network parameters based on loss
[net, average_grad, average_sqgrad] = adamupdate(net, gradients,average_grad, average_sqgrad, iteration, learn_rate);
current_rmse_train = double(sqrt(MSE));
rmse_train(monitor_data_idx) = current_rmse_train;
if rem(iteration, window_RMSE) == 0 || iteration == 1
if iteration == 1
smooth_rmse = current_rmse_train;
else
smooth_rmse = double(mean(rmse_train(iteration-window_RMSE+1:iteration)));
end
recordMetrics(monitor, iteration, TrainingRMSE_smooth=smooth_rmse);
rmse_train_smooth(monitor_data_idx) = smooth_rmse;
end
if rem(iteration, validation_steps) == 0 || iteration == 1
validation_rmse = rmse_validator(net, x_valid, y_valid);
recordMetrics(monitor, iteration, ValidationRMSE=double(validation_rmse));
rmse_validation(monitor_data_idx) = double(validation_rmse);
% Save min_val_net if new early stop condition is minimal
early_stop_condition = validation_rmse * early_stop_weights(1) + …
MSE * early_stop_weights(2);
if early_stop_condition < min_early_stop_condition
min_early_stop_condition = early_stop_condition;
min_cond_net = net;
% Save the relevant information for min_val_net
min_info.training_rmse = sqrt(MSE);
min_info.validation_rmse = validation_rmse;
min_monitor_data.rmse_train = rmse_train(1:monitor_data_idx);
min_monitor_data.rmse_train_smooth = rmse_train_smooth(1:monitor_data_idx);
min_monitor_data.rmse_validation = rmse_validation(1:monitor_data_idx);
min_monitor_data.iterations_store = iterations_store(1:monitor_data_idx);
% Save the iteration when min_val_net is triggered
recorded_monitor.min_val_iteration = iteration;
end
end
iterations_store(monitor_data_idx) = iteration;
recordMetrics(monitor, iteration, TrainingRMSE=current_rmse_train);
updateInfo(monitor, Epoch=[num2str(epoch) ‘ of ‘ num2str(max_epochs)], Iteration=[num2str(iteration) ‘ of ‘ num2str(num_iterations)], LearnRate=learn_rate);
monitor.Progress = 100 * iteration/num_iterations;
end
end
% Save all data in monitor
recorded_monitor.rmse_train = rmse_train; % Save RMSE for training
recorded_monitor.rmse_train_smooth = rmse_train_smooth; % Save smoothed training RMSE
recorded_monitor.rmse_validation = rmse_validation; % Save validation RMSE
recorded_monitor.iterations_store = iterations_store; % Save the iterations
net = min_cond_net;
info = min_info;
monitor = recorded_monitor;
end Hi everyone!
Context: I’m working on a research project to perform dynamical system identification using a custom Recurrent Neural Network (RNN), specifically a Minimal Gated Unit (MGU) as described in Minimal Gated Unit for Recurrent Neural Networks.
I’ve implemented my MGU layer and a custom training loop following MATLAB’s tutorials (e.g.,Define Custom Recurrent Deep Learning Layer, Train Network Using Custom Training Loop, …).
Problem: I’m encountering significantly slower training times with my custom MGU layer compared to using MATLAB’s built-in `lstmLayer` or `gruLayer`. As noted in discussion Why peepholeLSTMLayer implemented in a tutorial is much slower than built-in lstmlayer?, custom RNN implementations can be slow, and my experience aligns with this. My training has much slower training if used a MGU layer rather than for instance a built-in `gruLayer` or `lstmLayer`.
I’ve thoroughly analyzed the `lstmLayer` and `gruLayer` implementations within the Deep Learning Toolbox. It appears their performance advantage stems from highly optimized, likely C-based, built-in functions that are not accessible for custom layer development.
I’ve attempted to use acceleration techniques within MATLAB, but these haven’t yielded significant improvements for the custom layer’s speed bottleneck. I briefly explored creating custom C/MEX scripts but lack the expertise to implement this effectively (i.e., did not even compile or work in any manner).
Question: Is the observed performance disparity an inherent limitation when implementing custom RNN layers in MATLAB, or are there advanced strategies or techniques to significantly accelerate the training of custom recurrent layers (like MGU) to achieve performance closer to that of built-in `lstmLayer` or `gruLayer`? I’m aiming to at least reach comparable training velocity.
Additional Context
I am using MATLAB R2024b
I use a GPU for training (but even switching to CPU does no change)
For comparison, I am also implementing custom LSTM and GRU layers to benchmark their relative speeds against my custom MGU and the built-in layers.
Code
Below are the code snippets for my MGU layer and the custom training loop, which might be helpful for understanding the context:
classdef mguLayer < nnet.layer.Layer & nnet.layer.Formattable
%MGULAYER Minimal Gated Unit Layer
properties
% Layer properties.
NumHiddenUnits
OutputMode
end
properties (Learnable)
% Layer learnable parameters.
InputWeights
RecurrentWeights
Bias
end
properties (State)
% Layer state parameters.
HiddenState
end
methods
function layer = mguLayer(numHiddenUnits,args)
%MGULAYER Minimal Gated Unit Layer
% layer = MGULayer(numHiddenUnits)
% creates a MGU layer with the specified number of
% hidden units.
%
% layer = MGULayer(numHiddenUnits,Name=Value)
% creates a MGU layer and specifies additional
% options using one or more name-value arguments:
%
% Name – Name of the layer, specified as a string.
% The default is "".
%
% OutputMode – Output mode, specified as one of the
% following:
% "sequence" – Output the entire sequence
% of data.
%
% "last" – Output the last time step
% of the data.
% The default is "sequence".
% Parse input arguments.
arguments
numHiddenUnits
args.Name = "";
args.OutputMode = "sequence";
end
layer.NumHiddenUnits = numHiddenUnits;
layer.Name = args.Name;
layer.OutputMode = args.OutputMode;
% Set layer description.
layer.Description = "MGU with " + numHiddenUnits + " hidden units";
end
function layer = initialize(layer,layout)
% layer = initialize(layer,layout) initializes the layer
% learnable and state parameters.
%
% Inputs:
% layer – Layer to initialize.
% layout – Data layout, specified as a
% networkDataLayout object.
%
% Outputs:
% layer – Initialized layer.
numHiddenUnits = layer.NumHiddenUnits;
% Find number of channels.
idx = finddim(layout,"C");
numChannels = layout.Size(idx);
% Initialize input weights.
if isempty(layer.InputWeights)
sz = [2*numHiddenUnits numChannels]; % Only 2 gates
numOut = 2*numHiddenUnits;
numIn = numChannels;
layer.InputWeights = initializeGlorot(sz,numOut,numIn);
end
% Initialize recurrent weights.
if isempty(layer.RecurrentWeights)
sz = [2*numHiddenUnits numHiddenUnits]; % Only 2 gates
layer.RecurrentWeights = initializeOrthogonal(sz);
end
% Initialize bias.
if isempty(layer.Bias)
layer.Bias = initializeBias(numHiddenUnits);
end
% Initialize hidden state.
if isempty(layer.HiddenState)
layer.HiddenState = zeros(numHiddenUnits,1);
end
end
function [Y,hiddenState] = predict(layer,X)
%PREDICT MGU predict function
% [Y,hiddenState] = predict(layer,X) forward
% propagates the data X through the layer and returns the
% layer output Y and the updated hidden state. X
% is a dlarray with format "CBT" and Y is a dlarray with
% format "CB" or "CBT", depending on the layer OutputMode
% property.
% Initialize sequence output.
numHiddenUnits = layer.NumHiddenUnits;
miniBatchSize = size(X,finddim(X,"B"));
numTimeSteps = size(X,finddim(X,"T"));
if layer.OutputMode == "sequence"
Y = zeros(numHiddenUnits,miniBatchSize,numTimeSteps,like=X);
Y = dlarray(Y,"CBT");
end
% Calculate WX + b.
X = stripdims(X);
WX = pagemtimes(layer.InputWeights,X) + layer.Bias; % Input weight * input + bias
% Indices of concatenated weight arrays.
idxf = 1:numHiddenUnits; % Forget gate
idxh = 1+numHiddenUnits:2*numHiddenUnits; % Candidate hidden state
% Initial states.
hiddenState = layer.HiddenState;
% Loop over time steps.
for t = 1:numTimeSteps
% Forget Gate computation
% f_k = sigma( (W_f * u_k + b_f) + R_f * h_{k-1} )
recurrent_term_f = layer.RecurrentWeights(idxf,:) * hiddenState;
ft = sigmoid(WX(idxf,:,t) + recurrent_term_f);
% Candidate hidden state computation
% h_tilde_k = tanh( (W_h_tilde * u_k + b_h_tilde) + R_h_tilde * (f_k .* h_{k-1}) )
modulated_hidden_state_for_candidate = ft .* hiddenState;
recurrent_term_h_tilde = layer.RecurrentWeights(idxh,:) * modulated_hidden_state_for_candidate;
htildet = tanh(WX(idxh,:,t) + recurrent_term_h_tilde);
% Update hidden state
% h_k = (1 – f_k) .* h_{k-1} + f_k .* h_tilde_k
hiddenState = (1 – ft) .* hiddenState + ft .* htildet;
% Update sequence output
if layer.OutputMode == "sequence"
Y(:,:,t) = hiddenState;
end
end
% Last time step output.
if layer.OutputMode == "last"
Y = dlarray(hiddenState,"CB");
end
end
function layer = resetState(layer)
%RESETSTATE Reset layer state
% layer = resetState(layer) resets the state properties of the
% layer.
numHiddenUnits = layer.NumHiddenUnits;
layer.HiddenState = zeros(numHiddenUnits,1);
end
end
end
function weights = initializeGlorot(sz,numOut,numIn,className)
arguments
sz
numOut
numIn
className = ‘single’
end
Z = 2*rand(sz,className) – 1;
bound = sqrt(6 / (numIn + numOut));
weights = bound * Z;
weights = dlarray(weights);
end
function parameter = initializeOrthogonal(sz)
Z = randn(sz,’single’);
[Q,R] = qr(Z,0);
D = diag(R);
Q = Q * diag(D ./ abs(D));
parameter = dlarray(Q);
end
function bias = initializeBias(numHiddenUnits)
bias = zeros(2*numHiddenUnits,1,’single’);
idx = 1:numHiddenUnits;
bias(idx) = 1;
bias = dlarray(bias);
end
% Trains a GRNN network, either LSTM or GRU. Provides a network with specified characteristics, at the minimum validation RMSE.
function [net,info,monitor,net_name] = train_loop(train_dataset, valid_dataset, train_options)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% UNPACK OPTIONS
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
hidden_units = train_options.hidden_units;
dropout_rate = train_options.dropout_rate;
learn_rate = train_options.learn_rate;
max_epochs = train_options.max_epochs;
decay_rate = train_options.decay_rate;
decay_points = train_options.decay_points;
mini_batch_size = train_options.mini_batch_size;
is_visible = train_options.is_visible;
is_verbose = train_options.is_verbose;
early_stop_weights = train_options.early_stop_weights;
toy_system = train_options.toy_system;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% SETTING PARAMETERS
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% General parameters to set, CAN BE CHANGED
validation_frequency = max(5e-3, 1 / max_epochs); % Relative validation checks
window_RMSE = 5; % Smoothing window for RMSE reporting
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% DATA SCRAPING & PREALLOCATIONS
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Scrape and initialize data
num_layers = numel(hidden_units);
num_features = size(train_dataset.inputs_norm, 3);
num_responses = size(train_dataset.outputs_norm, 3);
network_layers = net_maker(hidden_units, num_features, num_responses, dropout_rate);
net_name = generate_net_name(num_layers, toy_system, hidden_units, learn_rate);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% PRE-INITIALIZE LOOP
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Turn dataset as in cells (legacy transformation)
x_train = tensor_to_cell(train_dataset.inputs_norm);
y_train = tensor_to_cell(train_dataset.outputs_norm);
x_valid = tensor_to_cell(valid_dataset.inputs_norm);
y_valid = tensor_to_cell(valid_dataset.outputs_norm);
n_train_seq = size(x_train,2);
% Create network from layers
net = dlnetwork(network_layers);
if canUseGPU, net = dlupdate(@gpuArray, net); end
% To load another net in case for parameter initialization
% loaded_data = load("net_resultstestmgu_tank_2L_6_4_1e-02");
% net_data = loaded_data.net_data;
% net = net_data.net;
% Get validation data in dlarray
if ~canUseGPU
x_valid = cellfun(@dlarray, x_valid, ‘UniformOutput’, false);
y_valid = cellfun(@dlarray, y_valid, ‘UniformOutput’, false);
else
x_valid = cellfun(@gpuArray, x_valid, ‘UniformOutput’, false);
y_valid = cellfun(@gpuArray, y_valid, ‘UniformOutput’, false);
end
% Select custom loss function
custom_loss = @clf;
% Initialize monitor data
num_iterations_per_epoch = n_train_seq / mini_batch_size;
num_iterations = num_iterations_per_epoch * max_epochs;
validation_steps = ceil(validation_frequency * num_iterations);
monitor = generate_monitor(is_visible);
monitor.Progress = 0;
% Initialize iteration and epoch counters
iteration = 0;
epoch = 0;
window_data = zeros(window_RMSE,1);
% Initialize data for solver
average_grad = [];
average_sqgrad = [];
% Select best execution environment, nice if GPU available
executionEnvironment = "auto";
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
updateInfo(monitor, ExecutionEnvironment="GPU");
else
updateInfo(monitor, ExecutionEnvironment="CPU");
end
% Preallocate arrays to store metrics and iterations
max_iterations = max_epochs * num_iterations_per_epoch;
rmse_train = zeros(max_iterations, 1);
rmse_train_smooth = zeros(max_iterations, 1);
rmse_validation = zeros(max_iterations, 1);
iterations_store = zeros(max_iterations, 1);
% Initialize min validation RMSE and corresponding network, to later save
% min val network in the end
min_early_stop_condition = inf;
min_cond_net = [];
min_info = struct(‘training_rmse’, [], ‘validation_rmse’, []);
min_monitor_data = struct(…
‘rmse_train’, [], …
‘rmse_train_smooth’, [], …
‘rmse_validation’, [], …
‘iterations_store’, []);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% CUSTOM LOOP
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
monitor_data_idx = 0; % Index for storing data in monitor_data
while epoch < max_epochs && ~monitor.Stop
epoch = epoch + 1;
if is_verbose, fprintf(‘%d. ‘, epoch); end
% Check if the current epoch matches any defined decay points
if ismember(epoch, ceil(decay_points * max_epochs))
learn_rate = learn_rate * decay_rate;
end
% Shuffle training sequences
sequence_shuffle_indices = randperm(n_train_seq);
x_train_sequences_shuffled_this_epoch = x_train(sequence_shuffle_indices);
y_train_sequences_shuffled_this_epoch = y_train(sequence_shuffle_indices);
% Create mini batches
[x_batches_cells_epoch, y_batches_cells_epoch] = create_mini_batches(x_train_sequences_shuffled_this_epoch, y_train_sequences_shuffled_this_epoch, mini_batch_size);
[dlx_batched_this_epoch, dly_batched_this_epoch] = preprocess_mini_batches(x_batches_cells_epoch, y_batches_cells_epoch);
% Determine the number of batches created in this epoch.
% This should ideally match num_iterations_per_epoch if batching is consistent.
current_num_batches_in_epoch = size(dlx_batched_this_epoch, 1);
% Or use num_iterations_per_epoch if it’s reliably fixed
% Shuffle the order of the batches
batch_order_in_epoch_indices = randperm(current_num_batches_in_epoch);
x_train_batch = dlx_batched_this_epoch(batch_order_in_epoch_indices);
y_train_batch = dly_batched_this_epoch(batch_order_in_epoch_indices);
for batch = 1:num_iterations_per_epoch
if monitor.Stop, break; end
iteration = iteration + 1;
monitor_data_idx = monitor_data_idx + 1;
% Evaluate loss function
[MSE, gradients, ~] = dlfeval(custom_loss, net, x_train_batch{batch}, y_train_batch{batch}, is_verbose);
% Update network parameters based on loss
[net, average_grad, average_sqgrad] = adamupdate(net, gradients,average_grad, average_sqgrad, iteration, learn_rate);
current_rmse_train = double(sqrt(MSE));
rmse_train(monitor_data_idx) = current_rmse_train;
if rem(iteration, window_RMSE) == 0 || iteration == 1
if iteration == 1
smooth_rmse = current_rmse_train;
else
smooth_rmse = double(mean(rmse_train(iteration-window_RMSE+1:iteration)));
end
recordMetrics(monitor, iteration, TrainingRMSE_smooth=smooth_rmse);
rmse_train_smooth(monitor_data_idx) = smooth_rmse;
end
if rem(iteration, validation_steps) == 0 || iteration == 1
validation_rmse = rmse_validator(net, x_valid, y_valid);
recordMetrics(monitor, iteration, ValidationRMSE=double(validation_rmse));
rmse_validation(monitor_data_idx) = double(validation_rmse);
% Save min_val_net if new early stop condition is minimal
early_stop_condition = validation_rmse * early_stop_weights(1) + …
MSE * early_stop_weights(2);
if early_stop_condition < min_early_stop_condition
min_early_stop_condition = early_stop_condition;
min_cond_net = net;
% Save the relevant information for min_val_net
min_info.training_rmse = sqrt(MSE);
min_info.validation_rmse = validation_rmse;
min_monitor_data.rmse_train = rmse_train(1:monitor_data_idx);
min_monitor_data.rmse_train_smooth = rmse_train_smooth(1:monitor_data_idx);
min_monitor_data.rmse_validation = rmse_validation(1:monitor_data_idx);
min_monitor_data.iterations_store = iterations_store(1:monitor_data_idx);
% Save the iteration when min_val_net is triggered
recorded_monitor.min_val_iteration = iteration;
end
end
iterations_store(monitor_data_idx) = iteration;
recordMetrics(monitor, iteration, TrainingRMSE=current_rmse_train);
updateInfo(monitor, Epoch=[num2str(epoch) ‘ of ‘ num2str(max_epochs)], Iteration=[num2str(iteration) ‘ of ‘ num2str(num_iterations)], LearnRate=learn_rate);
monitor.Progress = 100 * iteration/num_iterations;
end
end
% Save all data in monitor
recorded_monitor.rmse_train = rmse_train; % Save RMSE for training
recorded_monitor.rmse_train_smooth = rmse_train_smooth; % Save smoothed training RMSE
recorded_monitor.rmse_validation = rmse_validation; % Save validation RMSE
recorded_monitor.iterations_store = iterations_store; % Save the iterations
net = min_cond_net;
info = min_info;
monitor = recorded_monitor;
end rnn, custom training, lstm, gru, mgu MATLAB Answers — New Questions