diff --git a/model/attention/PSA.py b/model/attention/PSA.py index ba2a1eb..28eac07 100644 --- a/model/attention/PSA.py +++ b/model/attention/PSA.py @@ -4,28 +4,38 @@ from torch.nn import init - class PSA(nn.Module): - def __init__(self, channel=512,reduction=4,S=4): + def __init__(self, channel=512, reduction=4, S=4): super().__init__() - self.S=S + self.S = S - self.convs=[] - for i in range(S): - self.convs.append(nn.Conv2d(channel//S,channel//S,kernel_size=2*(i+1)+1,padding=i+1)) + self.convs = nn.ModuleList( + [nn.Conv2d(channel // S, channel // S, kernel_size=2 * (i + 1) + 1, padding=(i + 1)) for i in range(S)]) + # self.convs=[] + # for i in range(S): + # self.convs.append(nn.Conv2d(channel//S,channel//S,kernel_size=2*(i+1)+1,padding=i+1)) - self.se_blocks=[] - for i in range(S): - self.se_blocks.append(nn.Sequential( + self.se_blocks = nn.ModuleList( + nn.Sequential( nn.AdaptiveAvgPool2d(1), - nn.Conv2d(channel//S, channel // (S*reduction),kernel_size=1, bias=False), + nn.Conv2d(channel // S, channel // (S * reduction), kernel_size=1, bias=False), nn.ReLU(inplace=True), - nn.Conv2d(channel // (S*reduction), channel//S,kernel_size=1, bias=False), + nn.Conv2d(channel // (S * reduction), channel // S, kernel_size=1, bias=False), nn.Sigmoid() - )) - - self.softmax=nn.Softmax(dim=1) + ) for i in range(S) + ) + # self.se_blocks=[] + # for i in range(S): + # self.se_blocks.append(nn.Sequential( + # nn.AdaptiveAvgPool2d(1), + # nn.Conv2d(channel//S, channel // (S*reduction),kernel_size=1, bias=False), + # nn.ReLU(inplace=True), + # nn.Conv2d(channel // (S*reduction), channel//S,kernel_size=1, bias=False), + # nn.Sigmoid() + # )) + + self.softmax = nn.Softmax(dim=1) def init_weights(self): @@ -68,11 +78,10 @@ def forward(self, x): if __name__ == '__main__': - input=torch.randn(50,512,7,7) - psa = PSA(channel=512,reduction=8) - output=psa(input) - a=output.view(-1).sum() + device = torch.device('cuda') + input = torch.randn(8, 512, 7, 7).to(device) + psa = PSA(channel=512, reduction=8).to(device) + output = psa(input) + a = output.view(-1).sum() a.backward() print(output.shape) - - \ No newline at end of file