diff --git a/selector/methods/distance.py b/selector/methods/distance.py index 4c4c4e75..f3f2c9af 100644 --- a/selector/methods/distance.py +++ b/selector/methods/distance.py @@ -365,6 +365,16 @@ def algorithm(self, x, max_size) -> Union[List, Iterable]: # bv will serve as a mask to discard points within radius r of previously selected points bv = np.zeros(n_samples) candidates = list(range(n_samples)) + + # Initialize min_dists array to store minimum distances from points and selected points + # Initially consider all minimum distances as Infinity + min_dists = np.full(n_samples, np.inf) + + # Calculate distances of points from all initially selected points + for idx in selected: + dists = np.linalg.norm(x - x[idx], axis=1) + min_dists = np.minimum(min_dists, dists) + # determine which points are within radius r of initial point # note: workers=-1 uses all available processors/CPUs index_remove = tree.query_ball_point( @@ -384,13 +394,8 @@ def algorithm(self, x, max_size) -> Union[List, Iterable]: except ValueError: sublist = candidates.compressed() - # create a new kd-tree for nearest neighbor lookup with candidates - new_tree = spatial.KDTree(x[selected]) - # query the kd-tree for nearest neighbors to selected samples - # note: workers=-1 uses all available processors/CPUs - search, _ = new_tree.query(x[sublist], eps=self.eps, p=self.p, workers=-1) - # identify the nearest neighbor with the largest distance from previously selected samples - best_idx = sublist[np.argmax(search)] + # Select and Append the candidate farthest from its nearest selected point + best_idx = sublist[np.argmax(min_dists[sublist])] selected.append(best_idx) count += 1 @@ -398,6 +403,10 @@ def algorithm(self, x, max_size) -> Union[List, Iterable]: # do this if you have reached the maximum number of points selected return selected + # Update min_dists array: calculate distances from newly selected point + new_dists = np.linalg.norm(x - x[best_idx], axis=1) + min_dists = np.minimum(min_dists, new_dists) + # eliminate all samples within radius r of the selected sample index_remove = tree.query_ball_point( x[best_idx], self.r, eps=self.eps, p=self.p, workers=-1