@aleksandrpikul

Как правильно продублировать слой GAT?

Есть вот такой код, здесь используется ResNet и GAT. Мне нужно продублировать слой GAT. Как я понимаю, мне надо еще раз прописать строку x = self.gat(x, self.adj), но я не понимаю где именно

class AntispoofModel(nn.Module):
    def __init__(self, device="cpu", **kwargs):
        super().__init__()
        resnet = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
        self.resnet = nn.Sequential(*[i for i in list(resnet.children())[:-2]]).to(device)
        for ch in self.resnet.children():
            for param in ch.parameters():
                param.requires_grad = False
        self.gat = GAT(**kwargs).to(device)
        self.device = device
        self.adj = torch.tensor(grid_to_graph(3, 3, return_as=np.ndarray)).to(device)
        
    def forward(self, x):
        x = self.resnet(x.to(self.device))
        x = nn.functional.avg_pool2d(x, 2)
        x = x.view(-1, 9, 512)
        #adj = torch.stack([self.adj for i in range(x.shape[0])]).to(self.device)
        x = self.gat(x, self.adj)
        return torch.sigmoid(x)
  • Вопрос задан
  • 126 просмотров
Пригласить эксперта
Ваш ответ на вопрос

Войдите, чтобы написать ответ

Похожие вопросы