#!/usr/bin/env python3 # pairwise_compare.py import logging import random import lzma import torch from torch.nn import functional as F import re import pandas as pd import matplotlib.pyplot as plt import pairwise_comp_nn as comp_nn # early pytorch device setup DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else "cpu" # Valves DIMENSIONS = 2 HIDDEN_NEURONS = 4 ADAMW_LR = 5e-3 ADAMW_DECAY = 5e-4 TRAIN_STEPS = 2000 TRAIN_BATCHSZ = 8192 TRAIN_PROGRESS = 10 BATCH_LOWER = -100.0 BATCH_UPPER = 100.0 # Files MODEL_PATH = "./files/pwcomp.model" LOGGING_PATH = "./files/output.log" EMBED_CHART_PATH = "./files/embedding_chart.png" EMBEDDINGS_DATA_PATH = "./files/embedding_data.csv" TRAINING_LOG_PATH = "./files/training.log.xz" LOSS_CHART_PATH = "./files/training_loss_v_step.png" ACC_CHART_PATH = "./files/training_error_v_step.png" # TODO: Move plotting into its own file def parse_training_log(file_path: str) -> pd.DataFrame: text: str = "" with lzma.open(file_path, mode='rt') as f: text = f.read() pattern = re.compile(r"step=\s*(\d+)\s+loss=([0-9.]+)\s+acc=([0-9.]+)") rows = [(int(s), float(l), float(a)) for s, l, a in pattern.findall(text)] df = pd.DataFrame(rows, columns=["step", "loss", "acc"]).sort_values("step").reset_index(drop=True) # Avoid log(0) issues for loss plot by clamping at a tiny positive value eps = 1e-10 df["loss_clamped"] = df["loss"].clip(lower=eps) return df # TODO: Move plotting into its own file def plt_loss_tstep(df: pd.DataFrame) -> None: # Plot 1: Loss plt.figure(figsize=(10, 6)) plt.plot(df["step"], df["loss_clamped"]) plt.yscale("log") plt.xlabel("Step") plt.ylabel("Loss (log scale)") plt.title("Training Loss vs Step") plt.tight_layout() plt.savefig(LOSS_CHART_PATH) plt.close() return None # TODO: Move plotting into its own file def plt_acc_tstep(df: pd.DataFrame, eps=1e-10) -> None: # Plot 2: Accuracy df["err"] = (1.0 - df["acc"]).clip(lower=eps) plt.figure(figsize=(10, 6)) plt.plot(df["step"], df["err"]) plt.yscale("log") plt.xlabel("Step") plt.ylabel("Error rate (1 - accuracy) (log scale)") plt.title("Training Error Rate vs Step") plt.tight_layout() plt.savefig(ACC_CHART_PATH) plt.close() return None # TODO: Move plotting into its own file def plt_embeddings(model: comp_nn.PairwiseComparator) -> None: import csv log.info("Starting embeddings sweep...") # samples for embedding mapping with torch.no_grad(): xs = torch.arange( BATCH_LOWER, BATCH_UPPER + 1.0, 0.1, ).unsqueeze(1).to(DEVICE) # shape: (N, 1) embeddings = model.embed(xs) # shape: (N, d) # move data back to CPU for plotting embeddings = embeddings.cpu() xs = xs.cpu() # Plot 3: x vs h(x) plt.figure(figsize=(10, 6)) for i in range(embeddings.shape[1]): plt.plot(xs.squeeze(), embeddings[:, i], label=f"dim {i}") plt.title("x vs h(x)") plt.xlabel("x [input]") plt.ylabel("h(x) [embedding]") plt.legend() plt.savefig(EMBED_CHART_PATH) plt.close() # save all our embeddings data to csv csv_data = list(zip(xs.squeeze().tolist(), embeddings.tolist())) with open(file=EMBEDDINGS_DATA_PATH, mode="w", newline='') as f: csv_file = csv.writer(f) csv_file.writerows(csv_data) return None def get_torch_info() -> None: log.info("PyTorch Version: %s", torch.__version__) log.info("HIP Version: %s", torch.version.hip) log.info("CUDA support: %s", torch.cuda.is_available()) if torch.cuda.is_available(): log.info("CUDA device detected: %s", torch.cuda.get_device_name(0)) log.info("Using %s compute mode", DEVICE) def set_seed(seed: int) -> None: random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) # pairs (a, b) with label y = 1 if a > b else 0 -> (a,b,y) # uses epsi to select the window in which a == b for equality training def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER, epsi=1e-4) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: a = (high - low) * torch.rand(batch_size, 1) + low b = (high - low) * torch.rand(batch_size, 1) + low epsi = 1e-4 y = torch.where(a > b + epsi, 1.0, torch.where(a < b - epsi, 0.0, 0.5)) return a, b, y def training_entry(): get_torch_info() # all prng seeds to 0 for deterministic outputs durring testing # the seed should initialized normally otherwise set_seed(0) model = comp_nn.PairwiseComparator(d=DIMENSIONS, hidden=HIDDEN_NEURONS).to(DEVICE) opt = torch.optim.AdamW(model.parameters(), lr=ADAMW_LR, weight_decay=ADAMW_DECAY) log.info(f"Using {TRAINING_LOG_PATH} as the logging destination for training...") with lzma.open(TRAINING_LOG_PATH, mode='wt') as tlog: # training loop training_start_time = datetime.datetime.now() last_ack = datetime.datetime.now(datetime.timezone.utc) for step in range(TRAIN_STEPS): a, b, y = sample_batch(TRAIN_BATCHSZ) a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE) logits = model(a, b) loss_fn = F.binary_cross_entropy_with_logits(logits, y) opt.zero_grad() loss_fn.backward() opt.step() with torch.no_grad(): pred = (torch.sigmoid(logits) > 0.5).float() acc = (pred == y).float().mean().item() tlog.write(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}\n") # also print to normal text log occasionally to show some activity. # every 10 steps check if its been longer than 5 seconds since we've updated the user if step % 10 == 0: if (datetime.datetime.now(datetime.timezone.utc) - last_ack).total_seconds() > 5: log.info(f"still training... step={step} of {TRAIN_STEPS}") last_ack = datetime.datetime.now() training_end_time = datetime.datetime.now() log.info(f"Training steps complete. Start time: {training_start_time} End time: {training_end_time}") # evaluate final model accuracy on fresh pairs with torch.no_grad(): a, b, y = sample_batch(TRAIN_BATCHSZ*4) a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE) logits = model(a, b) pred = (torch.sigmoid(logits) > 0.5).float() errors = (pred != y).sum().item() acc = (pred == y).float().mean().item() log.info(f"Final test acc: {acc} errors: {errors}") # embed model dimensions into the model serialization torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS, "h": HIDDEN_NEURONS}, MODEL_PATH) log.info(f"Saved PyTorch Model State to {MODEL_PATH}") def infer_entry(): get_torch_info() model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE) model = comp_nn.PairwiseComparator(d=model_ckpt["d"], hidden=model_ckpt["h"]).to(DEVICE) model.load_state_dict(model_ckpt["state_dict"]) model.eval() # sample pairs pairs = [(1, 2), (10, 3), (5, 5), (10, 35), (-64, 11), (300, 162), (2, 0), (2, 1), (3, 1), (4, 1), (3, 10),(30, 1), (0, 0), (-162, 237), (10, 20), (100, 30), (50, 50), (100, 350), (-640, 110), (30, -420), (200, 0), (92, 5), (30, 17), (42, 10), (30, 100),(30, 1), (0, 400), (-42, -42)] a = torch.tensor([[p[0]] for p in pairs], dtype=torch.float32, device=DEVICE) b = torch.tensor([[p[1]] for p in pairs], dtype=torch.float32, device=DEVICE) # sanity check before inference log.debug(f"a.device: {a.device} model.device: {next(model.parameters()).device}") with torch.no_grad(): probs = torch.sigmoid(model(a, b)) log.info(f"Output probabilities for {pairs.__len__()} pairs") for (x, y), p in zip(pairs, probs): log.info(f"P({x} > {y}) = {p.item():.3f}") def graphs_entry(): get_torch_info() model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE) model = comp_nn.PairwiseComparator(d=model_ckpt["d"], hidden=model_ckpt["h"]).to(DEVICE) model.load_state_dict(model_ckpt["state_dict"]) model.eval() plt_embeddings(model) data = parse_training_log(TRAINING_LOG_PATH) plt_loss_tstep(data) plt_acc_tstep(data) help_text = r""" pairwise_compare.py — tiny pairwise "a > b?" neural comparator USAGE python3 pairwise_compare.py train Train a PairwiseComparator on synthetic (a,b) pairs sampled uniformly from [BATCH_LOWER, BATCH_UPPER]. Labels are: 1.0 if a > b + epsi 0.0 if a < b - epsi 0.5 otherwise (near-equality window) Writes training metrics to: ./files/training.log.xz Saves the trained model checkpoint to: ./files/pwcomp.model python3 pairwise_compare.py infer Load ./files/pwcomp.model and run inference on a built-in list of test pairs. Prints probabilities as: P(a > b) = sigmoid(model(a,b)) python3 pairwise_compare.py graphs Load ./files/pwcomp.model and generate plots + exports: ./files/embedding_chart.png (embed(x) vs x for each embedding dimension) ./files/embedding_data.csv (x and embedding vectors) ./files/training_loss_v_step.png ./files/training_error_v_step.png (1 - acc, log scale) Requires that ./files/training.log.xz exists (i.e., you ran "train" first). FILES ./files/output.log General runtime log (info/errors) ./files/pwcomp.model Torch checkpoint: {"state_dict": ..., "d": DIMENSIONS} ./files/training.log.xz step/loss/acc trace used for plots NOTES - DEVICE is chosen via torch.accelerator if available, else CPU. - Hyperparameters are controlled by the "Valves" constants near the top. """ if __name__ == '__main__': import sys import os import datetime # TODO: tidy up the paths to files and checking if the directory exists if not os.path.exists("./files/"): os.mkdir("./files") logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(LOGGING_PATH), logging.StreamHandler(stream=sys.stdout) ]) log = logging.getLogger(__name__) log.info(f"Log file {LOGGING_PATH} opened {datetime.datetime.now()}") # python3 pairwise_compare.py train # python3 pairwise_compare.py infer # python3 pairwise_compare.py graphs if len(sys.argv) > 1: match sys.argv[1].strip().lower(): case "train": training_entry() case "infer": infer_entry() case "graphs": graphs_entry() case "help": log.info(help_text) case mode: log.error(f"Unknown operation: {mode}") log.error("valid options are one of [\"train\", \"infer\", \"graphs\", \"help\"]") log.info(help_text) log.info(f"Log closed {datetime.datetime.now()}")