diff --git a/models/choose_optimizer.py b/models/choose_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1504fa7ae702a278d55a1add628f1e6331ed1741 --- /dev/null +++ b/models/choose_optimizer.py @@ -0,0 +1,5 @@ +import torch.optim + +def optimizer(cfg): + result = {"Adam" : torch.optim.Adam(model.parameters())} + return result[cfg["Optimizer"]]