|
|
|
|
@@ -2,11 +2,15 @@
|
|
|
|
|
# pairwise_compare.py
|
|
|
|
|
import logging
|
|
|
|
|
import random
|
|
|
|
|
import lzma
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from torch import nn
|
|
|
|
|
from torch.nn import functional as F
|
|
|
|
|
|
|
|
|
|
import re
|
|
|
|
|
import pandas as pd
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
|
import pairwise_comp_nn as comp_nn
|
|
|
|
|
|
|
|
|
|
# early pytorch device setup
|
|
|
|
|
@@ -14,19 +18,73 @@ DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_availab
|
|
|
|
|
|
|
|
|
|
# Valves
|
|
|
|
|
DIMENSIONS = 2
|
|
|
|
|
TRAIN_STEPS = 5000
|
|
|
|
|
HIDDEN_NEURONS = 4
|
|
|
|
|
ADAMW_LR = 5e-3
|
|
|
|
|
ADAMW_DECAY = 5e-4
|
|
|
|
|
TRAIN_STEPS = 2000
|
|
|
|
|
TRAIN_BATCHSZ = 8192
|
|
|
|
|
TRAIN_PROGRESS = 10
|
|
|
|
|
BATCH_LOWER = -100.0
|
|
|
|
|
BATCH_UPPER = 100.0
|
|
|
|
|
DO_VERBOSE_EARLY_TRAIN = False
|
|
|
|
|
|
|
|
|
|
# Files
|
|
|
|
|
MODEL_PATH = "./files/pwcomp.model"
|
|
|
|
|
LOGGING_PATH = "./files/output.log"
|
|
|
|
|
EMBED_CHART_PATH = "./files/embedding_chart.png"
|
|
|
|
|
EMBEDDINGS_DATA = "./files/embedding_data.csv"
|
|
|
|
|
EMBEDDINGS_DATA_PATH = "./files/embedding_data.csv"
|
|
|
|
|
TRAINING_LOG_PATH = "./files/training.log.xz"
|
|
|
|
|
LOSS_CHART_PATH = "./files/training_loss_v_step.png"
|
|
|
|
|
ACC_CHART_PATH = "./files/training_error_v_step.png"
|
|
|
|
|
|
|
|
|
|
def plt_embeddings(model: comp_nn.PairwiseComparator):
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
# TODO: Move plotting into its own file
|
|
|
|
|
def parse_training_log(file_path: str) -> pd.DataFrame:
|
|
|
|
|
text: str = ""
|
|
|
|
|
with lzma.open(file_path, mode='rt') as f:
|
|
|
|
|
text = f.read()
|
|
|
|
|
|
|
|
|
|
pattern = re.compile(r"step=\s*(\d+)\s+loss=([0-9.]+)\s+acc=([0-9.]+)")
|
|
|
|
|
rows = [(int(s), float(l), float(a)) for s, l, a in pattern.findall(text)]
|
|
|
|
|
df = pd.DataFrame(rows, columns=["step", "loss", "acc"]).sort_values("step").reset_index(drop=True)
|
|
|
|
|
|
|
|
|
|
# Avoid log(0) issues for loss plot by clamping at a tiny positive value
|
|
|
|
|
eps = 1e-10
|
|
|
|
|
df["loss_clamped"] = df["loss"].clip(lower=eps)
|
|
|
|
|
|
|
|
|
|
return df
|
|
|
|
|
|
|
|
|
|
# TODO: Move plotting into its own file
|
|
|
|
|
def plt_loss_tstep(df: pd.DataFrame) -> None:
|
|
|
|
|
# Plot 1: Loss
|
|
|
|
|
plt.figure(figsize=(10, 6))
|
|
|
|
|
plt.plot(df["step"], df["loss_clamped"])
|
|
|
|
|
plt.yscale("log")
|
|
|
|
|
plt.xlabel("Step")
|
|
|
|
|
plt.ylabel("Loss (log scale)")
|
|
|
|
|
plt.title("Training Loss vs Step")
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
plt.savefig(LOSS_CHART_PATH)
|
|
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# TODO: Move plotting into its own file
|
|
|
|
|
def plt_acc_tstep(df: pd.DataFrame, eps=1e-10) -> None:
|
|
|
|
|
# Plot 2: Accuracy
|
|
|
|
|
df["err"] = (1.0 - df["acc"]).clip(lower=eps)
|
|
|
|
|
plt.figure(figsize=(10, 6))
|
|
|
|
|
plt.plot(df["step"], df["err"])
|
|
|
|
|
plt.yscale("log")
|
|
|
|
|
plt.xlabel("Step")
|
|
|
|
|
plt.ylabel("Error rate (1 - accuracy) (log scale)")
|
|
|
|
|
plt.title("Training Error Rate vs Step")
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
plt.savefig(ACC_CHART_PATH)
|
|
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# TODO: Move plotting into its own file
|
|
|
|
|
def plt_embeddings(model: comp_nn.PairwiseComparator) -> None:
|
|
|
|
|
import csv
|
|
|
|
|
|
|
|
|
|
log.info("Starting embeddings sweep...")
|
|
|
|
|
@@ -35,7 +93,7 @@ def plt_embeddings(model: comp_nn.PairwiseComparator):
|
|
|
|
|
xs = torch.arange(
|
|
|
|
|
BATCH_LOWER,
|
|
|
|
|
BATCH_UPPER + 1.0,
|
|
|
|
|
1.0,
|
|
|
|
|
0.1,
|
|
|
|
|
).unsqueeze(1).to(DEVICE) # shape: (N, 1)
|
|
|
|
|
|
|
|
|
|
embeddings = model.embed(xs) # shape: (N, d)
|
|
|
|
|
@@ -44,18 +102,26 @@ def plt_embeddings(model: comp_nn.PairwiseComparator):
|
|
|
|
|
embeddings = embeddings.cpu()
|
|
|
|
|
xs = xs.cpu()
|
|
|
|
|
|
|
|
|
|
# Plot 3: x vs h(x)
|
|
|
|
|
plt.figure(figsize=(10, 6))
|
|
|
|
|
for i in range(embeddings.shape[1]):
|
|
|
|
|
plt.plot(xs.squeeze(), embeddings[:, i], label=f"dim {i}")
|
|
|
|
|
|
|
|
|
|
plt.title("x vs h(x)")
|
|
|
|
|
plt.xlabel("x [input]")
|
|
|
|
|
plt.ylabel("h(x) [embedding]")
|
|
|
|
|
plt.legend()
|
|
|
|
|
plt.savefig(EMBED_CHART_PATH)
|
|
|
|
|
#plt.show()
|
|
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
|
# save all our embeddings data to csv
|
|
|
|
|
csv_data = list(zip(xs.squeeze().tolist(), embeddings.tolist()))
|
|
|
|
|
with open(file=EMBEDDINGS_DATA, mode="w", newline='') as f:
|
|
|
|
|
with open(file=EMBEDDINGS_DATA_PATH, mode="w", newline='') as f:
|
|
|
|
|
csv_file = csv.writer(f)
|
|
|
|
|
csv_file.writerows(csv_data)
|
|
|
|
|
|
|
|
|
|
def get_torch_info():
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def get_torch_info() -> None:
|
|
|
|
|
log.info("PyTorch Version: %s", torch.__version__)
|
|
|
|
|
log.info("HIP Version: %s", torch.version.hip)
|
|
|
|
|
log.info("CUDA support: %s", torch.cuda.is_available())
|
|
|
|
|
@@ -65,14 +131,15 @@ def get_torch_info():
|
|
|
|
|
|
|
|
|
|
log.info("Using %s compute mode", DEVICE)
|
|
|
|
|
|
|
|
|
|
def set_seed(seed: int):
|
|
|
|
|
def set_seed(seed: int) -> None:
|
|
|
|
|
random.seed(seed)
|
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
|
|
|
|
# 1) Data: pairs (a, b) with label y = 1 if a > b else 0
|
|
|
|
|
def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER):
|
|
|
|
|
# pairs (a, b) with label y = 1 if a > b else 0 -> (a,b,y)
|
|
|
|
|
# uses epsi to select the window in which a == b for equality training
|
|
|
|
|
def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER, epsi=1e-4) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
|
|
a = (high - low) * torch.rand(batch_size, 1) + low
|
|
|
|
|
b = (high - low) * torch.rand(batch_size, 1) + low
|
|
|
|
|
|
|
|
|
|
@@ -83,39 +150,50 @@ def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER):
|
|
|
|
|
return a, b, y
|
|
|
|
|
|
|
|
|
|
def training_entry():
|
|
|
|
|
get_torch_info()
|
|
|
|
|
|
|
|
|
|
# all prng seeds to 0 for deterministic outputs durring testing
|
|
|
|
|
# the seed should initialized normally otherwise
|
|
|
|
|
set_seed(0)
|
|
|
|
|
|
|
|
|
|
model = comp_nn.PairwiseComparator(d=DIMENSIONS).to(DEVICE)
|
|
|
|
|
opt = torch.optim.AdamW(model.parameters(), lr=9e-4, weight_decay=1e-3)
|
|
|
|
|
model = comp_nn.PairwiseComparator(d=DIMENSIONS, hidden=HIDDEN_NEURONS).to(DEVICE)
|
|
|
|
|
opt = torch.optim.AdamW(model.parameters(), lr=ADAMW_LR, weight_decay=ADAMW_DECAY)
|
|
|
|
|
|
|
|
|
|
# 4) Train
|
|
|
|
|
for step in range(TRAIN_STEPS):
|
|
|
|
|
a, b, y = sample_batch(TRAIN_BATCHSZ)
|
|
|
|
|
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE)
|
|
|
|
|
log.info(f"Using {TRAINING_LOG_PATH} as the logging destination for training...")
|
|
|
|
|
with lzma.open(TRAINING_LOG_PATH, mode='wt') as tlog:
|
|
|
|
|
# training loop
|
|
|
|
|
training_start_time = datetime.datetime.now()
|
|
|
|
|
last_ack = datetime.datetime.now()
|
|
|
|
|
|
|
|
|
|
logits = model(a, b)
|
|
|
|
|
loss_fn = F.binary_cross_entropy_with_logits(logits, y)
|
|
|
|
|
for step in range(TRAIN_STEPS):
|
|
|
|
|
a, b, y = sample_batch(TRAIN_BATCHSZ)
|
|
|
|
|
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE)
|
|
|
|
|
|
|
|
|
|
opt.zero_grad()
|
|
|
|
|
loss_fn.backward()
|
|
|
|
|
opt.step()
|
|
|
|
|
logits = model(a, b)
|
|
|
|
|
loss_fn = F.binary_cross_entropy_with_logits(logits, y)
|
|
|
|
|
|
|
|
|
|
opt.zero_grad()
|
|
|
|
|
loss_fn.backward()
|
|
|
|
|
opt.step()
|
|
|
|
|
|
|
|
|
|
if step <= TRAIN_PROGRESS and DO_VERBOSE_EARLY_TRAIN is True:
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
pred = (torch.sigmoid(logits) > 0.5).float()
|
|
|
|
|
acc = (pred == y).float().mean().item()
|
|
|
|
|
log.info(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}")
|
|
|
|
|
elif step % TRAIN_PROGRESS == 0:
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
pred = (torch.sigmoid(logits) > 0.5).float()
|
|
|
|
|
acc = (pred == y).float().mean().item()
|
|
|
|
|
log.info(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}")
|
|
|
|
|
tlog.write(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}\n")
|
|
|
|
|
|
|
|
|
|
# 5) Quick test: evaluate final model accuracy on fresh pairs
|
|
|
|
|
# also print to normal text log occasionally to show some activity.
|
|
|
|
|
# every 10 steps check if its been longer than 5 seconds since we've updated the user
|
|
|
|
|
if step % 10 == 0:
|
|
|
|
|
if (datetime.datetime.now() - last_ack).total_seconds() > 5:
|
|
|
|
|
log.info(f"still training... step={step} of {TRAIN_STEPS}")
|
|
|
|
|
last_ack = datetime.datetime.now()
|
|
|
|
|
|
|
|
|
|
training_end_time = datetime.datetime.now()
|
|
|
|
|
log.info(f"Training steps complete. Start time: {training_start_time} End time: {training_end_time}")
|
|
|
|
|
|
|
|
|
|
# evaluate final model accuracy on fresh pairs
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
a, b, y = sample_batch(TRAIN_BATCHSZ)
|
|
|
|
|
a, b, y = sample_batch(TRAIN_BATCHSZ*4)
|
|
|
|
|
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE)
|
|
|
|
|
logits = model(a, b)
|
|
|
|
|
pred = (torch.sigmoid(logits) > 0.5).float()
|
|
|
|
|
@@ -124,12 +202,14 @@ def training_entry():
|
|
|
|
|
log.info(f"Final test acc: {acc} errors: {errors}")
|
|
|
|
|
|
|
|
|
|
# embed model dimensions into the model serialization
|
|
|
|
|
torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS}, MODEL_PATH)
|
|
|
|
|
torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS, "h": HIDDEN_NEURONS}, MODEL_PATH)
|
|
|
|
|
log.info(f"Saved PyTorch Model State to {MODEL_PATH}")
|
|
|
|
|
|
|
|
|
|
def infer_entry():
|
|
|
|
|
get_torch_info()
|
|
|
|
|
|
|
|
|
|
model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
|
|
|
|
|
model = comp_nn.PairwiseComparator(d=model_ckpt["d"]).to(DEVICE)
|
|
|
|
|
model = comp_nn.PairwiseComparator(d=model_ckpt["d"], hidden=model_ckpt["h"]).to(DEVICE)
|
|
|
|
|
model.load_state_dict(model_ckpt["state_dict"])
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
|
@@ -149,41 +229,95 @@ def infer_entry():
|
|
|
|
|
for (x, y), p in zip(pairs, probs):
|
|
|
|
|
log.info(f"P({x} > {y}) = {p.item():.3f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def graphs_entry():
|
|
|
|
|
get_torch_info()
|
|
|
|
|
|
|
|
|
|
model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
|
|
|
|
|
model = comp_nn.PairwiseComparator(d=model_ckpt["d"], hidden=model_ckpt["h"]).to(DEVICE)
|
|
|
|
|
model.load_state_dict(model_ckpt["state_dict"])
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
|
plt_embeddings(model)
|
|
|
|
|
|
|
|
|
|
data = parse_training_log(TRAINING_LOG_PATH)
|
|
|
|
|
plt_loss_tstep(data)
|
|
|
|
|
plt_acc_tstep(data)
|
|
|
|
|
|
|
|
|
|
help_text = r"""
|
|
|
|
|
pairwise_compare.py — tiny pairwise "a > b?" neural comparator
|
|
|
|
|
|
|
|
|
|
USAGE
|
|
|
|
|
python3 pairwise_compare.py train
|
|
|
|
|
Train a PairwiseComparator on synthetic (a,b) pairs sampled uniformly from
|
|
|
|
|
[BATCH_LOWER, BATCH_UPPER]. Labels are:
|
|
|
|
|
1.0 if a > b + epsi
|
|
|
|
|
0.0 if a < b - epsi
|
|
|
|
|
0.5 otherwise (near-equality window)
|
|
|
|
|
Writes training metrics to:
|
|
|
|
|
./files/training.log.xz
|
|
|
|
|
Saves the trained model checkpoint to:
|
|
|
|
|
./files/pwcomp.model
|
|
|
|
|
|
|
|
|
|
python3 pairwise_compare.py infer
|
|
|
|
|
Load ./files/pwcomp.model and run inference on a built-in list of test pairs.
|
|
|
|
|
Prints probabilities as:
|
|
|
|
|
P(a > b) = sigmoid(model(a,b))
|
|
|
|
|
|
|
|
|
|
python3 pairwise_compare.py graphs
|
|
|
|
|
Load ./files/pwcomp.model and generate plots + exports:
|
|
|
|
|
./files/embedding_chart.png (embed(x) vs x for each embedding dimension)
|
|
|
|
|
./files/embedding_data.csv (x and embedding vectors)
|
|
|
|
|
./files/training_loss_v_step.png
|
|
|
|
|
./files/training_error_v_step.png (1 - acc, log scale)
|
|
|
|
|
Requires that ./files/training.log.xz exists (i.e., you ran "train" first).
|
|
|
|
|
|
|
|
|
|
FILES
|
|
|
|
|
./files/output.log General runtime log (info/errors)
|
|
|
|
|
./files/pwcomp.model Torch checkpoint: {"state_dict": ..., "d": DIMENSIONS}
|
|
|
|
|
./files/training.log.xz step/loss/acc trace used for plots
|
|
|
|
|
|
|
|
|
|
NOTES
|
|
|
|
|
- DEVICE is chosen via torch.accelerator if available, else CPU.
|
|
|
|
|
- Hyperparameters are controlled by the "Valves" constants near the top.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
import sys
|
|
|
|
|
import os
|
|
|
|
|
import datetime
|
|
|
|
|
|
|
|
|
|
# TODO: tidy up the paths to files and checking if the directory exists
|
|
|
|
|
if not os.path.exists("./files/"):
|
|
|
|
|
os.mkdir("./files")
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO,
|
|
|
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
|
|
|
handlers=[
|
|
|
|
|
logging.FileHandler(LOGGING_PATH),
|
|
|
|
|
logging.StreamHandler(stream=sys.stdout)
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
logging.basicConfig(filename=LOGGING_PATH, level=logging.INFO)
|
|
|
|
|
log.info(f"Log file {LOGGING_PATH} opened {datetime.datetime.now()}")
|
|
|
|
|
|
|
|
|
|
log.info(f"Log opened {datetime.datetime.now()}")
|
|
|
|
|
get_torch_info()
|
|
|
|
|
|
|
|
|
|
name = os.path.basename(sys.argv[0])
|
|
|
|
|
if name == 'train.py':
|
|
|
|
|
training_entry()
|
|
|
|
|
elif name == 'infer.py':
|
|
|
|
|
infer_entry()
|
|
|
|
|
else:
|
|
|
|
|
# alt call patern
|
|
|
|
|
# python3 pairwise_compare.py train
|
|
|
|
|
# python3 pairwise_compare.py infer
|
|
|
|
|
if len(sys.argv) > 1:
|
|
|
|
|
mode = sys.argv[1].strip().lower()
|
|
|
|
|
if mode == "train":
|
|
|
|
|
# python3 pairwise_compare.py train
|
|
|
|
|
# python3 pairwise_compare.py infer
|
|
|
|
|
# python3 pairwise_compare.py graphs
|
|
|
|
|
if len(sys.argv) > 1:
|
|
|
|
|
match sys.argv[1].strip().lower():
|
|
|
|
|
case "train":
|
|
|
|
|
training_entry()
|
|
|
|
|
elif mode == "infer":
|
|
|
|
|
case "infer":
|
|
|
|
|
infer_entry()
|
|
|
|
|
else:
|
|
|
|
|
case "graphs":
|
|
|
|
|
graphs_entry()
|
|
|
|
|
case "help":
|
|
|
|
|
log.info(help_text)
|
|
|
|
|
case mode:
|
|
|
|
|
log.error(f"Unknown operation: {mode}")
|
|
|
|
|
log.error("Invalid call syntax, call script as \"train.py\" or \"infer.py\" or as pairwise_compare.py <mode> where mode is \"train\" or \"infer\"")
|
|
|
|
|
else:
|
|
|
|
|
log.error("Not enough arguments passed to script; call as train.py or infer.py or as pairwise_compare.py <mode> where mode is \"train\" or \"infer\"")
|
|
|
|
|
log.error("valid options are one of [\"train\", \"infer\", \"graphs\", \"help\"]")
|
|
|
|
|
log.info(help_text)
|
|
|
|
|
|
|
|
|
|
log.info(f"Log closed {datetime.datetime.now()}")
|