pinecone.data.query_results_aggregator

  1from typing import List, Tuple, Optional, Any, Dict
  2import json
  3import heapq
  4from pinecone.core.openapi.data.models import Usage
  5from pinecone.core.openapi.data.models import QueryResponse as OpenAPIQueryResponse
  6
  7from dataclasses import dataclass, asdict
  8
  9
 10@dataclass
 11class ScoredVectorWithNamespace:
 12    namespace: str
 13    score: float
 14    id: str
 15    values: List[float]
 16    sparse_values: dict
 17    metadata: dict
 18
 19    def __init__(self, aggregate_results_heap_tuple: Tuple[float, int, object, str]):
 20        json_vector = aggregate_results_heap_tuple[2]
 21        self.namespace = aggregate_results_heap_tuple[3]
 22        self.id = json_vector.get("id")  # type: ignore
 23        self.score = json_vector.get("score")  # type: ignore
 24        self.values = json_vector.get("values")  # type: ignore
 25        self.sparse_values = json_vector.get("sparse_values", None)  # type: ignore
 26        self.metadata = json_vector.get("metadata", None)  # type: ignore
 27
 28    def __getitem__(self, key):
 29        if hasattr(self, key):
 30            return getattr(self, key)
 31        else:
 32            raise KeyError(f"'{key}' not found in ScoredVectorWithNamespace")
 33
 34    def get(self, key, default=None):
 35        return getattr(self, key, default)
 36
 37    def __repr__(self):
 38        return json.dumps(self._truncate(asdict(self)), indent=4)
 39
 40    def __json__(self):
 41        return self._truncate(asdict(self))
 42
 43    def _truncate(self, obj, max_items=2):
 44        """
 45        Recursively traverse and truncate lists that exceed max_items length.
 46        Only display the "... X more" message if at least 2 elements are hidden.
 47        """
 48        if obj is None:
 49            return None  # Skip None values
 50        elif isinstance(obj, list):
 51            filtered_list = [self._truncate(i, max_items) for i in obj if i is not None]
 52            if len(filtered_list) > max_items:
 53                # Show the truncation message only if more than 1 item is hidden
 54                remaining_items = len(filtered_list) - max_items
 55                if remaining_items > 1:
 56                    return filtered_list[:max_items] + [f"... {remaining_items} more"]
 57                else:
 58                    # If only 1 item remains, show it
 59                    return filtered_list
 60            return filtered_list
 61        elif isinstance(obj, dict):
 62            # Recursively process dictionaries, omitting None values
 63            return {k: self._truncate(v, max_items) for k, v in obj.items() if v is not None}
 64        return obj
 65
 66
 67@dataclass
 68class QueryNamespacesResults:
 69    usage: Usage
 70    matches: List[ScoredVectorWithNamespace]
 71
 72    def __getitem__(self, key):
 73        if hasattr(self, key):
 74            return getattr(self, key)
 75        else:
 76            raise KeyError(f"'{key}' not found in QueryNamespacesResults")
 77
 78    def get(self, key, default=None):
 79        return getattr(self, key, default)
 80
 81    def __repr__(self):
 82        return json.dumps(
 83            {
 84                "usage": self.usage.to_dict(),
 85                "matches": [match.__json__() for match in self.matches],
 86            },
 87            indent=4,
 88        )
 89
 90
 91class QueryResultsAggregregatorNotEnoughResultsError(Exception):
 92    def __init__(self):
 93        super().__init__(
 94            "Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores."
 95        )
 96
 97
 98class QueryResultsAggregatorInvalidTopKError(Exception):
 99    def __init__(self, top_k: int):
