Model Training

Here are two examples of models being trained on FB15k.

Simplest training

This is the python code to train TransE without any wrapper. This script shows how all parts of TorchKGE should be used together:

from torch import cuda
from torch.optim import Adam

from torchkge.models import TransEModel
from torchkge.sampling import BernoulliNegativeSampler
from torchkge.utils import MarginLoss, DataLoader
from torchkge.utils.datasets import load_fb15k

from tqdm.autonotebook import tqdm

# Load dataset
kg_train, _, _ = load_fb15k()

# Define some hyper-parameters for training
emb_dim = 100
lr = 0.0004
n_epochs = 1000
b_size = 32768
margin = 0.5

# Define the model and criterion
model = TransEModel(emb_dim, kg_train.n_ent, kg_train.n_rel, dissimilarity_type='L2')
criterion = MarginLoss(margin)

# Move everything to CUDA if available
if cuda.is_available():
    cuda.empty_cache()
    model.cuda()
    criterion.cuda()

# Define the torch optimizer to be used
optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-5)

sampler = BernoulliNegativeSampler(kg_train)
dataloader = DataLoader(kg_train, batch_size=b_size, use_cuda='all')

iterator = tqdm(range(n_epochs), unit='epoch')
for epoch in iterator:
    running_loss = 0.0
    for i, batch in enumerate(dataloader):
        h, t, r = batch[0], batch[1], batch[2]
        n_h, n_t = sampler.corrupt_batch(h, t, r)

        optimizer.zero_grad()

        # forward + backward + optimize
        pos, neg = model(h, t, r, n_h, n_t)
        loss = criterion(pos, neg)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    iterator.set_description(
        'Epoch {} | mean loss: {:.5f}'.format(epoch + 1,
                                              running_loss / len(dataloader)))

model.normalize_parameters()

Shortest training

TorchKGE also provides simple utility wrappers for model training. Here is an example on how to use them:

from torch.optim import Adam

from torchkge.evaluation import LinkPredictionEvaluator
from torchkge.models import TransEModel
from torchkge.utils.datasets import load_fb15k
from torchkge.utils import Trainer, MarginLoss


def main():
    # Define some hyper-parameters for training
    emb_dim = 100
    lr = 0.0004
    margin = 0.5
    n_epochs = 1000
    batch_size = 32768

    # Load dataset
    kg_train, kg_val, kg_test = load_fb15k()

    # Define the model and criterion
    model = TransEModel(emb_dim, kg_train.n_ent, kg_train.n_rel,
                        dissimilarity_type='L2')
    criterion = MarginLoss(margin)
    optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-5)

    trainer = Trainer(model, criterion, kg_train, n_epochs, batch_size,
                      optimizer=optimizer, sampling_type='bern', use_cuda='all',)

    trainer.run()

    evaluator = LinkPredictionEvaluator(model, kg_test)
    evaluator.evaluate(200)
    evaluator.print_results()


if __name__ == "__main__":
    main()

Training with Ignite

TorchKGE can be used along with the PyTorch ignite library. It makes it easy to include early stopping in the training process. Here is an example script of training a TransE model on FB15k on GPU with early stopping on evaluation MRR:

import torch
from ignite.engine import Engine, Events
from ignite.handlers import EarlyStopping
from ignite.metrics import RunningAverage
from torch.optim import Adam

from torchkge.evaluation import LinkPredictionEvaluator
from torchkge.models import TransEModel
from torchkge.sampling import BernoulliNegativeSampler
from torchkge.utils import MarginLoss, DataLoader
from torchkge.utils.datasets import load_fb15k


def process_batch(engine, batch):
    h, t, r = batch[0], batch[1], batch[2]
    n_h, n_t = sampler.corrupt_batch(h, t, r)

    optimizer.zero_grad()

    pos, neg = model(h, t, r, n_h, n_t)
    loss = criterion(pos, neg)
    loss.backward()
    optimizer.step()

    return loss.item()


def linkprediction_evaluation(engine):
    model.normalize_parameters()

    loss = engine.state.output

    # validation MRR measure
    if engine.state.epoch % eval_epoch == 0:
        evaluator = LinkPredictionEvaluator(model, kg_val)
        evaluator.evaluate(b_size=256, verbose=False)
        val_mrr = evaluator.mrr()[1]
    else:
        val_mrr = 0

    print('Epoch {} | Train loss: {}, Validation MRR: {}'.format(
        engine.state.epoch, loss, val_mrr))

    try:
        if engine.state.best_mrr < val_mrr:
            engine.state.best_mrr = val_mrr
        return val_mrr

    except AttributeError as e:
        if engine.state.epoch == 1:
            engine.state.best_mrr = val_mrr
            return val_mrr
        else:
            raise e

device = torch.device('cuda')

eval_epoch = 20  # do link prediction evaluation each 20 epochs
max_epochs = 1000
patience = 40
batch_size = 32768
emb_dim = 100
lr = 0.0004
margin = 0.5

kg_train, kg_val, kg_test = load_fb15k()

# Define the model, optimizer and criterion
model = TransEModel(emb_dim, kg_train.n_ent, kg_train.n_rel,
                    dissimilarity_type='L2')
model.to(device)

optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-5)
criterion = MarginLoss(margin)
sampler = BernoulliNegativeSampler(kg_train, kg_val=kg_val, kg_test=kg_test)

# Define the engine
trainer = Engine(process_batch)

# Define the moving average
RunningAverage(output_transform=lambda x: x).attach(trainer, 'margin')

# Add early stopping
handler = EarlyStopping(patience=patience,
                        score_function=linkprediction_evaluation,
                        trainer=trainer)
trainer.add_event_handler(Events.EPOCH_COMPLETED, handler)

# Training
train_iterator = DataLoader(kg_train, batch_size, use_cuda='all')
trainer.run(train_iterator,
            epoch_length=len(train_iterator),
            max_epochs=max_epochs)

print('Best score {:.3f} at epoch {}'.format(handler.best_score,
                                             trainer.state.epoch - handler.patience))