default k=0.5

This commit is contained in:
2025-12-23 10:14:08 -05:00
parent 6e31865a84
commit 9ea8ef3458

View File

@@ -16,7 +16,7 @@ class NumberEmbedder(nn.Module):
# 3) Comparator head: takes (ea, eb, e) -> logit for "a > b"
class PairwiseComparator(nn.Module):
def __init__(self, d=4, hidden=16, k=1.0):
def __init__(self, d=4, hidden=16, k=0.5):
super().__init__()
self.log_k = nn.Parameter(torch.tensor([k]))
self.embed = NumberEmbedder(d, hidden)