Задать вопрос
@Denys557

Как обучить акустическую модель?

Недавно я начал разрабатывать ASR, и я начал с акустической модели. Я начал пытаться ее тренировать, но она дает мне совершенно неверный результат и loss становится отрицательными. Что делать? Может я просто начал не то писать?

acousticModel.cpp

#include "acousticModel/acousticModel.h"

SpeechRecognitionModelImpl::SpeechRecognitionModelImpl(int input_size, int hidden_size, int num_classes, int num_layers)
    : lstm(torch::nn::LSTMOptions(input_size, hidden_size).num_layers(num_layers).batch_first(true)),
    fc(hidden_size, num_classes),
    ctc_loss(torch::nn::CTCLoss()) {
    register_module("lstm", lstm);
    register_module("fc", fc);
    register_module("ctc_loss", ctc_loss);
}

torch::Tensor SpeechRecognitionModelImpl::forward(torch::Tensor x) {
    if (x.dim() == 2) {
        x = x.unsqueeze(0);
    }

    x = x.to(torch::kFloat);

    auto lstm_out = lstm->forward(x);
    auto hidden_states = std::get<0>(lstm_out);
    auto output = torch::log_softmax(fc->forward(hidden_states), 2);
    return output;
}


void SpeechRecognitionModelImpl::train(std::vector<torch::Tensor> inputs, std::vector<torch::Tensor> targets,
    std::vector<int> input_lengths, std::vector<int> target_lengths, size_t epochs) {
    std::cout << "-2" << std::endl;
    if (inputs.size() != targets.size() || inputs.size() != input_lengths.size()) {
        throw std::runtime_error("Inputs, targets, and lengths must have the same size");
    }
    torch::optim::Adam opt(parameters(), 0.001);

    for (size_t i = 0; i < inputs.size(); i++) {

        for (size_t epoch = 0; epoch < epochs; epoch++) {
            std::cout << "\nstart epoch" << std::endl;
            auto output = forward(inputs[i]);
            std::cout << "forward" << std::endl;

            output = output.transpose(0, 1);

            std::cout << "transpose" << std::endl;

            auto loss = ctc_loss(
                output,
                targets[i],
                torch::tensor(input_lengths[i], torch::kInt32),
                torch::tensor(target_lengths[i], torch::kInt32)
            );

            std::cout << "ctc_loss" << std::endl;

            opt.zero_grad();
            std::cout << "zero_grad" << std::endl;
            loss.backward();
            std::cout << "backward" << std::endl;
            opt.step();
            std::cout << "step" << std::endl;

            std::cout << "loss: " << loss.item<double>() << std::endl;
            std::cout << "epoch: " << epoch << std::endl << std::endl;
        }
    }

    /*for (size_t epoch = 0; epoch < epochs; ++epoch) {
        double total_loss = 0.0;

        for (size_t i = 0; i < inputs.size(); ++i) {

            std::cout << "1" << std::endl;
            auto output = forward(inputs[i]);
            std::cout << "2" << std::endl;

            output = output.transpose(0, 1);

            std::cout << "3" << std::endl;

            auto loss = ctc_loss(
                output, 
                targets[i], 
                torch::tensor(input_lengths[i], torch::kInt32),
                torch::tensor(target_lengths[i], torch::kInt32)
            );

            std::cout << "4" << std::endl;

            opt.zero_grad();
            std::cout << "5" << std::endl;
            loss.backward();
            std::cout << "6" << std::endl;
            opt.step();
            std::cout << "7" << std::endl;

            std::cout << loss.item<double>() << std::endl;  
            total_loss += loss.item<double>();
        }
         
        std::cout << "Epoch [" << epoch + 1 << "/" << epochs << "], Loss: " << total_loss / inputs.size() << std::endl;
    }*/
}

std::vector<int> SpeechRecognitionModelImpl::decode_greedy(torch::Tensor output) {
    output = output.argmax(2);
    std::vector<int> decoded_sequence;

    int prev = -1;
    for (int t = 0; t < output.size(1); ++t) {
        int current = output[0][t].item<int>();
        if (current != prev && current != 0) {
            decoded_sequence.push_back(current);
        }
        prev = current;
    }
    return decoded_sequence;
}


audio.cpp

#include "audio/audio.h"
#include <sndfile.h>
#include <stdexcept>

std::vector<double> read_audio(const std::string& filename) {
    SF_INFO sfinfo;
    SNDFILE* infile = sf_open(filename.c_str(), SFM_READ, &sfinfo);

    if (!infile) {
        throw std::runtime_error("Unable to open the file: \"" + filename + "\"");
    }

    std::vector<double> audio(sfinfo.frames);
    sf_read_double(infile, audio.data(), sfinfo.frames);
    sf_close(infile);

    return audio;
}


main.cpp

#include <iostream>
#include <vector>
#include <sndfile.h>
#include <filesystem>
#include <fftw3.h>
#include <cmath>
#include "audio/audio.h"
#include "defines.h"
#include <mimalloc.h>
#include "acousticModel/acousticModel.h"
#include "rapidcsv.h"

