Есть вот такой код, здесь используется ResNet и GAT. Мне нужно продублировать слой GAT. Как я понимаю, мне надо еще раз прописать строку
x = self.gat(x, self.adj)
, но я не понимаю где именно
class AntispoofModel(nn.Module):
def __init__(self, device="cpu", **kwargs):
super().__init__()
resnet = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
self.resnet = nn.Sequential(*[i for i in list(resnet.children())[:-2]]).to(device)
for ch in self.resnet.children():
for param in ch.parameters():
param.requires_grad = False
self.gat = GAT(**kwargs).to(device)
self.device = device
self.adj = torch.tensor(grid_to_graph(3, 3, return_as=np.ndarray)).to(device)
def forward(self, x):
x = self.resnet(x.to(self.device))
x = nn.functional.avg_pool2d(x, 2)
x = x.view(-1, 9, 512)
#adj = torch.stack([self.adj for i in range(x.shape[0])]).to(self.device)
x = self.gat(x, self.adj)
return torch.sigmoid(x)