Compare commits

...

13 Commits

Author SHA1 Message Date
9af36e7145 man python stinks... 2025-12-25 14:54:08 +00:00
23fddbe5b9 updated comment for PairwiseComparator 2025-12-25 14:53:35 +00:00
acbccebb2c renamed ADAMW_WiDECAY 2025-12-25 14:48:07 +00:00
edf8d46123 wut 2025-12-25 02:13:23 +00:00
4f500e8b4c write all data on training to training log 2025-12-25 02:07:49 +00:00
921e24b451 embeded # of hidden neurons in model save data
added more valves for options.
will update the console every 5 seconds when training now
2025-12-25 01:52:19 +00:00
d46712ff53 make the defaults more sane for the task
its still overkill
2025-12-25 01:23:54 +00:00
cd72cd7052 changes to figure sizes and add labels to embedding chart 2025-12-23 13:44:38 -05:00
755161c152 unused var 2025-12-23 13:29:52 -05:00
1d70935b64 general clean up & added help text
removed symbolic link calling path
2025-12-23 13:06:07 -05:00
0e2098ceec remove linked files 2025-12-23 12:30:18 -05:00
cfef24921d moved output_graphs into main script and restructured
training log is seperate from normal output & is compressed
slightly adjusted lr
made final test stage test 4xTRAIN_BATCHSZ number of samples
2025-12-23 10:18:02 -05:00
9ea8ef3458 default k=0.5 2025-12-23 10:14:08 -05:00
5 changed files with 196 additions and 101 deletions

View File

@@ -1 +0,0 @@
pairwise_compare.py

View File

@@ -1,37 +0,0 @@
import re
import pandas as pd
import matplotlib.pyplot as plt
text = r"""
"""
pattern = re.compile(r"step=\s*(\d+)\s+loss=([0-9.]+)\s+acc=([0-9.]+)")
rows = [(int(s), float(l), float(a)) for s, l, a in pattern.findall(text)]
df = pd.DataFrame(rows, columns=["step", "loss", "acc"]).sort_values("step").reset_index(drop=True)
# Avoid log(0) issues for loss plot by clamping at a tiny positive value
eps = 1e-10
df["loss_clamped"] = df["loss"].clip(lower=eps)
# Plot 1: Loss
plt.figure(figsize=(9, 4.8))
plt.plot(df["step"], df["loss_clamped"])
plt.yscale("log")
plt.xlabel("Step")
plt.ylabel("Loss (log scale)")
plt.title("Training Loss vs Step")
plt.tight_layout()
plt.savefig('./files/training_loss_v_step.png')
plt.show()
# Plot 2: Accuracy
df["err"] = (1.0 - df["acc"]).clip(lower=eps)
plt.figure(figsize=(9, 4.8))
plt.plot(df["step"], df["err"])
plt.yscale("log")
plt.xlabel("Step")
plt.ylabel("Error rate (1 - accuracy) (log scale)")
plt.title("Training Error Rate vs Step")
plt.tight_layout()
plt.savefig('./files/training_error_v_step.png')
plt.show()

View File

@@ -3,7 +3,7 @@ from torch import nn
# 2) Number "embedding" network: R -> R^d # 2) Number "embedding" network: R -> R^d
class NumberEmbedder(nn.Module): class NumberEmbedder(nn.Module):
def __init__(self, d=4, hidden=16): def __init__(self, d=2, hidden=4):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
nn.Linear(1, hidden), nn.Linear(1, hidden),
@@ -14,9 +14,9 @@ class NumberEmbedder(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
# 3) Comparator head: takes (ea, eb, e) -> logit for "a > b" # MLP Comparator head: takes (ea, eb, e) -> logit for "a > b"
class PairwiseComparator(nn.Module): class PairwiseComparator(nn.Module):
def __init__(self, d=4, hidden=16, k=1.0): def __init__(self, d=2, hidden=4, k=0.5):
super().__init__() super().__init__()
self.log_k = nn.Parameter(torch.tensor([k])) self.log_k = nn.Parameter(torch.tensor([k]))
self.embed = NumberEmbedder(d, hidden) self.embed = NumberEmbedder(d, hidden)
@@ -29,7 +29,7 @@ class PairwiseComparator(nn.Module):
) )
def forward(self, a, b): def forward(self, a, b):
# trying to force antisym here: h(a,b)=h(b,a) # trying to force antisym here: h(a,b)=-h(b,a)
phi = self.head(self.embed(a-b)) phi = self.head(self.embed(a-b))
phi_neg = self.head(self.embed(b-a)) phi_neg = self.head(self.embed(b-a))
logit = phi - phi_neg logit = phi - phi_neg

