Source code for torchkge.models.bilinear

# -*- coding: utf-8 -*-
"""
Copyright TorchKGE developers
@author: Armand Boschin <aboschin@enst.fr>
"""

from torch import matmul, cat
from torch.nn.functional import normalize

from ..models.interfaces import BilinearModel
from ..utils import init_embedding


[docs]class RESCALModel(BilinearModel): """Implementation of RESCAL model detailed in 2011 paper by Nickel et al.. In the original paper, optimization is done using Alternating Least Squares (ALS). Here we use iterative gradient descent optimization. This class inherits from the :class:`torchkge.models.interfaces.BilinearModel` interface. It then has its attributes as well. References ---------- * Maximilian Nickel, Volker Tresp, and Hans-Peter Kriegel. `A Three-way Model for Collective Learning on Multi-relational Data. <https://dl.acm.org/citation.cfm?id=3104584>`_ In Proceedings of the 28th International Conference on Machine Learning, 2011. Parameters ---------- emb_dim: int Dimension of embedding space. n_entities: int Number of entities in the current data set. n_relations: int Number of relations in the current data set. Attributes ---------- ent_emb: torch.nn.Embedding, shape: (n_ent, emb_dim) Embeddings of the entities, initialized with Xavier uniform distribution and then normalized. rel_mat: torch.nn.Embedding, shape: (n_rel, emb_dim x emb_dim) Matrices of the relations, initialized with Xavier uniform distribution. """ def __init__(self, emb_dim, n_entities, n_relations): super().__init__(emb_dim, n_entities, n_relations) # initialize embedding objects self.ent_emb = init_embedding(self.n_ent, self.emb_dim) self.rel_mat = init_embedding(self.n_rel, self.emb_dim * self.emb_dim) # normalize the embeddings self.normalize_parameters()
[docs] def scoring_function(self, h_idx, t_idx, r_idx): """Compute the scoring function for the triplets given as argument: :math:`h^T \\cdot M_r \\cdot t`. See referenced paper for more details on the score. See torchkge.models.interfaces.Models for more details on the API. """ h = normalize(self.ent_emb(h_idx), p=2, dim=1) t = normalize(self.ent_emb(t_idx), p=2, dim=1) r = self.rel_mat(r_idx).view(-1, self.emb_dim, self.emb_dim) hr = matmul(h.view(-1, 1, self.emb_dim), r) return (hr.view(-1, self.emb_dim) * t).sum(dim=1)
[docs] def normalize_parameters(self): """Normalize the entity embeddings, as explained in original paper. This methods should be called at the end of each training epoch and at the end of training as well. """ self.ent_emb.weight.data = normalize(self.ent_emb.weight.data, p=2, dim=1)
[docs] def get_embeddings(self): """Return the embeddings of entities and matrices of relations. Returns ------- ent_emb: torch.Tensor, shape: (n_ent, emb_dim), dtype: torch.float Embeddings of entities. rel_mat: torch.Tensor, shape: (n_rel, emb_dim, emb_dim), dtype: torch.float Matrices of relations. """ self.normalize_parameters() return self.ent_emb.weight.data, \ self.rel_mat.weight.data.view(-1, self.emb_dim, self.emb_dim)
[docs] def inference_scoring_function(self, h, t, r): """Link prediction evaluation helper function. See torchkge.models.interfaces.Models for more details of the API. """ b_size = h.shape[0] if len(h.shape) == 3: assert (len(t.shape) == 2) & (len(r.shape) == 3) # this is the head completion case in link prediction tr = matmul(r, t.view(b_size, self.emb_dim, 1)).view(b_size, 1, self.emb_dim) return (h * tr).sum(dim=2) elif len(t.shape) == 3: assert (len(h.shape) == 2) & (len(r.shape) == 3) # this is the tail completion case in link prediction hr = matmul(h.view(b_size, 1, self.emb_dim), r).view(b_size, 1, self.emb_dim) return (hr * t).sum(dim=2) elif len(r.shape) == 4: assert (len(h.shape) == 2) & (len(t.shape) == 2) # this is the relation completion case in link prediction h = h.view(b_size, 1, 1, self.emb_dim) t = t.view(b_size, 1, self.emb_dim) hr = matmul(h, r).view(b_size, self.n_rel, self.emb_dim) return (hr * t).sum(dim=2)
[docs] def inference_prepare_candidates(self, h_idx, t_idx, r_idx, entities=True): """Link prediction evaluation helper function. Get entities embeddings and relations embeddings. The output will be fed to the `inference_scoring_function` method. See torchkge.models.interfaces.Models for more details on the API. """ b_size = h_idx.shape[0] h_emb = self.ent_emb(h_idx) t_emb = self.ent_emb(t_idx) r_mat = self.rel_mat(r_idx).view(-1, self.emb_dim, self.emb_dim) if entities: candidates = self.ent_emb.weight.data.view(1, self.n_ent, self.emb_dim) candidates = candidates.expand(b_size, self.n_ent, self.emb_dim) else: candidates = self.rel_mat.weight.data.view(1, self.n_rel, self.emb_dim, self.emb_dim) candidates = candidates.expand(b_size, self.n_rel, self.emb_dim, self.emb_dim) return h_emb, t_emb, r_mat, candidates
[docs]class DistMultModel(BilinearModel): """Implementation of DistMult model detailed in 2014 paper by Yang et al.. This class inherits from the :class:`torchkge.models.interfaces.BilinearModel` interface. It then has its attributes as well. References ---------- * Bishan Yang, Wen-tau Yih, Xiaodong He, Jianfeng Gao, and Li Deng. `Embedding Entities and Relations for Learning and Inference in Knowledge Bases. <https://arxiv.org/abs/1412.6575>`_ arXiv :1412.6575 [cs], December 2014. Parameters ---------- emb_dim: int Dimension of embedding space. n_entities: int Number of entities in the current data set. n_relations: int Number of relations in the current data set. Attributes ---------- ent_emb: torch.nn.Embedding, shape: (n_ent, emb_dim) Embeddings of the entities, initialized with Xavier uniform distribution and then normalized. rel_emb: torch.nn.Embedding, shape: (n_rel, emb_dim) Embeddings of the relations, initialized with Xavier uniform distribution. """ def __init__(self, emb_dim, n_entities, n_relations): super().__init__(emb_dim, n_entities, n_relations) self.ent_emb = init_embedding(self.n_ent, self.emb_dim) self.rel_emb = init_embedding(self.n_rel, self.emb_dim) self.normalize_parameters()
[docs] def scoring_function(self, h_idx, t_idx, r_idx): """Compute the scoring function for the triplets given as argument: :math:`h^T \\cdot diag(r) \\cdot t`. See referenced paper for more details on the score. See torchkge.models.interfaces.Models for more details on the API. """ h = normalize(self.ent_emb(h_idx), p=2, dim=1) t = normalize(self.ent_emb(t_idx), p=2, dim=1) r = self.rel_emb(r_idx) return (h * r * t).sum(dim=1)
[docs] def normalize_parameters(self): """Normalize the entity embeddings, as explained in original paper. This methods should be called at the end of each training epoch and at the end of training as well. """ self.ent_emb.weight.data = normalize(self.ent_emb.weight.data, p=2, dim=1)
[docs] def get_embeddings(self): """Return the embeddings of entities and relations. Returns ------- ent_emb: torch.Tensor, shape: (n_ent, emb_dim), dtype: torch.float Embeddings of entities. rel_emb: torch.Tensor, shape: (n_rel, emb_dim), dtype: torch.float Embeddings of relations. """ self.normalize_parameters() return self.ent_emb.weight.data, self.rel_emb.weight.data
[docs] def inference_scoring_function(self, h, t, r): """Link prediction evaluation helper function. See torchkge.models.interfaces.Models for more details on the API. """ b_size = h.shape[0] if len(t.shape) == 3: assert (len(h.shape) == 2) & (len(r.shape) == 2) # this is the tail completion case in link prediction hr = (h * r).view(b_size, 1, self.emb_dim) return (hr * t).sum(dim=2) elif len(h.shape) == 3: assert (len(t.shape) == 2) & (len(r.shape) == 2) # this is the head completion case in link prediction rt = (r * t).view(b_size, 1, self.emb_dim) return (h * rt).sum(dim=2) elif len(r.shape) == 3: assert (len(h.shape) == 2) & (len(t.shape) == 2) # this is the relation prediction case hr = (h.view(b_size, 1, self.emb_dim) * r) # hr has shape (b_size, self.n_rel, self.emb_dim) return (hr * t.view(b_size, 1, self.emb_dim)).sum(dim=2)
[docs] def inference_prepare_candidates(self, h_idx, t_idx, r_idx, entities=True): """Link prediction evaluation helper function. Get entities embeddings and relations embeddings. The output will be fed to the `inference_scoring_function` method. See torchkge.models.interfaces.Models for more details on the API. """ b_size = h_idx.shape[0] h = self.ent_emb(h_idx) t = self.ent_emb(t_idx) r = self.rel_emb(r_idx) if entities: candidates = self.ent_emb.weight.data.view(1, self.n_ent, self.emb_dim) candidates = candidates.expand(b_size, self.n_ent, self.emb_dim) else: candidates = self.rel_emb.weight.data.view(1, self.n_rel, self.emb_dim) candidates = candidates.expand(b_size, self.n_rel, self.emb_dim) return h, t, r, candidates
[docs]class HolEModel(BilinearModel): """Implementation of HolE model detailed in 2015 paper by Nickel et al.. This class inherits from the :class:`torchkge.models.interfaces.BilinearModel` interface. It then has its attributes as well. References ---------- * Maximilian Nickel, Lorenzo Rosasco, and Tomaso Poggio. `Holographic Embeddings of Knowledge Graphs. <https://arxiv.org/abs/1510.04935>`_ arXiv :1510.04935 [cs, stat], October 2015. Parameters ---------- emb_dim: int Dimension of embedding space. n_entities: int Number of entities in the current data set. n_relations: int Number of relations in the current data set. Attributes ---------- ent_emb: torch.nn.Embedding, shape: (n_ent, emb_dim) Embeddings of the entities, initialized with Xavier uniform distribution and then normalized. rel_emb: torch.nn.Embedding, shape: (n_rel, emb_dim) Embeddings of the relations, initialized with Xavier uniform distribution. """ def __init__(self, emb_dim, n_entities, n_relations): super().__init__(emb_dim, n_entities, n_relations) self.ent_emb = init_embedding(self.n_ent, self.emb_dim) self.rel_emb = init_embedding(self.n_rel, self.emb_dim) self.normalize_parameters()
[docs] def scoring_function(self, h_idx, t_idx, r_idx): """Compute the scoring function for the triplets given as argument: :math:`h^T \\cdot M_r \\cdot t` where :math:`M_r` is the rolling matrix built from the relation embedding `r`. See referenced paper for more details on the score. See torchkge.models.interfaces.Models for more details on the API. """ h = normalize(self.ent_emb(h_idx), p=2, dim=1) t = normalize(self.ent_emb(t_idx), p=2, dim=1) r = self.get_rolling_matrix(self.rel_emb(r_idx)) hr = matmul(h.view(-1, 1, self.emb_dim), r) return (hr.view(-1, self.emb_dim) * t).sum(dim=1)
[docs] @staticmethod def get_rolling_matrix(x): """Build a rolling matrix. Parameters ---------- x: torch.Tensor, shape: (b_size, dim) Returns ------- mat: torch.Tensor, shape: (b_size, dim, dim) Rolling matrix such that mat[i,j] = x[j - i mod(dim)] """ b_size, dim = x.shape x = x.view(b_size, 1, dim) return cat([x.roll(i, dims=2) for i in range(dim)], dim=1)
[docs] def normalize_parameters(self): """Normalize the entity embeddings, as explained in original paper. This methods should be called at the end of each training epoch and at the end of training as well. """ self.ent_emb.weight.data = normalize(self.ent_emb.weight.data, p=2, dim=1)
[docs] def get_embeddings(self): """Return the embeddings of entities and relations. Returns ------- ent_emb: torch.Tensor, shape: (n_ent, emb_dim), dtype: torch.float Embeddings of entities. rel_emb: torch.Tensor, shape: (n_rel, emb_dim), dtype: torch.float Embeddings of relations. """ self.normalize_parameters() return self.ent_emb.weight.data, self.rel_emb.weight.data
[docs] def inference_scoring_function(self, h, t, r): """Link prediction evaluation helper function. See torchkge.models.interfaces.Models for more details on the API. """ b_size = h.shape[0] if len(t.shape) == 3: assert (len(h.shape) == 2) & (len(r.shape) == 3) # this is the tail completion case in link prediction h = h.view(b_size, 1, self.emb_dim) hr = matmul(h, r).view(b_size, self.emb_dim, 1) return (hr * t.transpose(1, 2)).sum(dim=1) elif len(h.shape) == 3: assert (len(t.shape) == 2) & (len(r.shape) == 3) # this is the head completion case in link prediction t = t.view(b_size, self.emb_dim, 1) return (h.transpose(1, 2) * matmul(r, t)).sum(dim=1) elif len(r.shape) == 4: assert (len(h.shape) == 2) & (len(t.shape) == 2) # this is the relation completion case in link prediction h = h.view(b_size, 1, 1, self.emb_dim) t = t.view(b_size, 1, self.emb_dim) hr = matmul(h, r).view(b_size, self.n_rel, self.emb_dim) return (hr * t).sum(dim=2)
[docs] def inference_prepare_candidates(self, h_idx, t_idx, r_idx, entities=True): """Link prediction evaluation helper function. Get entities embeddings and relations embeddings. The output will be fed to the `inference_scoring_function` method. See torchkge.models.interfaces.Models for more details on the API. """ b_size = h_idx.shape[0] h_emb = self.ent_emb(h_idx) t_emb = self.ent_emb(t_idx) r_mat = self.get_rolling_matrix(self.rel_emb(r_idx)) if entities: candidates = self.ent_emb.weight.data.view(1, self.n_ent, self.emb_dim) candidates = candidates.expand(b_size, self.n_ent, self.emb_dim) else: r_mat = self.get_rolling_matrix(self.rel_emb.weight.data) # TODO: do not recompute for each batch candidates = r_mat.view(1, self.n_rel, self.emb_dim, self.emb_dim) candidates = candidates.expand(b_size, self.n_rel, self.emb_dim, self.emb_dim) return h_emb, t_emb, r_mat, candidates
[docs]class ComplExModel(BilinearModel): """Implementation of ComplEx model detailed in 2016 paper by Trouillon et al.. This class inherits from the :class:`torchkge.models.interfaces.BilinearModel` interface. It then has its attributes as well. References ---------- * Théo Trouillon, Johannes Welbl, Sebastian Riedel, Éric Gaussier, and Guillaume Bouchard. `Complex Embeddings for Simple Link Prediction. <https://arxiv.org/abs/1606.06357>`_ arXiv :1606.06357 [cs, stat], June 2016. Parameters ---------- emb_dim: int Dimension of embedding space. n_entities: int Number of entities in the current data set. n_relations: int Number of relations in the current data set. Attributes ---------- re_ent_emb: torch.nn.Embedding, shape: (n_ent, emb_dim) Real part of the entities complex embeddings. Initialized with Xavier uniform distribution. im_ent_emb: torch.nn.Embedding, shape: (n_ent, emb_dim) Imaginary part of the entities complex embeddings. Initialized with Xavier uniform distribution. re_rel_emb: torch.nn.Embedding, shape: (n_rel, emb_dim) Real part of the relations complex embeddings. Initialized with Xavier uniform distribution. im_rel_emb: torch.nn.Embedding, shape: (n_rel, emb_dim) Imaginary part of the relations complex embeddings. Initialized with Xavier uniform distribution. """ def __init__(self, emb_dim, n_entities, n_relations): super().__init__(emb_dim, n_entities, n_relations) self.re_ent_emb = init_embedding(self.n_ent, self.emb_dim) self.im_ent_emb = init_embedding(self.n_ent, self.emb_dim) self.re_rel_emb = init_embedding(self.n_rel, self.emb_dim) self.im_rel_emb = init_embedding(self.n_rel, self.emb_dim)
[docs] def scoring_function(self, h_idx, t_idx, r_idx): """Compute the real part of the Hermitian product :math:`\\Re(h^T \\cdot diag(r) \\cdot \\bar{t})` for each sample of the batch. See referenced paper for more details on the score. See torchkge.models.interfaces.Models for more details on the API. """ re_h, im_h = self.re_ent_emb(h_idx), self.im_ent_emb(h_idx) re_t, im_t = self.re_ent_emb(t_idx), self.im_ent_emb(t_idx) re_r, im_r = self.re_rel_emb(r_idx), self.im_rel_emb(r_idx) return (re_h * (re_r * re_t + im_r * im_t) + im_h * ( re_r * im_t - im_r * re_t)).sum(dim=1)
[docs] def normalize_parameters(self): """According to original paper, the embeddings should not be normalized. """ pass
[docs] def get_embeddings(self): """Return the embeddings of entities and relations. Returns ------- re_ent_emb: torch.Tensor, shape: (n_ent, emb_dim), dtype: torch.float Real part of embeddings of entities. im_ent_emb: torch.Tensor, shape: (n_ent, emb_dim), dtype: torch.float Imaginary part of embeddings of entities. re_rel_emb: torch.Tensor, shape: (n_rel, emb_dim), dtype: torch.float Real part of embeddings of relations. im_rel_emb: torch.Tensor, shape: (n_rel, emb_dim), dtype: torch.float Imaginary part of embeddings of relations. """ self.normalize_parameters() return self.re_ent_emb.weight.data, self.im_ent_emb.weight.data,\ self.re_rel_emb.weight.data, self.im_rel_emb.weight.data
[docs] def inference_scoring_function(self, h, t, r): """Link prediction evaluation helper function. See torchkge.models.interfaces.Models for more details one the API. """ re_h, im_h = h[0], h[1] re_t, im_t = t[0], t[1] re_r, im_r = r[0], r[1] b_size = re_h.shape[0] if len(re_t.shape) == 3: assert (len(re_h.shape) == 2) & (len(re_r.shape) == 2) # this is the tail completion case in link prediction return ((re_h * re_r - im_h * im_r).view(b_size, 1, self.emb_dim) * re_t + (re_h * im_r + im_h * re_r).view(b_size, 1, self.emb_dim) * im_t).sum(dim=2) elif len(re_h.shape) == 3: assert (len(re_t.shape) == 2) & (len(re_r.shape) == 2) # this is the head completion case in link prediction return (re_h * (re_r * re_t + im_r * im_t).view(b_size, 1, self.emb_dim) + im_h * (re_r * im_t - im_r * re_t).view(b_size, 1, self.emb_dim)).sum(dim=2) elif len(re_r.shape) == 3: assert (len(re_h.shape) == 2) & (len(re_t.shape) == 2) # this is the relation prediction case return ((re_h * re_t + im_h * im_t).view(b_size, 1, self.emb_dim) * re_r + (re_h * im_t - im_h * re_t).view(b_size, 1, self.emb_dim) * im_r).sum(dim=2)
[docs] def inference_prepare_candidates(self, h_idx, t_idx, r_idx, entities=True): """Link prediction evaluation helper function. Get entities embeddings and relations embeddings. The output will be fed to the `inference_scoring_function` method. See torchkge.models.interfaces.Models for more details on the API. """ b_size = h_idx.shape[0] re_h, im_h = self.re_ent_emb(h_idx), self.im_ent_emb(h_idx) re_t, im_t = self.re_ent_emb(t_idx), self.im_ent_emb(t_idx) re_r, im_r = self.re_rel_emb(r_idx), self.im_rel_emb(r_idx) if entities: re_candidates = self.re_ent_emb.weight.data.view(1, self.n_ent, self.emb_dim) re_candidates = re_candidates.expand(b_size, self.n_ent, self.emb_dim) im_candidates = self.im_ent_emb.weight.data.view(1, self.n_ent, self.emb_dim) im_candidates = im_candidates.expand(b_size, self.n_ent, self.emb_dim) else: re_candidates = self.re_rel_emb.weight.data.view(1, self.n_rel, self.emb_dim) re_candidates = re_candidates.expand(b_size, self.n_rel, self.emb_dim) im_candidates = self.im_rel_emb.weight.data.view(1, self.n_rel, self.emb_dim) im_candidates = im_candidates.expand(b_size, self.n_rel, self.emb_dim) return (re_h, im_h), (re_t, im_t), (re_r, im_r), (re_candidates, im_candidates)
[docs]class AnalogyModel(BilinearModel): """Implementation of ANALOGY model detailed in 2017 paper by Liu et al.. According to their remark in the implementation details, the number of scalars on the diagonal of each relation-specific matrix is by default set to be half the embedding dimension. This class inherits from the :class:`torchkge.models.interfaces.BilinearModel` interface. It then has its attributes as well. References ---------- * Hanxiao Liu, Yuexin Wu, and Yiming Yang. `Analogical Inference for Multi-Relational Embeddings. <https://arxiv.org/abs/1705.02426>`_ arXiv :1705.02426 [cs], May 2017. Parameters ---------- emb_dim: int Dimension of embedding space. n_entities: int Number of entities in the current data set. n_relations: int Number of relations in the current data set. scalar_share: float Share of the diagonal elements of the relation-specific matrices to be scalars. By default it is set to half according to the original paper. Attributes ---------- scalar_dim: int Number of diagonal elements of the relation-specific matrices to be scalars. By default it is set to half the embedding dimension according to the original paper. complex_dim: int Number of 2x2 matrices on the diagonals of relation-specific matrices. sc_ent_emb: torch.nn.Embedding, shape: (n_ent, emb_dim) Part of the entities embeddings associated to the scalar part of the relation specific matrices. Initialized with Xavier uniform distribution. re_ent_emb: torch.nn.Embedding, shape: (n_ent, emb_dim) Real part of the entities complex embeddings. Initialized with Xavier uniform distribution. As explained in the authors' paper, almost diagonal matrices can be seen as complex matrices. im_ent_emb: torch.nn.Embedding, shape: (n_ent, emb_dim) Imaginary part of the entities complex embeddings. Initialized with Xavier uniform distribution. As explained in the authors' paper, almost diagonal matrices can be seen as complex matrices. sc_rel_emb: torch.nn.Embedding, shape: (n_rel, emb_dim) Part of the entities embeddings associated to the scalar part of the relation specific matrices. Initialized with Xavier uniform distribution. re_rel_emb: torch.nn.Embedding, shape: (n_rel, emb_dim) Real part of the relations complex embeddings. Initialized with Xavier uniform distribution. As explained in the authors' paper, almost diagonal matrices can be seen as complex matrices. im_rel_emb: torch.nn.Embedding, shape: (n_rel, emb_dim) Imaginary part of the relations complex embeddings. Initialized with Xavier uniform distribution. As explained in the authors' paper, almost diagonal matrices can be seen as complex matrices. """ def __init__(self, emb_dim, n_entities, n_relations, scalar_share=0.5): super().__init__(emb_dim, n_entities, n_relations) self.scalar_dim = int(self.emb_dim * scalar_share) self.complex_dim = int((self.emb_dim - self.scalar_dim)) self.sc_ent_emb = init_embedding(self.n_ent, self.scalar_dim) self.re_ent_emb = init_embedding(self.n_ent, self.complex_dim) self.im_ent_emb = init_embedding(self.n_ent, self.complex_dim) self.sc_rel_emb = init_embedding(self.n_rel, self.scalar_dim) self.re_rel_emb = init_embedding(self.n_rel, self.complex_dim) self.im_rel_emb = init_embedding(self.n_rel, self.complex_dim)
[docs] def scoring_function(self, h_idx, t_idx, r_idx): """Compute the scoring function for the triplets given as argument: :math:`h_{sc}^T \\cdot diag(r_{sc}) \\cdot t_{sc} + \\Re(h_{compl} \\cdot diag(r_{compl} \\cdot t_{compl}))`. See referenced paper for more details on the score. See torchkge.models.interfaces.Models for more details on the API. """ sc_h, re_h, im_h = self.sc_ent_emb(h_idx), self.re_ent_emb( h_idx), self.im_ent_emb(h_idx) sc_t, re_t, im_t = self.sc_ent_emb(t_idx), self.re_ent_emb( t_idx), self.im_ent_emb(t_idx) sc_r, re_r, im_r = self.sc_rel_emb(r_idx), self.re_rel_emb( r_idx), self.im_rel_emb(r_idx) return ((sc_h * sc_r * sc_t).sum(dim=1) + (re_h * (re_r * re_t + im_r * im_t) + im_h * (re_r * im_t - im_r * re_t)).sum(dim=1))
[docs] def normalize_parameters(self): """According to original paper, the embeddings should not be normalized. """ pass
[docs] def get_embeddings(self): """Return the embeddings of entities and relations. Returns ------- sc_ent_emb: torch.Tensor, shape: (n_ent, emb_dim), dtype: torch.float Scalar part of embeddings of entities. re_ent_emb: torch.Tensor, shape: (n_ent, emb_dim), dtype: torch.float Real part of embeddings of entities. im_ent_emb: torch.Tensor, shape: (n_ent, emb_dim), dtype: torch.float Imaginary part of embeddings of entities. sc_rel_emb: torch.Tensor, shape: (n_rel, emb_dim), dtype: torch.float Scalar part of embeddings of relations. re_rel_emb: torch.Tensor, shape: (n_rel, emb_dim), dtype: torch.float Real part of embeddings of relations. im_rel_emb: torch.Tensor, shape: (n_rel, emb_dim), dtype: torch.float Imaginary part of embeddings of relations. """ self.normalize_parameters() return self.sc_ent_emb.weight.data, self.re_ent_emb.weight.data, \ self.im_ent_emb.weight.data, self.sc_rel_emb.weight.data, \ self.re_rel_emb.weight.data, self.im_rel_emb.weight.data
[docs] def inference_scoring_function(self, h, t, r): """Link prediction evaluation helper function. See torchkge.models.interfaces.Models for more details one the API. """ sc_h, re_h, im_h = h[0], h[1], h[2] sc_t, re_t, im_t = t[0], t[1], t[2] sc_r, re_r, im_r = r[0], r[1], r[2] b_size = re_h.shape[0] if len(re_t.shape) == 3: assert (len(re_h.shape) == 2) & (len(re_r.shape) == 2) # this is the tail completion case in link prediction return ((sc_h * sc_r).view(b_size, 1, self.scalar_dim) * sc_t + (re_h * re_r - im_h * im_r).view(b_size, 1, self.complex_dim) * re_t + (re_h * im_r + im_h * re_r).view(b_size, 1, self.complex_dim) * im_t).sum(dim=2) elif len(re_h.shape) == 3: assert (len(re_t.shape) == 2) & (len(re_r.shape) == 2) # this is the head completion case in link prediction return (sc_h * (sc_r * sc_t).view(b_size, 1, self.scalar_dim) + re_h * (re_r * re_t + im_r * im_t).view(b_size, 1, self.complex_dim) + im_h * (re_r * im_t - im_r * re_t).view(b_size, 1, self.complex_dim)).sum(dim=2) elif len(re_r.shape) == 3: assert (len(re_h.shape) == 2) & (len(re_t.shape) == 2) # this is the relation prediction case return (sc_r * (sc_h * sc_t).view(b_size, 1, self.scalar_dim) + re_r * (re_h * re_t + im_h * im_t).view(b_size, 1, self.complex_dim) + im_r * (re_h * im_t - im_h * re_t).view(b_size, 1, self.complex_dim)).sum(dim=2)
[docs] def inference_prepare_candidates(self, h_idx, t_idx, r_idx, entities=True): """Link prediction evaluation helper function. Get entities embeddings and relations embeddings. The output will be fed to the `inference_scoring_function` method. See torchkge.models.interfaces.Models for more details on the API. """ b_size = h_idx.shape[0] sc_h = self.sc_ent_emb(h_idx) re_h = self.re_ent_emb(h_idx) im_h = self.im_ent_emb(h_idx) sc_t = self.sc_ent_emb(t_idx) re_t = self.re_ent_emb(t_idx) im_t = self.im_ent_emb(t_idx) sc_r = self.sc_rel_emb(r_idx) re_r = self.re_rel_emb(r_idx) im_r = self.im_rel_emb(r_idx) if entities: sc_candidates = self.sc_ent_emb.weight.data sc_candidates = sc_candidates.view(1, self.n_ent, self.scalar_dim) sc_candidates = sc_candidates.expand(b_size, self.n_ent, self.scalar_dim) re_candidates = self.re_ent_emb.weight.data re_candidates = re_candidates.view(1, self.n_ent, self.complex_dim) re_candidates = re_candidates.expand(b_size, self.n_ent, self.complex_dim) im_candidates = self.im_ent_emb.weight.data im_candidates = im_candidates.view(1, self.n_ent, self.complex_dim) im_candidates = im_candidates.expand(b_size, self.n_ent, self.complex_dim) else: sc_candidates = self.sc_rel_emb.weight.data sc_candidates = sc_candidates.view(1, self.n_rel, self.scalar_dim) sc_candidates = sc_candidates.expand(b_size, self.n_rel, self.scalar_dim) re_candidates = self.re_rel_emb.weight.data re_candidates = re_candidates.view(1, self.n_rel, self.complex_dim) re_candidates = re_candidates.expand(b_size, self.n_rel, self.complex_dim) im_candidates = self.im_rel_emb.weight.data im_candidates = im_candidates.view(1, self.n_rel, self.complex_dim) im_candidates = im_candidates.expand(b_size, self.n_rel, self.complex_dim) return (sc_h, re_h, im_h), \ (sc_t, re_t, im_t), \ (sc_r, re_r, im_r), \ (sc_candidates, re_candidates, im_candidates)