Compare commits
2 Commits
acbccebb2c
...
9af36e7145
| Author | SHA1 | Date | |
|---|---|---|---|
| 9af36e7145 | |||
| 23fddbe5b9 |
@@ -14,7 +14,7 @@ class NumberEmbedder(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# 3) Comparator head: takes (ea, eb, e) -> logit for "a > b"
|
||||
# MLP Comparator head: takes (ea, eb, e) -> logit for "a > b"
|
||||
class PairwiseComparator(nn.Module):
|
||||
def __init__(self, d=2, hidden=4, k=0.5):
|
||||
super().__init__()
|
||||
|
||||
@@ -163,7 +163,7 @@ def training_entry():
|
||||
with lzma.open(TRAINING_LOG_PATH, mode='wt') as tlog:
|
||||
# training loop
|
||||
training_start_time = datetime.datetime.now()
|
||||
last_ack = datetime.datetime.now(datetime.timezone.utc)
|
||||
last_ack = datetime.datetime.now()
|
||||
|
||||
for step in range(TRAIN_STEPS):
|
||||
a, b, y = sample_batch(TRAIN_BATCHSZ)
|
||||
@@ -184,7 +184,7 @@ def training_entry():
|
||||
# also print to normal text log occasionally to show some activity.
|
||||
# every 10 steps check if its been longer than 5 seconds since we've updated the user
|
||||
if step % 10 == 0:
|
||||
if (datetime.datetime.now(datetime.timezone.utc) - last_ack).total_seconds() > 5:
|
||||
if (datetime.datetime.now() - last_ack).total_seconds() > 5:
|
||||
log.info(f"still training... step={step} of {TRAIN_STEPS}")
|
||||
last_ack = datetime.datetime.now()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user