pinecone.data.query_results_aggregator

  1from typing import List, Tuple, Optional, Any, Dict, Literal
  2import json
  3import heapq
  4from pinecone.core.openapi.data.models import Usage
  5from pinecone.core.openapi.data.models import QueryResponse as OpenAPIQueryResponse
  6
  7from dataclasses import dataclass, asdict
  8
  9
 10@dataclass
 11class ScoredVectorWithNamespace:
 12    namespace: str
 13    score: float
 14    id: str
 15    values: List[float]
 16    sparse_values: dict
 17    metadata: dict
 18
 19    def __init__(self, aggregate_results_heap_tuple: Tuple[float, int, object, str]):
 20        json_vector = aggregate_results_heap_tuple[2]
 21        self.namespace = aggregate_results_heap_tuple[3]
 22        self.id = json_vector.get("id")  # type: ignore
 23        self.score = json_vector.get("score")  # type: ignore
 24        self.values = json_vector.get("values")  # type: ignore
 25        self.sparse_values = json_vector.get("sparse_values", None)  # type: ignore
 26        self.metadata = json_vector.get("metadata", None)  # type: ignore
 27
 28    def __getitem__(self, key):
 29        if hasattr(self, key):
 30            return getattr(self, key)
 31        else:
 32            raise KeyError(f"'{key}' not found in ScoredVectorWithNamespace")
 33
 34    def get(self, key, default=None):
 35        return getattr(self, key, default)
 36
 37    def __repr__(self):
 38        return json.dumps(self._truncate(asdict(self)), indent=4)
 39
 40    def __json__(self):
 41        return self._truncate(asdict(self))
 42
 43    def _truncate(self, obj, max_items=2):
 44        """
 45        Recursively traverse and truncate lists that exceed max_items length.
 46        Only display the "... X more" message if at least 2 elements are hidden.
 47        """
 48        if obj is None:
 49            return None  # Skip None values
 50        elif isinstance(obj, list):
 51            filtered_list = [self._truncate(i, max_items) for i in obj if i is not None]
 52            if len(filtered_list) > max_items:
 53                # Show the truncation message only if more than 1 item is hidden
 54                remaining_items = len(filtered_list) - max_items
 55                if remaining_items > 1:
 56                    return filtered_list[:max_items] + [f"... {remaining_items} more"]
 57                else:
 58                    # If only 1 item remains, show it
 59                    return filtered_list
 60            return filtered_list
 61        elif isinstance(obj, dict):
 62            # Recursively process dictionaries, omitting None values
 63            return {k: self._truncate(v, max_items) for k, v in obj.items() if v is not None}
 64        return obj
 65
 66
 67@dataclass
 68class QueryNamespacesResults:
 69    usage: Usage
 70    matches: List[ScoredVectorWithNamespace]
 71
 72    def __getitem__(self, key):
 73        if hasattr(self, key):
 74            return getattr(self, key)
 75        else:
 76            raise KeyError(f"'{key}' not found in QueryNamespacesResults")
 77
 78    def get(self, key, default=None):
 79        return getattr(self, key, default)
 80
 81    def __repr__(self):
 82        return json.dumps(
 83            {
 84                "usage": self.usage.to_dict(),
 85                "matches": [match.__json__() for match in self.matches],
 86            },
 87            indent=4,
 88        )
 89
 90
 91class QueryResultsAggregatorInvalidTopKError(Exception):
 92    def __init__(self, top_k: int):
 93        super().__init__(f"Invalid top_k value {top_k}. top_k must be at least 1.")
 94
 95
 96class QueryResultsAggregator:
 97    def __init__(self, top_k: int, metric: Literal["cosine", "euclidean", "dotproduct"]):
 98        if top_k < 1:
 99            raise QueryResultsAggregatorInvalidTopKError(top_k)
