@AndreiPy13

Как добавить явное условие при обучении модели AI?

Привет! У меня есть код для обучения BERT модели. Но мои обучающие данные не сбалансированы (есть 10 категорий, в некоторых есть 100 примеров, в других - 5), я добавил class_weights в свой код, но это не дало никакого эффекта. Возможно, я могу добавить явные условия?
Например, если есть слово "привет" в тексте, то это категория "добро пожаловать". Но я не могу понять, как это сделать, может кто подскажет?

class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
    logger.info(f"Class Weights: {class_weights}")
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)

    # Обучение и сохранение модели
    for epoch in tqdm(range(1, epochs + 1)):
        model.train()
        loss_train_total = 0

        progress_bar = tqdm(dataloader_train, desc='Эпоха {:1d}'.format(epoch), leave=False, disable=False)
        for batch in progress_bar:
            model.zero_grad()

            batch = tuple(b.to(device) for b in batch)

            inputs = {'input_ids': batch[0],
                      'attention_mask': batch[1],
                      'labels': batch[2],
                      }

            outputs = model(**inputs)

            loss = criterion(outputs.logits, inputs['labels'])
            loss_train_total += loss.item()
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()
            scheduler.step()

            progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item() / len(batch))})

        # Сохраняем состояние модели BERT
        current_path = Path.cwd()
        torch.save(model.state_dict(), current_path / 'app' / 'api' / 'ai' / 'trained_models' / 'new_models' /
                   f'fine_tuned_{model_name}_BERT_epoch_{epoch}.model')

        logger.info(f'\nЭпоха {epoch}')

        loss_train_avg = loss_train_total / len(dataloader_train)

        val_loss, predictions, true_vals = evaluate(dataloader_validation)
        val_f1 = f1_score_func(predictions, true_vals)
        accuracy_per_class(predictions, true_vals)
        logger.info(f'Функция потерь при обучении: {loss_train_avg}')
        logger.info(f'Функция потерь при валидации: {val_loss}')
        logger.info(f'F1-мера при валидации: {val_f1}')
  • Вопрос задан
  • 127 просмотров
Решения вопроса 1
Maksim_64
@Maksim_64
Data Analyst
Явные условия (детерминистические) это не про машинное обучение. Сама суть машинного обучения это обучение без задания явных инструкций.

Твоя проблема, большая называется "несбалансированные классы". 5 в одном 100 в другом это безнадега, модель по умолчанию имеет большую предрасположенность, что мешает обучению.

Это распространенная проблема и к сожалению простого решения (не имеет). Существуют разные стратегии, как с этим бороться, и надо пробовать, что будет работать.

Советую изучить вот эту статью на эту тему внимательно, и запастись терпением. Вот эта статья с медиума (открывается только через vpn) по крайней мере у меня, там тоже BERT и тоже решается проблема с не сбалансированным классами. Есть код. Но повторю проблема решается разными стратегиями препроцессинга (первая статья).
Ответ написан
Пригласить эксперта
Ваш ответ на вопрос

Войдите, чтобы написать ответ

Войти через центр авторизации
Похожие вопросы