37 lines
1.1 KiB
Python
37 lines
1.1 KiB
Python
import torch
|
||
from torch import nn
|
||
|
||
# 2) Number "embedding" network: R -> R^d
|
||
class NumberEmbedder(nn.Module):
|
||
def __init__(self, d=2, hidden=4):
|
||
super().__init__()
|
||
self.net = nn.Sequential(
|
||
nn.Linear(1, hidden),
|
||
nn.ReLU(),
|
||
nn.Linear(hidden, d),
|
||
)
|
||
|
||
def forward(self, x):
|
||
return self.net(x)
|
||
|
||
# 3) Comparator head: takes (ea, eb, e) -> logit for "a > b"
|
||
class PairwiseComparator(nn.Module):
|
||
def __init__(self, d=2, hidden=4, k=0.5):
|
||
super().__init__()
|
||
self.log_k = nn.Parameter(torch.tensor([k]))
|
||
self.embed = NumberEmbedder(d, hidden)
|
||
self.head = nn.Sequential(
|
||
nn.Linear(d, hidden),
|
||
nn.ReLU(),
|
||
nn.Linear(hidden, hidden),
|
||
nn.ReLU(),
|
||
nn.Linear(hidden, 1),
|
||
)
|
||
|
||
def forward(self, a, b):
|
||
# trying to force antisym here: h(a,b)=−h(b,a)
|
||
phi = self.head(self.embed(a-b))
|
||
phi_neg = self.head(self.embed(b-a))
|
||
logit = phi - phi_neg
|
||
|
||
return (self.log_k ** 2) * logit |