100        super().__init__(
101            f"Invalid top_k value {top_k}. To aggregate results from multiple queries the top_k must be at least 2."
102        )
103
104
105class QueryResultsAggregator:
106    def __init__(self, top_k: int):
107        if top_k < 2:
108            raise QueryResultsAggregatorInvalidTopKError(top_k)
109        self.top_k = top_k
110        self.usage_read_units = 0
111        self.heap: List[Tuple[float, int, object, str]] = []
112        self.insertion_counter = 0
113        self.is_dotproduct = None
114        self.read = False
115        self.final_results: Optional[QueryNamespacesResults] = None
116
117    def _is_dotproduct_index(self, matches):
118        # The interpretation of the score depends on the similar metric used.
119        # Unlike other index types, in indexes configured for dotproduct,
120        # a higher score is better. We have to infer this is the case by inspecting
121        # the order of the scores in the results.
122        for i in range(1, len(matches)):
123            if matches[i].get("score") > matches[i - 1].get("score"):  # Found an increase
124                return False
125        return True
126
127    def _dotproduct_heap_item(self, match, ns):
128        return (match.get("score"), -self.insertion_counter, match, ns)
129
130    def _non_dotproduct_heap_item(self, match, ns):
131        return (-match.get("score"), -self.insertion_counter, match, ns)
132
133    def _process_matches(self, matches, ns, heap_item_fn):
134        for match in matches:
135            self.insertion_counter += 1
136            if len(self.heap) < self.top_k:
137                heapq.heappush(self.heap, heap_item_fn(match, ns))
138            else:
139                # Assume we have dotproduct scores sorted in descending order
140                if self.is_dotproduct and match["score"] < self.heap[0][0]:
141                    # No further matches can improve the top-K heap
142                    break
143                elif not self.is_dotproduct and match["score"] > -self.heap[0][0]:
144                    # No further matches can improve the top-K heap
145                    break
146                heapq.heappushpop(self.heap, heap_item_fn(match, ns))
147
148    def add_results(self, results: Dict[str, Any]):
149        if self.read:
150            # This is mainly just to sanity check in test cases which get quite confusing
151            # if you read results twice due to the heap being emptied when constructing
152            # the ordered results.
153            raise ValueError("Results have already been read. Cannot add more results.")
154
155        matches = results.get("matches", [])
156        ns: str = results.get("namespace", "")
157        if isinstance(results, OpenAPIQueryResponse):
158            self.usage_read_units += results.usage.read_units
159        else:
160            self.usage_read_units += results.get("usage", {}).get("readUnits", 0)
161
162        if len(matches) == 0:
163            return
164
165        if self.is_dotproduct is None:
166            if len(matches) == 1:
167                # This condition should match the second time we add results containing
168                # only one match. We need at least two matches in a single response in order
169                # to infer the similarity metric
170                raise QueryResultsAggregregatorNotEnoughResultsError()
171            self.is_dotproduct = self._is_dotproduct_index(matches)
172
173        if self.is_dotproduct:
174            self._process_matches(matches, ns, self._dotproduct_heap_item)
175        else:
176            self._process_matches(matches, ns, self._non_dotproduct_heap_item)
177
178    def get_results(self) -> QueryNamespacesResults:
179        if self.read:
180            if self.final_results is not None:
181                return self.final_results
182            else:
183                # I don't think this branch can ever actually be reached, but the type checker disagrees
184                raise ValueError("Results have already been read. Cannot get results again.")
185        self.read = True
186
187        self.final_results = QueryNamespacesResults(
188            usage=Usage(read_units=self.usage_read_units),
189            matches=[
190                ScoredVectorWithNamespace(heapq.heappop(self.heap)) for _ in range(len(self.heap))
191            ][::-1],
192        )
193        return self.final_results
@dataclass
class ScoredVectorWithNamespace:
11@dataclass
12class ScoredVectorWithNamespace:
13    namespace: str
14    score: float
15    id: str
16    values: List[float]
17    sparse_values: dict
18    metadata: dict
19
20    def __init__(self, aggregate_results_heap_tuple: Tuple[float, int, object, str]):
21        json_vector = aggregate_results_heap_tuple[2]
22        self.namespace = aggregate_results_heap_tuple[3]
23        self.id = json_vector.get("id")  # type: ignore
24        self.score = json_vector.get("score")  # type: ignore
25        self.values = json_vector.get("values")  # type: ignore
26        self.sparse_values = json_vector.get("sparse_values", None)  # type: ignore
27        self.metadata = json_vector.get("metadata", None)  # type: ignore
28
29    def __getitem__(self, key):
30        if hasattr(self, key):
31            return getattr(self, key)
32        else:
33            raise KeyError(f"'{key}' not found in ScoredVectorWithNamespace")
34
35    def get(self, key, default=None):
36        return getattr(self, key, default)
37
38    def __repr__(self):
39        return json.dumps(self._truncate(asdict(self)), indent=4)
40
41    def __json__(self):
42        return self._truncate(asdict(self))
43
44    def _truncate(self, obj, max_items=2):
45        """
46        Recursively traverse and truncate lists that exceed max_items length.
47        Only display the "... X more" message if at least 2 elements are hidden.
48        """
49        if obj is None:
50            return None  # Skip None values
51        elif isinstance(obj, list):
52            filtered_list = [self._truncate(i, max_items) for i in obj if i is not None]
53            if len(filtered_list) > max_items:
54                # Show the truncation message only if more than 1 item is hidden
55                remaining_items = len(filtered_list) - max_items
56                if remaining_items > 1:
57                    return filtered_list[:max_items] + [f"... {remaining_items} more"]
58                else:
59                    # If only 1 item remains, show it
60                    return filtered_list
61            return filtered_list
62        elif isinstance(obj, dict):
63            # Recursively process dictionaries, omitting None values
64            return {k: self._truncate(v, max_items) for k, v in obj.items() if v is not None}
65        return obj
ScoredVectorWithNamespace(aggregate_results_heap_tuple: Tuple[float, int, object, str])
20    def __init__(self, aggregate_results_heap_tuple: Tuple[float, int, object, str]):
21        json_vector = aggregate_results_heap_tuple[2]
22        self.namespace = aggregate_results_heap_tuple[3]
23        self.id = json_vector.get("id")  # type: ignore
24        self.score = json_vector.get("score")  # type: ignore
25        self.values = json_vector.get("values")  # type: ignore
26        self.sparse_values = json_vector.get("sparse_values", None)  # type: ignore
27        self.metadata = json_vector.get("metadata", None)  # type: ignore
namespace: str
score: float
id: str
values: List[float]
sparse_values: dict
metadata: dict
def get(self, key, default=None):
35    def get(self, key, default=None):
36        return getattr(self, key, default)
@dataclass
class QueryNamespacesResults:
68@dataclass
69class QueryNamespacesResults:
70    usage: Usage
71    matches: List[ScoredVectorWithNamespace]
72
73    def __getitem__(self, key):
74        if hasattr(self, key):
75            return getattr(self, key)
76        else:
77            raise KeyError(f"'{key}' not found in QueryNamespacesResults")
78
79    def get(self, key, default=None):
80        return getattr(self, key, default)
81
82    def __repr__(self):
83        return json.dumps(
84            {
85                "usage": self.usage.to_dict(),
86                "matches": [match.__json__() for match in self.matches],
87            },
88            indent=4,
89        )
QueryNamespacesResults( usage: pinecone.core.openapi.data.model.usage.Usage, matches: List[ScoredVectorWithNamespace])
usage: pinecone.core.openapi.data.model.usage.Usage
matches: List[ScoredVectorWithNamespace]
def get(self, key, default=None):
79    def get(self, key, default=None):
80        return getattr(self, key, default)
class QueryResultsAggregregatorNotEnoughResultsError(builtins.Exception):
92class QueryResultsAggregregatorNotEnoughResultsError(Exception):
93    def __init__(self):
94        super().__init__(
95            "Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores."
96        )

