changes to figure sizes and add labels to embedding chart

This commit is contained in:
2025-12-23 13:44:38 -05:00
parent 755161c152
commit cd72cd7052

View File

@@ -53,7 +53,7 @@ def parse_training_log(file_path: str) -> pd.DataFrame:
# TODO: Move plotting into its own file
def plt_loss_tstep(df: pd.DataFrame) -> None:
# Plot 1: Loss
plt.figure(figsize=(8, 4))
plt.figure(figsize=(10, 6))
plt.plot(df["step"], df["loss_clamped"])
plt.yscale("log")
plt.xlabel("Step")
@@ -69,7 +69,7 @@ def plt_loss_tstep(df: pd.DataFrame) -> None:
def plt_acc_tstep(df: pd.DataFrame, eps=1e-10) -> None:
# Plot 2: Accuracy
df["err"] = (1.0 - df["acc"]).clip(lower=eps)
plt.figure(figsize=(8, 4))
plt.figure(figsize=(10, 6))
plt.plot(df["step"], df["err"])
plt.yscale("log")
plt.xlabel("Step")
@@ -99,11 +99,14 @@ def plt_embeddings(model: comp_nn.PairwiseComparator) -> None:
# move data back to CPU for plotting
embeddings = embeddings.cpu()
xs = xs.cpu()
# Plot 3: x vs h(x)
plt.figure(figsize=(10, 6))
for i in range(embeddings.shape[1]):
plt.plot(xs.squeeze(), embeddings[:, i], label=f"dim {i}")
plt.legend()
plt.title("x vs h(x)")
plt.xlabel("x [input]")
plt.ylabel("h(x) [embedding]")
plt.savefig(EMBED_CHART_PATH)
plt.close()