Skip to content
Snippets Groups Projects
model.py 1.17 KiB
Newer Older
Yandi's avatar
Yandi committed
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function

Yandi's avatar
Yandi committed
import torch.nn as nn

class LinearRegression(nn.Module):
    def __init__(self, cfg, input_size):
        super(LinearRegression, self).__init__()
        self.input_size = input_size
        self.bias = cfg["LinearRegression"]["Bias"]
        self.regressor = nn.Linear(input_size, 1, self.bias)
        self.activate = nn.ReLU()
    def forward(self, x):
        y = self.regressor(x).view((x.shape[0],-1))
Yandi's avatar
Yandi committed
        return self.activate(y)

def build_model(cfg, input_size):    
Yandi's avatar
Yandi committed
    return eval(f"{cfg['Model']['Name']}(cfg, input_size)")

Yandi's avatar
Yandi committed
class ModelCheckpoint:
    def __init__(self, filepath, model):
        self.min_loss = None
        self.filepath = filepath
        self.model = model

    def update(self, loss):
        if (self.min_loss is None) or (loss < self.min_loss):
            print("Saving a better model")
            torch.save(self.model.state_dict(), self.filepath)
            self.min_loss = loss

if __name__== "__main__":
    import yaml
    config_file = open("config.yml","r")
    cfg = yaml.load(config_file)
    print(cfg['Model']['Name'])