Run the full embedding pipeline for all workers
run_embeddings.RdReads triplet comparison data from input_file, trains a separate
embedding model with early stopping for each worker, and then trains a
combined group-level embedding across all workers. Output CSV files are
written to output_dir and the results are also returned as R data
frames.
Usage
run_embeddings(
input_file,
additional_data_file,
output_dir,
d = 5L,
max_epochs = 50000L,
tolerance = 1e-04,
tol_window = 10000L,
seed = 222L,
device = NULL
)Arguments
- input_file
Path to the CSV file containing all triplets (see Input format above).
- additional_data_file
Path to a CSV file with item metadata to append to the embedding output (e.g. image filenames listed in alphabetical order). The number of rows should match the number of unique items.
- output_dir
Path to the directory where output CSV files will be saved. Created automatically if it does not already exist.
- d
Number of embedding dimensions. Default
5.- max_epochs
Maximum number of training epochs. Default
50000.- tolerance
Loss tolerance for early stopping. Default
1e-4.- tol_window
Epochs without improvement before early stopping triggers. Default
10000.- seed
Integer random seed for reproducibility. Default
222.- device
PyTorch device string, or
NULL(default) to auto-select: CUDA GPU if available, then Apple MPS, then CPU. Pass"cpu"to force CPU even on a GPU machine.
Value
A named list with two elements:
historyData frame with one row per worker (plus one for the group model) containing:
worker_id,lowest_loss,epoch,counter_from_last_update,n_train_triplets,n_test_triplets.embeddingsData frame of all embeddings concatenated, with dimension columns (
dim_0,dim_1, …), aworker_idcolumn, and any columns fromadditional_data_file.
Details
For a higher-level interface that accepts triplet data already loaded into R
as a named list (the format returned by get.combined), see
run_embeddings_from_list.
Input format
input_file must be a CSV with at least the following columns:
worker_idIdentifier for the respondent.
headZero-based integer index of the reference item.
winnerZero-based integer index of the item judged closer to
head.loserZero-based integer index of the item judged further from
head.sampleSetEither
"train"or"test", used to split data for early stopping.