100
101        if metric in ["dotproduct", "cosine"]:
102            self.is_bigger_better = True
103        elif metric in ["euclidean"]:
104            self.is_bigger_better = False
105        else:
106            raise ValueError(
107                f"Cannot merge results for unknown similarity metric {metric}. Supported metrics are 'dotproduct', 'cosine', and 'euclidean'."
108            )
109
110        self.top_k = top_k
111        self.usage_read_units = 0
112        self.heap: List[Tuple[float, int, object, str]] = []
113        self.insertion_counter = 0
114        self.read = False
115        self.final_results: Optional[QueryNamespacesResults] = None
116
117    def _bigger_better_heap_item(self, match, ns):
118        # This 4-tuple is used to ensure that the heap is sorted by score followed by
119        # insertion order. The insertion order is used to break any ties in the score.
120        return (match.get("score"), -self.insertion_counter, match, ns)
121
122    def _smaller_better_heap_item(self, match, ns):
123        return (-match.get("score"), -self.insertion_counter, match, ns)
124
125    def _process_matches(self, matches, ns, heap_item_fn):
126        for match in matches:
127            self.insertion_counter += 1
128            if len(self.heap) < self.top_k:
129                heapq.heappush(self.heap, heap_item_fn(match, ns))
130            else:
131                # Assume we have dotproduct scores sorted in descending order
132                if self.is_bigger_better and match["score"] < self.heap[0][0]:
133                    # No further matches can improve the top-K heap
134                    break
135                elif not self.is_bigger_better and match["score"] > -self.heap[0][0]:
136                    # No further matches can improve the top-K heap
137                    break
138                heapq.heappushpop(self.heap, heap_item_fn(match, ns))
139
140    def add_results(self, results: Dict[str, Any]):
141        if self.read:
142            # This is mainly just to sanity check in test cases which get quite confusing
143            # if you read results twice due to the heap being emptied when constructing
144            # the ordered results.
145            raise ValueError("Results have already been read. Cannot add more results.")
146
147        matches = results.get("matches", [])
148        ns: str = results.get("namespace", "")
149        if isinstance(results, OpenAPIQueryResponse):
150            self.usage_read_units += results.usage.read_units
151        else:
152            self.usage_read_units += results.get("usage", {}).get("readUnits", 0)
153
154        if len(matches) == 0:
155            return
156
157        if self.is_bigger_better:
158            self._process_matches(matches, ns, self._bigger_better_heap_item)
159        else:
160            self._process_matches(matches, ns, self._smaller_better_heap_item)
161
162    def get_results(self) -> QueryNamespacesResults:
163        if self.read:
164            if self.final_results is not None:
165                return self.final_results
166            else:
167                # I don't think this branch can ever actually be reached, but the type checker disagrees
168                raise ValueError("Results have already been read. Cannot get results again.")
169        self.read = True
170
171        self.final_results = QueryNamespacesResults(
172            usage=Usage(read_units=self.usage_read_units),
173            matches=[
174                ScoredVectorWithNamespace(heapq.heappop(self.heap)) for _ in range(len(self.heap))
175            ][::-1],
176        )
177        return self.final_results
@dataclass
class ScoredVectorWithNamespace:
11@dataclass
12class ScoredVectorWithNamespace:
13    namespace: str
14    score: float
15    id: str
16    values: List[float]
17    sparse_values: dict
18    metadata: dict
19
20    def __init__(self, aggregate_results_heap_tuple: Tuple[float, int, object, str]):
21        json_vector = aggregate_results_heap_tuple[2]
22        self.namespace = aggregate_results_heap_tuple[3]
23        self.id = json_vector.get("id")  # type: ignore
24        self.score = json_vector.get("score")  # type: ignore
25        self.values = json_vector.get("values")  # type: ignore
26        self.sparse_values = json_vector.get("sparse_values", None)  # type: ignore
27        self.metadata = json_vector.get("metadata", None)  # type: ignore
28
29    def __getitem__(self, key):
30        if hasattr(self, key):
31            return getattr(self, key)
32        else:
33            raise KeyError(f"'{key}' not found in ScoredVectorWithNamespace")
34
35    def get(self, key, default=None):
36        return getattr(self, key, default)
37
38    def __repr__(self):
39        return json.dumps(self._truncate(asdict(self)), indent=4)
40
41    def __json__(self):
42        return self._truncate(asdict(self))
43
44    def _truncate(self, obj, max_items=2):
45        """
46        Recursively traverse and truncate lists that exceed max_items length.
47        Only display the "... X more" message if at least 2 elements are hidden.
48        """
49        if obj is None:
50            return None  # Skip None values
51        elif isinstance(obj, list):
52            filtered_list = [self._truncate(i, max_items) for i in obj if i is not None]
53            if len(filtered_list) > max_items:
54                # Show the truncation message only if more than 1 item is hidden
55                remaining_items = len(filtered_list) - max_items
56                if remaining_items > 1:
57                    return filtered_list[:max_items] + [f"... {remaining_items} more"]
58                else:
59                    # If only 1 item remains, show it
60                    return filtered_list
61            return filtered_list
62        elif isinstance(obj, dict):
63            # Recursively process dictionaries, omitting None values
64            return {k: self._truncate(v, max_items) for k, v in obj.items() if v is not None}
65        return obj
ScoredVectorWithNamespace(aggregate_results_heap_tuple: Tuple[float, int, object, str])
20    def __init__(self, aggregate_results_heap_tuple: Tuple[float, int, object, str]):
21        json_vector = aggregate_results_heap_tuple[2]
22        self.namespace = aggregate_results_heap_tuple[3]
23        self.id = json_vector.get("id")  # type: ignore
24        self.score = json_vector.get("score")  # type: ignore
25        self.values = json_vector.get("values")  # type: ignore
26        self.sparse_values = json_vector.get("sparse_values", None)  # type: ignore
27        self.metadata = json_vector.get("metadata", None)  # type: ignore
namespace: str
score: float
id: str
values: List[float]
sparse_values: dict
metadata: dict
def get(self, key, default=None):
35    def get(self, key, default=None):
36        return getattr(self, key, default)
@dataclass
class QueryNamespacesResults:
68@dataclass
69class QueryNamespacesResults:
70    usage: Usage
71    matches: List[ScoredVectorWithNamespace]
72
73    def __getitem__(self, key):
74        if hasattr(self, key):
75            return getattr(self, key)
76        else:
77            raise KeyError(f"'{key}' not found in QueryNamespacesResults")
78
79    def get(self, key, default=None):
80        return getattr(self, key, default)
81
82    def __repr__(self):
83        return json.dumps(
84            {
85                "usage": self.usage.to_dict(),
86                "matches": [match.__json__() for match in self.matches],
87            },
88            indent=4,
89        )
QueryNamespacesResults( usage: pinecone.core.openapi.data.model.usage.Usage, matches: List[ScoredVectorWithNamespace])
usage: pinecone.core.openapi.data.model.usage.Usage
matches: List[ScoredVectorWithNamespace]
def get(self, key, default=None):
79    def get(self, key, default=None):
80        return getattr(self, key, default)
class QueryResultsAggregatorInvalidTopKError(builtins.Exception):
92class QueryResultsAggregatorInvalidTopKError(Exception):
93    def __init__(self, top_k: int):
94        super().__init__(f"Invalid top_k value {top_k}. top_k must be at least 1.")

