testing adadelta optimizer
This commit is contained in:
@@ -11,12 +11,12 @@ DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_availab
|
|||||||
|
|
||||||
# Valves
|
# Valves
|
||||||
DIMENSIONS = 1
|
DIMENSIONS = 1
|
||||||
TRAIN_STEPS = 20000
|
TRAIN_STEPS = 25000
|
||||||
TRAIN_BATCHSZ = 16384
|
TRAIN_BATCHSZ = 16384
|
||||||
TRAIN_PROGRESS = 500
|
TRAIN_PROGRESS = 100
|
||||||
BATCH_LOWER = -512.0
|
BATCH_LOWER = -512.0
|
||||||
BATCH_UPPER = 512.0
|
BATCH_UPPER = 512.0
|
||||||
DO_VERBOSE_EARLY_TRAIN = True
|
DO_VERBOSE_EARLY_TRAIN = False
|
||||||
|
|
||||||
def get_torch_info():
|
def get_torch_info():
|
||||||
log.info("PyTorch Version: %s", torch.__version__)
|
log.info("PyTorch Version: %s", torch.__version__)
|
||||||
@@ -86,7 +86,8 @@ def training_entry():
|
|||||||
set_seed(0)
|
set_seed(0)
|
||||||
|
|
||||||
model = PairwiseComparator(d=DIMENSIONS).to(DEVICE)
|
model = PairwiseComparator(d=DIMENSIONS).to(DEVICE)
|
||||||
opt = torch.optim.AdamW(model.parameters(), lr=2e-3)
|
# opt = torch.optim.AdamW(model.parameters(), lr=2e-3)
|
||||||
|
opt = torch.optim.Adadelta(model.parameters(), lr=1.0)
|
||||||
|
|
||||||
# 4) Train
|
# 4) Train
|
||||||
for step in range(TRAIN_STEPS):
|
for step in range(TRAIN_STEPS):
|
||||||
|
|||||||
Reference in New Issue
Block a user