Jatin-tec
Add application file
65d7391
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA').
#
# All rights reserved.
# This work should only be used for nonprofit purposes.
#
# By downloading and/or using any of these files, you implicitly agree to all the
# terms of the license, as specified in the document LICENSE.txt
# (included in this package) and online at
# http://www.grip.unina.it/download/LICENSE_OPEN.txt
"""
Created in September 2022
@author: davide.cozzolino
"""
import torch
import torch.nn.functional as F
def weighted_statistics_pooling(x, log_w=None):
b = x.shape[0]
c = x.shape[1]
x = x.view(b,c,-1)
if log_w is None:
log_w = torch.zeros((b,1,x.shape[-1]), device=x.device)
else:
assert log_w.shape[0]==b
assert log_w.shape[1]==1
log_w = log_w.view(b,1,-1)
assert log_w.shape[-1]==x.shape[-1]
log_w = F.log_softmax(log_w, dim=-1)
x_min = -torch.logsumexp(log_w-x, dim=-1)
x_max = torch.logsumexp(log_w+x, dim=-1)
w = torch.exp(log_w)
x_avg = torch.sum(w*x , dim=-1)
x_msq = torch.sum(w*x*x, dim=-1)
x = torch.cat((x_min, x_max, x_avg, x_msq), dim=1)
return x