diff --git a/pairwise_comp_nn.py b/pairwise_comp_nn.py index 6bfff5f..84e5299 100644 --- a/pairwise_comp_nn.py +++ b/pairwise_comp_nn.py @@ -29,7 +29,7 @@ class PairwiseComparator(nn.Module): ) def forward(self, a, b): - # trying to force antisym here: h(a,b)=−h(b,a) + # trying to force antisym here: h(a,b)=-h(b,a) phi = self.head(self.embed(a-b)) phi_neg = self.head(self.embed(b-a)) logit = phi - phi_neg