diff --git a/pairwise_compare.py b/pairwise_compare.py index b542b96..99eb1e2 100755 --- a/pairwise_compare.py +++ b/pairwise_compare.py @@ -5,7 +5,6 @@ import random import lzma import torch -from torch import nn from torch.nn import functional as F import re @@ -32,7 +31,10 @@ LOGGING_PATH = "./files/output.log" EMBED_CHART_PATH = "./files/embedding_chart.png" EMBEDDINGS_DATA_PATH = "./files/embedding_data.csv" TRAINING_LOG_PATH = "./files/training.log.xz" +LOSS_CHART_PATH = "./files/training_loss_v_step.png" +ACC_CHART_PATH = "./files/training_error_v_step.png" +# TODO: Move plotting into its own file def parse_training_log(file_path: str) -> pd.DataFrame: text: str = "" with lzma.open(file_path, mode='rt') as f: @@ -49,6 +51,7 @@ def parse_training_log(file_path: str) -> pd.DataFrame: return df +# TODO: Move plotting into its own file def plt_loss_tstep(df: pd.DataFrame) -> None: # Plot 1: Loss plt.figure(figsize=(8, 4)) @@ -58,11 +61,12 @@ def plt_loss_tstep(df: pd.DataFrame) -> None: plt.ylabel("Loss (log scale)") plt.title("Training Loss vs Step") plt.tight_layout() - plt.savefig('./files/training_loss_v_step.png') + plt.savefig(LOSS_CHART_PATH) plt.close() return None +# TODO: Move plotting into its own file def plt_acc_tstep(df: pd.DataFrame, eps=1e-10) -> None: # Plot 2: Accuracy df["err"] = (1.0 - df["acc"]).clip(lower=eps) @@ -73,11 +77,12 @@ def plt_acc_tstep(df: pd.DataFrame, eps=1e-10) -> None: plt.ylabel("Error rate (1 - accuracy) (log scale)") plt.title("Training Error Rate vs Step") plt.tight_layout() - plt.savefig('./files/training_error_v_step.png') + plt.savefig(ACC_CHART_PATH) plt.close() return None +# TODO: Move plotting into its own file def plt_embeddings(model: comp_nn.PairwiseComparator) -> None: import csv @@ -127,7 +132,8 @@ def set_seed(seed: int) -> None: if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) -# Data: pairs (a, b) with label y = 1 if a > b else 0 -> (a,b,y) +# pairs (a, b) with label y = 1 if a > b else 0 -> (a,b,y) +# uses epsi to select the window in which a == b for equality training def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER, epsi=1e-4) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: a = (high - low) * torch.rand(batch_size, 1) + low b = (high - low) * torch.rand(batch_size, 1) + low @@ -139,6 +145,8 @@ def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER, epsi=1e-4) return a, b, y def training_entry(): + get_torch_info() + # all prng seeds to 0 for deterministic outputs durring testing # the seed should initialized normally otherwise set_seed(0) @@ -189,6 +197,8 @@ def training_entry(): log.info(f"Saved PyTorch Model State to {MODEL_PATH}") def infer_entry(): + get_torch_info() + model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE) model = comp_nn.PairwiseComparator(d=model_ckpt["d"]).to(DEVICE) model.load_state_dict(model_ckpt["state_dict"]) @@ -210,46 +220,95 @@ def infer_entry(): for (x, y), p in zip(pairs, probs): log.info(f"P({x} > {y}) = {p.item():.3f}") + + +def graphs_entry(): + get_torch_info() + + model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE) + model = comp_nn.PairwiseComparator(d=model_ckpt["d"]).to(DEVICE) + model.load_state_dict(model_ckpt["state_dict"]) + model.eval() + plt_embeddings(model) + data = parse_training_log(TRAINING_LOG_PATH) + plt_loss_tstep(data) + plt_acc_tstep(data) + +help_text = r""" +pairwise_compare.py — tiny pairwise "a > b?" neural comparator + +USAGE + python3 pairwise_compare.py train + Train a PairwiseComparator on synthetic (a,b) pairs sampled uniformly from + [BATCH_LOWER, BATCH_UPPER]. Labels are: + 1.0 if a > b + epsi + 0.0 if a < b - epsi + 0.5 otherwise (near-equality window) + Writes training metrics to: + ./files/training.log.xz + Saves the trained model checkpoint to: + ./files/pwcomp.model + + python3 pairwise_compare.py infer + Load ./files/pwcomp.model and run inference on a built-in list of test pairs. + Prints probabilities as: + P(a > b) = sigmoid(model(a,b)) + + python3 pairwise_compare.py graphs + Load ./files/pwcomp.model and generate plots + exports: + ./files/embedding_chart.png (embed(x) vs x for each embedding dimension) + ./files/embedding_data.csv (x and embedding vectors) + ./files/training_loss_v_step.png + ./files/training_error_v_step.png (1 - acc, log scale) + Requires that ./files/training.log.xz exists (i.e., you ran "train" first). + +FILES + ./files/output.log General runtime log (info/errors) + ./files/pwcomp.model Torch checkpoint: {"state_dict": ..., "d": DIMENSIONS} + ./files/training.log.xz step/loss/acc trace used for plots + +NOTES + - DEVICE is chosen via torch.accelerator if available, else CPU. + - Hyperparameters are controlled by the "Valves" constants near the top. +""" + if __name__ == '__main__': import sys import os import datetime + # TODO: tidy up the paths to files and checking if the directory exists if not os.path.exists("./files/"): os.mkdir("./files") - log = logging.getLogger(__name__) - logging.basicConfig(filename=LOGGING_PATH, level=logging.INFO) + logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler(LOGGING_PATH), + logging.StreamHandler(stream=sys.stdout) + ]) - log.info(f"Log opened {datetime.datetime.now()}") - get_torch_info() + log = logging.getLogger(__name__) + log.info(f"Log file {LOGGING_PATH} opened {datetime.datetime.now()}") - name = os.path.basename(sys.argv[0]) - if name == 'train.py': - training_entry() - elif name == 'infer.py': - infer_entry() - else: - # alt call patern - # python3 pairwise_compare.py train - # python3 pairwise_compare.py infer - # python3 pairwise_compare.py graphs - if len(sys.argv) > 1: - mode = sys.argv[1].strip().lower() - if mode == "train": + # python3 pairwise_compare.py train + # python3 pairwise_compare.py infer + # python3 pairwise_compare.py graphs + if len(sys.argv) > 1: + match sys.argv[1].strip().lower(): + case "train": training_entry() - elif mode == "infer": + case "infer": infer_entry() - elif mode == "graphs": - data = parse_training_log(TRAINING_LOG_PATH) - plt_loss_tstep(data) - plt_acc_tstep(data) - else: + case "graphs": + graphs_entry() + case "help": + log.info(help_text) + case mode: log.error(f"Unknown operation: {mode}") - log.error("Invalid call syntax, call script as \"train.py\" or \"infer.py\" or as pairwise_compare.py where mode is \"train\" or \"infer\"") - else: - log.error("Not enough arguments passed to script; call as train.py or infer.py or as pairwise_compare.py where mode is \"train\" or \"infer\"") + log.error("valid options are one of [\"train\", \"infer\", \"graphs\", \"help\"]") + log.info(help_text) log.info(f"Log closed {datetime.datetime.now()}") \ No newline at end of file