View File

@@ -2,11 +2,15 @@
# pairwise_compare.py # pairwise_compare.py
import logging import logging
import random import random
import lzma
import torch import torch
from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
import re
import pandas as pd
import matplotlib.pyplot as plt
import pairwise_comp_nn as comp_nn import pairwise_comp_nn as comp_nn
# early pytorch device setup # early pytorch device setup
@@ -14,19 +18,73 @@ DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_availab
# Valves # Valves
DIMENSIONS = 2 DIMENSIONS = 2
TRAIN_STEPS = 5000 HIDDEN_NEURONS = 4
ADAMW_LR = 5e-3
ADAMW_DECAY = 5e-4
TRAIN_STEPS = 2000
TRAIN_BATCHSZ = 8192 TRAIN_BATCHSZ = 8192
TRAIN_PROGRESS = 10 TRAIN_PROGRESS = 10
BATCH_LOWER = -100.0 BATCH_LOWER = -100.0
BATCH_UPPER = 100.0 BATCH_UPPER = 100.0
DO_VERBOSE_EARLY_TRAIN = False
# Files
MODEL_PATH = "./files/pwcomp.model" MODEL_PATH = "./files/pwcomp.model"
LOGGING_PATH = "./files/output.log" LOGGING_PATH = "./files/output.log"
EMBED_CHART_PATH = "./files/embedding_chart.png" EMBED_CHART_PATH = "./files/embedding_chart.png"
EMBEDDINGS_DATA = "./files/embedding_data.csv" EMBEDDINGS_DATA_PATH = "./files/embedding_data.csv"
TRAINING_LOG_PATH = "./files/training.log.xz"
LOSS_CHART_PATH = "./files/training_loss_v_step.png"
ACC_CHART_PATH = "./files/training_error_v_step.png"
def plt_embeddings(model: comp_nn.PairwiseComparator): # TODO: Move plotting into its own file
import matplotlib.pyplot as plt def parse_training_log(file_path: str) -> pd.DataFrame:
text: str = ""
with lzma.open(file_path, mode='rt') as f:
text = f.read()
pattern = re.compile(r"step=\s*(\d+)\s+loss=([0-9.]+)\s+acc=([0-9.]+)")
rows = [(int(s), float(l), float(a)) for s, l, a in pattern.findall(text)]
df = pd.DataFrame(rows, columns=["step", "loss", "acc"]).sort_values("step").reset_index(drop=True)
# Avoid log(0) issues for loss plot by clamping at a tiny positive value
eps = 1e-10
df["loss_clamped"] = df["loss"].clip(lower=eps)
return df
# TODO: Move plotting into its own file
def plt_loss_tstep(df: pd.DataFrame) -> None:
# Plot 1: Loss
plt.figure(figsize=(10, 6))
plt.plot(df["step"], df["loss_clamped"])
plt.yscale("log")
plt.xlabel("Step")
plt.ylabel("Loss (log scale)")
plt.title("Training Loss vs Step")
plt.tight_layout()
plt.savefig(LOSS_CHART_PATH)
plt.close()
return None
# TODO: Move plotting into its own file
def plt_acc_tstep(df: pd.DataFrame, eps=1e-10) -> None:
# Plot 2: Accuracy
df["err"] = (1.0 - df["acc"]).clip(lower=eps)
plt.figure(figsize=(10, 6))
plt.plot(df["step"], df["err"])
plt.yscale("log")
plt.xlabel("Step")
plt.ylabel("Error rate (1 - accuracy) (log scale)")
plt.title("Training Error Rate vs Step")
plt.tight_layout()
plt.savefig(ACC_CHART_PATH)
plt.close()
return None
# TODO: Move plotting into its own file
def plt_embeddings(model: comp_nn.PairwiseComparator) -> None:
import csv import csv
log.info("Starting embeddings sweep...") log.info("Starting embeddings sweep...")
@@ -35,7 +93,7 @@ def plt_embeddings(model: comp_nn.PairwiseComparator):
xs = torch.arange( xs = torch.arange(
BATCH_LOWER, BATCH_LOWER,
BATCH_UPPER + 1.0, BATCH_UPPER + 1.0,
1.0, 0.1,
).unsqueeze(1).to(DEVICE) # shape: (N, 1) ).unsqueeze(1).to(DEVICE) # shape: (N, 1)
embeddings = model.embed(xs) # shape: (N, d) embeddings = model.embed(xs) # shape: (N, d)
@@ -43,19 +101,27 @@ def plt_embeddings(model: comp_nn.PairwiseComparator):
# move data back to CPU for plotting # move data back to CPU for plotting
embeddings = embeddings.cpu() embeddings = embeddings.cpu()
xs = xs.cpu() xs = xs.cpu()
# Plot 3: x vs h(x)
plt.figure(figsize=(10, 6))
for i in range(embeddings.shape[1]): for i in range(embeddings.shape[1]):
plt.plot(xs.squeeze(), embeddings[:, i], label=f"dim {i}") plt.plot(xs.squeeze(), embeddings[:, i], label=f"dim {i}")
plt.title("x vs h(x)")
plt.xlabel("x [input]")
plt.ylabel("h(x) [embedding]")
plt.legend() plt.legend()
plt.savefig(EMBED_CHART_PATH) plt.savefig(EMBED_CHART_PATH)
#plt.show() plt.close()
# save all our embeddings data to csv
csv_data = list(zip(xs.squeeze().tolist(), embeddings.tolist())) csv_data = list(zip(xs.squeeze().tolist(), embeddings.tolist()))
with open(file=EMBEDDINGS_DATA, mode="w", newline='') as f: with open(file=EMBEDDINGS_DATA_PATH, mode="w", newline='') as f:
csv_file = csv.writer(f) csv_file = csv.writer(f)
csv_file.writerows(csv_data) csv_file.writerows(csv_data)
return None
def get_torch_info(): def get_torch_info() -> None:
log.info("PyTorch Version: %s", torch.__version__) log.info("PyTorch Version: %s", torch.__version__)
log.info("HIP Version: %s", torch.version.hip) log.info("HIP Version: %s", torch.version.hip)
log.info("CUDA support: %s", torch.cuda.is_available()) log.info("CUDA support: %s", torch.cuda.is_available())
@@ -65,14 +131,15 @@ def get_torch_info():
log.info("Using %s compute mode", DEVICE) log.info("Using %s compute mode", DEVICE)
def set_seed(seed: int): def set_seed(seed: int) -> None:
random.seed(seed) random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
# 1) Data: pairs (a, b) with label y = 1 if a > b else 0 # pairs (a, b) with label y = 1 if a > b else 0 -> (a,b,y)
def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER): # uses epsi to select the window in which a == b for equality training
def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER, epsi=1e-4) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
a = (high - low) * torch.rand(batch_size, 1) + low a = (high - low) * torch.rand(batch_size, 1) + low
b = (high - low) * torch.rand(batch_size, 1) + low b = (high - low) * torch.rand(batch_size, 1) + low
@@ -83,39 +150,50 @@ def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER):
return a, b, y return a, b, y
def training_entry(): def training_entry():
get_torch_info()
# all prng seeds to 0 for deterministic outputs durring testing # all prng seeds to 0 for deterministic outputs durring testing
# the seed should initialized normally otherwise # the seed should initialized normally otherwise
set_seed(0) set_seed(0)
model = comp_nn.PairwiseComparator(d=DIMENSIONS).to(DEVICE) model = comp_nn.PairwiseComparator(d=DIMENSIONS, hidden=HIDDEN_NEURONS).to(DEVICE)
opt = torch.optim.AdamW(model.parameters(), lr=9e-4, weight_decay=1e-3) opt = torch.optim.AdamW(model.parameters(), lr=ADAMW_LR, weight_decay=ADAMW_DECAY)
# 4) Train log.info(f"Using {TRAINING_LOG_PATH} as the logging destination for training...")
for step in range(TRAIN_STEPS): with lzma.open(TRAINING_LOG_PATH, mode='wt') as tlog:
a, b, y = sample_batch(TRAIN_BATCHSZ) # training loop
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE) training_start_time = datetime.datetime.now()
last_ack = datetime.datetime.now()
logits = model(a, b) for step in range(TRAIN_STEPS):
loss_fn = F.binary_cross_entropy_with_logits(logits, y) a, b, y = sample_batch(TRAIN_BATCHSZ)
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE)
opt.zero_grad() logits = model(a, b)
loss_fn.backward() loss_fn = F.binary_cross_entropy_with_logits(logits, y)
opt.step()
opt.zero_grad()
loss_fn.backward()
opt.step()
if step <= TRAIN_PROGRESS and DO_VERBOSE_EARLY_TRAIN is True:
with torch.no_grad(): with torch.no_grad():
pred = (torch.sigmoid(logits) > 0.5).float() pred = (torch.sigmoid(logits) > 0.5).float()
acc = (pred == y).float().mean().item() acc = (pred == y).float().mean().item()
log.info(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}") tlog.write(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}\n")
elif step % TRAIN_PROGRESS == 0:
with torch.no_grad():
pred = (torch.sigmoid(logits) > 0.5).float()
acc = (pred == y).float().mean().item()
log.info(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}")
# 5) Quick test: evaluate final model accuracy on fresh pairs # also print to normal text log occasionally to show some activity.
# every 10 steps check if its been longer than 5 seconds since we've updated the user
if step % 10 == 0:
if (datetime.datetime.now() - last_ack).total_seconds() > 5:
log.info(f"still training... step={step} of {TRAIN_STEPS}")
last_ack = datetime.datetime.now()
training_end_time = datetime.datetime.now()
log.info(f"Training steps complete. Start time: {training_start_time} End time: {training_end_time}")
# evaluate final model accuracy on fresh pairs
with torch.no_grad(): with torch.no_grad():
a, b, y = sample_batch(TRAIN_BATCHSZ) a, b, y = sample_batch(TRAIN_BATCHSZ*4)
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE) a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE)
logits = model(a, b) logits = model(a, b)
pred = (torch.sigmoid(logits) > 0.5).float() pred = (torch.sigmoid(logits) > 0.5).float()
@@ -124,12 +202,14 @@ def training_entry():
log.info(f"Final test acc: {acc} errors: {errors}") log.info(f"Final test acc: {acc} errors: {errors}")
# embed model dimensions into the model serialization # embed model dimensions into the model serialization
torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS}, MODEL_PATH) torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS, "h": HIDDEN_NEURONS}, MODEL_PATH)
log.info(f"Saved PyTorch Model State to {MODEL_PATH}") log.info(f"Saved PyTorch Model State to {MODEL_PATH}")
def infer_entry(): def infer_entry():
get_torch_info()
model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE) model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
model = comp_nn.PairwiseComparator(d=model_ckpt["d"]).to(DEVICE) model = comp_nn.PairwiseComparator(d=model_ckpt["d"], hidden=model_ckpt["h"]).to(DEVICE)
model.load_state_dict(model_ckpt["state_dict"]) model.load_state_dict(model_ckpt["state_dict"])
model.eval() model.eval()
@@ -149,41 +229,95 @@ def infer_entry():
for (x, y), p in zip(pairs, probs): for (x, y), p in zip(pairs, probs):
log.info(f"P({x} > {y}) = {p.item():.3f}") log.info(f"P({x} > {y}) = {p.item():.3f}")
def graphs_entry():
get_torch_info()
model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
model = comp_nn.PairwiseComparator(d=model_ckpt["d"], hidden=model_ckpt["h"]).to(DEVICE)
model.load_state_dict(model_ckpt["state_dict"])
model.eval()
plt_embeddings(model) plt_embeddings(model)
data = parse_training_log(TRAINING_LOG_PATH)
plt_loss_tstep(data)
plt_acc_tstep(data)
help_text = r"""
pairwise_compare.py — tiny pairwise "a > b?" neural comparator
USAGE
python3 pairwise_compare.py train
Train a PairwiseComparator on synthetic (a,b) pairs sampled uniformly from
[BATCH_LOWER, BATCH_UPPER]. Labels are:
1.0 if a > b + epsi
0.0 if a < b - epsi
0.5 otherwise (near-equality window)
Writes training metrics to:
./files/training.log.xz
Saves the trained model checkpoint to:
./files/pwcomp.model
python3 pairwise_compare.py infer
Load ./files/pwcomp.model and run inference on a built-in list of test pairs.
Prints probabilities as:
P(a > b) = sigmoid(model(a,b))
python3 pairwise_compare.py graphs
Load ./files/pwcomp.model and generate plots + exports:
./files/embedding_chart.png (embed(x) vs x for each embedding dimension)
./files/embedding_data.csv (x and embedding vectors)
./files/training_loss_v_step.png
./files/training_error_v_step.png (1 - acc, log scale)
Requires that ./files/training.log.xz exists (i.e., you ran "train" first).
FILES
./files/output.log General runtime log (info/errors)
./files/pwcomp.model Torch checkpoint: {"state_dict": ..., "d": DIMENSIONS}
./files/training.log.xz step/loss/acc trace used for plots
NOTES
- DEVICE is chosen via torch.accelerator if available, else CPU.
- Hyperparameters are controlled by the "Valves" constants near the top.
"""
if __name__ == '__main__': if __name__ == '__main__':
import sys import sys
import os import os
import datetime import datetime
# TODO: tidy up the paths to files and checking if the directory exists
if not os.path.exists("./files/"): if not os.path.exists("./files/"):
os.mkdir("./files") os.mkdir("./files")
log = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO,
logging.basicConfig(filename=LOGGING_PATH, level=logging.INFO) format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(LOGGING_PATH),
logging.StreamHandler(stream=sys.stdout)
])
log.info(f"Log opened {datetime.datetime.now()}") log = logging.getLogger(__name__)
get_torch_info() log.info(f"Log file {LOGGING_PATH} opened {datetime.datetime.now()}")
name = os.path.basename(sys.argv[0]) # python3 pairwise_compare.py train
if name == 'train.py': # python3 pairwise_compare.py infer
training_entry() # python3 pairwise_compare.py graphs
elif name == 'infer.py': if len(sys.argv) > 1:
infer_entry() match sys.argv[1].strip().lower():
else: case "train":
# alt call patern
# python3 pairwise_compare.py train
# python3 pairwise_compare.py infer
if len(sys.argv) > 1:
mode = sys.argv[1].strip().lower()
if mode == "train":
training_entry() training_entry()
elif mode == "infer": case "infer":
infer_entry() infer_entry()
else: case "graphs":
graphs_entry()
case "help":
log.info(help_text)
case mode:
log.error(f"Unknown operation: {mode}") log.error(f"Unknown operation: {mode}")
log.error("Invalid call syntax, call script as \"train.py\" or \"infer.py\" or as pairwise_compare.py <mode> where mode is \"train\" or \"infer\"") log.error("valid options are one of [\"train\", \"infer\", \"graphs\", \"help\"]")
else: log.info(help_text)
log.error("Not enough arguments passed to script; call as train.py or infer.py or as pairwise_compare.py <mode> where mode is \"train\" or \"infer\"")
log.info(f"Log closed {datetime.datetime.now()}") log.info(f"Log closed {datetime.datetime.now()}")

View File

@@ -1 +0,0 @@
pairwise_compare.py