renamed ADAMW_WiDECAY
This commit is contained in:
@@ -20,7 +20,7 @@ DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_availab
|
|||||||
DIMENSIONS = 2
|
DIMENSIONS = 2
|
||||||
HIDDEN_NEURONS = 4
|
HIDDEN_NEURONS = 4
|
||||||
ADAMW_LR = 5e-3
|
ADAMW_LR = 5e-3
|
||||||
ADAMW_WiDECAY = 5e-4
|
ADAMW_DECAY = 5e-4
|
||||||
TRAIN_STEPS = 2000
|
TRAIN_STEPS = 2000
|
||||||
TRAIN_BATCHSZ = 8192
|
TRAIN_BATCHSZ = 8192
|
||||||
TRAIN_PROGRESS = 10
|
TRAIN_PROGRESS = 10
|
||||||
@@ -157,13 +157,13 @@ def training_entry():
|
|||||||
set_seed(0)
|
set_seed(0)
|
||||||
|
|
||||||
model = comp_nn.PairwiseComparator(d=DIMENSIONS, hidden=HIDDEN_NEURONS).to(DEVICE)
|
model = comp_nn.PairwiseComparator(d=DIMENSIONS, hidden=HIDDEN_NEURONS).to(DEVICE)
|
||||||
opt = torch.optim.AdamW(model.parameters(), lr=ADAMW_LR, weight_decay=ADAMW_WiDECAY)
|
opt = torch.optim.AdamW(model.parameters(), lr=ADAMW_LR, weight_decay=ADAMW_DECAY)
|
||||||
|
|
||||||
log.info(f"Using {TRAINING_LOG_PATH} as the logging destination for training...")
|
log.info(f"Using {TRAINING_LOG_PATH} as the logging destination for training...")
|
||||||
with lzma.open(TRAINING_LOG_PATH, mode='wt') as tlog:
|
with lzma.open(TRAINING_LOG_PATH, mode='wt') as tlog:
|
||||||
# training loop
|
# training loop
|
||||||
training_start_time = datetime.datetime.now()
|
training_start_time = datetime.datetime.now()
|
||||||
last_ack = datetime.datetime.now()
|
last_ack = datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
|
||||||
for step in range(TRAIN_STEPS):
|
for step in range(TRAIN_STEPS):
|
||||||
a, b, y = sample_batch(TRAIN_BATCHSZ)
|
a, b, y = sample_batch(TRAIN_BATCHSZ)
|
||||||
@@ -182,9 +182,9 @@ def training_entry():
|
|||||||
tlog.write(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}\n")
|
tlog.write(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}\n")
|
||||||
|
|
||||||
# also print to normal text log occasionally to show some activity.
|
# also print to normal text log occasionally to show some activity.
|
||||||
# every 100 steps check if its been longer than 5 seconds since we've updated the user
|
# every 10 steps check if its been longer than 5 seconds since we've updated the user
|
||||||
if step % 100 == 0:
|
if step % 10 == 0:
|
||||||
if (datetime.datetime.now() - last_ack).total_seconds() > 5:
|
if (datetime.datetime.now(datetime.timezone.utc) - last_ack).total_seconds() > 5:
|
||||||
log.info(f"still training... step={step} of {TRAIN_STEPS}")
|
log.info(f"still training... step={step} of {TRAIN_STEPS}")
|
||||||
last_ack = datetime.datetime.now()
|
last_ack = datetime.datetime.now()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user