182 lines
6.3 KiB
Python
Executable File
182 lines
6.3 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# pairwise_compare.py
|
|
import logging
|
|
import random
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
# early pytorch device setup
|
|
DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else "cpu"
|
|
|
|
# Valves
|
|
DIMENSIONS = 1
|
|
TRAIN_STEPS = 25000
|
|
TRAIN_BATCHSZ = 16384
|
|
TRAIN_PROGRESS = 100
|
|
BATCH_LOWER = -512.0
|
|
BATCH_UPPER = 512.0
|
|
DO_VERBOSE_EARLY_TRAIN = False
|
|
|
|
def get_torch_info():
|
|
log.info("PyTorch Version: %s", torch.__version__)
|
|
log.info("HIP Version: %s", torch.version.hip)
|
|
log.info("CUDA support: %s", torch.cuda.is_available())
|
|
|
|
if torch.cuda.is_available():
|
|
log.info("CUDA device detected: %s", torch.cuda.get_device_name(0))
|
|
|
|
log.info("Using %s compute mode", DEVICE)
|
|
|
|
def set_seed(seed: int):
|
|
random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
# 1) Data: pairs (a, b) with label y = 1 if a > b else 0
|
|
def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER):
|
|
a = (high - low) * torch.rand(batch_size, 1) + low
|
|
b = (high - low) * torch.rand(batch_size, 1) + low
|
|
|
|
# train for if a > b
|
|
y = (a > b).float()
|
|
|
|
# removed but left for my notes; it seems training for equality hurts classifing results that are ~eq
|
|
# when trained only on "if a > b => y", the model produces more accurate results when classifing if things are equal (~.5 prob).
|
|
# eq = (a == b).float()
|
|
# y = gt + 0.5 * eq
|
|
return a, b, y
|
|
|
|
# 2) Number "embedding" network: R -> R^d
|
|
class NumberEmbedder(nn.Module):
|
|
def __init__(self, d=8):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
nn.Linear(1, 16),
|
|
nn.ReLU(),
|
|
nn.Linear(16, d),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
# 3) Comparator head: takes (ea, eb) -> logit for "a > b"
|
|
class PairwiseComparator(nn.Module):
|
|
def __init__(self, d=8):
|
|
super().__init__()
|
|
self.embed = NumberEmbedder(d)
|
|
self.head = nn.Sequential(
|
|
nn.Linear(2 * d + 1, 16),
|
|
nn.ReLU(),
|
|
nn.Linear(16, 1),
|
|
)
|
|
|
|
def forward(self, a, b):
|
|
ea = self.embed(a)
|
|
eb = self.embed(b)
|
|
delta_ab = a - b
|
|
x = torch.cat([ea, eb, delta_ab], dim=-1)
|
|
|
|
return self.head(x) # logits
|
|
|
|
def training_entry():
|
|
# all prng seeds to 0 for deterministic outputs durring testing
|
|
# the seed should initialized normally otherwise
|
|
set_seed(0)
|
|
|
|
model = PairwiseComparator(d=DIMENSIONS).to(DEVICE)
|
|
# opt = torch.optim.AdamW(model.parameters(), lr=2e-3)
|
|
opt = torch.optim.Adadelta(model.parameters(), lr=1.0)
|
|
|
|
# 4) Train
|
|
for step in range(TRAIN_STEPS):
|
|
a, b, y = sample_batch(TRAIN_BATCHSZ)
|
|
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE)
|
|
|
|
logits = model(a, b)
|
|
loss_fn = F.binary_cross_entropy_with_logits(logits, y)
|
|
|
|
opt.zero_grad()
|
|
loss_fn.backward()
|
|
opt.step()
|
|
|
|
if step <= TRAIN_PROGRESS and DO_VERBOSE_EARLY_TRAIN is True:
|
|
with torch.no_grad():
|
|
pred = (torch.sigmoid(logits) > 0.5).float()
|
|
acc = (pred == y).float().mean().item()
|
|
log.info(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}")
|
|
elif step % TRAIN_PROGRESS == 0:
|
|
with torch.no_grad():
|
|
pred = (torch.sigmoid(logits) > 0.5).float()
|
|
acc = (pred == y).float().mean().item()
|
|
log.info(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}")
|
|
|
|
# 5) Quick test: evaluate accuracy on fresh pairs
|
|
with torch.no_grad():
|
|
a, b, y = sample_batch(TRAIN_BATCHSZ)
|
|
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE)
|
|
logits = model(a, b)
|
|
pred = (torch.sigmoid(logits) > 0.5).float()
|
|
errors = (pred != y).sum().item()
|
|
acc = (pred == y).float().mean().item()
|
|
log.info(f"Final test acc: {acc} errors: {errors}")
|
|
|
|
# embed model depth into the model serialization
|
|
torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS}, "model.pth")
|
|
log.info("Saved PyTorch Model State to model.pth")
|
|
|
|
def infer_entry():
|
|
model_ckpt = torch.load("model.pth", map_location=DEVICE)
|
|
model = PairwiseComparator(d=model_ckpt["d"]).to(DEVICE)
|
|
model.load_state_dict(model_ckpt["state_dict"])
|
|
model.eval()
|
|
|
|
# sample pairs
|
|
pairs = [(1, 2), (10, 3), (5, 5), (10, 35), (-64, 11), (300, 162), (2, 0), (2, 1), (3, 1), (4, 1), (3, 10),(30, 1), (0, 0), (-162, 237),
|
|
(10, 20), (100, 30), (50, 50), (100, 350), (-640, 110), (30, -420), (200, 0), (92, 5), (30, 17), (42, 10), (30, 100),(30, 1), (0, 400), (-42, -42)]
|
|
a = torch.tensor([[p[0]] for p in pairs], dtype=torch.float32, device=DEVICE)
|
|
b = torch.tensor([[p[1]] for p in pairs], dtype=torch.float32, device=DEVICE)
|
|
|
|
# sanity check before inference
|
|
log.info(f"a.device: {a.device} model.device: {next(model.parameters()).device}")
|
|
|
|
with torch.no_grad():
|
|
probs = torch.sigmoid(model(a, b))
|
|
|
|
for (x, y), p in zip(pairs, probs):
|
|
log.info(f"P({x} > {y}) = {p.item():.3f}")
|
|
|
|
if __name__ == '__main__':
|
|
import sys
|
|
import os
|
|
import datetime
|
|
|
|
log = logging.getLogger(__name__)
|
|
logging.basicConfig(filename='pairwise_compare.log', level=logging.INFO)
|
|
log.info(f"Log opened {datetime.datetime.now()}")
|
|
|
|
get_torch_info()
|
|
|
|
name = os.path.basename(sys.argv[0])
|
|
if name == 'train.py':
|
|
training_entry()
|
|
elif name == 'infer.py':
|
|
infer_entry()
|
|
else:
|
|
# alt call patern
|
|
# python3 pairwise_compare.py train
|
|
# python3 pairwise_compare.py infer
|
|
if len(sys.argv) > 1:
|
|
mode = sys.argv[1].strip().lower()
|
|
if mode == "train":
|
|
training_entry()
|
|
elif mode == "infer":
|
|
infer_entry()
|
|
else:
|
|
log.error(f"Unknown operation: {mode}")
|
|
log.error("Invalid call syntax, call script as \"train.py\" or \"infer.py\" or as pairwise_compare.py <mode> where mode is \"train\" or \"infer\"")
|
|
else:
|
|
log.error("Not enough arguments passed to script; call as train.py or infer.py or as pairwise_compare.py <mode> where mode is \"train\" or \"infer\"")
|
|
|
|
log.info(f"Log closed {datetime.datetime.now()}") |