Skip to content

Introduce knn-engine and further refactors in aoa and trainDI#167

Open
goergen95 wants to merge 51 commits intoHannaMeyer:masterfrom
goergen95:distances
Open

Introduce knn-engine and further refactors in aoa and trainDI#167
goergen95 wants to merge 51 commits intoHannaMeyer:masterfrom
goergen95:distances

Conversation

@goergen95
Copy link
Copy Markdown
Contributor

@goergen95 goergen95 commented Mar 18, 2026

KNN-Engine and Refactors

@HannaMeyer, @Nowosad, @JanLinnenbrink:

The below summary of the PR was produced by the Claude Haiku 4.5 model inspecting
the previous PR description and analysing the diff. I reviewed all content
and edited/deleted obviously wrong content (though I am less confident for
the analysis of the changes in nndm.R/knndm.R).

Summary

This PR introduces a centralized k-nearest neighbor (KNN) distance calculation
engine (knn-engine.R) and extensive refactoring of distance
calculations across the CAST package. The refactor replaces fragmented distance
computation logic with a unified, modular architecture while maintaining backward
compatibility through deprecation warnings. This represents a significant
architectural improvement that simplifies dependency management, improves code
maintainability, and sets the foundation for future parallelization.


1. KNN-Engine Architecture (knn-engine.R)

1.1 Core Design Principles

The new KNN engine is built on a preprocessor-calculator pattern that
separates distance metric preprocessing from pairwise distance computation:

Input Data → Normalization → Metric-Specific Preprocessing → 
Distance Calculation (philentropy) → k-NN Selection → Output

1.2 Key Functions

.distance(reference, query, dist_fun)

  • Purpose: Compute pairwise distance matrix between reference and query observations

  • Preprocessing Routing:

    • "euclidean" → default normalization only
    • "mahalanobis" → covariance pseudo-inversion + linear transformation
    • "gower" → reference-based min/max scaling + factor encoding
  • Key Assumptions:

    1. Reference is canonical: Scaling parameters derived from reference apply to query
    2. Mahalanobis preprocessing is exact: Uses Cholesky decomposition (fallback to eigenvalue decomposition if numerically unstable)
    3. Gower scaling is reference-relative: Not global min-max, only reference min/max used
    4. Null query means self-distances: When query is NULL, self-distances are set to NA

.knndist(reference, query, k, dist_fun, offset)

  • Purpose: Return k nearest neighbor distances with row indices
  • Returns: Matrix of distances + attribute "indices" with corresponding reference row indices
  • Offset Parameter: Allows skipping nearest neighbors (e.g., offset=1 excludes self when query=reference)

1.3 Known Caveats

  1. Gower Distance: Previous gower::gower() applied global min-max normalization. New implementation uses reference-only normalization → test expectations changed
  2. Mahalanobis Distance: Centralized calculation introduces small numerical differences (reasons not yet explored)
  3. No Parallelization: Temporarily removed; will be reintroduced via chunking logic with a yet-to-be-determined framework

2. Helper Functions & Data Preparation

2.1 Categorical Variable Handling (aoa-helpers.R)

.prepare_categorical_variables()

  • Converts factors/characters to dummy variables using caret::dummyVars
  • Handles unknown levels in query by dropping (before dummy creation) and setting to 0 in dummy columns
  • Critical: Weights expanded proportionally to dummy variables

2.2 Weight Management

.prepare_weights(weight, model, variables, useWeight)

  • Three-tier weight source hierarchy:
    1. Model-based importance (if model provided)
    2. User-specified weights (if provided)
    3. Defaults to uniform weights (all 1.0)
  • Validation: Ensures weights sum > 0; negative weights clamped to 0
  • Application: Via sweep() - multiplies each column by corresponding weight

2.3 Input Validation

.validate_LPD(maxLPD, n_samples)

  • Converts proportion (0-1) to count, or validates count is integer
  • Ensures maxLPD > 1 (minimum 2 neighbors)
  • Error handling: Clear, specific error messages for each validation failure

.prepare_folds(model, CVtrain, CVtest, useCV)

  • Extracts CV folds from caret models or uses manual specifications
  • Handles vector-to-list conversion (CVtest as factor IDs → fold lists)
  • Returns NULL if useCV=FALSE (message logged)

3. Distance Computation Refactors

3.1 trainDI.R - Training Data Dissimilarity

Previous Approach:

  • Loop over each training sample
  • Calculate distances individually via .knndistfun()
  • Apply CV fold masking per-sample
  • Limited parallelization with fork clusters

New Approach:

  • Chunked processing: Split training data into chunks, process each chunk's distances against full reference
  • Single distance matrix: Compute all distances once, then mask appropriately
  • Separation of concerns: .calc_dist() and .calc_lpd() handle specific metrics

Key Function: .chunked_apply()

.chunked_apply(
  train, folds, dist_fun, chunk_size=1000L, 
  calc_fun, verbose=TRUE, ...
)
  • Processes training data in row-wise chunks
  • Applies calc_fun (e.g., .calc_dist, .calc_lpd) to each chunk
  • Combines results (named lists concatenated per name; vectors flattened)
  • Memory efficiency: chunk_size parameter controls RAM usage

