output plots and csv data for e(x) over the trained range.

This commit is contained in:
2025-12-22 12:23:59 -05:00
parent c3fbc44a34
commit 997303028e

View File

@@ -2,6 +2,7 @@
# pairwise_compare.py # pairwise_compare.py
import logging import logging
import random import random
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
@@ -21,6 +22,40 @@ BATCH_UPPER = 512.0
DO_VERBOSE_EARLY_TRAIN = False DO_VERBOSE_EARLY_TRAIN = False
MODEL_PATH = "./files/pwcomp.model" MODEL_PATH = "./files/pwcomp.model"
LOGGING_PATH = "./files/output.log" LOGGING_PATH = "./files/output.log"
EMBED_CHART_PATH = "./files/embedding_chart.png"
EMBEDDINGS_DATA = "./files/embedding_data.csv"
def plt_embeddings(model: comp_nn.PairwiseComparator):
import matplotlib.pyplot as plt
import csv
log.info("Starting embeddings sweep...")
# samples for embedding mapping
with torch.no_grad():
xs = torch.arange(
BATCH_LOWER,
BATCH_UPPER + 1.0,
1.0,
).unsqueeze(1).to(DEVICE) # shape: (N, 1)
embeddings = model.embed(xs) # shape: (N, d)
# move data back to CPU for plotting
embeddings = embeddings.cpu()
xs = xs.cpu()
for i in range(embeddings.shape[1]):
plt.plot(xs.squeeze(), embeddings[:, i], label=f"dim {i}")
plt.legend()
plt.savefig(EMBED_CHART_PATH)
plt.show()
csv_data = list(zip(xs.squeeze().tolist(), embeddings.tolist()))
with open(file=EMBEDDINGS_DATA, mode="w", newline='') as f:
csv_file = csv.writer(f)
csv_file.writerows(csv_data)
def get_torch_info(): def get_torch_info():
log.info("PyTorch Version: %s", torch.__version__) log.info("PyTorch Version: %s", torch.__version__)
@@ -115,9 +150,12 @@ def infer_entry():
with torch.no_grad(): with torch.no_grad():
probs = torch.sigmoid(model(a, b)) probs = torch.sigmoid(model(a, b))
log.info(f"Output probabilities for {pairs.count} pairs")
for (x, y), p in zip(pairs, probs): for (x, y), p in zip(pairs, probs):
log.info(f"P({x} > {y}) = {p.item():.3f}") log.info(f"P({x} > {y}) = {p.item():.3f}")
plt_embeddings(model)
if __name__ == '__main__': if __name__ == '__main__':
import sys import sys
import os import os