diff --git a/pairwise_compare.py b/pairwise_compare.py index 83f30d8..b96041b 100755 --- a/pairwise_compare.py +++ b/pairwise_compare.py @@ -2,6 +2,7 @@ # pairwise_compare.py import logging import random + import torch from torch import nn from torch.nn import functional as F @@ -21,6 +22,40 @@ BATCH_UPPER = 512.0 DO_VERBOSE_EARLY_TRAIN = False MODEL_PATH = "./files/pwcomp.model" LOGGING_PATH = "./files/output.log" +EMBED_CHART_PATH = "./files/embedding_chart.png" +EMBEDDINGS_DATA = "./files/embedding_data.csv" + +def plt_embeddings(model: comp_nn.PairwiseComparator): + import matplotlib.pyplot as plt + import csv + + log.info("Starting embeddings sweep...") + # samples for embedding mapping + with torch.no_grad(): + xs = torch.arange( + BATCH_LOWER, + BATCH_UPPER + 1.0, + 1.0, + ).unsqueeze(1).to(DEVICE) # shape: (N, 1) + + embeddings = model.embed(xs) # shape: (N, d) + + # move data back to CPU for plotting + embeddings = embeddings.cpu() + xs = xs.cpu() + + for i in range(embeddings.shape[1]): + plt.plot(xs.squeeze(), embeddings[:, i], label=f"dim {i}") + plt.legend() + plt.savefig(EMBED_CHART_PATH) + plt.show() + + csv_data = list(zip(xs.squeeze().tolist(), embeddings.tolist())) + with open(file=EMBEDDINGS_DATA, mode="w", newline='') as f: + csv_file = csv.writer(f) + csv_file.writerows(csv_data) + + def get_torch_info(): log.info("PyTorch Version: %s", torch.__version__) @@ -115,9 +150,12 @@ def infer_entry(): with torch.no_grad(): probs = torch.sigmoid(model(a, b)) + log.info(f"Output probabilities for {pairs.count} pairs") for (x, y), p in zip(pairs, probs): log.info(f"P({x} > {y}) = {p.item():.3f}") + plt_embeddings(model) + if __name__ == '__main__': import sys import os