Compare commits

...

2 Commits

Author SHA1 Message Date
9af36e7145 man python stinks... 2025-12-25 14:54:08 +00:00
23fddbe5b9 updated comment for PairwiseComparator 2025-12-25 14:53:35 +00:00
2 changed files with 3 additions and 3 deletions

View File

@@ -14,7 +14,7 @@ class NumberEmbedder(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
# 3) Comparator head: takes (ea, eb, e) -> logit for "a > b" # MLP Comparator head: takes (ea, eb, e) -> logit for "a > b"
class PairwiseComparator(nn.Module): class PairwiseComparator(nn.Module):
def __init__(self, d=2, hidden=4, k=0.5): def __init__(self, d=2, hidden=4, k=0.5):
super().__init__() super().__init__()

View File

@@ -163,7 +163,7 @@ def training_entry():
with lzma.open(TRAINING_LOG_PATH, mode='wt') as tlog: with lzma.open(TRAINING_LOG_PATH, mode='wt') as tlog:
# training loop # training loop
training_start_time = datetime.datetime.now() training_start_time = datetime.datetime.now()
last_ack = datetime.datetime.now(datetime.timezone.utc) last_ack = datetime.datetime.now()
for step in range(TRAIN_STEPS): for step in range(TRAIN_STEPS):
a, b, y = sample_batch(TRAIN_BATCHSZ) a, b, y = sample_batch(TRAIN_BATCHSZ)
@@ -184,7 +184,7 @@ def training_entry():
# also print to normal text log occasionally to show some activity. # also print to normal text log occasionally to show some activity.
# every 10 steps check if its been longer than 5 seconds since we've updated the user # every 10 steps check if its been longer than 5 seconds since we've updated the user
if step % 10 == 0: if step % 10 == 0:
if (datetime.datetime.now(datetime.timezone.utc) - last_ack).total_seconds() > 5: if (datetime.datetime.now() - last_ack).total_seconds() > 5:
log.info(f"still training... step={step} of {TRAIN_STEPS}") log.info(f"still training... step={step} of {TRAIN_STEPS}")
last_ack = datetime.datetime.now() last_ack = datetime.datetime.now()