pinecone.data.query_results_aggregator
1from typing import List, Tuple, Optional, Any, Dict 2import json 3import heapq 4from pinecone.core.openapi.data.models import Usage 5from pinecone.core.openapi.data.models import QueryResponse as OpenAPIQueryResponse 6 7from dataclasses import dataclass, asdict 8 9 10@dataclass 11class ScoredVectorWithNamespace: 12 namespace: str 13 score: float 14 id: str 15 values: List[float] 16 sparse_values: dict 17 metadata: dict 18 19 def __init__(self, aggregate_results_heap_tuple: Tuple[float, int, object, str]): 20 json_vector = aggregate_results_heap_tuple[2] 21 self.namespace = aggregate_results_heap_tuple[3] 22 self.id = json_vector.get("id") # type: ignore 23 self.score = json_vector.get("score") # type: ignore 24 self.values = json_vector.get("values") # type: ignore 25 self.sparse_values = json_vector.get("sparse_values", None) # type: ignore 26 self.metadata = json_vector.get("metadata", None) # type: ignore 27 28 def __getitem__(self, key): 29 if hasattr(self, key): 30 return getattr(self, key) 31 else: 32 raise KeyError(f"'{key}' not found in ScoredVectorWithNamespace") 33 34 def get(self, key, default=None): 35 return getattr(self, key, default) 36 37 def __repr__(self): 38 return json.dumps(self._truncate(asdict(self)), indent=4) 39 40 def __json__(self): 41 return self._truncate(asdict(self)) 42 43 def _truncate(self, obj, max_items=2): 44 """ 45 Recursively traverse and truncate lists that exceed max_items length. 46 Only display the "... X more" message if at least 2 elements are hidden. 47 """ 48 if obj is None: 49 return None # Skip None values 50 elif isinstance(obj, list): 51 filtered_list = [self._truncate(i, max_items) for i in obj if i is not None] 52 if len(filtered_list) > max_items: 53 # Show the truncation message only if more than 1 item is hidden 54 remaining_items = len(filtered_list) - max_items 55 if remaining_items > 1: 56 return filtered_list[:max_items] + [f"... {remaining_items} more"] 57 else: 58 # If only 1 item remains, show it 59 return filtered_list 60 return filtered_list 61 elif isinstance(obj, dict): 62 # Recursively process dictionaries, omitting None values 63 return {k: self._truncate(v, max_items) for k, v in obj.items() if v is not None} 64 return obj 65 66 67@dataclass 68class QueryNamespacesResults: 69 usage: Usage 70 matches: List[ScoredVectorWithNamespace] 71 72 def __getitem__(self, key): 73 if hasattr(self, key): 74 return getattr(self, key) 75 else: 76 raise KeyError(f"'{key}' not found in QueryNamespacesResults") 77 78 def get(self, key, default=None): 79 return getattr(self, key, default) 80 81 def __repr__(self): 82 return json.dumps( 83 { 84 "usage": self.usage.to_dict(), 85 "matches": [match.__json__() for match in self.matches], 86 }, 87 indent=4, 88 ) 89 90 91class QueryResultsAggregregatorNotEnoughResultsError(Exception): 92 def __init__(self): 93 super().__init__( 94 "Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores." 95 ) 96 97 98class QueryResultsAggregatorInvalidTopKError(Exception): 99 def __init__(self, top_k: int): 100 super().__init__( 101 f"Invalid top_k value {top_k}. To aggregate results from multiple queries the top_k must be at least 2." 102 ) 103 104 105class QueryResultsAggregator: 106 def __init__(self, top_k: int): 107 if top_k < 2: 108 raise QueryResultsAggregatorInvalidTopKError(top_k) 109 self.top_k = top_k 110 self.usage_read_units = 0 111 self.heap: List[Tuple[float, int, object, str]] = [] 112 self.insertion_counter = 0 113 self.is_dotproduct = None 114 self.read = False 115 self.final_results: Optional[QueryNamespacesResults] = None 116 117 def _is_dotproduct_index(self, matches): 118 # The interpretation of the score depends on the similar metric used. 119 # Unlike other index types, in indexes configured for dotproduct, 120 # a higher score is better. We have to infer this is the case by inspecting 121 # the order of the scores in the results. 122 for i in range(1, len(matches)): 123 if matches[i].get("score") > matches[i - 1].get("score"): # Found an increase 124 return False 125 return True 126 127 def _dotproduct_heap_item(self, match, ns): 128 return (match.get("score"), -self.insertion_counter, match, ns) 129 130 def _non_dotproduct_heap_item(self, match, ns): 131 return (-match.get("score"), -self.insertion_counter, match, ns) 132 133 def _process_matches(self, matches, ns, heap_item_fn): 134 for match in matches: 135 self.insertion_counter += 1 136 if len(self.heap) < self.top_k: 137 heapq.heappush(self.heap, heap_item_fn(match, ns)) 138 else: 139 # Assume we have dotproduct scores sorted in descending order 140 if self.is_dotproduct and match["score"] < self.heap[0][0]: 141 # No further matches can improve the top-K heap 142 break 143 elif not self.is_dotproduct and match["score"] > -self.heap[0][0]: 144 # No further matches can improve the top-K heap 145 break 146 heapq.heappushpop(self.heap, heap_item_fn(match, ns)) 147 148 def add_results(self, results: Dict[str, Any]): 149 if self.read: 150 # This is mainly just to sanity check in test cases which get quite confusing 151 # if you read results twice due to the heap being emptied when constructing 152 # the ordered results. 153 raise ValueError("Results have already been read. Cannot add more results.") 154 155 matches = results.get("matches", []) 156 ns: str = results.get("namespace", "") 157 if isinstance(results, OpenAPIQueryResponse): 158 self.usage_read_units += results.usage.read_units 159 else: 160 self.usage_read_units += results.get("usage", {}).get("readUnits", 0) 161 162 if len(matches) == 0: 163 return 164 165 if self.is_dotproduct is None: 166 if len(matches) == 1: 167 # This condition should match the second time we add results containing 168 # only one match. We need at least two matches in a single response in order 169 # to infer the similarity metric 170 raise QueryResultsAggregregatorNotEnoughResultsError() 171 self.is_dotproduct = self._is_dotproduct_index(matches) 172 173 if self.is_dotproduct: 174 self._process_matches(matches, ns, self._dotproduct_heap_item) 175 else: 176 self._process_matches(matches, ns, self._non_dotproduct_heap_item) 177 178 def get_results(self) -> QueryNamespacesResults: 179 if self.read: 180 if self.final_results is not None: 181 return self.final_results 182 else: 183 # I don't think this branch can ever actually be reached, but the type checker disagrees 184 raise ValueError("Results have already been read. Cannot get results again.") 185 self.read = True 186 187 self.final_results = QueryNamespacesResults( 188 usage=Usage(read_units=self.usage_read_units), 189 matches=[ 190 ScoredVectorWithNamespace(heapq.heappop(self.heap)) for _ in range(len(self.heap)) 191 ][::-1], 192 ) 193 return self.final_results
@dataclass
class
ScoredVectorWithNamespace:
11@dataclass 12class ScoredVectorWithNamespace: 13 namespace: str 14 score: float 15 id: str 16 values: List[float] 17 sparse_values: dict 18 metadata: dict 19 20 def __init__(self, aggregate_results_heap_tuple: Tuple[float, int, object, str]): 21 json_vector = aggregate_results_heap_tuple[2] 22 self.namespace = aggregate_results_heap_tuple[3] 23 self.id = json_vector.get("id") # type: ignore 24 self.score = json_vector.get("score") # type: ignore 25 self.values = json_vector.get("values") # type: ignore 26 self.sparse_values = json_vector.get("sparse_values", None) # type: ignore 27 self.metadata = json_vector.get("metadata", None) # type: ignore 28 29 def __getitem__(self, key): 30 if hasattr(self, key): 31 return getattr(self, key) 32 else: 33 raise KeyError(f"'{key}' not found in ScoredVectorWithNamespace") 34 35 def get(self, key, default=None): 36 return getattr(self, key, default) 37 38 def __repr__(self): 39 return json.dumps(self._truncate(asdict(self)), indent=4) 40 41 def __json__(self): 42 return self._truncate(asdict(self)) 43 44 def _truncate(self, obj, max_items=2): 45 """ 46 Recursively traverse and truncate lists that exceed max_items length. 47 Only display the "... X more" message if at least 2 elements are hidden. 48 """ 49 if obj is None: 50 return None # Skip None values 51 elif isinstance(obj, list): 52 filtered_list = [self._truncate(i, max_items) for i in obj if i is not None] 53 if len(filtered_list) > max_items: 54 # Show the truncation message only if more than 1 item is hidden 55 remaining_items = len(filtered_list) - max_items 56 if remaining_items > 1: 57 return filtered_list[:max_items] + [f"... {remaining_items} more"] 58 else: 59 # If only 1 item remains, show it 60 return filtered_list 61 return filtered_list 62 elif isinstance(obj, dict): 63 # Recursively process dictionaries, omitting None values 64 return {k: self._truncate(v, max_items) for k, v in obj.items() if v is not None} 65 return obj
ScoredVectorWithNamespace(aggregate_results_heap_tuple: Tuple[float, int, object, str])
20 def __init__(self, aggregate_results_heap_tuple: Tuple[float, int, object, str]): 21 json_vector = aggregate_results_heap_tuple[2] 22 self.namespace = aggregate_results_heap_tuple[3] 23 self.id = json_vector.get("id") # type: ignore 24 self.score = json_vector.get("score") # type: ignore 25 self.values = json_vector.get("values") # type: ignore 26 self.sparse_values = json_vector.get("sparse_values", None) # type: ignore 27 self.metadata = json_vector.get("metadata", None) # type: ignore
@dataclass
class
QueryNamespacesResults:
68@dataclass 69class QueryNamespacesResults: 70 usage: Usage 71 matches: List[ScoredVectorWithNamespace] 72 73 def __getitem__(self, key): 74 if hasattr(self, key): 75 return getattr(self, key) 76 else: 77 raise KeyError(f"'{key}' not found in QueryNamespacesResults") 78 79 def get(self, key, default=None): 80 return getattr(self, key, default) 81 82 def __repr__(self): 83 return json.dumps( 84 { 85 "usage": self.usage.to_dict(), 86 "matches": [match.__json__() for match in self.matches], 87 }, 88 indent=4, 89 )
QueryNamespacesResults( usage: pinecone.core.openapi.data.model.usage.Usage, matches: List[ScoredVectorWithNamespace])
matches: List[ScoredVectorWithNamespace]
class
QueryResultsAggregregatorNotEnoughResultsError(builtins.Exception):
92class QueryResultsAggregregatorNotEnoughResultsError(Exception): 93 def __init__(self): 94 super().__init__( 95 "Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores." 96 )
Common base class for all non-exit exceptions.
Inherited Members
- builtins.BaseException
- with_traceback
- add_note
- args
class
QueryResultsAggregatorInvalidTopKError(builtins.Exception):
99class QueryResultsAggregatorInvalidTopKError(Exception): 100 def __init__(self, top_k: int): 101 super().__init__( 102 f"Invalid top_k value {top_k}. To aggregate results from multiple queries the top_k must be at least 2." 103 )
Common base class for all non-exit exceptions.
Inherited Members
- builtins.BaseException
- with_traceback
- add_note
- args
class
QueryResultsAggregator:
106class QueryResultsAggregator: 107 def __init__(self, top_k: int): 108 if top_k < 2: 109 raise QueryResultsAggregatorInvalidTopKError(top_k) 110 self.top_k = top_k 111 self.usage_read_units = 0 112 self.heap: List[Tuple[float, int, object, str]] = [] 113 self.insertion_counter = 0 114 self.is_dotproduct = None 115 self.read = False 116 self.final_results: Optional[QueryNamespacesResults] = None 117 118 def _is_dotproduct_index(self, matches): 119 # The interpretation of the score depends on the similar metric used. 120 # Unlike other index types, in indexes configured for dotproduct, 121 # a higher score is better. We have to infer this is the case by inspecting 122 # the order of the scores in the results. 123 for i in range(1, len(matches)): 124 if matches[i].get("score") > matches[i - 1].get("score"): # Found an increase 125 return False 126 return True 127 128 def _dotproduct_heap_item(self, match, ns): 129 return (match.get("score"), -self.insertion_counter, match, ns) 130 131 def _non_dotproduct_heap_item(self, match, ns): 132 return (-match.get("score"), -self.insertion_counter, match, ns) 133 134 def _process_matches(self, matches, ns, heap_item_fn): 135 for match in matches: 136 self.insertion_counter += 1 137 if len(self.heap) < self.top_k: 138 heapq.heappush(self.heap, heap_item_fn(match, ns)) 139 else: 140 # Assume we have dotproduct scores sorted in descending order 141 if self.is_dotproduct and match["score"] < self.heap[0][0]: 142 # No further matches can improve the top-K heap 143 break 144 elif not self.is_dotproduct and match["score"] > -self.heap[0][0]: 145 # No further matches can improve the top-K heap 146 break 147 heapq.heappushpop(self.heap, heap_item_fn(match, ns)) 148 149 def add_results(self, results: Dict[str, Any]): 150 if self.read: 151 # This is mainly just to sanity check in test cases which get quite confusing 152 # if you read results twice due to the heap being emptied when constructing 153 # the ordered results. 154 raise ValueError("Results have already been read. Cannot add more results.") 155 156 matches = results.get("matches", []) 157 ns: str = results.get("namespace", "") 158 if isinstance(results, OpenAPIQueryResponse): 159 self.usage_read_units += results.usage.read_units 160 else: 161 self.usage_read_units += results.get("usage", {}).get("readUnits", 0) 162 163 if len(matches) == 0: 164 return 165 166 if self.is_dotproduct is None: 167 if len(matches) == 1: 168 # This condition should match the second time we add results containing 169 # only one match. We need at least two matches in a single response in order 170 # to infer the similarity metric 171 raise QueryResultsAggregregatorNotEnoughResultsError() 172 self.is_dotproduct = self._is_dotproduct_index(matches) 173 174 if self.is_dotproduct: 175 self._process_matches(matches, ns, self._dotproduct_heap_item) 176 else: 177 self._process_matches(matches, ns, self._non_dotproduct_heap_item) 178 179 def get_results(self) -> QueryNamespacesResults: 180 if self.read: 181 if self.final_results is not None: 182 return self.final_results 183 else: 184 # I don't think this branch can ever actually be reached, but the type checker disagrees 185 raise ValueError("Results have already been read. Cannot get results again.") 186 self.read = True 187 188 self.final_results = QueryNamespacesResults( 189 usage=Usage(read_units=self.usage_read_units), 190 matches=[ 191 ScoredVectorWithNamespace(heapq.heappop(self.heap)) for _ in range(len(self.heap)) 192 ][::-1], 193 ) 194 return self.final_results
QueryResultsAggregator(top_k: int)
107 def __init__(self, top_k: int): 108 if top_k < 2: 109 raise QueryResultsAggregatorInvalidTopKError(top_k) 110 self.top_k = top_k 111 self.usage_read_units = 0 112 self.heap: List[Tuple[float, int, object, str]] = [] 113 self.insertion_counter = 0 114 self.is_dotproduct = None 115 self.read = False 116 self.final_results: Optional[QueryNamespacesResults] = None
final_results: Optional[QueryNamespacesResults]
def
add_results(self, results: Dict[str, Any]):
149 def add_results(self, results: Dict[str, Any]): 150 if self.read: 151 # This is mainly just to sanity check in test cases which get quite confusing 152 # if you read results twice due to the heap being emptied when constructing 153 # the ordered results. 154 raise ValueError("Results have already been read. Cannot add more results.") 155 156 matches = results.get("matches", []) 157 ns: str = results.get("namespace", "") 158 if isinstance(results, OpenAPIQueryResponse): 159 self.usage_read_units += results.usage.read_units 160 else: 161 self.usage_read_units += results.get("usage", {}).get("readUnits", 0) 162 163 if len(matches) == 0: 164 return 165 166 if self.is_dotproduct is None: 167 if len(matches) == 1: 168 # This condition should match the second time we add results containing 169 # only one match. We need at least two matches in a single response in order 170 # to infer the similarity metric 171 raise QueryResultsAggregregatorNotEnoughResultsError() 172 self.is_dotproduct = self._is_dotproduct_index(matches) 173 174 if self.is_dotproduct: 175 self._process_matches(matches, ns, self._dotproduct_heap_item) 176 else: 177 self._process_matches(matches, ns, self._non_dotproduct_heap_item)
179 def get_results(self) -> QueryNamespacesResults: 180 if self.read: 181 if self.final_results is not None: 182 return self.final_results 183 else: 184 # I don't think this branch can ever actually be reached, but the type checker disagrees 185 raise ValueError("Results have already been read. Cannot get results again.") 186 self.read = True 187 188 self.final_results = QueryNamespacesResults( 189 usage=Usage(read_units=self.usage_read_units), 190 matches=[ 191 ScoredVectorWithNamespace(heapq.heappop(self.heap)) for _ in range(len(self.heap)) 192 ][::-1], 193 ) 194 return self.final_results