CV Fold Masking via .mask_dist_mat():

Set self-distances to NA
↓
For each sample, find its testing fold
↓
Mask all non-training samples to NA
↓
Return masked distance matrix

3.2 aoa.R - Area of Applicability

Signature Change:

# Old
aoa(newdata, method="L2", parallel=TRUE, cores=4, algorithm="brute")

# New  
aoa(newdata, dist_fun="euclidean", chunk_size=1000L)
# With deprecations for old parameters

Key Refactors:

  1. Parameter harmonization: methoddist_fun (matches geodist, knndm, nndm)
  2. Removed parallelization code: ~200 lines of .process_row() and cluster logic deleted
  3. Simplified categorical handling: Uses .convert_factors_to_dummy() from helpers
  4. Streamlined LPD calculation: Single .knndist() call with maxLPD parameter

4. Cross-Package Refactors

4.1 knndm.R - k-NN Distance Matching

Changes:

  • Removed: FNN dependency (FNN::knn.dist, FNN::knnx.dist)
  • Added: Direct calls to .knndist() and .distance() from knn-engine
  • Benefit: Unified distance calculation; gower/mahalanobis support now consistent
  • Simplification: Removed duplicate distance matrix calculations in knndm_feature() and knndm_geo()

Removed Helper: distclust_distmat() replaced by cv_distances() using .knndist()

4.2 nndm.R - Nearest Neighbor Distance Matching

Changes:

  • Parameter renamed: spacedist_space (consistency with other functions)
  • New parameters: dist_fun="euclidean", scale_vars=TRUE
  • Feature space now supports gower and mahalanobis
  • Removes categorical variable dummy encoding (handled by .distance())

4.3 geodist.R - Geographical Distance

Changes:

  • Removed algorithm parameter (now in knn-engine)
  • Automatic dist_fun detection for geographical coordinates
  • Simplified feature space distance calculation via .knndist()
  • Removed duplicated Gower distance code

5. Dependency Changes

5.1 Removed Dependencies

  • FNN - Fast Nearest Neighbor library (for euclidean only)
  • parallel - Fork cluster approach

5.2 New Dependencies

  • philentropy - Unified distance computation backend (likely via dist_many_many())

5.3 Implications

  • Simpler dependency tree: One distance engine vs scattered FNN calls
  • Behavioral consistency: All functions use same underlying distance algorithms

6. Further Refactors in PR

6.1 aoa-helpers.R (NEW FILE - 243 lines)

Purpose: Consolidate repetitive validation and preprocessing logic

Components:

Function Lines Purpose
.validate_LPD() 25 Input validation for maxLPD parameter
.get_categorical_variables() 3 Identify factor/character columns
.drop_unknown_levels() 10 Handle unknown factor levels in query
.create_dummy_variables() 45 Caret dummyVars wrapper with weight expansion
.convert_factors_to_dummy() 10 Loop over categorical variables
.prepare_categorical_variables() 12 Unified categorical preprocessing
.prepare_weights() 28 Three-tier weight source resolution
.check_weights() 25 Validation and normalization
.apply_weights() 10 Multiply columns by weights via sweep
.prepare_folds() 32 Extract/validate CV folds
.prepare_variables() 12 Expand "all" and validate variables
.di_threshold() 7 Robust threshold via Q3 + 1.5*IQR

Extraction Rationale: These 10 helper functions were previously embedded in aoa.R (lines 25-550) and trainDI.R. Extraction improves:

  • Reusability: Used by both aoa() and trainDI()
  • Testability: Dedicated test file (test-aoa-helpers.R with 170 lines)
  • Clarity: Single responsibility per function

6.2 caret-helpers.R (NEW FILE - 48 lines)

Purpose: Centralize caret model introspection

Function Purpose
.caret_get_data() Extract trainingData as data.frame
.caret_get_variables() Get predictor names (exclude .outcome)
.caret_get_folds() Extract CV indices from train$control
.caret_get_weights() Variable importance → weights data.frame

Why new file?: Previously scattered across trainDI.R as aoa_get_*() functions. Unified naming and extracted to separate module signals these are infrastructure functions, not domain-specific.

6.3 trainDI.R - Major Refactor (600 → 400 lines)

Structural Changes:

  1. Simplified main function: Removed ~150 lines of nested loops and conditional logic
  2. Extracted chunked operations:
    • Old: Manual loop with progress bar + fold masking + distance calculation per sample
    • New: .chunked_apply() with .calc_dist() and .calc_lpd() callbacks
  3. Replaced helpers: Old aoa_* functions moved to aoa-helpers.R and caret-helpers.R

Before (trainDI main loop):

for(i in seq(nrow(train))) {
  trainDist <- .knndistfun(train[i, ], train, k=1, ...)
  trainDist[i] <- NA
  trainDist_avrg <- append(trainDist_avrg, mean(trainDist, na.rm = TRUE))
  
  # fold masking logic (15 lines)
  whichfold <- which(lapply(...) == TRUE)
  if(length(whichfold) > 1) stop(...)
  if(length(whichfold) != 0) trainDist[...] <- NA
  
  trainDist_min <- append(trainDist_min, min(..., na.rm = TRUE))
  # progress bar
}

