output plots and csv data for e(x) over the trained range.
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
# pairwise_compare.py
|
||||
import logging
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
@@ -21,6 +22,40 @@ BATCH_UPPER = 512.0
|
||||
DO_VERBOSE_EARLY_TRAIN = False
|
||||
MODEL_PATH = "./files/pwcomp.model"
|
||||
LOGGING_PATH = "./files/output.log"
|
||||
EMBED_CHART_PATH = "./files/embedding_chart.png"
|
||||
EMBEDDINGS_DATA = "./files/embedding_data.csv"
|
||||
|
||||
def plt_embeddings(model: comp_nn.PairwiseComparator):
|
||||
import matplotlib.pyplot as plt
|
||||
import csv
|
||||
|
||||
log.info("Starting embeddings sweep...")
|
||||
# samples for embedding mapping
|
||||
with torch.no_grad():
|
||||
xs = torch.arange(
|
||||
BATCH_LOWER,
|
||||
BATCH_UPPER + 1.0,
|
||||
1.0,
|
||||
).unsqueeze(1).to(DEVICE) # shape: (N, 1)
|
||||
|
||||
embeddings = model.embed(xs) # shape: (N, d)
|
||||
|
||||
# move data back to CPU for plotting
|
||||
embeddings = embeddings.cpu()
|
||||
xs = xs.cpu()
|
||||
|
||||
for i in range(embeddings.shape[1]):
|
||||
plt.plot(xs.squeeze(), embeddings[:, i], label=f"dim {i}")
|
||||
plt.legend()
|
||||
plt.savefig(EMBED_CHART_PATH)
|
||||
plt.show()
|
||||
|
||||
csv_data = list(zip(xs.squeeze().tolist(), embeddings.tolist()))
|
||||
with open(file=EMBEDDINGS_DATA, mode="w", newline='') as f:
|
||||
csv_file = csv.writer(f)
|
||||
csv_file.writerows(csv_data)
|
||||
|
||||
|
||||
|
||||
def get_torch_info():
|
||||
log.info("PyTorch Version: %s", torch.__version__)
|
||||
@@ -115,9 +150,12 @@ def infer_entry():
|
||||
with torch.no_grad():
|
||||
probs = torch.sigmoid(model(a, b))
|
||||
|
||||
log.info(f"Output probabilities for {pairs.count} pairs")
|
||||
for (x, y), p in zip(pairs, probs):
|
||||
log.info(f"P({x} > {y}) = {p.item():.3f}")
|
||||
|
||||
plt_embeddings(model)
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
import os
|
||||
|
||||
Reference in New Issue
Block a user