Files
mltoys/pairwise_compare.py
Elaina Claus cfef24921d moved output_graphs into main script and restructured
training log is seperate from normal output & is compressed
slightly adjusted lr
made final test stage test 4xTRAIN_BATCHSZ number of samples
2025-12-23 10:18:02 -05:00

255 lines
8.6 KiB
Python
Executable File

#!/usr/bin/env python3
# pairwise_compare.py
import logging
import random
import lzma
import torch
from torch import nn
from torch.nn import functional as F
import re
import pandas as pd
import matplotlib.pyplot as plt
import pairwise_comp_nn as comp_nn
# early pytorch device setup
DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else "cpu"
# Valves
DIMENSIONS = 2
TRAIN_STEPS = 5000
TRAIN_BATCHSZ = 8192
TRAIN_PROGRESS = 10
BATCH_LOWER = -100.0
BATCH_UPPER = 100.0
DO_VERBOSE_EARLY_TRAIN = False
# Files
MODEL_PATH = "./files/pwcomp.model"
LOGGING_PATH = "./files/output.log"
EMBED_CHART_PATH = "./files/embedding_chart.png"
EMBEDDINGS_DATA_PATH = "./files/embedding_data.csv"
TRAINING_LOG_PATH = "./files/training.log.xz"
def parse_training_log(file_path: str) -> pd.DataFrame:
text: str = ""
with lzma.open(file_path, mode='rt') as f:
text = f.read()
pattern = re.compile(r"step=\s*(\d+)\s+loss=([0-9.]+)\s+acc=([0-9.]+)")
rows = [(int(s), float(l), float(a)) for s, l, a in pattern.findall(text)]
df = pd.DataFrame(rows, columns=["step", "loss", "acc"]).sort_values("step").reset_index(drop=True)
# Avoid log(0) issues for loss plot by clamping at a tiny positive value
eps = 1e-10
df["loss_clamped"] = df["loss"].clip(lower=eps)
return df
def plt_loss_tstep(df: pd.DataFrame) -> None:
# Plot 1: Loss
plt.figure(figsize=(8, 4))
plt.plot(df["step"], df["loss_clamped"])
plt.yscale("log")
plt.xlabel("Step")
plt.ylabel("Loss (log scale)")
plt.title("Training Loss vs Step")
plt.tight_layout()
plt.savefig('./files/training_loss_v_step.png')
plt.close()
return None
def plt_acc_tstep(df: pd.DataFrame, eps=1e-10) -> None:
# Plot 2: Accuracy
df["err"] = (1.0 - df["acc"]).clip(lower=eps)
plt.figure(figsize=(8, 4))
plt.plot(df["step"], df["err"])
plt.yscale("log")
plt.xlabel("Step")
plt.ylabel("Error rate (1 - accuracy) (log scale)")
plt.title("Training Error Rate vs Step")
plt.tight_layout()
plt.savefig('./files/training_error_v_step.png')
plt.close()
return None
def plt_embeddings(model: comp_nn.PairwiseComparator) -> None:
import csv
log.info("Starting embeddings sweep...")
# samples for embedding mapping
with torch.no_grad():
xs = torch.arange(
BATCH_LOWER,
BATCH_UPPER + 1.0,
0.1,
).unsqueeze(1).to(DEVICE) # shape: (N, 1)
embeddings = model.embed(xs) # shape: (N, d)
# move data back to CPU for plotting
embeddings = embeddings.cpu()
xs = xs.cpu()
for i in range(embeddings.shape[1]):
plt.plot(xs.squeeze(), embeddings[:, i], label=f"dim {i}")
plt.legend()
plt.savefig(EMBED_CHART_PATH)
plt.close()
# save all our embeddings data to csv
csv_data = list(zip(xs.squeeze().tolist(), embeddings.tolist()))
with open(file=EMBEDDINGS_DATA_PATH, mode="w", newline='') as f:
csv_file = csv.writer(f)
csv_file.writerows(csv_data)
return None
def get_torch_info() -> None:
log.info("PyTorch Version: %s", torch.__version__)
log.info("HIP Version: %s", torch.version.hip)
log.info("CUDA support: %s", torch.cuda.is_available())
if torch.cuda.is_available():
log.info("CUDA device detected: %s", torch.cuda.get_device_name(0))
log.info("Using %s compute mode", DEVICE)
def set_seed(seed: int) -> None:
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# Data: pairs (a, b) with label y = 1 if a > b else 0 -> (a,b,y)
def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER, epsi=1e-4) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
a = (high - low) * torch.rand(batch_size, 1) + low
b = (high - low) * torch.rand(batch_size, 1) + low
epsi = 1e-4
y = torch.where(a > b + epsi, 1.0,
torch.where(a < b - epsi, 0.0, 0.5))
return a, b, y
def training_entry():
# all prng seeds to 0 for deterministic outputs durring testing
# the seed should initialized normally otherwise
set_seed(0)
model = comp_nn.PairwiseComparator(d=DIMENSIONS).to(DEVICE)
opt = torch.optim.AdamW(model.parameters(), lr=8e-4, weight_decay=1e-3)
log.info(f"Using {TRAINING_LOG_PATH} as the logging destination for training...")
with lzma.open(TRAINING_LOG_PATH, mode='wt') as tlog:
# training loop
training_start_time = datetime.datetime.now()
for step in range(TRAIN_STEPS):
a, b, y = sample_batch(TRAIN_BATCHSZ)
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE)
logits = model(a, b)
loss_fn = F.binary_cross_entropy_with_logits(logits, y)
opt.zero_grad()
loss_fn.backward()
opt.step()
if step % TRAIN_PROGRESS == 0:
with torch.no_grad():
pred = (torch.sigmoid(logits) > 0.5).float()
acc = (pred == y).float().mean().item()
tlog.write(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}\n")
# also print to normal text log occasionally to show some activity.
if step % 2500 == 0:
log.info(f"still training... step={step} of {TRAIN_STEPS}")
training_end_time = datetime.datetime.now()
log.info(f"Training steps complete. Start time: {training_start_time} End time: {training_end_time}")
# evaluate final model accuracy on fresh pairs
with torch.no_grad():
a, b, y = sample_batch(TRAIN_BATCHSZ*4)
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE)
logits = model(a, b)
pred = (torch.sigmoid(logits) > 0.5).float()
errors = (pred != y).sum().item()
acc = (pred == y).float().mean().item()
log.info(f"Final test acc: {acc} errors: {errors}")
# embed model dimensions into the model serialization
torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS}, MODEL_PATH)
log.info(f"Saved PyTorch Model State to {MODEL_PATH}")
def infer_entry():
model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
model = comp_nn.PairwiseComparator(d=model_ckpt["d"]).to(DEVICE)
model.load_state_dict(model_ckpt["state_dict"])
model.eval()
# sample pairs
pairs = [(1, 2), (10, 3), (5, 5), (10, 35), (-64, 11), (300, 162), (2, 0), (2, 1), (3, 1), (4, 1), (3, 10),(30, 1), (0, 0), (-162, 237),
(10, 20), (100, 30), (50, 50), (100, 350), (-640, 110), (30, -420), (200, 0), (92, 5), (30, 17), (42, 10), (30, 100),(30, 1), (0, 400), (-42, -42)]
a = torch.tensor([[p[0]] for p in pairs], dtype=torch.float32, device=DEVICE)
b = torch.tensor([[p[1]] for p in pairs], dtype=torch.float32, device=DEVICE)
# sanity check before inference
log.debug(f"a.device: {a.device} model.device: {next(model.parameters()).device}")
with torch.no_grad():
probs = torch.sigmoid(model(a, b))
log.info(f"Output probabilities for {pairs.__len__()} pairs")
for (x, y), p in zip(pairs, probs):
log.info(f"P({x} > {y}) = {p.item():.3f}")
plt_embeddings(model)
if __name__ == '__main__':
import sys
import os
import datetime
if not os.path.exists("./files/"):
os.mkdir("./files")
log = logging.getLogger(__name__)
logging.basicConfig(filename=LOGGING_PATH, level=logging.INFO)
log.info(f"Log opened {datetime.datetime.now()}")
get_torch_info()
name = os.path.basename(sys.argv[0])
if name == 'train.py':
training_entry()
elif name == 'infer.py':
infer_entry()
else:
# alt call patern
# python3 pairwise_compare.py train
# python3 pairwise_compare.py infer
# python3 pairwise_compare.py graphs
if len(sys.argv) > 1:
mode = sys.argv[1].strip().lower()
if mode == "train":
training_entry()
elif mode == "infer":
infer_entry()
elif mode == "graphs":
data = parse_training_log(TRAINING_LOG_PATH)
plt_loss_tstep(data)
plt_acc_tstep(data)
else:
log.error(f"Unknown operation: {mode}")
log.error("Invalid call syntax, call script as \"train.py\" or \"infer.py\" or as pairwise_compare.py <mode> where mode is \"train\" or \"infer\"")
else:
log.error("Not enough arguments passed to script; call as train.py or infer.py or as pairwise_compare.py <mode> where mode is \"train\" or \"infer\"")
log.info(f"Log closed {datetime.datetime.now()}")