Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1351,7 +1351,7 @@ def validate_include(include: Include, dissalowed: Optional[Include] = None) ->
def validate_n_results(n_results: int) -> int:
"""Validates n_results to ensure it is a positive Integer. Since hnswlib does not allow n_results to be negative."""
# Check Number of requested results
if not isinstance(n_results, int):
if not isinstance(n_results, int) or isinstance(n_results, bool):
raise ValueError(
f"Expected requested number of results to be a int, got {n_results}"
)
Expand Down
14 changes: 13 additions & 1 deletion chromadb/test/api/test_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import pytest
from typing import List, cast, Dict, Any
from chromadb.api.types import Documents, Image, Document, Embeddings
from chromadb.api.types import (
Documents,
Image,
Document,
Embeddings,
validate_n_results,
)
from chromadb.utils.embedding_functions import (
EmbeddingFunction,
register_embedding_function,
Expand Down Expand Up @@ -103,3 +109,9 @@ def __call__(self, input: Documents) -> Embeddings:
from chromadb.api.types import normalize_embeddings

normalize_embeddings(result)


@pytest.mark.parametrize("n_results", [True, False])
def test_validate_n_results_rejects_bool(n_results: bool) -> None:
with pytest.raises(ValueError, match="Expected requested number of results"):
validate_n_results(n_results)
Loading