Spaces:
Sleeping
Sleeping
| # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
| # Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA'). | |
| # | |
| # All rights reserved. | |
| # This work should only be used for nonprofit purposes. | |
| # | |
| # By downloading and/or using any of these files, you implicitly agree to all the | |
| # terms of the license, as specified in the document LICENSE.txt | |
| # (included in this package) and online at | |
| # http://www.grip.unina.it/download/LICENSE_OPEN.txt | |
| """ | |
| Created in September 2022 | |
| @author: davide.cozzolino | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| def weighted_statistics_pooling(x, log_w=None): | |
| b = x.shape[0] | |
| c = x.shape[1] | |
| x = x.view(b,c,-1) | |
| if log_w is None: | |
| log_w = torch.zeros((b,1,x.shape[-1]), device=x.device) | |
| else: | |
| assert log_w.shape[0]==b | |
| assert log_w.shape[1]==1 | |
| log_w = log_w.view(b,1,-1) | |
| assert log_w.shape[-1]==x.shape[-1] | |
| log_w = F.log_softmax(log_w, dim=-1) | |
| x_min = -torch.logsumexp(log_w-x, dim=-1) | |
| x_max = torch.logsumexp(log_w+x, dim=-1) | |
| w = torch.exp(log_w) | |
| x_avg = torch.sum(w*x , dim=-1) | |
| x_msq = torch.sum(w*x*x, dim=-1) | |
| x = torch.cat((x_min, x_max, x_avg, x_msq), dim=1) | |
| return x |