Как зафиксировать значение ошибки для каждой эпохи в MATLAB?

1
3

всем! У меня возникла проблема с моим проектом Deep Learning в задаче регрессии изображения, и я хотел бы узнать, как я могу получить значение RMSE в каждой эпохе выполнения моей модели и экспортировать его в архив .xlsx.

Однако из-за природы цикла for обучение выполняется с той же скоростью, что и определенное количество эпох. Поэтому я хотел бы узнать, есть ли функция, которая позволяет мне, например, получить массив данных выполнения.

Мотивация в том, что я хотел бы использовать эти данные для другого проекта.

Я хочу экспортировать данные в архив .xlsx или .csv, но я хочу узнать, как это получить.

rmseValues = zeros(numEpochs, 1);
for epoch = 1:numEpochs
    [net, info] = trainNetwork(XTrain, YTrain, layers, options);
    YPred = predict(net, XTrain);
    rmseValues(epoch) = calculateRMSE(YTrain, YPred);
end
rmseTable = array2table(rmseValues, 'VariableNames', {'RMSE'});
disp(rmseTable)
Эрнст
Вопрос задан16 июля 2024 г.

1 Ответ

2

Фиксация значения ошибки для каждой эпохи в MATLAB

В MATLAB вы можете отслеживать значение ошибки для каждой эпохи обучения нейронной сети с помощью функции trainNetwork и ее опции 'Plots' или путем ручного сбора значений ошибки в цикле обучения.

Метод 1: Использование опции ‘Plots’

Функция trainNetwork предоставляет возможность отслеживать различные метрики обучения, включая ошибку, в процессе тренировки. 

Пример:

% Загрузка данных
data = load('mydata.mat');
X = data.X;
Y = data.Y;

% Определение архитектуры сети
layers = [
    imageInputLayer([28 28 1])
    convolution2dLayer(5, 20)
    reluLayer
    maxPooling2dLayer(2,'Stride',2)
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer
];

% Обучение сети
options = trainingOptions('sgdm','Plots','training-progress');
net = trainNetwork(X,Y,layers,options);

Объяснение кода:

  • trainingOptions('sgdm','Plots','training-progress') устанавливает опцию 'Plots' в  'training-progress', что отображает график ошибки в режиме реального времени.
  • График будет показывать значение ошибки для каждой эпохи обучения.

Метод 2: Ручное отслеживание ошибки

Если вы хотите получить более точный контроль над процессом сбора данных об ошибках, можно использовать цикл обучения и ручно собирать значения ошибки.

Пример:

% Загрузка данных
data = load('mydata.mat');
X = data.X;
Y = data.Y;

% Определение архитектуры сети
layers = [
    imageInputLayer([28 28 1])
    convolution2dLayer(5, 20)
    reluLayer
    maxPooling2dLayer(2,'Stride',2)
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer
];

% Обучение сети
epochs = 10;
errors = zeros(epochs,1);
for i = 1:epochs
    % Обучение одной эпохи
    [net,trainInfo] = trainNetwork(X,Y,layers,options);
    % Сохранение ошибки текущей эпохи
    errors(i) = trainInfo.TrainingLoss;
end

% Визуализация значений ошибки
plot(1:epochs, errors);
xlabel('Эпоха');
ylabel('Ошибка');
title('Значение ошибки за эпохи');

Объяснение кода:

  • errors = zeros(epochs,1) создает вектор для хранения значений ошибки.
  • trainInfo.TrainingLoss содержит значение ошибки обучения после каждой эпохи.
  • plot(1:epochs, errors) показывает график значений ошибки в зависимости от номера эпохи.

Дополнительные возможности

  • Изменение метрики ошибки: Вы можете использовать другие метрики ошибки в зависимости от вашей задачи. 
  • Сохранение данных об ошибке: Вы можете сохранить полученные значения ошибки в файл для позднейшего анализа.
  • Использование  Performance структуры:Функция  trainNetwork также возвращает структуру  Performance, которая содержит дополнительные метрики обучения, включая точность, отзыв, F1-оценку и др.

 

Михей
Ответ получен14 сентября 2024 г.

Ваш ответ

Загрузить файл.