@AndreiPy13

Как добавить явное условие при обучении модели AI?

Привет! У меня есть код для обучения BERT модели. Но мои обучающие данные не сбалансированы (есть 10 категорий, в некоторых есть 100 примеров, в других - 5), я добавил class_weights в свой код, но это не дало никакого эффекта. Возможно, я могу добавить явные условия?
Например, если есть слово "привет" в тексте, то это категория "добро пожаловать". Но я не могу понять, как это сделать, может кто подскажет?

class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
    logger.info(f"Class Weights: {class_weights}")
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)

    # Обучение и сохранение модели
    for epoch in tqdm(range(1, epochs + 1)):
        model.train()
        loss_train_total = 0

        progress_bar = tqdm(dataloader_train, desc='Эпоха {:1d}'.format(epoch), leave=False, disable=False)
        for batch in progress_bar:
            model.zero_grad()

            batch = tuple(b.to(device) for b in batch)

            inputs = {'input_ids': batch[0],
                      'attention_mask': batch[1],
                      'labels': batch[2],
                      }

            outputs = model(**inputs)

            loss = criterion(outputs.logits, inputs['labels'])
            loss_train_total += loss.item()
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()
            scheduler.step()

            progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item() / len(batch))})

        # Сохраняем состояние модели BERT
        current_path = Path.cwd()
        torch.save(model.state_dict(), current_path / 'app' / 'api' / 'ai' / 'trained_models' / 'new_models' /
                   f'fine_tuned_{model_name}_BERT_epoch_{epoch}.model')

        logger.info(f'\nЭпоха {epoch}')

        loss_train_avg = loss_train_total / len(dataloader_train)

        val_loss, predictions, true_vals = evaluate(dataloader_validation)
        val_f1 = f1_score_func(predictions, true_vals)
        accuracy_per_class(predictions, true_vals)
        logger.info(f'Функция потерь при обучении: {loss_train_avg}')
        logger.info(f'Функция потерь при валидации: {val_loss}')
        logger.info(f'F1-мера при валидации: {val_f1}')
  • Вопрос задан
  • 127 просмотров
Решения вопроса 1
Maksim_64
@Maksim_64
Data Analyst
Явные условия (детерминистические) это не про машинное обучение. Сама суть машинного обучения это обучение без задания явных инструкций.

Твоя проблема, большая называется "несбалансированные классы". 5 в одном 100 в другом это безнадега, модель по умолчанию имеет большую предрасположенность, что мешает обучению.

Это распространенная проблема и к сожалению простого решения (не имеет). Существуют разные стратегии, как с этим бороться, и надо пробовать, что будет работать.

Советую изучить вот эту статью на эту тему внимательно, и запастись терпением. Вот эта статья с медиума (открывается только через vpn) по крайней мере у меня, там тоже BERT и тоже решается проблема с не сбалансированным классами. Есть код. Но повторю проблема решается разными стратегиями препроцессинга (первая статья).
Ответ написан
Пригласить эксперта
Ваш ответ на вопрос

Войдите, чтобы написать ответ

Похожие вопросы