updated comment for PairwiseComparator

This commit is contained in:
2025-12-25 14:53:35 +00:00
parent acbccebb2c
commit 23fddbe5b9

View File

@@ -14,7 +14,7 @@ class NumberEmbedder(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
# 3) Comparator head: takes (ea, eb, e) -> logit for "a > b" # MLP Comparator head: takes (ea, eb, e) -> logit for "a > b"
class PairwiseComparator(nn.Module): class PairwiseComparator(nn.Module):
def __init__(self, d=2, hidden=4, k=0.5): def __init__(self, d=2, hidden=4, k=0.5):
super().__init__() super().__init__()