added episilon for equality check
major layout changes in the network
This commit is contained in:
@@ -13,12 +13,12 @@ import pairwise_comp_nn as comp_nn
|
||||
DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else "cpu"
|
||||
|
||||
# Valves
|
||||
DIMENSIONS = 1
|
||||
TRAIN_STEPS = 10000
|
||||
DIMENSIONS = 2
|
||||
TRAIN_STEPS = 5000
|
||||
TRAIN_BATCHSZ = 8192
|
||||
TRAIN_PROGRESS = 100
|
||||
BATCH_LOWER = -512.0
|
||||
BATCH_UPPER = 512.0
|
||||
TRAIN_PROGRESS = 10
|
||||
BATCH_LOWER = -100.0
|
||||
BATCH_UPPER = 100.0
|
||||
DO_VERBOSE_EARLY_TRAIN = False
|
||||
MODEL_PATH = "./files/pwcomp.model"
|
||||
LOGGING_PATH = "./files/output.log"
|
||||
@@ -46,17 +46,15 @@ def plt_embeddings(model: comp_nn.PairwiseComparator):
|
||||
|
||||
for i in range(embeddings.shape[1]):
|
||||
plt.plot(xs.squeeze(), embeddings[:, i], label=f"dim {i}")
|
||||
plt.legend()
|
||||
plt.savefig(EMBED_CHART_PATH)
|
||||
plt.show()
|
||||
|
||||
|
||||
plt.legend()
|
||||
plt.savefig(EMBED_CHART_PATH)
|
||||
#plt.show()
|
||||
csv_data = list(zip(xs.squeeze().tolist(), embeddings.tolist()))
|
||||
with open(file=EMBEDDINGS_DATA, mode="w", newline='') as f:
|
||||
csv_file = csv.writer(f)
|
||||
csv_file.writerows(csv_data)
|
||||
|
||||
|
||||
|
||||
def get_torch_info():
|
||||
log.info("PyTorch Version: %s", torch.__version__)
|
||||
log.info("HIP Version: %s", torch.version.hip)
|
||||
@@ -78,13 +76,10 @@ def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER):
|
||||
a = (high - low) * torch.rand(batch_size, 1) + low
|
||||
b = (high - low) * torch.rand(batch_size, 1) + low
|
||||
|
||||
# train for if a > b
|
||||
y = (a > b).float()
|
||||
epsi = 1e-4
|
||||
y = torch.where(a > b + epsi, 1.0,
|
||||
torch.where(a < b - epsi, 0.0, 0.5))
|
||||
|
||||
# removed but left for my notes; it seems training for equality hurts classifing results that are ~eq
|
||||
# when trained only on "if a > b => y", the model produces more accurate results when classifing if things are equal (~.5 prob).
|
||||
# eq = (a == b).float()
|
||||
# y = gt + 0.5 * eq
|
||||
return a, b, y
|
||||
|
||||
def training_entry():
|
||||
@@ -150,7 +145,7 @@ def infer_entry():
|
||||
with torch.no_grad():
|
||||
probs = torch.sigmoid(model(a, b))
|
||||
|
||||
log.info(f"Output probabilities for {pairs.count} pairs")
|
||||
log.info(f"Output probabilities for {pairs.__len__()} pairs")
|
||||
for (x, y), p in zip(pairs, probs):
|
||||
log.info(f"P({x} > {y}) = {p.item():.3f}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user