Source code for torchkge.utils.dissimilarities

# -*- coding: utf-8 -*-
"""
Copyright TorchKGE developers
@author: Armand Boschin <aboschin@enst.fr>
"""

from math import pi
from torch import abs, cos, min


[docs]def l1_dissimilarity(a, b): """Compute dissimilarity between rows of `a` and `b` as :math:`||a-b||_1`. """ assert len(a.shape) == len(b.shape) return (a-b).norm(p=1, dim=-1)
[docs]def l2_dissimilarity(a, b): """Compute dissimilarity between rows of `a` and `b` as :math:`||a-b||_2^2`. """ assert len(a.shape) == len(b.shape) return (a-b).norm(p=2, dim=-1)**2
[docs]def l1_torus_dissimilarity(a, b): """See `paper by Ebisu et al. <https://www.semanticscholar.org/paper/TorusE%3A-Knowledge-Graph-Embedding-on-a-Lie-Group-Ebisu-Ichise/990334cf76845e2da64d3baa10b0a671e433d4b6>`_ for details about the definition of this dissimilarity function. """ assert len(a.shape) == len(b.shape) return 2 * min(abs(a - b), 1 - abs(a - b)).sum(dim=-1)
[docs]def l2_torus_dissimilarity(a, b): """See `paper by Ebisu et al. <https://www.semanticscholar.org/paper/TorusE%3A-Knowledge-Graph-Embedding-on-a-Lie-Group-Ebisu-Ichise/990334cf76845e2da64d3baa10b0a671e433d4b6>`_ for details about the definition of this dissimilarity function. """ assert len(a.shape) == len(b.shape) return 4 * min((a - b) ** 2, 1 - (a - b) ** 2).sum(dim=-1)
[docs]def el2_torus_dissimilarity(a, b): """See `paper by Ebisu et al. <https://www.semanticscholar.org/paper/TorusE%3A-Knowledge-Graph-Embedding-on-a-Lie-Group-Ebisu-Ichise/990334cf76845e2da64d3baa10b0a671e433d4b6>`_ for details about the definition of this dissimilarity function. """ assert len(a.shape) == len(b.shape) tmp = min(a - b, 1 - (a - b)) tmp = 2 * (1 - cos(2 * pi * tmp)) return tmp.sum(dim=-1) / 4