import torch.nn as nn
import torch

class RMSLELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
        
    def forward(self, pred, actual):
        resized_actual =  actual.view([actual.shape[0], actual.shape[1],1])
        return torch.sqrt(self.mse(torch.log(torch.add(pred,1)), torch.log(torch.add(resized_actual, 1))))