#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif

torch::Tensor string_to_tensor(const std::string& str) {
    std::vector<double> data;

    for (auto& c : str) {
        double x = static_cast<double>(c) / 128.0;
        data.push_back(x);
    }
    return torch::tensor(data, torch::kFloat32);
}

std::string tensor_to_string(const torch::Tensor& tensor) {
    std::string result;

    auto normalized_values = tensor.contiguous().data_ptr<float>();
    auto num_elements = tensor.size(0);

    for (size_t i = 0; i < num_elements; i++) {
        char c = static_cast<char>(normalized_values[i] * 128.0);
        result.push_back(c);
    }

    return result;
}

torch::Tensor calculate_spectrogram(const std::vector<double>& audio) {
    int num_frames = (audio.size() - WINDOW_SIZE) / HOP_SIZE + 1;

    auto spectrogram = torch::zeros({ num_frames, WINDOW_SIZE / 2 + 1 }, torch::kDouble);

    fftw_complex* fft_out = fftw_alloc_complex(WINDOW_SIZE);
    fftw_plan fft_plan = fftw_plan_dft_r2c_1d(WINDOW_SIZE, nullptr, fft_out, FFTW_ESTIMATE);

    for (int i = 0; i < num_frames; ++i) {
        std::vector<double> window(WINDOW_SIZE);
        int start = i * HOP_SIZE;

        for (int j = 0; j < WINDOW_SIZE; ++j) {
            if (start + j < audio.size()) {
                window[j] = audio[start + j] * 0.5 * (1 - cos(2 * M_PI * j / (WINDOW_SIZE - 1))); 
            }
            else {
                window[j] = 0.0;
            }
        }

        fftw_execute_dft_r2c(fft_plan, window.data(), fft_out);

        for (int k = 0; k < WINDOW_SIZE / 2 + 1; ++k) {
            spectrogram[i][k] = std::log1p(std::sqrt(fft_out[k][0] * fft_out[k][0] + fft_out[k][1] * fft_out[k][1]));
        }
    }

    fftw_destroy_plan(fft_plan);
    fftw_free(fft_out);

    return spectrogram;
}

std::pair<std::vector<torch::Tensor>, std::vector<torch::Tensor>> get_train_data(const std::filesystem::path& path) {
   
    if (!std::filesystem::exists(path) || !std::filesystem::is_directory(path)) {
        throw std::runtime_error(path.string() + " invalid path");
    }

    std::cout << "-7" << std::endl;

    std::pair<std::vector<torch::Tensor>, std::vector<torch::Tensor>> data;

    rapidcsv::Document doc("data/validated.tsv", rapidcsv::LabelParams(), rapidcsv::SeparatorParams('\t'));
    auto path_column = doc.GetColumn<std::string>("path");
    auto sentence_column = doc.GetColumn<std::string>("sentence");

    std::cout << "-6" << std::endl;

    if (path_column.size() != sentence_column.size()) {
        throw std::out_of_range("path column size not equal sentence column size");
    }

    for (size_t i = 0; i < path_column.size(); i++) {
        for (const auto& entry : std::filesystem::directory_iterator(path)) {
            if (entry.is_regular_file() && entry.path().filename() == path_column[i]) {
                
                std::string sentence = sentence_column[i];

                data.first.push_back(calculate_spectrogram(read_audio(path.string() + "/" + path_column[i])));
                data.second.push_back(string_to_tensor(sentence));
                std::cout << path_column[i] << " " << sentence << std::endl;

                if (data.first.size() >= 1) {
                    return data;
                }
            }
        }
    }


    return data;
}

int main(int argc, char* argv[]) {
    mi_version();
    try {
        std::cout << "-10" << std::endl;
        int input_size = WINDOW_SIZE / 2 + 1;
        int hidden_size = 128;
        int num_classes = 30;
        int num_layers = 2;

        std::shared_ptr<SpeechRecognitionModelImpl> model = std::make_shared<SpeechRecognitionModelImpl>(input_size, hidden_size, num_classes, num_layers);

        torch::load(model, "nn/nn2.pt");

        auto data = get_train_data("data/clips");

        std::vector<int> input_lengths, target_lengths;
        for (const auto& input : data.first) input_lengths.push_back(input.size(0));
        for (const auto& target : data.second) target_lengths.push_back(target.size(0));

        int epochs = 10;

        if (argc == 2) {
            epochs = std::stoi(std::string(argv[1]));
            std::cout << "Epochs = " << epochs << std::endl;
        }

        model->train(data.first, data.second, input_lengths, target_lengths, epochs);

        torch::save(model, "nn/nn2.pt");

        std::cout << tensor_to_string(model->forward(calculate_spectrogram(read_audio("data/clips/common_voice_en_41047776.mp3"))));
    }
    catch (const std::exception& ex) {
        std::cout << ex.what() << std::endl;
    }

    return 0;
}
  • Вопрос задан
  • 134 просмотра
Подписаться 2 Простой 2 комментария
Пригласить эксперта
Ваш ответ на вопрос

Войдите, чтобы написать ответ

Похожие вопросы