all files (logs/models/images) go into ./files

using AdamW optim again, AdamW is the go to for small toys and transformers.
refactored NN classes to thier own module under pairwise_comp_nn.py
This commit is contained in:
2025-12-19 09:15:41 -05:00
parent cfcec07b9c
commit 0d6a92823a
3 changed files with 302 additions and 294 deletions

34
pairwise_comp_nn.py Normal file
View File

@@ -0,0 +1,34 @@
import torch
from torch import nn
# 2) Number "embedding" network: R -> R^d
class NumberEmbedder(nn.Module):
def __init__(self, d=8):
super().__init__()
self.net = nn.Sequential(
nn.Linear(1, 16),
nn.ReLU(),
nn.Linear(16, d),
)
def forward(self, x):
return self.net(x)
# 3) Comparator head: takes (ea, eb) -> logit for "a > b"
class PairwiseComparator(nn.Module):
def __init__(self, d=8):
super().__init__()
self.embed = NumberEmbedder(d)
self.head = nn.Sequential(
nn.Linear(2 * d + 1, 16),
nn.ReLU(),
nn.Linear(16, 1),
)
def forward(self, a, b):
ea = self.embed(a)
eb = self.embed(b)
delta_ab = a - b
x = torch.cat([ea, eb, delta_ab], dim=-1)
return self.head(x) # logits