Задать вопрос
@aleksandrpikul

Как обучить модель на видео данных?

Всем привет.
Обучал модель на изображениях следующим образом:
class ASDataset(Dataset):
    def __init__(self, client_file: str, imposter_file: str, transforms=None):
        with open(client_file, "r") as f:
            client_files = f.read().splitlines()
        with open(imposter_file, "r") as f:
            imposter_files = f.read().splitlines()
        self.labels = torch.cat((torch.ones(len(client_files)), torch.zeros(len(imposter_files))))
        self.imgs = client_files + imposter_files
        self.transforms = transforms

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_name = self.imgs[idx]
        img = Image.open(img_name)
        label = self.labels[idx]
        if self.transforms:
            img = self.transforms(img)
        return img, label

train_dataset = ASDataset(client_file="/kaggle/input/nuaaaa/raw/client_train_raw.txt", imposter_file="/kaggle/input/nuaaaa/raw/imposter_train_raw.txt", transforms=preprocess)
val_dataset = ASDataset(client_file="/kaggle/input/nuaaaa/raw/client_test_raw.txt", imposter_file="/kaggle/input/nuaaaa/raw/imposter_test_raw.txt", transforms=preprocess)



# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet18WithAttention().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
num_epochs = 20

def train(epoch):
    running_loss = 0.0 
    running_acc = 0.0
    model.train()
    for i, (inputs, labels) in (pbar := tqdm(enumerate(train_loader), total=len(train_loader))):
        pbar.set_postfix(**msg)
        optimizer.zero_grad()
        outputs = model(inputs.to(device))
        loss = criterion(
            outputs, labels.unsqueeze(-1).to(device)
        )
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        running_acc += get_accuracy(outputs, labels.to(device))
        if not i % 10 and i > 0:
            msg.update({"train_epoch": epoch + 1, "loss": running_loss / i, "acc": (running_acc / i).item()})
            pbar.set_postfix(**msg)


def validate():
    val_loss = 0.0  
    val_acc = 0.0
    model.eval()
    for i, (inputs, labels) in enumerate(val_loader):
        with torch.no_grad():
            outputs = model(inputs.to(device))
        loss = criterion(
            outputs, labels.unsqueeze(-1).to(device)  # Считаем loss
        )
        val_loss += loss.item()
        val_acc += get_accuracy(outputs, labels.to(device))
    val_loss /= len(val_loader)
    val_acc /= len(val_loader)
    scheduler.step()
    msg.update({"val_loss": val_loss, "val_acc": val_acc.item()})
    print(f"val_loss: {val_loss}, val_acc: {val_acc.item()}")
    return val_acc.item()
    
msg = OrderedDict({"train_epoch": None, "loss": None, "acc": None, \
                "val_loss": None, "val_acc": None})


Теперь пробую обучить на видео данных, и как я понял, нужно изменить class ASDataset, но не совсем понимаю, как это сделать. Пробовал вот так, и еще несколько похожих вариантов:
class VideoDataset(Dataset):
    def __init__(self, video_file: str, label: int, transforms=None):
        self.video = cv2.VideoCapture(video_file)
        self.label = label
        self.transforms = transforms

    def __len__(self):
        return int(self.video.get(cv2.CAP_PROP_FRAME_COUNT))

    def __getitem__(self, idx):
        self.video.set(cv2.CAP_PROP_POS_FRAMES, idx)
        success, frame = self.video.read()
        if not success:
            raise ValueError("Failed to read frame")
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)  # Convert to RGB
        if self.transforms:
            frame = self.transforms(frame)
        return frame, self.label


Скажите, пожалуйста, что нужно изменить, чтобы видео данные считывались правильно? Сами данные находятся в папке, в текстовом файле прописаны пути к файлам. Например, для файла imposter_train_raw.txt:
/kaggle/input/dfdcdfdc/DFDCDFDC/dbnygxtwek.mp4
/kaggle/input/dfdcdfdc/DFDCDFDC/dbtbbhakdv.mp4
/kaggle/input/dfdcdfdc/DFDCDFDC/ddepeddixj.mp4
/kaggle/input/dfdcdfdc/DFDCDFDC/dhcndnuwta.mp4
/kaggle/input/dfdcdfdc/DFDCDFDC/dhxctgyoqj.mp4
/kaggle/input/dfdcdfdc/DFDCDFDC/djxdyjopjd.mp4
/kaggle/input/dfdcdfdc/DFDCDFDC/dkuayagnmc.mp4
  • Вопрос задан
  • 100 просмотров
Подписаться 1 Средний 4 комментария
Пригласить эксперта
Ваш ответ на вопрос

Войдите, чтобы написать ответ

Похожие вопросы