Compare commits
2 Commits
cfef24921d
...
1d70935b64
| Author | SHA1 | Date | |
|---|---|---|---|
| 1d70935b64 | |||
| 0e2098ceec |
@@ -5,7 +5,6 @@ import random
|
||||
import lzma
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
import re
|
||||
@@ -32,7 +31,10 @@ LOGGING_PATH = "./files/output.log"
|
||||
EMBED_CHART_PATH = "./files/embedding_chart.png"
|
||||
EMBEDDINGS_DATA_PATH = "./files/embedding_data.csv"
|
||||
TRAINING_LOG_PATH = "./files/training.log.xz"
|
||||
LOSS_CHART_PATH = "./files/training_loss_v_step.png"
|
||||
ACC_CHART_PATH = "./files/training_error_v_step.png"
|
||||
|
||||
# TODO: Move plotting into its own file
|
||||
def parse_training_log(file_path: str) -> pd.DataFrame:
|
||||
text: str = ""
|
||||
with lzma.open(file_path, mode='rt') as f:
|
||||
@@ -49,6 +51,7 @@ def parse_training_log(file_path: str) -> pd.DataFrame:
|
||||
|
||||
return df
|
||||
|
||||
# TODO: Move plotting into its own file
|
||||
def plt_loss_tstep(df: pd.DataFrame) -> None:
|
||||
# Plot 1: Loss
|
||||
plt.figure(figsize=(8, 4))
|
||||
@@ -58,11 +61,12 @@ def plt_loss_tstep(df: pd.DataFrame) -> None:
|
||||
plt.ylabel("Loss (log scale)")
|
||||
plt.title("Training Loss vs Step")
|
||||
plt.tight_layout()
|
||||
plt.savefig('./files/training_loss_v_step.png')
|
||||
plt.savefig(LOSS_CHART_PATH)
|
||||
plt.close()
|
||||
|
||||
return None
|
||||
|
||||
# TODO: Move plotting into its own file
|
||||
def plt_acc_tstep(df: pd.DataFrame, eps=1e-10) -> None:
|
||||
# Plot 2: Accuracy
|
||||
df["err"] = (1.0 - df["acc"]).clip(lower=eps)
|
||||
@@ -73,11 +77,12 @@ def plt_acc_tstep(df: pd.DataFrame, eps=1e-10) -> None:
|
||||
plt.ylabel("Error rate (1 - accuracy) (log scale)")
|
||||
plt.title("Training Error Rate vs Step")
|
||||
plt.tight_layout()
|
||||
plt.savefig('./files/training_error_v_step.png')
|
||||
plt.savefig(ACC_CHART_PATH)
|
||||
plt.close()
|
||||
|
||||
return None
|
||||
|
||||
# TODO: Move plotting into its own file
|
||||
def plt_embeddings(model: comp_nn.PairwiseComparator) -> None:
|
||||
import csv
|
||||
|
||||
@@ -127,7 +132,8 @@ def set_seed(seed: int) -> None:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
# Data: pairs (a, b) with label y = 1 if a > b else 0 -> (a,b,y)
|
||||
# pairs (a, b) with label y = 1 if a > b else 0 -> (a,b,y)
|
||||
# uses epsi to select the window in which a == b for equality training
|
||||
def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER, epsi=1e-4) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
a = (high - low) * torch.rand(batch_size, 1) + low
|
||||
b = (high - low) * torch.rand(batch_size, 1) + low
|
||||
@@ -139,6 +145,8 @@ def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER, epsi=1e-4)
|
||||
return a, b, y
|
||||
|
||||
def training_entry():
|
||||
get_torch_info()
|
||||
|
||||
# all prng seeds to 0 for deterministic outputs durring testing
|
||||
# the seed should initialized normally otherwise
|
||||
set_seed(0)
|
||||
@@ -189,6 +197,8 @@ def training_entry():
|
||||
log.info(f"Saved PyTorch Model State to {MODEL_PATH}")
|
||||
|
||||
def infer_entry():
|
||||
get_torch_info()
|
||||
|
||||
model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
|
||||
model = comp_nn.PairwiseComparator(d=model_ckpt["d"]).to(DEVICE)
|
||||
model.load_state_dict(model_ckpt["state_dict"])
|
||||
@@ -210,46 +220,95 @@ def infer_entry():
|
||||
for (x, y), p in zip(pairs, probs):
|
||||
log.info(f"P({x} > {y}) = {p.item():.3f}")
|
||||
|
||||
|
||||
|
||||
def graphs_entry():
|
||||
get_torch_info()
|
||||
|
||||
model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
|
||||
model = comp_nn.PairwiseComparator(d=model_ckpt["d"]).to(DEVICE)
|
||||
model.load_state_dict(model_ckpt["state_dict"])
|
||||
model.eval()
|
||||
|
||||
plt_embeddings(model)
|
||||
|
||||
data = parse_training_log(TRAINING_LOG_PATH)
|
||||
plt_loss_tstep(data)
|
||||
plt_acc_tstep(data)
|
||||
|
||||
help_text = r"""
|
||||
pairwise_compare.py — tiny pairwise "a > b?" neural comparator
|
||||
|
||||
USAGE
|
||||
python3 pairwise_compare.py train
|
||||
Train a PairwiseComparator on synthetic (a,b) pairs sampled uniformly from
|
||||
[BATCH_LOWER, BATCH_UPPER]. Labels are:
|
||||
1.0 if a > b + epsi
|
||||
0.0 if a < b - epsi
|
||||
0.5 otherwise (near-equality window)
|
||||
Writes training metrics to:
|
||||
./files/training.log.xz
|
||||
Saves the trained model checkpoint to:
|
||||
./files/pwcomp.model
|
||||
|
||||
python3 pairwise_compare.py infer
|
||||
Load ./files/pwcomp.model and run inference on a built-in list of test pairs.
|
||||
Prints probabilities as:
|
||||
P(a > b) = sigmoid(model(a,b))
|
||||
|
||||
python3 pairwise_compare.py graphs
|
||||
Load ./files/pwcomp.model and generate plots + exports:
|
||||
./files/embedding_chart.png (embed(x) vs x for each embedding dimension)
|
||||
./files/embedding_data.csv (x and embedding vectors)
|
||||
./files/training_loss_v_step.png
|
||||
./files/training_error_v_step.png (1 - acc, log scale)
|
||||
Requires that ./files/training.log.xz exists (i.e., you ran "train" first).
|
||||
|
||||
FILES
|
||||
./files/output.log General runtime log (info/errors)
|
||||
./files/pwcomp.model Torch checkpoint: {"state_dict": ..., "d": DIMENSIONS}
|
||||
./files/training.log.xz step/loss/acc trace used for plots
|
||||
|
||||
NOTES
|
||||
- DEVICE is chosen via torch.accelerator if available, else CPU.
|
||||
- Hyperparameters are controlled by the "Valves" constants near the top.
|
||||
"""
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
import os
|
||||
import datetime
|
||||
|
||||
# TODO: tidy up the paths to files and checking if the directory exists
|
||||
if not os.path.exists("./files/"):
|
||||
os.mkdir("./files")
|
||||
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(LOGGING_PATH),
|
||||
logging.StreamHandler(stream=sys.stdout)
|
||||
])
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logging.basicConfig(filename=LOGGING_PATH, level=logging.INFO)
|
||||
log.info(f"Log file {LOGGING_PATH} opened {datetime.datetime.now()}")
|
||||
|
||||
log.info(f"Log opened {datetime.datetime.now()}")
|
||||
get_torch_info()
|
||||
|
||||
name = os.path.basename(sys.argv[0])
|
||||
if name == 'train.py':
|
||||
training_entry()
|
||||
elif name == 'infer.py':
|
||||
infer_entry()
|
||||
else:
|
||||
# alt call patern
|
||||
# python3 pairwise_compare.py train
|
||||
# python3 pairwise_compare.py infer
|
||||
# python3 pairwise_compare.py graphs
|
||||
if len(sys.argv) > 1:
|
||||
mode = sys.argv[1].strip().lower()
|
||||
if mode == "train":
|
||||
# python3 pairwise_compare.py train
|
||||
# python3 pairwise_compare.py infer
|
||||
# python3 pairwise_compare.py graphs
|
||||
if len(sys.argv) > 1:
|
||||
match sys.argv[1].strip().lower():
|
||||
case "train":
|
||||
training_entry()
|
||||
elif mode == "infer":
|
||||
case "infer":
|
||||
infer_entry()
|
||||
elif mode == "graphs":
|
||||
data = parse_training_log(TRAINING_LOG_PATH)
|
||||
plt_loss_tstep(data)
|
||||
plt_acc_tstep(data)
|
||||
else:
|
||||
case "graphs":
|
||||
graphs_entry()
|
||||
case "help":
|
||||
log.info(help_text)
|
||||
case mode:
|
||||
log.error(f"Unknown operation: {mode}")
|
||||
log.error("Invalid call syntax, call script as \"train.py\" or \"infer.py\" or as pairwise_compare.py <mode> where mode is \"train\" or \"infer\"")
|
||||
else:
|
||||
log.error("Not enough arguments passed to script; call as train.py or infer.py or as pairwise_compare.py <mode> where mode is \"train\" or \"infer\"")
|
||||
log.error("valid options are one of [\"train\", \"infer\", \"graphs\", \"help\"]")
|
||||
log.info(help_text)
|
||||
|
||||
log.info(f"Log closed {datetime.datetime.now()}")
|
||||
Reference in New Issue
Block a user