1import logging
2from typing import Optional, Dict, Any, Union
3from enum import Enum
4
5
6from pinecone.utils import convert_enum_to_string
7from pinecone.core.openapi.db_control.models import (
8 CreateCollectionRequest,
9 CreateIndexForModelRequest,
10 CreateIndexForModelRequestEmbed,
11 CreateIndexRequest,
12 ConfigureIndexRequest,
13 ConfigureIndexRequestSpec,
14 ConfigureIndexRequestSpecPod,
15 DeletionProtection as DeletionProtectionModel,
16 IndexSpec,
17 IndexTags,
18 ServerlessSpec as ServerlessSpecModel,
19 PodSpec as PodSpecModel,
20 PodSpecMetadataConfig,
21)
22from pinecone.models import ServerlessSpec, PodSpec, IndexModel, IndexEmbed
23from pinecone.utils import parse_non_empty_args
24
25from pinecone.enums import (
26 Metric,
27 VectorType,
28 DeletionProtection,
29 PodType,
30 CloudProvider,
31 AwsRegion,
32 GcpRegion,
33 AzureRegion,
34)
35from .types import CreateIndexForModelEmbedTypedDict
36
37
38logger = logging.getLogger(__name__)
39""" @private """
40
41
42class PineconeDBControlRequestFactory:
43 """
44 @private
45
46 This class facilitates translating user inputs into request objects.
47 """
48
49 @staticmethod
50 def __parse_tags(tags: Optional[Dict[str, str]]) -> IndexTags:
51 if tags is None:
52 return IndexTags()
53 else:
54 return IndexTags(**tags)
55
56 @staticmethod
57 def __parse_deletion_protection(
58 deletion_protection: Union[DeletionProtection, str],
59 ) -> DeletionProtectionModel:
60 deletion_protection = convert_enum_to_string(deletion_protection)
61 if deletion_protection in ["enabled", "disabled"]:
62 return DeletionProtectionModel(deletion_protection)
63 else:
64 raise ValueError("deletion_protection must be either 'enabled' or 'disabled'")
65
66 @staticmethod
67 def __parse_index_spec(spec: Union[Dict, ServerlessSpec, PodSpec]) -> IndexSpec:
68 if isinstance(spec, dict):
69 if "serverless" in spec:
70 spec["serverless"]["cloud"] = convert_enum_to_string(spec["serverless"]["cloud"])
71 spec["serverless"]["region"] = convert_enum_to_string(spec["serverless"]["region"])
72
73 index_spec = IndexSpec(serverless=ServerlessSpecModel(**spec["serverless"]))
74 elif "pod" in spec:
75 spec["pod"]["environment"] = convert_enum_to_string(spec["pod"]["environment"])
76 args_dict = parse_non_empty_args(
77 [
78 ("environment", spec["pod"].get("environment")),
79 ("metadata_config", spec["pod"].get("metadata_config")),
80 ("replicas", spec["pod"].get("replicas")),
81 ("shards", spec["pod"].get("shards")),
82 ("pods", spec["pod"].get("pods")),
83 ("source_collection", spec["pod"].get("source_collection")),
84 ]
85 )
86 if args_dict.get("metadata_config"):
87 args_dict["metadata_config"] = PodSpecMetadataConfig(
88 indexed=args_dict["metadata_config"].get("indexed", None)
89 )
90 index_spec = IndexSpec(pod=PodSpecModel(**args_dict))
91 else:
92 raise ValueError("spec must contain either 'serverless' or 'pod' key")
93 elif isinstance(spec, ServerlessSpec):
94 index_spec = IndexSpec(
95 serverless=ServerlessSpecModel(cloud=spec.cloud, region=spec.region)
96 )
97 elif isinstance(spec, PodSpec):
98 args_dict = parse_non_empty_args(
99 [
100 ("replicas", spec.replicas),
101 ("shards", spec.shards),
102 ("pods", spec.pods),
103 ("source_collection", spec.source_collection),
104 ]
105 )
106 if spec.metadata_config:
107 args_dict["metadata_config"] = PodSpecMetadataConfig(
108 indexed=spec.metadata_config.get("indexed", None)
109 )
110
111 index_spec = IndexSpec(
112 pod=PodSpecModel(environment=spec.environment, pod_type=spec.pod_type, **args_dict)
113 )
114 else:
115 raise TypeError("spec must be of type dict, ServerlessSpec, or PodSpec")
116
117 return index_spec
118
119 @staticmethod
120 def create_index_request(
121 name: str,
122 spec: Union[Dict, ServerlessSpec, PodSpec],
123 dimension: Optional[int] = None,
124 metric: Optional[Union[Metric, str]] = Metric.COSINE,
125 deletion_protection: Optional[Union[DeletionProtection, str]] = DeletionProtection.DISABLED,
126 vector_type: Optional[Union[VectorType, str]] = VectorType.DENSE,
127 tags: Optional[Dict[str, str]] = None,
128 ) -> CreateIndexRequest:
129 if metric is not None:
130 metric = convert_enum_to_string(metric)
131 if vector_type is not None:
132 vector_type = convert_enum_to_string(vector_type)
133 if deletion_protection is not None:
134 dp = PineconeDBControlRequestFactory.__parse_deletion_protection(deletion_protection)
135 else:
136 dp = None
137
138 tags_obj = PineconeDBControlRequestFactory.__parse_tags(tags)
139 index_spec = PineconeDBControlRequestFactory.__parse_index_spec(spec)
140
141 if vector_type == VectorType.SPARSE.value and dimension is not None:
142 raise ValueError("dimension should not be specified for sparse indexes")
143
144 args = parse_non_empty_args(
145 [
146 ("name", name),
147 ("dimension", dimension),
148 ("metric", metric),
149 ("spec", index_spec),
150 ("deletion_protection", dp),
151 ("vector_type", vector_type),
152 ("tags", tags_obj),
153 ]
154 )
155
156 return CreateIndexRequest(**args)
157
158 @staticmethod
159 def create_index_for_model_request(
160 name: str,
161 cloud: Union[CloudProvider, str],
162 region: Union[AwsRegion, GcpRegion, AzureRegion, str],
163 embed: Union[IndexEmbed, CreateIndexForModelEmbedTypedDict],
164 tags: Optional[Dict[str, str]] = None,
165 deletion_protection: Optional[Union[DeletionProtection, str]] = DeletionProtection.DISABLED,
166 ) -> CreateIndexForModelRequest:
167 cloud = convert_enum_to_string(cloud)
168 region = convert_enum_to_string(region)
169 if deletion_protection is not None:
170 dp = PineconeDBControlRequestFactory.__parse_deletion_protection(deletion_protection)
171 else:
172 dp = None
173 tags_obj = PineconeDBControlRequestFactory.__parse_tags(tags)
174
175 if isinstance(embed, IndexEmbed):
176 parsed_embed = embed.as_dict()
177 else:
178 # if dict, we need to parse enum values, if any, to string
179 # and verify required fields are present
180 required_fields = ["model", "field_map"]
181 for field in required_fields:
182 if field not in embed:
183 raise ValueError(f"{field} is required in embed")
184 parsed_embed = {}
185 for key, value in embed.items():
186 if isinstance(value, Enum):
187 parsed_embed[key] = convert_enum_to_string(value)
188 else:
189 parsed_embed[key] = value
190
191 args = parse_non_empty_args(
192 [
193 ("name", name),
194 ("cloud", cloud),
195 ("region", region),
196 ("embed", CreateIndexForModelRequestEmbed(**parsed_embed)),
197 ("deletion_protection", dp),
198 ("tags", tags_obj),
199 ]
200 )
201
202 return CreateIndexForModelRequest(**args)
203
204 @staticmethod
205 def configure_index_request(
206 description: IndexModel,
207 replicas: Optional[int] = None,
208 pod_type: Optional[Union[PodType, str]] = None,
209 deletion_protection: Optional[Union[DeletionProtection, str]] = None,
210 tags: Optional[Dict[str, str]] = None,
211 ):
212 if deletion_protection is None:
213 dp = DeletionProtectionModel(description.deletion_protection)
214 elif isinstance(deletion_protection, DeletionProtection):
215 dp = DeletionProtectionModel(deletion_protection.value)
216 elif deletion_protection in ["enabled", "disabled"]:
217 dp = DeletionProtectionModel(deletion_protection)
218 else:
219 raise ValueError("deletion_protection must be either 'enabled' or 'disabled'")
220
221 fetched_tags = description.tags
222 if fetched_tags is None:
223 starting_tags = {}
224 else:
225 starting_tags = fetched_tags.to_dict()
226
227 if tags is None:
228 # Do not modify tags if none are provided
229 tags = starting_tags
230 else:
231 # Merge existing tags with new tags
232 tags = {**starting_tags, **tags}
233
234 pod_config_args: Dict[str, Any] = {}
235 if pod_type:
236 new_pod_type = convert_enum_to_string(pod_type)
237 pod_config_args.update(pod_type=new_pod_type)
238 if replicas:
239 pod_config_args.update(replicas=replicas)
240
241 if pod_config_args != {}:
242 spec = ConfigureIndexRequestSpec(pod=ConfigureIndexRequestSpecPod(**pod_config_args))
243 req = ConfigureIndexRequest(deletion_protection=dp, spec=spec, tags=IndexTags(**tags))
244 else:
245 req = ConfigureIndexRequest(deletion_protection=dp, tags=IndexTags(**tags))
246
247 return req
248
249 @staticmethod
250 def create_collection_request(name: str, source: str) -> CreateCollectionRequest:
251 return CreateCollectionRequest(name=name, source=source)