|
|
|
|
@@ -18,12 +18,14 @@ DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_availab
|
|
|
|
|
|
|
|
|
|
# Valves
|
|
|
|
|
DIMENSIONS = 2
|
|
|
|
|
TRAIN_STEPS = 5000
|
|
|
|
|
HIDDEN_NEURONS = 4
|
|
|
|
|
ADAMW_LR = 5e-3
|
|
|
|
|
ADAMW_DECAY = 5e-4
|
|
|
|
|
TRAIN_STEPS = 2000
|
|
|
|
|
TRAIN_BATCHSZ = 8192
|
|
|
|
|
TRAIN_PROGRESS = 10
|
|
|
|
|
BATCH_LOWER = -100.0
|
|
|
|
|
BATCH_UPPER = 100.0
|
|
|
|
|
DO_VERBOSE_EARLY_TRAIN = False
|
|
|
|
|
|
|
|
|
|
# Files
|
|
|
|
|
MODEL_PATH = "./files/pwcomp.model"
|
|
|
|
|
@@ -40,7 +42,6 @@ def parse_training_log(file_path: str) -> pd.DataFrame:
|
|
|
|
|
with lzma.open(file_path, mode='rt') as f:
|
|
|
|
|
text = f.read()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pattern = re.compile(r"step=\s*(\d+)\s+loss=([0-9.]+)\s+acc=([0-9.]+)")
|
|
|
|
|
rows = [(int(s), float(l), float(a)) for s, l, a in pattern.findall(text)]
|
|
|
|
|
df = pd.DataFrame(rows, columns=["step", "loss", "acc"]).sort_values("step").reset_index(drop=True)
|
|
|
|
|
@@ -54,7 +55,7 @@ def parse_training_log(file_path: str) -> pd.DataFrame:
|
|
|
|
|
# TODO: Move plotting into its own file
|
|
|
|
|
def plt_loss_tstep(df: pd.DataFrame) -> None:
|
|
|
|
|
# Plot 1: Loss
|
|
|
|
|
plt.figure(figsize=(8, 4))
|
|
|
|
|
plt.figure(figsize=(10, 6))
|
|
|
|
|
plt.plot(df["step"], df["loss_clamped"])
|
|
|
|
|
plt.yscale("log")
|
|
|
|
|
plt.xlabel("Step")
|
|
|
|
|
@@ -70,7 +71,7 @@ def plt_loss_tstep(df: pd.DataFrame) -> None:
|
|
|
|
|
def plt_acc_tstep(df: pd.DataFrame, eps=1e-10) -> None:
|
|
|
|
|
# Plot 2: Accuracy
|
|
|
|
|
df["err"] = (1.0 - df["acc"]).clip(lower=eps)
|
|
|
|
|
plt.figure(figsize=(8, 4))
|
|
|
|
|
plt.figure(figsize=(10, 6))
|
|
|
|
|
plt.plot(df["step"], df["err"])
|
|
|
|
|
plt.yscale("log")
|
|
|
|
|
plt.xlabel("Step")
|
|
|
|
|
@@ -100,10 +101,14 @@ def plt_embeddings(model: comp_nn.PairwiseComparator) -> None:
|
|
|
|
|
# move data back to CPU for plotting
|
|
|
|
|
embeddings = embeddings.cpu()
|
|
|
|
|
xs = xs.cpu()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Plot 3: x vs h(x)
|
|
|
|
|
plt.figure(figsize=(10, 6))
|
|
|
|
|
for i in range(embeddings.shape[1]):
|
|
|
|
|
plt.plot(xs.squeeze(), embeddings[:, i], label=f"dim {i}")
|
|
|
|
|
|
|
|
|
|
plt.title("x vs h(x)")
|
|
|
|
|
plt.xlabel("x [input]")
|
|
|
|
|
plt.ylabel("h(x) [embedding]")
|
|
|
|
|
plt.legend()
|
|
|
|
|
plt.savefig(EMBED_CHART_PATH)
|
|
|
|
|
plt.close()
|
|
|
|
|
@@ -151,13 +156,15 @@ def training_entry():
|
|
|
|
|
# the seed should initialized normally otherwise
|
|
|
|
|
set_seed(0)
|
|
|
|
|
|
|
|
|
|
model = comp_nn.PairwiseComparator(d=DIMENSIONS).to(DEVICE)
|
|
|
|
|
opt = torch.optim.AdamW(model.parameters(), lr=8e-4, weight_decay=1e-3)
|
|
|
|
|
model = comp_nn.PairwiseComparator(d=DIMENSIONS, hidden=HIDDEN_NEURONS).to(DEVICE)
|
|
|
|
|
opt = torch.optim.AdamW(model.parameters(), lr=ADAMW_LR, weight_decay=ADAMW_DECAY)
|
|
|
|
|
|
|
|
|
|
log.info(f"Using {TRAINING_LOG_PATH} as the logging destination for training...")
|
|
|
|
|
with lzma.open(TRAINING_LOG_PATH, mode='wt') as tlog:
|
|
|
|
|
# training loop
|
|
|
|
|
training_start_time = datetime.datetime.now()
|
|
|
|
|
last_ack = datetime.datetime.now()
|
|
|
|
|
|
|
|
|
|
for step in range(TRAIN_STEPS):
|
|
|
|
|
a, b, y = sample_batch(TRAIN_BATCHSZ)
|
|
|
|
|
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE)
|
|
|
|
|
@@ -169,15 +176,17 @@ def training_entry():
|
|
|
|
|
loss_fn.backward()
|
|
|
|
|
opt.step()
|
|
|
|
|
|
|
|
|
|
if step % TRAIN_PROGRESS == 0:
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
pred = (torch.sigmoid(logits) > 0.5).float()
|
|
|
|
|
acc = (pred == y).float().mean().item()
|
|
|
|
|
tlog.write(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}\n")
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
pred = (torch.sigmoid(logits) > 0.5).float()
|
|
|
|
|
acc = (pred == y).float().mean().item()
|
|
|
|
|
tlog.write(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}\n")
|
|
|
|
|
|
|
|
|
|
# also print to normal text log occasionally to show some activity.
|
|
|
|
|
if step % 2500 == 0:
|
|
|
|
|
# also print to normal text log occasionally to show some activity.
|
|
|
|
|
# every 10 steps check if its been longer than 5 seconds since we've updated the user
|
|
|
|
|
if step % 10 == 0:
|
|
|
|
|
if (datetime.datetime.now() - last_ack).total_seconds() > 5:
|
|
|
|
|
log.info(f"still training... step={step} of {TRAIN_STEPS}")
|
|
|
|
|
last_ack = datetime.datetime.now()
|
|
|
|
|
|
|
|
|
|
training_end_time = datetime.datetime.now()
|
|
|
|
|
log.info(f"Training steps complete. Start time: {training_start_time} End time: {training_end_time}")
|
|
|
|
|
@@ -193,14 +202,14 @@ def training_entry():
|
|
|
|
|
log.info(f"Final test acc: {acc} errors: {errors}")
|
|
|
|
|
|
|
|
|
|
# embed model dimensions into the model serialization
|
|
|
|
|
torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS}, MODEL_PATH)
|
|
|
|
|
torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS, "h": HIDDEN_NEURONS}, MODEL_PATH)
|
|
|
|
|
log.info(f"Saved PyTorch Model State to {MODEL_PATH}")
|
|
|
|
|
|
|
|
|
|
def infer_entry():
|
|
|
|
|
get_torch_info()
|
|
|
|
|
|
|
|
|
|
model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
|
|
|
|
|
model = comp_nn.PairwiseComparator(d=model_ckpt["d"]).to(DEVICE)
|
|
|
|
|
model = comp_nn.PairwiseComparator(d=model_ckpt["d"], hidden=model_ckpt["h"]).to(DEVICE)
|
|
|
|
|
model.load_state_dict(model_ckpt["state_dict"])
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
|
@@ -226,7 +235,7 @@ def graphs_entry():
|
|
|
|
|
get_torch_info()
|
|
|
|
|
|
|
|
|
|
model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
|
|
|
|
|
model = comp_nn.PairwiseComparator(d=model_ckpt["d"]).to(DEVICE)
|
|
|
|
|
model = comp_nn.PairwiseComparator(d=model_ckpt["d"], hidden=model_ckpt["h"]).to(DEVICE)
|
|
|
|
|
model.load_state_dict(model_ckpt["state_dict"])
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
|
|