# -*- coding: utf-8 -*-
"""
Copyright TorchKGE developers
@author: Armand Boschin <aboschin@enst.fr>
This module contains functions implementing methods explained in `this
paper<https://arxiv.org/pdf/2003.08001.pdf>`__ by Akrami et al.
"""
from itertools import combinations
from torch import cat
from tqdm.autonotebook import tqdm
def concat_kgs(kg_tr, kg_val, kg_te):
h = cat((kg_tr.head_idx, kg_val.head_idx, kg_te.head_idx))
t = cat((kg_tr.tail_idx, kg_val.tail_idx, kg_te.tail_idx))
r = cat((kg_tr.relations, kg_val.relations, kg_te.relations))
return h, t, r
def get_pairs(kg, r, type='ht'):
mask = (kg.relations == r)
if type == 'ht':
return set((i.item(), j.item()) for i, j in cat(
(kg.head_idx[mask].view(-1, 1),
kg.tail_idx[mask].view(-1, 1)), dim=1))
else:
assert type == 'th'
return set((j.item(), i.item()) for i, j in cat(
(kg.head_idx[mask].view(-1, 1),
kg.tail_idx[mask].view(-1, 1)), dim=1))
[docs]def count_triplets(kg1, kg2, duplicates, rev_duplicates):
"""
Parameters
----------
kg1: torchkge.data_structures.KnowledgeGraph
kg2: torchkge.data_structures.KnowledgeGraph
duplicates: list
List returned by torchkge.utils.data_redundancy.duplicates.
rev_duplicates: list
List returned by torchkge.utils.data_redundancy.duplicates.
Returns
-------
n_duplicates: int
Number of triplets in kg2 that have their duplicate triplet
in kg1
n_rev_duplicates: int
Number of triplets in kg2 that have their reverse duplicate
triplet in kg1.
"""
n_duplicates = 0
for r1, r2 in duplicates:
ht_tr = get_pairs(kg1, r2, type='ht')
ht_te = get_pairs(kg2, r1, type='ht')
n_duplicates += len(ht_te.intersection(ht_tr))
ht_tr = get_pairs(kg1, r1, type='ht')
ht_te = get_pairs(kg2, r2, type='ht')
n_duplicates += len(ht_te.intersection(ht_tr))
n_rev_duplicates = 0
for r1, r2 in rev_duplicates:
th_tr = get_pairs(kg1, r2, type='th')
ht_te = get_pairs(kg2, r1, type='ht')
n_rev_duplicates += len(ht_te.intersection(th_tr))
th_tr = get_pairs(kg1, r1, type='th')
ht_te = get_pairs(kg2, r2, type='ht')
n_rev_duplicates += len(ht_te.intersection(th_tr))
return n_duplicates, n_rev_duplicates
[docs]def duplicates(kg_tr, kg_val, kg_te, theta1=0.8, theta2=0.8,
verbose=False, counts=False, reverses=None):
"""Return the duplicate and reverse duplicate relations as explained
in paper by Akrami et al.
References
----------
* Farahnaz Akrami, Mohammed Samiul Saeef, Quingheng Zhang.
`Realistic Re-evaluation of Knowledge Graph Completion Methods:
An Experimental Study. <https://arxiv.org/pdf/2003.08001.pdf>`_
SIGMOD’20, June 14–19, 2020, Portland, OR, USA
Parameters
----------
kg_tr: torchkge.data_structures.KnowledgeGraph
Train set
kg_val: torchkge.data_structures.KnowledgeGraph
Validation set
kg_te: torchkge.data_structures.KnowledgeGraph
Test set
theta1: float
First threshold (see paper).
theta2: float
Second threshold (see paper).
verbose: bool
counts: bool
Should the triplets involving (reverse) duplicate relations be
counted in all sets.
reverses: list
List of known reverse relations.
Returns
-------
duplicates: list
List of pairs giving duplicate relations.
rev_duplicates: list
List of pairs giving reverse duplicate relations.
"""
if verbose:
print('Computing Ts')
if reverses is None:
reverses = []
T = dict()
T_inv = dict()
lengths = dict()
h, t, r = concat_kgs(kg_tr, kg_val, kg_te)
for r_ in tqdm(range(kg_tr.n_rel)):
mask = (r == r_)
lengths[r_] = mask.sum().item()
pairs = cat((h[mask].view(-1, 1), t[mask].view(-1, 1)), dim=1)
T[r_] = set([(h_.item(), t_.item()) for h_, t_ in pairs])
T_inv[r_] = set([(t_.item(), h_.item()) for h_, t_ in pairs])
if verbose:
print('Finding duplicate relations')
duplicates = []
rev_duplicates = []
iter_ = list(combinations(range(1345), 2))
for r1, r2 in tqdm(iter_):
a = len(T[r1].intersection(T[r2])) / lengths[r1]
b = len(T[r1].intersection(T[r2])) / lengths[r2]
if a > theta1 and b > theta2:
duplicates.append((r1, r2))
if (r1, r2) not in reverses:
a = len(T[r1].intersection(T_inv[r2])) / lengths[r1]
b = len(T[r1].intersection(T_inv[r2])) / lengths[r2]
if a > theta1 and b > theta2:
rev_duplicates.append((r1, r2))
if verbose:
print('Duplicate relations: {}'.format(len(duplicates)))
print('Reverse duplicate relations: '
'{}\n'.format(len(rev_duplicates)))
if counts:
dupl, rev = count_triplets(kg_tr, kg_tr, duplicates, rev_duplicates)
print('{} train triplets have duplicate in train set '
'({}%)'.format(dupl, int(dupl / len(kg_tr))))
print('{} train triplets have reverse duplicate in train set '
'({}%)\n'.format(rev, int(rev / len(kg_tr) * 100)))
dupl, rev = count_triplets(kg_tr, kg_te, duplicates, rev_duplicates)
print('{} test triplets have duplicate in train set '
'({}%)'.format(dupl, int(dupl / len(kg_te))))
print('{} test triplets have reverse duplicate in train set '
'({}%)\n'.format(rev, int(rev / len(kg_te) * 100)))
dupl, rev = count_triplets(kg_te, kg_te, duplicates, rev_duplicates)
print('{} test triplets have duplicate in test set '
'({}%)'.format(dupl, int(dupl / len(kg_te))))
print('{} test triplets have reverse duplicate in test set '
'({}%)\n'.format(rev, int(rev / len(kg_te) * 100)))
return duplicates, rev_duplicates
[docs]def cartesian_product_relations(kg_tr, kg_val, kg_te, theta=0.8):
"""Return the cartesian product relations as explained in paper by
Akrami et al.
References
----------
* Farahnaz Akrami, Mohammed Samiul Saeef, Quingheng Zhang.
`Realistic Re-evaluation of Knowledge Graph Completion Methods: An
Experimental Study. <https://arxiv.org/pdf/2003.08001.pdf>`_
SIGMOD’20, June 14–19, 2020, Portland, OR, USA
Parameters
----------
kg_tr: torchkge.data_structures.KnowledgeGraph
Train set
kg_val: torchkge.data_structures.KnowledgeGraph
Validation set
kg_te: torchkge.data_structures.KnowledgeGraph
Test set
theta: float
Threshold used to compute the cartesian product relations.
Returns
-------
selected_relations: list
List of relations index that are cartesian product relations
(see paper for details).
"""
selected_relations = []
h, t, r = concat_kgs(kg_tr, kg_val, kg_te)
S = dict()
O = dict()
lengths = dict()
for r_ in tqdm(range(kg_tr.n_rel)):
mask = (r == r_)
lengths[r_] = mask.sum().item()
S[r_] = set(h_.item() for h_ in h[mask])
O[r_] = set(t_.item() for t_ in t[mask])
if lengths[r_] / (len(S[r_]) * len(O[r_])) > theta:
selected_relations.append(r_)
return selected_relations