changes to figure sizes and add labels to embedding chart
This commit is contained in:
@@ -53,7 +53,7 @@ def parse_training_log(file_path: str) -> pd.DataFrame:
|
|||||||
# TODO: Move plotting into its own file
|
# TODO: Move plotting into its own file
|
||||||
def plt_loss_tstep(df: pd.DataFrame) -> None:
|
def plt_loss_tstep(df: pd.DataFrame) -> None:
|
||||||
# Plot 1: Loss
|
# Plot 1: Loss
|
||||||
plt.figure(figsize=(8, 4))
|
plt.figure(figsize=(10, 6))
|
||||||
plt.plot(df["step"], df["loss_clamped"])
|
plt.plot(df["step"], df["loss_clamped"])
|
||||||
plt.yscale("log")
|
plt.yscale("log")
|
||||||
plt.xlabel("Step")
|
plt.xlabel("Step")
|
||||||
@@ -69,7 +69,7 @@ def plt_loss_tstep(df: pd.DataFrame) -> None:
|
|||||||
def plt_acc_tstep(df: pd.DataFrame, eps=1e-10) -> None:
|
def plt_acc_tstep(df: pd.DataFrame, eps=1e-10) -> None:
|
||||||
# Plot 2: Accuracy
|
# Plot 2: Accuracy
|
||||||
df["err"] = (1.0 - df["acc"]).clip(lower=eps)
|
df["err"] = (1.0 - df["acc"]).clip(lower=eps)
|
||||||
plt.figure(figsize=(8, 4))
|
plt.figure(figsize=(10, 6))
|
||||||
plt.plot(df["step"], df["err"])
|
plt.plot(df["step"], df["err"])
|
||||||
plt.yscale("log")
|
plt.yscale("log")
|
||||||
plt.xlabel("Step")
|
plt.xlabel("Step")
|
||||||
@@ -99,11 +99,14 @@ def plt_embeddings(model: comp_nn.PairwiseComparator) -> None:
|
|||||||
# move data back to CPU for plotting
|
# move data back to CPU for plotting
|
||||||
embeddings = embeddings.cpu()
|
embeddings = embeddings.cpu()
|
||||||
xs = xs.cpu()
|
xs = xs.cpu()
|
||||||
|
|
||||||
|
# Plot 3: x vs h(x)
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
for i in range(embeddings.shape[1]):
|
for i in range(embeddings.shape[1]):
|
||||||
plt.plot(xs.squeeze(), embeddings[:, i], label=f"dim {i}")
|
plt.plot(xs.squeeze(), embeddings[:, i], label=f"dim {i}")
|
||||||
|
plt.title("x vs h(x)")
|
||||||
plt.legend()
|
plt.xlabel("x [input]")
|
||||||
|
plt.ylabel("h(x) [embedding]")
|
||||||
plt.savefig(EMBED_CHART_PATH)
|
plt.savefig(EMBED_CHART_PATH)
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user