all files (logs/models/images) go into ./files
using AdamW optim again, AdamW is the go to for small toys and transformers. refactored NN classes to thier own module under pairwise_comp_nn.py
This commit is contained in:
@@ -6,6 +6,8 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
import pairwise_comp_nn as comp_nn
|
||||
|
||||
# early pytorch device setup
|
||||
DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else "cpu"
|
||||
|
||||
@@ -17,6 +19,8 @@ TRAIN_PROGRESS = 100
|
||||
BATCH_LOWER = -512.0
|
||||
BATCH_UPPER = 512.0
|
||||
DO_VERBOSE_EARLY_TRAIN = False
|
||||
MODEL_PATH = "./files/pwcomp.model"
|
||||
LOGGING_PATH = "./files/output.log"
|
||||
|
||||
def get_torch_info():
|
||||
log.info("PyTorch Version: %s", torch.__version__)
|
||||
@@ -47,47 +51,14 @@ def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER):
|
||||
# eq = (a == b).float()
|
||||
# y = gt + 0.5 * eq
|
||||
return a, b, y
|
||||
|
||||
# 2) Number "embedding" network: R -> R^d
|
||||
class NumberEmbedder(nn.Module):
|
||||
def __init__(self, d=8):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(1, 16),
|
||||
nn.ReLU(),
|
||||
nn.Linear(16, d),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# 3) Comparator head: takes (ea, eb) -> logit for "a > b"
|
||||
class PairwiseComparator(nn.Module):
|
||||
def __init__(self, d=8):
|
||||
super().__init__()
|
||||
self.embed = NumberEmbedder(d)
|
||||
self.head = nn.Sequential(
|
||||
nn.Linear(2 * d + 1, 16),
|
||||
nn.ReLU(),
|
||||
nn.Linear(16, 1),
|
||||
)
|
||||
|
||||
def forward(self, a, b):
|
||||
ea = self.embed(a)
|
||||
eb = self.embed(b)
|
||||
delta_ab = a - b
|
||||
x = torch.cat([ea, eb, delta_ab], dim=-1)
|
||||
|
||||
return self.head(x) # logits
|
||||
|
||||
def training_entry():
|
||||
# all prng seeds to 0 for deterministic outputs durring testing
|
||||
# the seed should initialized normally otherwise
|
||||
set_seed(0)
|
||||
|
||||
model = PairwiseComparator(d=DIMENSIONS).to(DEVICE)
|
||||
# opt = torch.optim.AdamW(model.parameters(), lr=2e-3)
|
||||
opt = torch.optim.Adadelta(model.parameters(), lr=1.0)
|
||||
model = comp_nn.PairwiseComparator(d=DIMENSIONS).to(DEVICE)
|
||||
opt = torch.optim.AdamW(model.parameters(), lr=9e-4, weight_decay=1e-3)
|
||||
|
||||
# 4) Train
|
||||
for step in range(TRAIN_STEPS):
|
||||
@@ -112,7 +83,7 @@ def training_entry():
|
||||
acc = (pred == y).float().mean().item()
|
||||
log.info(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}")
|
||||
|
||||
# 5) Quick test: evaluate accuracy on fresh pairs
|
||||
# 5) Quick test: evaluate final model accuracy on fresh pairs
|
||||
with torch.no_grad():
|
||||
a, b, y = sample_batch(TRAIN_BATCHSZ)
|
||||
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE)
|
||||
@@ -122,13 +93,13 @@ def training_entry():
|
||||
acc = (pred == y).float().mean().item()
|
||||
log.info(f"Final test acc: {acc} errors: {errors}")
|
||||
|
||||
# embed model depth into the model serialization
|
||||
torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS}, "model.pth")
|
||||
# embed model dimensions into the model serialization
|
||||
torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS}, MODEL_PATH)
|
||||
log.info("Saved PyTorch Model State to model.pth")
|
||||
|
||||
def infer_entry():
|
||||
model_ckpt = torch.load("model.pth", map_location=DEVICE)
|
||||
model = PairwiseComparator(d=model_ckpt["d"]).to(DEVICE)
|
||||
model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
|
||||
model = comp_nn.PairwiseComparator(d=model_ckpt["d"]).to(DEVICE)
|
||||
model.load_state_dict(model_ckpt["state_dict"])
|
||||
model.eval()
|
||||
|
||||
@@ -139,7 +110,7 @@ def infer_entry():
|
||||
b = torch.tensor([[p[1]] for p in pairs], dtype=torch.float32, device=DEVICE)
|
||||
|
||||
# sanity check before inference
|
||||
log.info(f"a.device: {a.device} model.device: {next(model.parameters()).device}")
|
||||
log.debug(f"a.device: {a.device} model.device: {next(model.parameters()).device}")
|
||||
|
||||
with torch.no_grad():
|
||||
probs = torch.sigmoid(model(a, b))
|
||||
@@ -152,10 +123,13 @@ if __name__ == '__main__':
|
||||
import os
|
||||
import datetime
|
||||
|
||||
if not os.path.exists("./files/"):
|
||||
os.mkdir("./files")
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logging.basicConfig(filename='pairwise_compare.log', level=logging.INFO)
|
||||
logging.basicConfig(filename=LOGGING_PATH, level=logging.INFO)
|
||||
|
||||
log.info(f"Log opened {datetime.datetime.now()}")
|
||||
|
||||
get_torch_info()
|
||||
|
||||
name = os.path.basename(sys.argv[0])
|
||||
|
||||
Reference in New Issue
Block a user