Performances and scikit-learn: Pairwise Distances Reductions

Extra notes on technical details, benchmarks and further work.

Published on the: 21.12.2021
Last modified on the: 16.01.2022
Estimated reading time: ~ 10 min.

Following up with this initial post on the design of Pairwise Distances Reductions, we give here more details regarding the design and experiments results for performance assessment.

Performances changes

KNeighborsMixing.kneighbors which is a good proxy for accessing performances because in is backed by the proposed PairwiseDistancesArgKmin and FastEuclideanPairwiseDistancesArgKmin.

In what follows, experiments are made via this interface two various aspects: hardware scalability and computational efficiency.

Hardware scalability

This is the hardware scalability of kneighbors in scikit-learn 1.0:

Scalability of argkmin reductions on main

This is the hardware the scalability of kneighbors as proposed in sklearn#21462:

Scalability of argkmin reductions using the proposed PairwiseDistancesReductionArgKmin

The plateau after 64 cores can be explained by Amdahl’s law1: as the number of threads grows, the parallel portion of the algorithm becomes negligeable compared to its sequential portion, reaching a limit in term of computational time — the execution period of the sequential part — hence causing speed-up ratio to stop increasing. Moreover, the small drop in speed-up for 128 threads can be explained by the overhead of setting up threads which becomes non-negligeable compared to the actual computations made in each thread.

Computational efficiency of FastEuclideanPairwiseDistancesArgKmin

On distributions of GNU/Linux, perf(1) comes in handy to introspect a program execution in details2.

Here, we inspect where CPUs cycles are spent, as well as L3 caches misses and L3 caches hits using the following script on a machine having 20 physical cores3:


import numpy as np
import os
from sklearn.neighbors import NearestNeighbors

if __name__ == "__main__":

    n_train = 100_000
    n_test = 100_000
    n_features = 30

    rng = np.random.RandomState(0)

    # We persist datasets on disk so as to solely have
    # `perf(1)` introspect the events for the core
    # of the computations: `kneighbors`.

    X_train_file = "X_train.npy"
    X_test_file = "X_test.npy"

    if os.path.exists(X_train_file):
        X_train = np.load(X_train_file)
        X_train = rng.rand(n_train, n_features), X_train)

    if os.path.exists(X_test_file):
        X_test = np.load(X_test_file)
        X_test = rng.rand(n_test, n_features), X_test)

    est = NearestNeighbors(n_neighbors=10, algorithm="brute").fit(X=X_train)

    # FastEuclideanPairwiseDistancesArgKmin will be used under the hood.

And the following call to perf-record(1)4:

perf record -e \
    cycles,\                         # Record CPU cycles
    mem_load_uops_retired.llc_miss,\ # Record L3 caches' misses
    mem_load_uops_retired.llc_hit \  # Record L3 caches' hits

this dumps a binary file which can be explored using perf-report(1):

perf report --hierarchical \ # to be able to see overhead hierarchicaly
            --inline         # to annotate with callgraph addresses

On CPUs cycles

This is the report for the cycles events.

Samples: 543K of event 'cycles:u', Event count (approx.): 335205056539

-  100.00%        python                                                       
   -   68.07%                                   
          57.45%        [.] dgemm_kernel_SANDYBRIDGE                           
           4.51%        [.] dgemm_beta_SANDYBRIDGE                             
           3.33%        [.] dgemm_incopy_SANDYBRIDGE                           
           2.59%        [.] dgemm_oncopy_SANDYBRIDGE                           
           0.09%        [.] dgemm_tn                                           
           0.04%        [.] blas_thread_server                                 
           0.01%        [.] dgemm_                                             
           0.01%        [.] ddot_kernel_8                                      
           0.01%        [.] blas_memory_free                                   
           0.01%        [.] blas_memory_alloc                                  
           0.00%        [.] dgemm_small_matrix_permit_SANDYBRIDGE              
           0.00%        [.] dot_compute                                        
           0.00%        [.] ddot_k_SANDYBRIDGE                                 
           0.00%        [.] ddot_                                              
   -   22.17%        _pairwise_distances_reduction.cpython-39-x86_64-linux-gnu.
          22.16%        [.] __pyx_f_7sklearn_7metrics_29_pairwise_distances_red
           0.00%        [.] __pyx_memoryview_slice_memviewslice                
           0.00%        [.] __pyx_f_7sklearn_7metrics_29_pairwise_distances_red
           0.00%        [.] __pyx_f_7sklearn_7metrics_29_pairwise_distances_red
           0.00%        [.] __pyx_f_7sklearn_7metrics_29_pairwise_distances_red
           0.00%        [.] __pyx_f_7sklearn_7metrics_29_pairwise_distances_red
   -    9.25%                      
           9.24%        [.] __pyx_fuse_1__pyx_f_7sklearn_5utils_5_heap_heap_pus
           0.01%        [.] __pyx_fuse_1__pyx_f_7sklearn_5utils_5_heap_simultan
   +    0.20%        python3.9                                                 
   -    0.15%                                          
           0.15%        [.] do_wait                                            
           0.00%        [.] gomp_barrier_wait_end                              
           0.00%        [.] gomp_thread_start                                  
           0.00%        [.] gomp_team_barrier_wait_end                         
           0.00%        [.] futex_wake                                         

