added episilon for equality check
major layout changes in the network
This commit is contained in:
@@ -3,32 +3,35 @@ from torch import nn
|
||||
|
||||
# 2) Number "embedding" network: R -> R^d
|
||||
class NumberEmbedder(nn.Module):
|
||||
def __init__(self, d=8):
|
||||
def __init__(self, d=4, hidden=16):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(1, 16),
|
||||
nn.Linear(1, hidden),
|
||||
nn.ReLU(),
|
||||
nn.Linear(16, d),
|
||||
nn.Linear(hidden, d),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# 3) Comparator head: takes (ea, eb) -> logit for "a > b"
|
||||
# 3) Comparator head: takes (ea, eb, e) -> logit for "a > b"
|
||||
class PairwiseComparator(nn.Module):
|
||||
def __init__(self, d=8):
|
||||
def __init__(self, d=4, hidden=16, k=1.0):
|
||||
super().__init__()
|
||||
self.embed = NumberEmbedder(d)
|
||||
self.log_k = nn.Parameter(torch.tensor([k]))
|
||||
self.embed = NumberEmbedder(d, hidden)
|
||||
self.head = nn.Sequential(
|
||||
nn.Linear(2 * d + 1, 16),
|
||||
nn.Linear(d, hidden),
|
||||
nn.ReLU(),
|
||||
nn.Linear(16, 1),
|
||||
nn.Linear(hidden, hidden),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden, 1),
|
||||
)
|
||||
|
||||
def forward(self, a, b):
|
||||
ea = self.embed(a)
|
||||
eb = self.embed(b)
|
||||
delta_ab = a - b
|
||||
x = torch.cat([ea, eb, delta_ab], dim=-1)
|
||||
# trying to force antisym here: h(a,b)=−h(b,a)
|
||||
phi = self.head(self.embed(a-b))
|
||||
phi_neg = self.head(self.embed(b-a))
|
||||
logit = phi - phi_neg
|
||||
|
||||
return self.head(x) # logits
|
||||
return (self.log_k ** 2) * logit
|
||||
Reference in New Issue
Block a user