Common base class for all non-exit exceptions.

Inherited Members
builtins.BaseException
with_traceback
add_note
args
class QueryResultsAggregatorInvalidTopKError(builtins.Exception):
 99class QueryResultsAggregatorInvalidTopKError(Exception):
100    def __init__(self, top_k: int):
101        super().__init__(
102            f"Invalid top_k value {top_k}. To aggregate results from multiple queries the top_k must be at least 2."
103        )

Common base class for all non-exit exceptions.

QueryResultsAggregatorInvalidTopKError(top_k: int)
100    def __init__(self, top_k: int):
101        super().__init__(
102            f"Invalid top_k value {top_k}. To aggregate results from multiple queries the top_k must be at least 2."
103        )
Inherited Members
builtins.BaseException
with_traceback
add_note
args
class QueryResultsAggregator:
106class QueryResultsAggregator:
107    def __init__(self, top_k: int):
108        if top_k < 2:
109            raise QueryResultsAggregatorInvalidTopKError(top_k)
110        self.top_k = top_k
111        self.usage_read_units = 0
112        self.heap: List[Tuple[float, int, object, str]] = []
113        self.insertion_counter = 0
114        self.is_dotproduct = None
115        self.read = False
116        self.final_results: Optional[QueryNamespacesResults] = None
117
118    def _is_dotproduct_index(self, matches):
119        # The interpretation of the score depends on the similar metric used.
120        # Unlike other index types, in indexes configured for dotproduct,
121        # a higher score is better. We have to infer this is the case by inspecting
122        # the order of the scores in the results.
123        for i in range(1, len(matches)):
124            if matches[i].get("score") > matches[i - 1].get("score"):  # Found an increase
125                return False
126        return True
127
128    def _dotproduct_heap_item(self, match, ns):
129        return (match.get("score"), -self.insertion_counter, match, ns)
130
131    def _non_dotproduct_heap_item(self, match, ns):
132        return (-match.get("score"), -self.insertion_counter, match, ns)
133
134    def _process_matches(self, matches, ns, heap_item_fn):
135        for match in matches:
136            self.insertion_counter += 1
137            if len(self.heap) < self.top_k:
138                heapq.heappush(self.heap, heap_item_fn(match, ns))
139            else:
140                # Assume we have dotproduct scores sorted in descending order
141                if self.is_dotproduct and match["score"] < self.heap[0][0]:
142                    # No further matches can improve the top-K heap
143                    break
144                elif not self.is_dotproduct and match["score"] > -self.heap[0][0]:
145                    # No further matches can improve the top-K heap
146                    break
147                heapq.heappushpop(self.heap, heap_item_fn(match, ns))
148
149    def add_results(self, results: Dict[str, Any]):
150        if self.read:
151            # This is mainly just to sanity check in test cases which get quite confusing
152            # if you read results twice due to the heap being emptied when constructing
153            # the ordered results.
154            raise ValueError("Results have already been read. Cannot add more results.")
155
156        matches = results.get("matches", [])
157        ns: str = results.get("namespace", "")
158        if isinstance(results, OpenAPIQueryResponse):
159            self.usage_read_units += results.usage.read_units
160        else:
161            self.usage_read_units += results.get("usage", {}).get("readUnits", 0)
162
163        if len(matches) == 0:
164            return
165
166        if self.is_dotproduct is None:
167            if len(matches) == 1:
168                # This condition should match the second time we add results containing
169                # only one match. We need at least two matches in a single response in order
170                # to infer the similarity metric
171                raise QueryResultsAggregregatorNotEnoughResultsError()
172            self.is_dotproduct = self._is_dotproduct_index(matches)
173
174        if self.is_dotproduct:
175            self._process_matches(matches, ns, self._dotproduct_heap_item)
176        else:
177            self._process_matches(matches, ns, self._non_dotproduct_heap_item)
178
179    def get_results(self) -> QueryNamespacesResults:
180        if self.read:
181            if self.final_results is not None:
182                return self.final_results
183            else:
184                # I don't think this branch can ever actually be reached, but the type checker disagrees
185                raise ValueError("Results have already been read. Cannot get results again.")
186        self.read = True
187
188        self.final_results = QueryNamespacesResults(
189            usage=Usage(read_units=self.usage_read_units),
190            matches=[
191                ScoredVectorWithNamespace(heapq.heappop(self.heap)) for _ in range(len(self.heap))
192            ][::-1],
193        )
194        return self.final_results
QueryResultsAggregator(top_k: int)
107    def __init__(self, top_k: int):
108        if top_k < 2:
109            raise QueryResultsAggregatorInvalidTopKError(top_k)
110        self.top_k = top_k
111        self.usage_read_units = 0
112        self.heap: List[Tuple[float, int, object, str]] = []
113        self.insertion_counter = 0
114        self.is_dotproduct = None
115        self.read = False
116        self.final_results: Optional[QueryNamespacesResults] = None
top_k
usage_read_units
heap: List[Tuple[float, int, object, str]]
insertion_counter
is_dotproduct
read
final_results: Optional[QueryNamespacesResults]
def add_results(self, results: Dict[str, Any]):
149    def add_results(self, results: Dict[str, Any]):
150        if self.read:
151            # This is mainly just to sanity check in test cases which get quite confusing
152            # if you read results twice due to the heap being emptied when constructing
153            # the ordered results.
154            raise ValueError("Results have already been read. Cannot add more results.")
155
156        matches = results.get("matches", [])
157        ns: str = results.get("namespace", "")
158        if isinstance(results, OpenAPIQueryResponse):
159            self.usage_read_units += results.usage.read_units
160        else:
161            self.usage_read_units += results.get("usage", {}).get("readUnits", 0)
162
163        if len(matches) == 0:
164            return
165
166        if self.is_dotproduct is None:
167            if len(matches) == 1:
168                # This condition should match the second time we add results containing
169                # only one match. We need at least two matches in a single response in order
170                # to infer the similarity metric
171                raise QueryResultsAggregregatorNotEnoughResultsError()
172            self.is_dotproduct = self._is_dotproduct_index(matches)
173
174        if self.is_dotproduct:
175            self._process_matches(matches, ns, self._dotproduct_heap_item)
176        else:
177            self._process_matches(matches, ns, self._non_dotproduct_heap_item)
def get_results(self) -> QueryNamespacesResults:
179    def get_results(self) -> QueryNamespacesResults:
180        if self.read:
181            if self.final_results is not None:
182                return self.final_results
183            else:
184                # I don't think this branch can ever actually be reached, but the type checker disagrees
185                raise ValueError("Results have already been read. Cannot get results again.")
186        self.read = True
187
188        self.final_results = QueryNamespacesResults(
189            usage=Usage(read_units=self.usage_read_units),
190            matches=[
191                ScoredVectorWithNamespace(heapq.heappop(self.heap)) for _ in range(len(self.heap))
192            ][::-1],
193        )
194        return self.final_results