changes to figure sizes and add labels to embedding chart

This commit is contained in:
2025-12-23 13:44:38 -05:00
parent 755161c152
commit cd72cd7052

View File

@@ -53,7 +53,7 @@ def parse_training_log(file_path: str) -> pd.DataFrame:
# TODO: Move plotting into its own file # TODO: Move plotting into its own file
def plt_loss_tstep(df: pd.DataFrame) -> None: def plt_loss_tstep(df: pd.DataFrame) -> None:
# Plot 1: Loss # Plot 1: Loss
plt.figure(figsize=(8, 4)) plt.figure(figsize=(10, 6))
plt.plot(df["step"], df["loss_clamped"]) plt.plot(df["step"], df["loss_clamped"])
plt.yscale("log") plt.yscale("log")
plt.xlabel("Step") plt.xlabel("Step")
@@ -69,7 +69,7 @@ def plt_loss_tstep(df: pd.DataFrame) -> None:
def plt_acc_tstep(df: pd.DataFrame, eps=1e-10) -> None: def plt_acc_tstep(df: pd.DataFrame, eps=1e-10) -> None:
# Plot 2: Accuracy # Plot 2: Accuracy
df["err"] = (1.0 - df["acc"]).clip(lower=eps) df["err"] = (1.0 - df["acc"]).clip(lower=eps)
plt.figure(figsize=(8, 4)) plt.figure(figsize=(10, 6))
plt.plot(df["step"], df["err"]) plt.plot(df["step"], df["err"])
plt.yscale("log") plt.yscale("log")
plt.xlabel("Step") plt.xlabel("Step")
@@ -100,10 +100,13 @@ def plt_embeddings(model: comp_nn.PairwiseComparator) -> None:
embeddings = embeddings.cpu() embeddings = embeddings.cpu()
xs = xs.cpu() xs = xs.cpu()
# Plot 3: x vs h(x)
plt.figure(figsize=(10, 6))
for i in range(embeddings.shape[1]): for i in range(embeddings.shape[1]):
plt.plot(xs.squeeze(), embeddings[:, i], label=f"dim {i}") plt.plot(xs.squeeze(), embeddings[:, i], label=f"dim {i}")
plt.title("x vs h(x)")
plt.legend() plt.xlabel("x [input]")
plt.ylabel("h(x) [embedding]")
plt.savefig(EMBED_CHART_PATH) plt.savefig(EMBED_CHART_PATH)
plt.close() plt.close()