added episilon for equality check

major layout changes in the network
This commit is contained in:
2025-12-22 21:47:16 -05:00
parent 997303028e
commit 6e31865a84
3 changed files with 30 additions and 281 deletions

View File

@@ -13,12 +13,12 @@ import pairwise_comp_nn as comp_nn
DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else "cpu"
# Valves
DIMENSIONS = 1
TRAIN_STEPS = 10000
DIMENSIONS = 2
TRAIN_STEPS = 5000
TRAIN_BATCHSZ = 8192
TRAIN_PROGRESS = 100
BATCH_LOWER = -512.0
BATCH_UPPER = 512.0
TRAIN_PROGRESS = 10
BATCH_LOWER = -100.0
BATCH_UPPER = 100.0
DO_VERBOSE_EARLY_TRAIN = False
MODEL_PATH = "./files/pwcomp.model"
LOGGING_PATH = "./files/output.log"
@@ -46,17 +46,15 @@ def plt_embeddings(model: comp_nn.PairwiseComparator):
for i in range(embeddings.shape[1]):
plt.plot(xs.squeeze(), embeddings[:, i], label=f"dim {i}")
plt.legend()
plt.savefig(EMBED_CHART_PATH)
plt.show()
plt.legend()
plt.savefig(EMBED_CHART_PATH)
#plt.show()
csv_data = list(zip(xs.squeeze().tolist(), embeddings.tolist()))
with open(file=EMBEDDINGS_DATA, mode="w", newline='') as f:
csv_file = csv.writer(f)
csv_file.writerows(csv_data)
def get_torch_info():
log.info("PyTorch Version: %s", torch.__version__)
log.info("HIP Version: %s", torch.version.hip)
@@ -78,13 +76,10 @@ def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER):
a = (high - low) * torch.rand(batch_size, 1) + low
b = (high - low) * torch.rand(batch_size, 1) + low
# train for if a > b
y = (a > b).float()
epsi = 1e-4
y = torch.where(a > b + epsi, 1.0,
torch.where(a < b - epsi, 0.0, 0.5))
# removed but left for my notes; it seems training for equality hurts classifing results that are ~eq
# when trained only on "if a > b => y", the model produces more accurate results when classifing if things are equal (~.5 prob).
# eq = (a == b).float()
# y = gt + 0.5 * eq
return a, b, y
def training_entry():
@@ -150,7 +145,7 @@ def infer_entry():
with torch.no_grad():
probs = torch.sigmoid(model(a, b))
log.info(f"Output probabilities for {pairs.count} pairs")
log.info(f"Output probabilities for {pairs.__len__()} pairs")
for (x, y), p in zip(pairs, probs):
log.info(f"P({x} > {y}) = {p.item():.3f}")