Evaluate classification performance of factor embeddings
Source:R/classifier_metrics.R
classify_embedding.RdTrains a k-nearest-neighbor classifier on factor embeddings and evaluates on held-out test samples. Returns a comprehensive metrics object.
Usage
classify_embedding(
embedding,
labels,
test_fraction = 0.2,
test_idx = NULL,
k = 5L,
seed = NULL,
distance = c("euclidean", "cosine")
)Arguments
- embedding
Numeric matrix where rows are samples and columns are features (e.g.,
t(result$H)for sample embeddings).- labels
Integer or factor vector of class labels. Length must equal
nrow(embedding).- test_fraction
Fraction of samples held out for testing (default 0.2).
- test_idx
Optional integer vector of test indices. If provided,
test_fractionis ignored.- k
Number of nearest neighbors (default 5).
- seed
Random seed for train/test split reproducibility.
- distance
Distance metric: "euclidean" (default) or "cosine".
Value
An fn_classifier_eval object with fields:
- accuracy
Overall test accuracy
- per_class
Data frame with per-class precision, recall, F1, support
- macro_precision
Macro-averaged precision
- macro_recall
Macro-averaged recall
- macro_f1
Macro-averaged F1
- weighted_f1
Support-weighted F1
- auc
Macro-averaged one-vs-rest AUC (from neighbor vote fractions)
- confusion
Confusion matrix (rows = true, cols = predicted)
- predictions
Test set predictions
- test_labels
True test labels
- train_idx
Training sample indices
- test_idx
Test sample indices
- k
Number of neighbors used
Examples
# \donttest{
data(digits)
model <- nmf(digits, 10, maxit = 20, seed = 1, verbose = FALSE)
labels <- attr(digits, "target")
eval <- classify_embedding(model$w, labels, test_fraction = 0.2, k = 5, seed = 42)
print(eval)
#> Classification Evaluation (5-NN, euclidean distance)
#> Samples: 1438 train, 359 test, 10 classes
#> Accuracy: 0.9499
#> Macro F1: 0.9491 (P=0.9516, R=0.9499)
#> Weighted F1: 0.9494
#> AUC (macro): 0.9949
#>
#> Per-class:
#> class precision recall f1 support
#> 0 0.9355 1.0000 0.9667 29
#> 1 0.8500 0.9444 0.8947 36
#> 2 0.9070 1.0000 0.9512 39
#> 3 0.9429 0.9429 0.9429 35
#> 4 1.0000 0.9500 0.9744 40
#> 5 1.0000 0.9706 0.9851 34
#> 6 1.0000 1.0000 1.0000 39
#> 7 0.9444 1.0000 0.9714 34
#> 8 0.9643 0.7941 0.8710 34
#> 9 0.9722 0.8974 0.9333 39
# }