Привет! У меня есть код для обучения BERT модели. Но мои обучающие данные не сбалансированы (есть 10 категорий, в некоторых есть 100 примеров, в других - 5), я добавил class_weights в свой код, но это не дало никакого эффекта. Возможно, я могу добавить явные условия?
Например, если есть слово "привет" в тексте, то это категория "добро пожаловать". Но я не могу понять, как это сделать, может кто подскажет?
class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
logger.info(f"Class Weights: {class_weights}")
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
# Обучение и сохранение модели
for epoch in tqdm(range(1, epochs + 1)):
model.train()
loss_train_total = 0
progress_bar = tqdm(dataloader_train, desc='Эпоха {:1d}'.format(epoch), leave=False, disable=False)
for batch in progress_bar:
model.zero_grad()
batch = tuple(b.to(device) for b in batch)
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'labels': batch[2],
}
outputs = model(**inputs)
loss = criterion(outputs.logits, inputs['labels'])
loss_train_total += loss.item()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item() / len(batch))})
# Сохраняем состояние модели BERT
current_path = Path.cwd()
torch.save(model.state_dict(), current_path / 'app' / 'api' / 'ai' / 'trained_models' / 'new_models' /
f'fine_tuned_{model_name}_BERT_epoch_{epoch}.model')
logger.info(f'\nЭпоха {epoch}')
loss_train_avg = loss_train_total / len(dataloader_train)
val_loss, predictions, true_vals = evaluate(dataloader_validation)
val_f1 = f1_score_func(predictions, true_vals)
accuracy_per_class(predictions, true_vals)
logger.info(f'Функция потерь при обучении: {loss_train_avg}')
logger.info(f'Функция потерь при валидации: {val_loss}')
logger.info(f'F1-мера при валидации: {val_f1}')