output plots and csv data for e(x) over the trained range.
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
# pairwise_compare.py
|
# pairwise_compare.py
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
@@ -21,6 +22,40 @@ BATCH_UPPER = 512.0
|
|||||||
DO_VERBOSE_EARLY_TRAIN = False
|
DO_VERBOSE_EARLY_TRAIN = False
|
||||||
MODEL_PATH = "./files/pwcomp.model"
|
MODEL_PATH = "./files/pwcomp.model"
|
||||||
LOGGING_PATH = "./files/output.log"
|
LOGGING_PATH = "./files/output.log"
|
||||||
|
EMBED_CHART_PATH = "./files/embedding_chart.png"
|
||||||
|
EMBEDDINGS_DATA = "./files/embedding_data.csv"
|
||||||
|
|
||||||
|
def plt_embeddings(model: comp_nn.PairwiseComparator):
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import csv
|
||||||
|
|
||||||
|
log.info("Starting embeddings sweep...")
|
||||||
|
# samples for embedding mapping
|
||||||
|
with torch.no_grad():
|
||||||
|
xs = torch.arange(
|
||||||
|
BATCH_LOWER,
|
||||||
|
BATCH_UPPER + 1.0,
|
||||||
|
1.0,
|
||||||
|
).unsqueeze(1).to(DEVICE) # shape: (N, 1)
|
||||||
|
|
||||||
|
embeddings = model.embed(xs) # shape: (N, d)
|
||||||
|
|
||||||
|
# move data back to CPU for plotting
|
||||||
|
embeddings = embeddings.cpu()
|
||||||
|
xs = xs.cpu()
|
||||||
|
|
||||||
|
for i in range(embeddings.shape[1]):
|
||||||
|
plt.plot(xs.squeeze(), embeddings[:, i], label=f"dim {i}")
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig(EMBED_CHART_PATH)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
csv_data = list(zip(xs.squeeze().tolist(), embeddings.tolist()))
|
||||||
|
with open(file=EMBEDDINGS_DATA, mode="w", newline='') as f:
|
||||||
|
csv_file = csv.writer(f)
|
||||||
|
csv_file.writerows(csv_data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_torch_info():
|
def get_torch_info():
|
||||||
log.info("PyTorch Version: %s", torch.__version__)
|
log.info("PyTorch Version: %s", torch.__version__)
|
||||||
@@ -115,9 +150,12 @@ def infer_entry():
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
probs = torch.sigmoid(model(a, b))
|
probs = torch.sigmoid(model(a, b))
|
||||||
|
|
||||||
|
log.info(f"Output probabilities for {pairs.count} pairs")
|
||||||
for (x, y), p in zip(pairs, probs):
|
for (x, y), p in zip(pairs, probs):
|
||||||
log.info(f"P({x} > {y}) = {p.item():.3f}")
|
log.info(f"P({x} > {y}) = {p.item():.3f}")
|
||||||
|
|
||||||
|
plt_embeddings(model)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
|||||||
Reference in New Issue
Block a user