After (trainDI refactored):

train_dists <- .chunked_dist(train, folds, dist_fun, 
                             chunk_size=1000L, verbose)
trainDist_avrgmean <- mean(train_dists$trainDist_avrg, na.rm=TRUE)
trainLPD <- .chunked_lpd(train, folds, dist_fun, 
                         train_mean, threshold, chunk_size)

Benefits:

  • 50% code reduction in main function
  • Chunk size configurable (memory/speed tradeoff)
  • Easier to parallelize chunks independently
  • Clear separation: masking logic in .mask_dist_mat(), calculation in .calc_dist()/.calc_lpd()

7. Impact on Package Architecture

7.1 Before: Fragmented Distance Logic

aoa.R          → .knndistfun() [FNN or Mahalanobis branch]
trainDI.R      → .knndistfun() [duplicated]
knndm.R        → FNN::knn.dist / Gower distances
nndm.R         → FNN::knnx.dist / Gower distances  
geodist.R      → FNN functions / gower::gower_* / sf::st_distance
                 + duplicated categorical handling
                 + duplicated scaling logic

7.2 After: Centralized KNN Engine

knn-engine.R   → .distance() [unified preprocessing + philentropy]
               → .knndist()  [k-NN selection]
aoa.R          → .knndist() + helpers
trainDI.R      → .knndist() + chunked processing
knndm.R        → .knndist() / .distance()
nndm.R         → .knndist() / .distance()
geodist.R      → .knndist() / .distance()

aoa-helpers.R  → validation, categorical, weights, folds
caret-helpers.R → model introspection

8. Backward Compatibility & Deprecations

8.1 Deprecated Parameters

All functions show warnings (not errors) for old parameters:

# Old API
aoa(data, method="L2", algorithm="brute", parallel=TRUE, cores=4)

# New API
aoa(data, dist_fun="euclidean", chunk_size=1000L)

# Both work, but old parameters show:
# Warning: The 'method' and 'algorithm' parameters are deprecated. 
#          Please use 'dist_fun' instead.
# Warning: The 'parallel' and 'cores' parameters are deprecated. 
#          Parallelization is currently not implemented.

8.2 Signature Mapping

Old New Note
method="L2" dist_fun="euclidean" Direct rename
method="MD" dist_fun="mahalanobis" Direct rename
parallel=TRUE/FALSE Removed Will reintroduce via chunk-level parallelization
cores=N Removed Pair with parallel deprecation
algorithm="brute" Removed Incorporated into knn-engine internals

8.3 Test Expectations Updates

Tests that verify numerical results changed for:

  • Gower distances: Reference-only scaling vs global min-max
  • Mahalanobis distances: Centralized calculation differences (small, not fully understood)
  • Parallel tests removed: Marked as deprecated; skip for now

9. Future Development & Maintenance Implications

9.1 Parallelization Strategy (Deferred)

Current Design:

  • Chunking logic in .chunked_apply() ready for parallelization
  • Processing chunks independently (no shared state)
  • Chunk results trivially combinable (list concatenation)

Recommended Approach:

# Pseudo-code for future implementation
.chunked_apply(...) {
  future::future_lapply(chunks, function(chunk) {
    # apply calc_fun to chunk
  })
}

Benefits of chunking vs old fork approach:

  • Platform-agnostic (not Linux-only)
  • Works with future framework (supports multiple backends)
  • Easy to control parallelization level (per-chunk)

9.2 Numerical Precision

Known Issues:

  • Mahalanobis: "Small numerical differences (exact reasons not explored)"
  • Gower: "Test expectations changed" for global vs reference-relative scaling

9.3 Testing Infrastructure

New Test Files:

  1. test-knn-engine.R (117 lines): Distance metric correctness
  2. test-aoa-helpers.R (170 lines): Validation and preprocessing
  3. test-caret-helpers.R (32 lines): Model introspection
  4. data-fixture.R (29 lines): Shared test setup
  5. Updated: test-aoa.R, test-geodist.R, test-knndm.R, test-nndm.R

Test Coverage Strategy:

  • Input type handling (matrix, data.frame, vector)
  • Missing value masking
  • Categorical level mismatch
  • Chunk size edge cases

Conclusion

This PR represents a major architectural improvement to CAST. By centralizing
k-nearest neighbor distance computation, eliminating 200+ lines of duplicated logic,
and establishing a clean separation between distance metrics and spatial/feature applications,
the package becomes:

  • More maintainable: Single source of truth for each distance metric
  • More testable: Dedicated unit tests for distance engine
  • More extensible: New metrics addable to knn-engine only
  • More portable: Platform-independent (no fork clusters)
  • Future-proof: Chunk-based parallelization ready to implement

@goergen95 goergen95 marked this pull request as ready for review April 15, 2026 10:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants