Train a single triplet embedding model
train_embedding.RdA lower-level interface to the embedding pipeline. Use this function when
you want to manage the train/test split in R rather than relying on the
sampleSet column in your data, or when you want to train an
embedding on a single subset of responses.
Usage
train_embedding(
X_train,
X_test,
d = 5L,
max_epochs = 50000L,
tolerance = 1e-04,
tol_window = 10000L,
print_every = 100L,
device = NULL,
random_state = NULL
)Arguments
- X_train
Integer matrix of shape \(n_{\text{triplets}} \times 3\). Columns must be
head,winner,loserin that order, with zero-based integer item indices. Pass usingas.matrix(df[, c("head", "winner", "loser")]).- X_test
Integer matrix in the same format as
X_train, used for computing validation metrics and triggering early stopping.- d
Number of embedding dimensions. Default
5.- max_epochs
Maximum number of training epochs. Default
50000.- tolerance
Loss tolerance for early stopping. Default
1e-4.- tol_window
Number of epochs the loss must remain within
toleranceof the best before training halts early. Default10000.- print_every
Print a progress line every this many epochs. Default
100. Increase to reduce console output; set tomax_epochsto suppress mid-training output entirely.- device
PyTorch device string, or
NULL(default) to auto-select: CUDA GPU if available, then Apple MPS, then CPU. Pass"cpu"to force CPU even on a GPU machine.- random_state
Integer seed passed to both NumPy and PyTorch before training begins.
NULL(default) leaves the global random state unchanged. Set this when you need reproducible embeddings, e.g. when comparing multiple random restarts.
Value
A named list with five elements:
embeddingNumeric matrix of shape \(n_{\text{items}} \times d\) containing the learned positions. Rows correspond to items in index order.
lossBest test loss achieved during training.
epochEpoch number at which training stopped.
counterNumber of epochs since the last meaningful improvement at the point training stopped.
historyData frame with one row per epoch and columns
epoch,train_loss,test_loss,train_acc,test_acc.
Details
For processing all workers in a dataset at once, see
run_embeddings and run_embeddings_from_list.
Early stopping
Training runs for up to max_epochs passes through the training data.
It stops early if the test loss is within tolerance of the best
observed test loss for more than tol_window consecutive epochs.
The embedding that achieved the best test loss during training is returned,
not necessarily the final-epoch embedding.
Item indices
Items are identified by zero-based integer indices. If your data uses
one-based indices (as is typical in R), subtract 1 from the
head, winner, and loser columns before passing them
to this function.
Progress output
Training progress is printed to the console every print_every
epochs, showing epoch number, train loss, test loss, train accuracy, and
test accuracy. A final line is printed when training stops, labelled
[early stop] if stopping was triggered before max_epochs.
Examples
if (FALSE) { # \dontrun{
triplets <- read.csv("triplets.csv")
is_train <- triplets$sampleSet == "train"
X_train <- as.matrix(triplets[is_train, c("head", "winner", "loser")])
X_test <- as.matrix(triplets[!is_train, c("head", "winner", "loser")])
out <- train_embedding(X_train, X_test, d = 5L, max_epochs = 50000L)
dim(out$embedding) # n_items x 5
cat("Best loss:", out$loss, "\n")
cat("Stopped at epoch:", out$epoch, "\n")
# Per-epoch training curve
head(out$history)
plot(out$history$epoch, out$history$test_loss, type = "l",
xlab = "Epoch", ylab = "Test loss")
} # }