all files (logs/models/images) go into ./files

using AdamW optim again, AdamW is the go to for small toys and transformers.
refactored NN classes to thier own module under pairwise_comp_nn.py
This commit is contained in:
2025-12-19 09:15:41 -05:00
parent cfcec07b9c
commit 0d6a92823a
3 changed files with 302 additions and 294 deletions

View File

@@ -6,6 +6,8 @@ import torch
from torch import nn
from torch.nn import functional as F
import pairwise_comp_nn as comp_nn
# early pytorch device setup
DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else "cpu"
@@ -17,6 +19,8 @@ TRAIN_PROGRESS = 100
BATCH_LOWER = -512.0
BATCH_UPPER = 512.0
DO_VERBOSE_EARLY_TRAIN = False
MODEL_PATH = "./files/pwcomp.model"
LOGGING_PATH = "./files/output.log"
def get_torch_info():
log.info("PyTorch Version: %s", torch.__version__)
@@ -47,47 +51,14 @@ def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER):
# eq = (a == b).float()
# y = gt + 0.5 * eq
return a, b, y
# 2) Number "embedding" network: R -> R^d
class NumberEmbedder(nn.Module):
def __init__(self, d=8):
super().__init__()
self.net = nn.Sequential(
nn.Linear(1, 16),
nn.ReLU(),
nn.Linear(16, d),
)
def forward(self, x):
return self.net(x)
# 3) Comparator head: takes (ea, eb) -> logit for "a > b"
class PairwiseComparator(nn.Module):
def __init__(self, d=8):
super().__init__()
self.embed = NumberEmbedder(d)
self.head = nn.Sequential(
nn.Linear(2 * d + 1, 16),
nn.ReLU(),
nn.Linear(16, 1),
)
def forward(self, a, b):
ea = self.embed(a)
eb = self.embed(b)
delta_ab = a - b
x = torch.cat([ea, eb, delta_ab], dim=-1)
return self.head(x) # logits
def training_entry():
# all prng seeds to 0 for deterministic outputs durring testing
# the seed should initialized normally otherwise
set_seed(0)
model = PairwiseComparator(d=DIMENSIONS).to(DEVICE)
# opt = torch.optim.AdamW(model.parameters(), lr=2e-3)
opt = torch.optim.Adadelta(model.parameters(), lr=1.0)
model = comp_nn.PairwiseComparator(d=DIMENSIONS).to(DEVICE)
opt = torch.optim.AdamW(model.parameters(), lr=9e-4, weight_decay=1e-3)
# 4) Train
for step in range(TRAIN_STEPS):
@@ -112,7 +83,7 @@ def training_entry():
acc = (pred == y).float().mean().item()
log.info(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}")
# 5) Quick test: evaluate accuracy on fresh pairs
# 5) Quick test: evaluate final model accuracy on fresh pairs
with torch.no_grad():
a, b, y = sample_batch(TRAIN_BATCHSZ)
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE)
@@ -122,13 +93,13 @@ def training_entry():
acc = (pred == y).float().mean().item()
log.info(f"Final test acc: {acc} errors: {errors}")
# embed model depth into the model serialization
torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS}, "model.pth")
# embed model dimensions into the model serialization
torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS}, MODEL_PATH)
log.info("Saved PyTorch Model State to model.pth")
def infer_entry():
model_ckpt = torch.load("model.pth", map_location=DEVICE)
model = PairwiseComparator(d=model_ckpt["d"]).to(DEVICE)
model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
model = comp_nn.PairwiseComparator(d=model_ckpt["d"]).to(DEVICE)
model.load_state_dict(model_ckpt["state_dict"])
model.eval()
@@ -139,7 +110,7 @@ def infer_entry():
b = torch.tensor([[p[1]] for p in pairs], dtype=torch.float32, device=DEVICE)
# sanity check before inference
log.info(f"a.device: {a.device} model.device: {next(model.parameters()).device}")
log.debug(f"a.device: {a.device} model.device: {next(model.parameters()).device}")
with torch.no_grad():
probs = torch.sigmoid(model(a, b))
@@ -152,10 +123,13 @@ if __name__ == '__main__':
import os
import datetime
if not os.path.exists("./files/"):
os.mkdir("./files")
log = logging.getLogger(__name__)
logging.basicConfig(filename='pairwise_compare.log', level=logging.INFO)
logging.basicConfig(filename=LOGGING_PATH, level=logging.INFO)
log.info(f"Log opened {datetime.datetime.now()}")
get_torch_info()
name = os.path.basename(sys.argv[0])