Most of the CPUs cycles are spent in GEMM. The rest of them are mainly used to iterate on the distances matrix chunks, pushing values and indices on the max-heaps.

Note that the calls of the parallelelisation using OpenMP via Cython and of the CPython interpreter comes with negligeable overhead.

Assuming most readers are curious and like getting into details, we can actually look at the kind of CPU instructions which are being used in dgemm_kernel_SANDYBRIDGE5, the critical region.

Samples: 543K of event 'cycles:u', 4000 Hz, Event count (approx.): 335205056539
  0.94         vmulpd       %ymm1,%ymm3,%ymm7
  0.50         vpermilpd    $0x5,%ymm2,%ymm3
  0.52         vaddpd       %ymm14,%ymm6,%ymm14
  1.11         vaddpd       %ymm12,%ymm7,%ymm12
  1.55         vmulpd       %ymm0,%ymm4,%ymm6
  0.25         vmulpd       %ymm0,%ymm5,%ymm7
  0.51         vmovapd      0xc0(%rdi),%ymm0
  1.81         vaddpd       %ymm11,%ymm6,%ymm11
  1.65         vaddpd       %ymm9,%ymm7,%ymm9
  0.71         vmulpd       %ymm1,%ymm4,%ymm6
  0.33         vmulpd       %ymm1,%ymm5,%ymm7
  0.77         vaddpd       %ymm10,%ymm6,%ymm10
  2.08         vaddpd       %ymm8,%ymm7,%ymm8
  0.86         vmovapd      0xe0(%rdi),%ymm1
  0.85         vmulpd       %ymm0,%ymm2,%ymm6
  0.85         vperm2f128   $0x3,%ymm2,%ymm2,%ymm4
  0.97         vmulpd       %ymm0,%ymm3,%ymm7
  0.85         vperm2f128   $0x3,%ymm3,%ymm3,%ymm5
  0.22         add          $0x100,%rdi
  0.38         vaddpd       %ymm15,%ymm6,%ymm15
  1.62         vaddpd       %ymm13,%ymm7,%ymm13
  1.12         prefetcht0   0x2c0(%rdi)
  0.23         vmulpd       %ymm1,%ymm2,%ymm6
  0.80         vmovapd      (%rsi),%ymm2

Most of the instructions there are SIMD instructions.

If the reader is interested in knowing how those instructions they are used, they can have a look at OpenBLAS/kernel/x84_64/dgemm_kernel_4x8_sandy.S which comes which a setup of compilers’ macros to define the computations at a high-level in assembly.

On L3 cache hits and L3 cache misses

One can inspect the report of the mem_load_uops_retired.llc_miss events for cache misses:

Samples: 88  of event 'mem_load_uops_retired.llc_miss:u', Event count (approx.):
543K cycles:u                                                                  
-  100.00%        python                                                       
   -   82.95%                                   
          81.82%        [.] dgemm_incopy_SANDYBRIDGE                           
           1.14%        [.] dgemm_kernel_SANDYBRIDGE                           
   +    7.95%        [unknown]                                                 
   -    6.82%        _pairwise_distances_reduction.cpython-39-x86_64-linux-gnu.
           6.82%        [.] __pyx_f_7sklearn_7metrics_29_pairwise_distances_red
   +    2.27%        python3.9                                                 

One can inspect the report of the mem_load_uops_retired.llc_hit events for cache hits:

