Source code for pinecone.models.vectors.query_aggregator
"""Query results aggregator for multi-namespace queries."""
from __future__ import annotations
import heapq
from typing import Any
from msgspec import Struct, field
from pinecone.models._mixin import StructDictMixin
from pinecone.models.vectors.responses import QueryResponse
from pinecone.models.vectors.usage import Usage
from pinecone.models.vectors.vector import ScoredVector
class QueryResultsAggregatorInvalidTopKError(ValueError):
def __init__(self, top_k: int) -> None:
super().__init__(f"Invalid top_k value {top_k}. top_k must be at least 1.")
[docs]
class QueryNamespacesResults(StructDictMixin, Struct, kw_only=True):
"""Aggregated results from querying multiple namespaces.
Attributes:
matches (list[ScoredVector]): Combined top-k results across all namespaces, sorted by
relevance according to the metric used.
usage (Usage): Total aggregated read unit usage across all namespaces.
ns_usage (dict[str, Usage]): Per-namespace read unit usage keyed by namespace name.
"""
matches: list[ScoredVector] = field(default_factory=list)
usage: Usage = field(default_factory=Usage)
ns_usage: dict[str, Usage] = field(default_factory=dict)
def __getitem__(self, key: str) -> Any:
"""Support bracket access (e.g. result['matches'])."""
if key not in self.__struct_fields__:
raise KeyError(key)
return getattr(self, key)
def __contains__(self, key: object) -> bool:
"""Support ``in`` operator (e.g. ``'matches' in result``)."""
return key in self.__struct_fields__
_VALID_METRICS = frozenset({"cosine", "euclidean", "dotproduct"})
[docs]
class QueryResultsAggregator:
"""Merges per-namespace QueryResponse objects into a single combined result.
Uses a heap-based algorithm to efficiently merge scored vectors from
multiple namespaces. For cosine/dotproduct metrics, higher scores rank
first. For euclidean, lower scores rank first. Ties are broken by
insertion order.
Args:
metric: Distance metric — one of ``"cosine"``, ``"euclidean"``,
or ``"dotproduct"``.
top_k: Maximum number of results to return. Defaults to 10.
Raises:
ValueError: If *metric* is not a recognized value or *top_k* < 1.
"""
__slots__ = (
"_counter",
"_finalized",
"_heap",
"_is_bigger_better",
"_metric",
"_ns_usage",
"_read_units",
"_top_k",
)
[docs]
def __init__(self, *, metric: str, top_k: int = 10) -> None:
if metric not in _VALID_METRICS:
raise ValueError(
f"Invalid metric {metric!r}. Must be one of: {', '.join(sorted(_VALID_METRICS))}"
)
if top_k < 1:
raise QueryResultsAggregatorInvalidTopKError(top_k)
self._metric = metric
self._top_k = top_k
self._heap: list[tuple[float, int, ScoredVector]] = []
self._counter: int = 0
self._finalized: bool = False
self._read_units: int = 0
self._ns_usage: dict[str, Usage] = {}
self._is_bigger_better: bool = metric in ("cosine", "dotproduct")
[docs]
def add_results(self, namespace: str, response: QueryResponse) -> None:
"""Add results from a single namespace query.
Args:
namespace: Namespace that was queried.
response: Query response from that namespace.
Raises:
ValueError: If called after :meth:`get_results`.
"""
if self._finalized:
raise ValueError("Cannot add results after get_results()")
if response.usage is not None:
self._read_units += response.usage.read_units or 0
self._ns_usage[namespace] = response.usage
for match in response.matches:
if self._is_bigger_better:
key = -match.score
else:
key = match.score
heapq.heappush(self._heap, (key, self._counter, match))
self._counter += 1
if len(self._heap) > self._top_k:
self._heap = heapq.nsmallest(self._top_k, self._heap)
heapq.heapify(self._heap)
[docs]
def get_results(self) -> QueryNamespacesResults:
"""Finalize and return the aggregated results.
After calling this method, no more results can be added.
Returns:
Aggregated query results with the top-k matches across all
namespaces.
"""
self._finalized = True
sorted_entries = sorted(self._heap)
matches = [entry[2] for entry in sorted_entries[: self._top_k]]
return QueryNamespacesResults(
matches=matches,
usage=Usage(read_units=self._read_units),
ns_usage=self._ns_usage,
)