pinecone.control.request_factory

  1import logging
  2from typing import Optional, Dict, Any, Union
  3from enum import Enum
  4
  5
  6from pinecone.utils import convert_enum_to_string
  7from pinecone.core.openapi.db_control.models import (
  8    CreateCollectionRequest,
  9    CreateIndexForModelRequest,
 10    CreateIndexForModelRequestEmbed,
 11    CreateIndexRequest,
 12    ConfigureIndexRequest,
 13    ConfigureIndexRequestSpec,
 14    ConfigureIndexRequestSpecPod,
 15    DeletionProtection as DeletionProtectionModel,
 16    IndexSpec,
 17    IndexTags,
 18    ServerlessSpec as ServerlessSpecModel,
 19    PodSpec as PodSpecModel,
 20    PodSpecMetadataConfig,
 21)
 22from pinecone.models import ServerlessSpec, PodSpec, IndexModel, IndexEmbed
 23from pinecone.utils import parse_non_empty_args
 24
 25from pinecone.enums import (
 26    Metric,
 27    VectorType,
 28    DeletionProtection,
 29    PodType,
 30    CloudProvider,
 31    AwsRegion,
 32    GcpRegion,
 33    AzureRegion,
 34)
 35from .types import CreateIndexForModelEmbedTypedDict
 36
 37
 38logger = logging.getLogger(__name__)
 39""" @private """
 40
 41
 42class PineconeDBControlRequestFactory:
 43    """
 44    @private
 45
 46    This class facilitates translating user inputs into request objects.
 47    """
 48
 49    @staticmethod
 50    def __parse_tags(tags: Optional[Dict[str, str]]) -> IndexTags:
 51        if tags is None:
 52            return IndexTags()
 53        else:
 54            return IndexTags(**tags)
 55
 56    @staticmethod
 57    def __parse_deletion_protection(
 58        deletion_protection: Union[DeletionProtection, str],
 59    ) -> DeletionProtectionModel:
 60        deletion_protection = convert_enum_to_string(deletion_protection)
 61        if deletion_protection in ["enabled", "disabled"]:
 62            return DeletionProtectionModel(deletion_protection)
 63        else:
 64            raise ValueError("deletion_protection must be either 'enabled' or 'disabled'")
 65
 66    @staticmethod
 67    def __parse_index_spec(spec: Union[Dict, ServerlessSpec, PodSpec]) -> IndexSpec:
 68        if isinstance(spec, dict):
 69            if "serverless" in spec:
 70                spec["serverless"]["cloud"] = convert_enum_to_string(spec["serverless"]["cloud"])
 71                spec["serverless"]["region"] = convert_enum_to_string(spec["serverless"]["region"])
 72
 73                index_spec = IndexSpec(serverless=ServerlessSpecModel(**spec["serverless"]))
 74            elif "pod" in spec:
 75                spec["pod"]["environment"] = convert_enum_to_string(spec["pod"]["environment"])
 76                args_dict = parse_non_empty_args(
 77                    [
 78                        ("environment", spec["pod"].get("environment")),
 79                        ("metadata_config", spec["pod"].get("metadata_config")),
 80                        ("replicas", spec["pod"].get("replicas")),
 81                        ("shards", spec["pod"].get("shards")),
 82                        ("pods", spec["pod"].get("pods")),
 83                        ("source_collection", spec["pod"].get("source_collection")),
 84                    ]
 85                )
 86                if args_dict.get("metadata_config"):
 87                    args_dict["metadata_config"] = PodSpecMetadataConfig(
 88                        indexed=args_dict["metadata_config"].get("indexed", None)
 89                    )
 90                index_spec = IndexSpec(pod=PodSpecModel(**args_dict))
 91            else:
 92                raise ValueError("spec must contain either 'serverless' or 'pod' key")
 93        elif isinstance(spec, ServerlessSpec):
 94            index_spec = IndexSpec(
 95                serverless=ServerlessSpecModel(cloud=spec.cloud, region=spec.region)
 96            )
 97        elif isinstance(spec, PodSpec):
 98            args_dict = parse_non_empty_args(
 99                [
100                    ("replicas", spec.replicas),
101                    ("shards", spec.shards),
102                    ("pods", spec.pods),
103                    ("source_collection", spec.source_collection),
104                ]
105            )
106            if spec.metadata_config:
107                args_dict["metadata_config"] = PodSpecMetadataConfig(
108                    indexed=spec.metadata_config.get("indexed", None)
109                )
110
111            index_spec = IndexSpec(
112                pod=PodSpecModel(environment=spec.environment, pod_type=spec.pod_type, **args_dict)
113            )
114        else:
115            raise TypeError("spec must be of type dict, ServerlessSpec, or PodSpec")
116
117        return index_spec
118
119    @staticmethod
120    def create_index_request(
121        name: str,
122        spec: Union[Dict, ServerlessSpec, PodSpec],
123        dimension: Optional[int] = None,
124        metric: Optional[Union[Metric, str]] = Metric.COSINE,
125        deletion_protection: Optional[Union[DeletionProtection, str]] = DeletionProtection.DISABLED,
126        vector_type: Optional[Union[VectorType, str]] = VectorType.DENSE,
127        tags: Optional[Dict[str, str]] = None,
128    ) -> CreateIndexRequest:
129        if metric is not None:
130            metric = convert_enum_to_string(metric)
131        if vector_type is not None:
132            vector_type = convert_enum_to_string(vector_type)
133        if deletion_protection is not None:
134            dp = PineconeDBControlRequestFactory.__parse_deletion_protection(deletion_protection)
135        else:
136            dp = None
137
138        tags_obj = PineconeDBControlRequestFactory.__parse_tags(tags)
139        index_spec = PineconeDBControlRequestFactory.__parse_index_spec(spec)
140
141        if vector_type == VectorType.SPARSE.value and dimension is not None:
142            raise ValueError("dimension should not be specified for sparse indexes")
143
144        args = parse_non_empty_args(
145            [
146                ("name", name),
147                ("dimension", dimension),
148                ("metric", metric),
149                ("spec", index_spec),
150                ("deletion_protection", dp),
151                ("vector_type", vector_type),
152                ("tags", tags_obj),
153            ]
154        )
155
156        return CreateIndexRequest(**args)
157
158    @staticmethod
159    def create_index_for_model_request(
160        name: str,
161        cloud: Union[CloudProvider, str],
162        region: Union[AwsRegion, GcpRegion, AzureRegion, str],
163        embed: Union[IndexEmbed, CreateIndexForModelEmbedTypedDict],
164        tags: Optional[Dict[str, str]] = None,
165        deletion_protection: Optional[Union[DeletionProtection, str]] = DeletionProtection.DISABLED,
166    ) -> CreateIndexForModelRequest:
167        cloud = convert_enum_to_string(cloud)
168        region = convert_enum_to_string(region)
169        if deletion_protection is not None:
170            dp = PineconeDBControlRequestFactory.__parse_deletion_protection(deletion_protection)
171        else:
172            dp = None
173        tags_obj = PineconeDBControlRequestFactory.__parse_tags(tags)
174
175        if isinstance(embed, IndexEmbed):
176            parsed_embed = embed.as_dict()
177        else:
178            # if dict, we need to parse enum values, if any, to string
179            # and verify required fields are present
180            required_fields = ["model", "field_map"]
181            for field in required_fields:
182                if field not in embed:
183                    raise ValueError(f"{field} is required in embed")
184            parsed_embed = {}
185            for key, value in embed.items():
186                if isinstance(value, Enum):
187                    parsed_embed[key] = convert_enum_to_string(value)
188                else:
189                    parsed_embed[key] = value
190
191        args = parse_non_empty_args(
192            [
193                ("name", name),
194                ("cloud", cloud),
195                ("region", region),
196                ("embed", CreateIndexForModelRequestEmbed(**parsed_embed)),
197                ("deletion_protection", dp),
198                ("tags", tags_obj),
199            ]
200        )
201
202        return CreateIndexForModelRequest(**args)
203
204    @staticmethod
205    def configure_index_request(
206        description: IndexModel,
207        replicas: Optional[int] = None,
208        pod_type: Optional[Union[PodType, str]] = None,
209        deletion_protection: Optional[Union[DeletionProtection, str]] = None,
210        tags: Optional[Dict[str, str]] = None,
211    ):
212        if deletion_protection is None:
213            dp = DeletionProtectionModel(description.deletion_protection)
214        elif isinstance(deletion_protection, DeletionProtection):
215            dp = DeletionProtectionModel(deletion_protection.value)
216        elif deletion_protection in ["enabled", "disabled"]:
217            dp = DeletionProtectionModel(deletion_protection)
218        else:
219            raise ValueError("deletion_protection must be either 'enabled' or 'disabled'")
220
221        fetched_tags = description.tags
222        if fetched_tags is None:
223            starting_tags = {}
224        else:
225            starting_tags = fetched_tags.to_dict()
226
227        if tags is None:
228            # Do not modify tags if none are provided
229            tags = starting_tags
230        else:
231            # Merge existing tags with new tags
232            tags = {**starting_tags, **tags}
233
234        pod_config_args: Dict[str, Any] = {}
235        if pod_type:
236            new_pod_type = convert_enum_to_string(pod_type)
237            pod_config_args.update(pod_type=new_pod_type)
238        if replicas:
239            pod_config_args.update(replicas=replicas)
240
241        if pod_config_args != {}:
242            spec = ConfigureIndexRequestSpec(pod=ConfigureIndexRequestSpecPod(**pod_config_args))
243            req = ConfigureIndexRequest(deletion_protection=dp, spec=spec, tags=IndexTags(**tags))
244        else:
245            req = ConfigureIndexRequest(deletion_protection=dp, tags=IndexTags(**tags))
246
247        return req
248
249    @staticmethod
250    def create_collection_request(name: str, source: str) -> CreateCollectionRequest:
251        return CreateCollectionRequest(name=name, source=source)