embeded # of hidden neurons in model save data

added more valves for options.
will update the console every 5 seconds when training now
This commit is contained in:
2025-12-25 01:52:19 +00:00
parent d46712ff53
commit 921e24b451

View File

@@ -18,7 +18,10 @@ DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_availab
# Valves # Valves
DIMENSIONS = 2 DIMENSIONS = 2
TRAIN_STEPS = 5000 HIDDEN_NEURONS = 4
ADAMW_LR = 5e-3
ADAMW_WiDECAY = 5e-4
TRAIN_STEPS = 2000
TRAIN_BATCHSZ = 8192 TRAIN_BATCHSZ = 8192
TRAIN_PROGRESS = 10 TRAIN_PROGRESS = 10
BATCH_LOWER = -100.0 BATCH_LOWER = -100.0
@@ -39,7 +42,6 @@ def parse_training_log(file_path: str) -> pd.DataFrame:
with lzma.open(file_path, mode='rt') as f: with lzma.open(file_path, mode='rt') as f:
text = f.read() text = f.read()
pattern = re.compile(r"step=\s*(\d+)\s+loss=([0-9.]+)\s+acc=([0-9.]+)") pattern = re.compile(r"step=\s*(\d+)\s+loss=([0-9.]+)\s+acc=([0-9.]+)")
rows = [(int(s), float(l), float(a)) for s, l, a in pattern.findall(text)] rows = [(int(s), float(l), float(a)) for s, l, a in pattern.findall(text)]
df = pd.DataFrame(rows, columns=["step", "loss", "acc"]).sort_values("step").reset_index(drop=True) df = pd.DataFrame(rows, columns=["step", "loss", "acc"]).sort_values("step").reset_index(drop=True)
@@ -107,6 +109,7 @@ def plt_embeddings(model: comp_nn.PairwiseComparator) -> None:
plt.title("x vs h(x)") plt.title("x vs h(x)")
plt.xlabel("x [input]") plt.xlabel("x [input]")
plt.ylabel("h(x) [embedding]") plt.ylabel("h(x) [embedding]")
plt.legend()
plt.savefig(EMBED_CHART_PATH) plt.savefig(EMBED_CHART_PATH)
plt.close() plt.close()
@@ -153,13 +156,15 @@ def training_entry():
# the seed should initialized normally otherwise # the seed should initialized normally otherwise
set_seed(0) set_seed(0)
model = comp_nn.PairwiseComparator(d=DIMENSIONS).to(DEVICE) model = comp_nn.PairwiseComparator(d=DIMENSIONS, hidden=HIDDEN_NEURONS).to(DEVICE)
opt = torch.optim.AdamW(model.parameters(), lr=8e-4, weight_decay=1e-3) opt = torch.optim.AdamW(model.parameters(), lr=ADAMW_LR, weight_decay=ADAMW_WiDECAY)
log.info(f"Using {TRAINING_LOG_PATH} as the logging destination for training...") log.info(f"Using {TRAINING_LOG_PATH} as the logging destination for training...")
with lzma.open(TRAINING_LOG_PATH, mode='wt') as tlog: with lzma.open(TRAINING_LOG_PATH, mode='wt') as tlog:
# training loop # training loop
training_start_time = datetime.datetime.now() training_start_time = datetime.datetime.now()
last_ack = datetime.datetime.now()
for step in range(TRAIN_STEPS): for step in range(TRAIN_STEPS):
a, b, y = sample_batch(TRAIN_BATCHSZ) a, b, y = sample_batch(TRAIN_BATCHSZ)
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE) a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE)
@@ -178,8 +183,11 @@ def training_entry():
tlog.write(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}\n") tlog.write(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}\n")
# also print to normal text log occasionally to show some activity. # also print to normal text log occasionally to show some activity.
if step % 2500 == 0: # every 100 steps check if its been longer than 5 seconds since we've updated the user
if step % 100 == 0:
if (datetime.datetime.now() - last_ack).total_seconds() > 5:
log.info(f"still training... step={step} of {TRAIN_STEPS}") log.info(f"still training... step={step} of {TRAIN_STEPS}")
last_ack = datetime.datetime.now()
training_end_time = datetime.datetime.now() training_end_time = datetime.datetime.now()
log.info(f"Training steps complete. Start time: {training_start_time} End time: {training_end_time}") log.info(f"Training steps complete. Start time: {training_start_time} End time: {training_end_time}")
@@ -195,14 +203,14 @@ def training_entry():
log.info(f"Final test acc: {acc} errors: {errors}") log.info(f"Final test acc: {acc} errors: {errors}")
# embed model dimensions into the model serialization # embed model dimensions into the model serialization
torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS}, MODEL_PATH) torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS, "h": HIDDEN_NEURONS}, MODEL_PATH)
log.info(f"Saved PyTorch Model State to {MODEL_PATH}") log.info(f"Saved PyTorch Model State to {MODEL_PATH}")
def infer_entry(): def infer_entry():
get_torch_info() get_torch_info()
model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE) model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
model = comp_nn.PairwiseComparator(d=model_ckpt["d"]).to(DEVICE) model = comp_nn.PairwiseComparator(d=model_ckpt["d"], hidden=model_ckpt["h"]).to(DEVICE)
model.load_state_dict(model_ckpt["state_dict"]) model.load_state_dict(model_ckpt["state_dict"])
model.eval() model.eval()
@@ -228,7 +236,7 @@ def graphs_entry():
get_torch_info() get_torch_info()
model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE) model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
model = comp_nn.PairwiseComparator(d=model_ckpt["d"]).to(DEVICE) model = comp_nn.PairwiseComparator(d=model_ckpt["d"], hidden=model_ckpt["h"]).to(DEVICE)
model.load_state_dict(model_ckpt["state_dict"]) model.load_state_dict(model_ckpt["state_dict"])
model.eval() model.eval()