renamed ADAMW_WiDECAY

This commit is contained in:
2025-12-25 14:48:07 +00:00
parent edf8d46123
commit acbccebb2c

View File

@@ -20,7 +20,7 @@ DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_availab
DIMENSIONS = 2
HIDDEN_NEURONS = 4
ADAMW_LR = 5e-3
ADAMW_WiDECAY = 5e-4
ADAMW_DECAY = 5e-4
TRAIN_STEPS = 2000
TRAIN_BATCHSZ = 8192
TRAIN_PROGRESS = 10
@@ -157,13 +157,13 @@ def training_entry():
set_seed(0)
model = comp_nn.PairwiseComparator(d=DIMENSIONS, hidden=HIDDEN_NEURONS).to(DEVICE)
opt = torch.optim.AdamW(model.parameters(), lr=ADAMW_LR, weight_decay=ADAMW_WiDECAY)
opt = torch.optim.AdamW(model.parameters(), lr=ADAMW_LR, weight_decay=ADAMW_DECAY)
log.info(f"Using {TRAINING_LOG_PATH} as the logging destination for training...")
with lzma.open(TRAINING_LOG_PATH, mode='wt') as tlog:
# training loop
training_start_time = datetime.datetime.now()
last_ack = datetime.datetime.now()
last_ack = datetime.datetime.now(datetime.timezone.utc)
for step in range(TRAIN_STEPS):
a, b, y = sample_batch(TRAIN_BATCHSZ)
@@ -182,9 +182,9 @@ def training_entry():
tlog.write(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}\n")
# also print to normal text log occasionally to show some activity.
# every 100 steps check if its been longer than 5 seconds since we've updated the user
if step % 100 == 0:
if (datetime.datetime.now() - last_ack).total_seconds() > 5:
# every 10 steps check if its been longer than 5 seconds since we've updated the user
if step % 10 == 0:
if (datetime.datetime.now(datetime.timezone.utc) - last_ack).total_seconds() > 5:
log.info(f"still training... step={step} of {TRAIN_STEPS}")
last_ack = datetime.datetime.now()