From 921e24b451bbc8045db2168ed6dfdaaa8e6d67a7 Mon Sep 17 00:00:00 2001 From: Elaina Claus Date: Thu, 25 Dec 2025 01:52:19 +0000 Subject: [PATCH] embeded # of hidden neurons in model save data added more valves for options. will update the console every 5 seconds when training now --- pairwise_compare.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/pairwise_compare.py b/pairwise_compare.py index 2101a22..754cbd8 100755 --- a/pairwise_compare.py +++ b/pairwise_compare.py @@ -18,7 +18,10 @@ DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_availab # Valves DIMENSIONS = 2 -TRAIN_STEPS = 5000 +HIDDEN_NEURONS = 4 +ADAMW_LR = 5e-3 +ADAMW_WiDECAY = 5e-4 +TRAIN_STEPS = 2000 TRAIN_BATCHSZ = 8192 TRAIN_PROGRESS = 10 BATCH_LOWER = -100.0 @@ -39,7 +42,6 @@ def parse_training_log(file_path: str) -> pd.DataFrame: with lzma.open(file_path, mode='rt') as f: text = f.read() - pattern = re.compile(r"step=\s*(\d+)\s+loss=([0-9.]+)\s+acc=([0-9.]+)") rows = [(int(s), float(l), float(a)) for s, l, a in pattern.findall(text)] df = pd.DataFrame(rows, columns=["step", "loss", "acc"]).sort_values("step").reset_index(drop=True) @@ -107,6 +109,7 @@ def plt_embeddings(model: comp_nn.PairwiseComparator) -> None: plt.title("x vs h(x)") plt.xlabel("x [input]") plt.ylabel("h(x) [embedding]") + plt.legend() plt.savefig(EMBED_CHART_PATH) plt.close() @@ -153,13 +156,15 @@ def training_entry(): # the seed should initialized normally otherwise set_seed(0) - model = comp_nn.PairwiseComparator(d=DIMENSIONS).to(DEVICE) - opt = torch.optim.AdamW(model.parameters(), lr=8e-4, weight_decay=1e-3) + model = comp_nn.PairwiseComparator(d=DIMENSIONS, hidden=HIDDEN_NEURONS).to(DEVICE) + opt = torch.optim.AdamW(model.parameters(), lr=ADAMW_LR, weight_decay=ADAMW_WiDECAY) log.info(f"Using {TRAINING_LOG_PATH} as the logging destination for training...") with lzma.open(TRAINING_LOG_PATH, mode='wt') as tlog: # training loop training_start_time = datetime.datetime.now() + last_ack = datetime.datetime.now() + for step in range(TRAIN_STEPS): a, b, y = sample_batch(TRAIN_BATCHSZ) a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE) @@ -178,8 +183,11 @@ def training_entry(): tlog.write(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}\n") # also print to normal text log occasionally to show some activity. - if step % 2500 == 0: - log.info(f"still training... step={step} of {TRAIN_STEPS}") + # every 100 steps check if its been longer than 5 seconds since we've updated the user + if step % 100 == 0: + if (datetime.datetime.now() - last_ack).total_seconds() > 5: + log.info(f"still training... step={step} of {TRAIN_STEPS}") + last_ack = datetime.datetime.now() training_end_time = datetime.datetime.now() log.info(f"Training steps complete. Start time: {training_start_time} End time: {training_end_time}") @@ -195,14 +203,14 @@ def training_entry(): log.info(f"Final test acc: {acc} errors: {errors}") # embed model dimensions into the model serialization - torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS}, MODEL_PATH) + torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS, "h": HIDDEN_NEURONS}, MODEL_PATH) log.info(f"Saved PyTorch Model State to {MODEL_PATH}") def infer_entry(): get_torch_info() model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE) - model = comp_nn.PairwiseComparator(d=model_ckpt["d"]).to(DEVICE) + model = comp_nn.PairwiseComparator(d=model_ckpt["d"], hidden=model_ckpt["h"]).to(DEVICE) model.load_state_dict(model_ckpt["state_dict"]) model.eval() @@ -228,7 +236,7 @@ def graphs_entry(): get_torch_info() model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE) - model = comp_nn.PairwiseComparator(d=model_ckpt["d"]).to(DEVICE) + model = comp_nn.PairwiseComparator(d=model_ckpt["d"], hidden=model_ckpt["h"]).to(DEVICE) model.load_state_dict(model_ckpt["state_dict"]) model.eval()