Common base class for all non-exit exceptions.

QueryResultsAggregatorInvalidTopKError(top_k: int)
93    def __init__(self, top_k: int):
94        super().__init__(f"Invalid top_k value {top_k}. top_k must be at least 1.")
Inherited Members
builtins.BaseException
with_traceback
add_note
args
class QueryResultsAggregator:
 97class QueryResultsAggregator:
 98    def __init__(self, top_k: int, metric: Literal["cosine", "euclidean", "dotproduct"]):
 99        if top_k < 1:
100            raise QueryResultsAggregatorInvalidTopKError(top_k)
101
102        if metric in ["dotproduct", "cosine"]:
103            self.is_bigger_better = True
104        elif metric in ["euclidean"]:
105            self.is_bigger_better = False
106        else:
107            raise ValueError(
108                f"Cannot merge results for unknown similarity metric {metric}. Supported metrics are 'dotproduct', 'cosine', and 'euclidean'."
109            )
110
111        self.top_k = top_k
112        self.usage_read_units = 0
113        self.heap: List[Tuple[float, int, object, str]] = []
114        self.insertion_counter = 0
115        self.read = False
116        self.final_results: Optional[QueryNamespacesResults] = None
117
118    def _bigger_better_heap_item(self, match, ns):
119        # This 4-tuple is used to ensure that the heap is sorted by score followed by
120        # insertion order. The insertion order is used to break any ties in the score.
121        return (match.get("score"), -self.insertion_counter, match, ns)
122
123    def _smaller_better_heap_item(self, match, ns):
124        return (-match.get("score"), -self.insertion_counter, match, ns)
125
126    def _process_matches(self, matches, ns, heap_item_fn):
127        for match in matches:
128            self.insertion_counter += 1
129            if len(self.heap) < self.top_k:
130                heapq.heappush(self.heap, heap_item_fn(match, ns))
131            else:
132                # Assume we have dotproduct scores sorted in descending order
133                if self.is_bigger_better and match["score"] < self.heap[0][0]:
134                    # No further matches can improve the top-K heap
135                    break
136                elif not self.is_bigger_better and match["score"] > -self.heap[0][0]:
137                    # No further matches can improve the top-K heap
138                    break
139                heapq.heappushpop(self.heap, heap_item_fn(match, ns))
140
141    def add_results(self, results: Dict[str, Any]):
142        if self.read:
143            # This is mainly just to sanity check in test cases which get quite confusing
144            # if you read results twice due to the heap being emptied when constructing
145            # the ordered results.
146            raise ValueError("Results have already been read. Cannot add more results.")
147
148        matches = results.get("matches", [])
149        ns: str = results.get("namespace", "")
150        if isinstance(results, OpenAPIQueryResponse):
151            self.usage_read_units += results.usage.read_units
152        else:
153            self.usage_read_units += results.get("usage", {}).get("readUnits", 0)
154
155        if len(matches) == 0:
156            return
157
158        if self.is_bigger_better:
159            self._process_matches(matches, ns, self._bigger_better_heap_item)
160        else:
161            self._process_matches(matches, ns, self._smaller_better_heap_item)
162
163    def get_results(self) -> QueryNamespacesResults:
164        if self.read:
165            if self.final_results is not None:
166                return self.final_results
167            else:
168                # I don't think this branch can ever actually be reached, but the type checker disagrees
169                raise ValueError("Results have already been read. Cannot get results again.")
170        self.read = True
171
172        self.final_results = QueryNamespacesResults(
173            usage=Usage(read_units=self.usage_read_units),
174            matches=[
175                ScoredVectorWithNamespace(heapq.heappop(self.heap)) for _ in range(len(self.heap))
176            ][::-1],
177        )
178        return self.final_results
QueryResultsAggregator(top_k: int, metric: Literal['cosine', 'euclidean', 'dotproduct'])
 98    def __init__(self, top_k: int, metric: Literal["cosine", "euclidean", "dotproduct"]):
 99        if top_k < 1:
