import torch.nn as nn import torch class RMSLELoss(nn.Module): def __init__(self): super().__init__() self.mse = nn.MSELoss() def forward(self, pred, actual): resized_actual = actual.view([actual.shape[0], actual.shape[1],1]) return torch.sqrt(self.mse(torch.log(torch.add(pred,1)), torch.log(torch.add(resized_actual, 1))))