@aleksandrpikul

Почему ResNet дает слишком хорошие результаты?

Пытаюсь протестировать ResNet:
def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class ASDataset(Dataset):
    def __init__(self, client_file: str, imposter_file: str, transforms = None):
        with open(client_file, "r") as f:
            client_files = f.read().splitlines()
        with open(imposter_file, "r") as f:
            imposter_files = f.read().splitlines()
        self.labels = torch.cat((torch.ones(len(client_files)), \
            torch.zeros(len(imposter_files))))
        self.imgs = client_files + imposter_files
        self.transforms = transforms

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_name = self.imgs[idx]
        img = Image.open(img_name)
        label = self.labels[idx]
        if self.transforms:
            img = self.transforms(img)
        return img, label


class Block(nn.Module):
    def __init__(self, num_layers, in_channels, out_channels, identity_downsample=None, stride=1):
        assert num_layers in [18, 34, 50, 101, 152], "should be a a valid architecture"
        super(Block, self).__init__()
        self.num_layers = num_layers
        if self.num_layers > 34:
            self.expansion = 4
        else:
            self.expansion = 1
        # ResNet50, 101, and 152 include additional layer of 1x1 kernels
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(out_channels)
        if self.num_layers > 34:
            self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        else:
            # for ResNet18 and 34, connect input directly to (3x3) kernel (skip first (1x1))
            self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, stride=1, padding=0)
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
        self.relu = nn.ReLU()
        self.identity_downsample = identity_downsample

    def forward(self, x):
        identity = x
        if self.num_layers > 34:
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.bn3(x)

        if self.identity_downsample is not None:
            identity = self.identity_downsample(identity)

        x += identity
        x = self.relu(x)
        return x


class ResNet(nn.Module):
    def __init__(self, num_layers, block, image_channels, num_classes, **kwargs):
        assert num_layers in [18, 34, 50, 101, 152], f'ResNet{num_layers}: Unknown architecture! Number of layers has ' \
                                                     f'to be 18, 34, 50, 101, or 152 '
        super(ResNet, self).__init__()
        if num_layers < 50:
            self.expansion = 1
        else:
            self.expansion = 4
        if num_layers == 18:
            layers = [2, 2, 2, 2]
        elif num_layers == 34 or num_layers == 50:
            layers = [3, 4, 6, 3]
        elif num_layers == 101:
            layers = [3, 4, 23, 3]
        else:
            layers = [3, 8, 36, 3]
        self.in_channels = 64
        self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # ResNetLayers
        self.layer1 = self.make_layers(num_layers, block, layers[0], intermediate_channels=64, stride=1)
        self.layer2 = self.make_layers(num_layers, block, layers[1], intermediate_channels=128, stride=2)
        self.layer3 = self.make_layers(num_layers, block, layers[2], intermediate_channels=256, stride=2)
        self.layer4 = self.make_layers(num_layers, block, layers[3], intermediate_channels=512, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * self.expansion, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        return torch.sigmoid(x)

    def make_layers(self, num_layers, block, num_residual_blocks, intermediate_channels, stride):
        layers = []

        identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels, intermediate_channels*self.expansion, kernel_size=1, stride=stride),
                                            nn.BatchNorm2d(intermediate_channels*self.expansion))
        layers.append(block(num_layers, self.in_channels, intermediate_channels, identity_downsample, stride))
        self.in_channels = intermediate_channels * self.expansion # 256
        for i in range(num_residual_blocks - 1):
            layers.append(block(num_layers, self.in_channels, intermediate_channels)) # 256 -> 64, 64*4 (256) again
        return nn.Sequential(*layers)


def ResNet18(img_channels=3, num_classes=1):
    return ResNet(18, Block, img_channels, num_classes)


preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

torch.manual_seed(42)
kwargs = {"nfeat":64, "nhid":64, "nclass":1, "nheads":49, "dropout":0.6, "alpha":0.01}
model = ResNet18()

test_dataset = ASDataset(client_file="raw/client_train_raw.txt", imposter_file="raw/imposter_train_raw.txt", \
    transforms=preprocess)
train_dataset = ASDataset(client_file="raw/client_test_raw.txt", imposter_file="raw/imposter_test_raw.txt", \
    transforms=preprocess)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True)

criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=15e-4, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.3)

EPOCHS = 10
device = "cpu"


Но получаю AUC = 100, EER = 0
Разве такое может быть?
Есть подозрения, что неверно расписан сам ResNet или как-то не так его использую.
  • Вопрос задан
  • 94 просмотра
Пригласить эксперта
Ответы на вопрос 1
Maksim_64
@Maksim_64
Data Analyst
На практике такие метрики могут означать лишь одно вы тренируете и тестируете на одних и тех же данных. Первое место для проверки
test_dataset = ASDataset(client_file="raw/client_train_raw.txt", imposter_file="raw/imposter_train_raw.txt", \
    transforms=preprocess)
train_dataset = ASDataset(client_file="raw/client_test_raw.txt", imposter_file="raw/imposter_test_raw.txt", \
    transforms=preprocess)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True)

второе место для проверки естественно сам класс DataLoader.
ну и третье сами файлы что бы по запарке они не содержали одно и тоже содержимое.

В каком случае возможны такие метрики ну например данные на которых вы учитесь в них входная переменная это температура в цельсии а то что нужно "предсказать" температура в фаренгейтах. Ваши тестовые данные по структуре такие же но сам датасет алгоритм никогда не видел. На тех данных что алгоритм тренировался он благополучно выучит школьную формулу перевода из цельсии в фаренгейты и справится со 100 точностью. По простой причине в данная проблема состоит только из детерминистической составляющей, т.е вариативность отсутствует. (Данный пример специально примитивен. Это может и Unsupervised Learning это может задача где на вход "features" пойдут десятки переменных и.т.д лишь бы отсутствовала вариативность). Подобные примеры как этот будут выдавать такие метрики.
Ответ написан
Комментировать
Ваш ответ на вопрос

Войдите, чтобы написать ответ

Войти через центр авторизации
Похожие вопросы