100            raise QueryResultsAggregatorInvalidTopKError(top_k)
101
102        if metric in ["dotproduct", "cosine"]:
103            self.is_bigger_better = True
104        elif metric in ["euclidean"]:
105            self.is_bigger_better = False
106        else:
107            raise ValueError(
108                f"Cannot merge results for unknown similarity metric {metric}. Supported metrics are 'dotproduct', 'cosine', and 'euclidean'."
109            )
110
111        self.top_k = top_k
112        self.usage_read_units = 0
113        self.heap: List[Tuple[float, int, object, str]] = []
114        self.insertion_counter = 0
115        self.read = False
116        self.final_results: Optional[QueryNamespacesResults] = None
top_k
usage_read_units
heap: List[Tuple[float, int, object, str]]
insertion_counter
read
final_results: Optional[QueryNamespacesResults]
def add_results(self, results: Dict[str, Any]):
141    def add_results(self, results: Dict[str, Any]):
142        if self.read:
143            # This is mainly just to sanity check in test cases which get quite confusing
144            # if you read results twice due to the heap being emptied when constructing
145            # the ordered results.
146            raise ValueError("Results have already been read. Cannot add more results.")
147
148        matches = results.get("matches", [])
149        ns: str = results.get("namespace", "")
150        if isinstance(results, OpenAPIQueryResponse):
151            self.usage_read_units += results.usage.read_units
152        else:
153            self.usage_read_units += results.get("usage", {}).get("readUnits", 0)
154
155        if len(matches) == 0:
156            return
157
158        if self.is_bigger_better:
159            self._process_matches(matches, ns, self._bigger_better_heap_item)
160        else:
161            self._process_matches(matches, ns, self._smaller_better_heap_item)
def get_results(self) -> QueryNamespacesResults:
163    def get_results(self) -> QueryNamespacesResults:
164        if self.read:
165            if self.final_results is not None:
166                return self.final_results
167            else:
168                # I don't think this branch can ever actually be reached, but the type checker disagrees
169                raise ValueError("Results have already been read. Cannot get results again.")
170        self.read = True
171
172        self.final_results = QueryNamespacesResults(
173            usage=Usage(read_units=self.usage_read_units),
174            matches=[
175                ScoredVectorWithNamespace(heapq.heappop(self.heap)) for _ in range(len(self.heap))
176            ][::-1],
177        )
178        return self.final_results