diff --git a/chromadb/api/types.py b/chromadb/api/types.py index ef5aab75926..e0845c6a88a 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -1351,7 +1351,7 @@ def validate_include(include: Include, dissalowed: Optional[Include] = None) -> def validate_n_results(n_results: int) -> int: """Validates n_results to ensure it is a positive Integer. Since hnswlib does not allow n_results to be negative.""" # Check Number of requested results - if not isinstance(n_results, int): + if not isinstance(n_results, int) or isinstance(n_results, bool): raise ValueError( f"Expected requested number of results to be a int, got {n_results}" ) diff --git a/chromadb/test/api/test_types.py b/chromadb/test/api/test_types.py index 56f3a75d683..aa387d8496f 100644 --- a/chromadb/test/api/test_types.py +++ b/chromadb/test/api/test_types.py @@ -1,6 +1,12 @@ import pytest from typing import List, cast, Dict, Any -from chromadb.api.types import Documents, Image, Document, Embeddings +from chromadb.api.types import ( + Documents, + Image, + Document, + Embeddings, + validate_n_results, +) from chromadb.utils.embedding_functions import ( EmbeddingFunction, register_embedding_function, @@ -103,3 +109,9 @@ def __call__(self, input: Documents) -> Embeddings: from chromadb.api.types import normalize_embeddings normalize_embeddings(result) + + +@pytest.mark.parametrize("n_results", [True, False]) +def test_validate_n_results_rejects_bool(n_results: bool) -> None: + with pytest.raises(ValueError, match="Expected requested number of results"): + validate_n_results(n_results)