Samples: 984  of event 'mem_load_uops_retired.llc_hit:u', Event count (approx.):
543K cycles:u                                                                  
-  100.00%        python                                                       
   -   66.26%                                   
          31.00%        [.] dgemm_kernel_SANDYBRIDGE                           
          19.21%        [.] dgemm_incopy_SANDYBRIDGE                           
          10.16%        [.] dgemm_oncopy_SANDYBRIDGE                           
           3.66%        [.] dgemm_tn                                           
           1.12%        [.] blas_memory_alloc                                  
           0.51%        [.] dgemm_                                             
           0.41%        [.] blas_memory_free                                   
           0.20%        [.] dgemm_beta_SANDYBRIDGE                             
   +   16.36%        [unknown]                                                 
   +    8.84%        python3.9                                                 
   -    5.08%        _pairwise_distances_reduction.cpython-39-x86_64-linux-gnu.
           4.98%        [.] __pyx_f_7sklearn_7metrics_29_pairwise_distances_red
           0.10%        [.] __pyx_memoryview_slice_memviewslice                
   +    1.83%                      
   +    0.71%                                        
   +    0.30%                                                
   +    0.30%                                              
   +    0.20%               

The L3 cache hits and misses happens exactly where we ought them to — that is in the critical region computing chunks of the distance matrix with GEMM.

In the critical region, one instruction out of ten6 is missing the L3 cache, showing that the datastructures used to compute the chunks of the distance matrix generally stay the L3 caches as intended7.


In what we just have covered:

  • The computations scale linearly with respect to the number of threads used, reaching theoretical limits.
  • Moves of data between the RAM and the L3 caches are minimized.
  • SIMD instructions are effectively used in critical sections.

Hence, this shows that the parallel execution of the algorithm is efficient8.

Further work: some food for thoughs

Disclaimer: Parts of this subisequent work were initially present in sklearn#20254 but were removed to make this PR a bit smaller.

Further work would treat the last requirements:

  • Support for 32 bits datasetspairs
  • Support for the last fused \(\{\text{sparse}, \text{dense}\}^2\) datasets pairs, i.e.:
    • sparse \(\mathbf{X}\) and dense \(\mathbf{Y}\)
    • dense \(\mathbf{X}\) and sparse \(\mathbf{Y}\)
    • sparse \(\mathbf{X}\) and sparse \(\mathbf{Y}\)
  • Implement adapted operations for each reduction (radius neighborhood, threshold filtering, cumulative sum, etc.)

The first point can be addressed using Tempita so as to expand the previous interfaces support for 64bit to 32bit9.

The second point can be addressed by implementing two new DatasetsPairs (as distance metrics are commutative, we can simply implement the dense-sparse case or the sparse-dense case).

As of the last and third point, many things can be imagined. The reductions for the radius neighborhood queries can easily be implemented using resizable buffers as provided via std::vectors, with some adaptation to return them as numpy arrays safely: this has been tested and it works perfectly. Though it suffers from concurrent reallocation in threads (namely when vectors’ buffers are being reallocated). This concurrent reallocation causes some drops in performance as calls to malloc(3) (used under the hood for the buffers’ reallocations) lock by default in the compilers’ standard libraries’ implementations. A possible solution to alleviate this problem would be to use another implementation of malloc(3) such as mimalloc‘s.

Furthermore, some other and similar patterns using Gram matrices of positive definite kernels10 instead of distances matrices exist for Gaussian Processes and Support Vector Machines and could be optimised.


  1. Gene M. Amdahl. 1967. Validity of the single processor approach to achieving large scale computing capabilities. In Proceedings of the April 18-20, 1967, spring joint computer conference (AFIPS ‘67 (Spring)). Association for Computing Machinery, New York, NY, USA, 483–485. DOI:
  2. If you are using another OS, perf(1) won’t be usable. Still, you should be able to perform similar inspections using dtrace.
  3. The CPUs used are: Intel(R) Xeon(R) CPU E5-2660 v2 @ 2.20GHz
  4. You might need to adapt the events because they change from one architecture to another. See perf-list(1).
  5. Unmangling dgemm_kernel_SANDYBRIDGE: this is the core (kernel) of the float64/double (d) implementation of GEMM for the Sandy Bridge architecture.
  6. This is a rough estimation based on the number of sampled events, namely 984 for L3 cache hits and 88 for L3 caches misses.
  7. For maximum performance, one can thune \(\text{chunk_size}\) for the L3 cache size of the machine they use.
  8. If this can be made more efficient, feel free to propose on the dedicated PR!
  9. Cython does not support templating but Tempita allows alleviating most cases which need it.
  10. Hofmann, Thomas and Schölkopf, Bernhard and Smola, Alexander J., Kernel methods in machine learning. DOI: