Source code for torchkge.models.deep

# -*- coding: utf-8 -*-
"""
Copyright TorchKGE developers
@author: Armand Boschin <aboschin@enst.fr>
"""

from torch import nn, cat

from ..models.interfaces import Model
from ..utils import init_embedding


[docs]class ConvKBModel(Model): """Implementation of ConvKB model detailed in 2018 paper by Nguyen et al.. This class inherits from the :class:`torchkge.models.interfaces.Model` interface. It then has its attributes as well. References ---------- * Nguyen, D. Q., Nguyen, T. D., Nguyen, D. Q., and Phung, D. `A Novel Embed- ding Model for Knowledge Base Completion Based on Convolutional Neural Network. <https://arxiv.org/abs/1712.02121>`_ In Proceedings of the 2018 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (2018), vol. 2, pp. 327–333. Parameters ---------- emb_dim: int Dimension of embedding space. n_filters: int Number of filters used for convolution. n_entities: int Number of entities in the current data set. n_relations: int Number of relations in the current data set. Attributes ---------- ent_emb: torch.nn.Embedding, shape: (n_ent, emb_dim) Embeddings of the entities, initialized with Xavier uniform distribution and then normalized. rel_emb: torch.nn.Embedding, shape: (n_rel, emb_dim) Embeddings of the relations, initialized with Xavier uniform distribution. """ def __init__(self, emb_dim, n_filters, n_entities, n_relations): super().__init__(n_entities, n_relations) self.emb_dim = emb_dim self.ent_emb = init_embedding(self.n_ent, self.emb_dim) self.rel_emb = init_embedding(self.n_rel, self.emb_dim) self.convlayer = nn.Sequential(nn.Conv1d(3, n_filters, 1, stride=1), nn.ReLU()) self.output = nn.Sequential(nn.Linear(emb_dim * n_filters, 2), nn.Softmax(dim=1))
[docs] def scoring_function(self, h_idx, t_idx, r_idx): """Compute the scoring function for the triplets given as argument: by applying convolutions to the concatenation of the embeddings. See referenced paper for more details on the score. See torchkge.models.interfaces.Models for more details on the API. """ b_size = h_idx.shape[0] h = self.ent_emb(h_idx).view(b_size, 1, -1) t = self.ent_emb(t_idx).view(b_size, 1, -1) r = self.rel_emb(r_idx).view(b_size, 1, -1) concat = cat((h, r, t), dim=1) return self.output(self.convlayer(concat).reshape(b_size, -1))[:, 1]
[docs] def normalize_parameters(self): """Normalize the entity embeddings, as explained in original paper. This methods should be called at the end of each training epoch and at the end of training as well. """ pass
[docs] def get_embeddings(self): """Return the embeddings of entities and relations. Returns ------- ent_emb: torch.Tensor, shape: (n_ent, emb_dim), dtype: torch.float Embeddings of entities. rel_emb: torch.Tensor, shape: (n_rel, emb_dim), dtype: torch.float Embeddings of relations. """ self.normalize_parameters() return self.ent_emb.weight.data, self.rel_emb.weight.data
[docs] def inference_scoring_function(self, h, t, r): """Link prediction evaluation helper function. See torchkge.models.interfaces.Models for more details on the API. """ b_size = h.shape[0] if (len(h.shape) == 2) & (len(t.shape) == 4) & (len(r.shape) == 2): concat = cat((h.view(b_size, 1, 1, self.emb_dim).expand(b_size, self.n_ent, 1, self.emb_dim), r.view(b_size, 1, 1, self.emb_dim).expand(b_size, self.n_ent, 1, self.emb_dim), t), dim=2) concat = concat.reshape(-1, 3, self.emb_dim) elif (len(h.shape) == 4) & (len(t.shape) == 2) & (len(r.shape) == 2): concat = cat((h, r.view(b_size, 1, 1, self.emb_dim).expand(b_size, self.n_ent, 1, self.emb_dim), t.view(b_size, 1, 1, self.emb_dim).expand(b_size, self.n_ent, 1, self.emb_dim)), dim=2) concat = concat.reshape(-1, 3, self.emb_dim) else: assert (len(h.shape) == 2) & (len(t.shape) == 2) & (len(r.shape) == 4) concat = cat((h.view(b_size, 1, 1, self.emb_dim).expand(b_size, self.n_rel, 1, self.emb_dim), r, t.view(b_size, 1, 1, self.emb_dim).expand(b_size, self.n_rel, 1, self.emb_dim)), dim=2) concat = concat.reshape(-1, 3, self.emb_dim) scores = self.output(self.convlayer(concat).reshape(concat.shape[0], -1)) scores = scores.reshape(b_size, -1, 2) return scores[:, :, 1]
[docs] def inference_prepare_candidates(self, h_idx, t_idx, r_idx, entities=True): """Link prediction evaluation helper function. Get entities embeddings and relations embeddings. The output will be fed to the `inference_scoring_function` method. See torchkge.models.interfaces.Models for more details on the API. """ b_size = h_idx.shape[0] h = self.ent_emb(h_idx) t = self.ent_emb(t_idx) r = self.rel_emb(r_idx) if entities: candidates = self.ent_emb.weight.data.view(1, self.n_ent, self.emb_dim) candidates = candidates.expand(b_size, self.n_ent, self.emb_dim) candidates = candidates.view(b_size, self.n_ent, 1, self.emb_dim) else: candidates = self.rel_emb.weight.data.view(1, self.n_rel, self.emb_dim) candidates = candidates.expand(b_size, self.n_rel, self.emb_dim) candidates = candidates.view(b_size, self.n_rel, 1, self.emb_dim) return h, t, r, candidates