diff --git a/pairwise_comp_nn.py b/pairwise_comp_nn.py index f33df5f..e03184d 100644 --- a/pairwise_comp_nn.py +++ b/pairwise_comp_nn.py @@ -16,7 +16,7 @@ class NumberEmbedder(nn.Module): # 3) Comparator head: takes (ea, eb, e) -> logit for "a > b" class PairwiseComparator(nn.Module): - def __init__(self, d=4, hidden=16, k=1.0): + def __init__(self, d=4, hidden=16, k=0.5): super().__init__() self.log_k = nn.Parameter(torch.tensor([k])) self.embed = NumberEmbedder(d, hidden)