import datetime
import os
import time
from abc import ABC, abstractmethod
from typing import Dict, Literal, Tuple, TypeVar, Union, cast, overload
from urllib.parse import urlparse

import grpc  # type: ignore
from grpc import Channel, ssl_channel_credentials
from grpc.aio import Channel as AsyncChannel  # type: ignore

from pydantic import BaseModel, field_validator, model_validator

from weaviate.config import Proxies
from weaviate.types import NUMBER

JSONPayload = Union[dict, list]
MAX_GRPC_MESSAGE_LENGTH = 104858000  # 10mb, needs to be synchronized with GRPC server
    ("grpc.max_send_message_length", MAX_GRPC_MESSAGE_LENGTH),
    ("grpc.max_receive_message_length", MAX_GRPC_MESSAGE_LENGTH),

[docs] class ProtocolParams(BaseModel): host: str port: int secure: bool @field_validator("host") def _check_host(cls, v: str) -> str: if v == "": raise ValueError("host must not be empty") return v @field_validator("port") def _check_port(cls, v: int) -> int: if v < 0 or v > 65535: raise ValueError("port must be between 0 and 65535") return v
T = TypeVar("T", bound="ConnectionParams")
[docs] class ConnectionParams(BaseModel): http: ProtocolParams grpc: ProtocolParams
[docs] @classmethod def from_url(cls, url: str, grpc_port: int, grpc_secure: bool = False) -> "ConnectionParams": parsed_url = urlparse(url) if parsed_url.scheme not in ["http", "https"]: raise ValueError(f"Unsupported scheme: {parsed_url.scheme}") if parsed_url.port is None: port = 443 if parsed_url.scheme == "https" else 80 else: port = parsed_url.port return cls( http=ProtocolParams( host=cast(str, parsed_url.hostname), port=port, secure=parsed_url.scheme == "https", ), grpc=ProtocolParams( host=cast(str, parsed_url.hostname), port=grpc_port, secure=grpc_secure or parsed_url.scheme == "https", ), )
[docs] @classmethod def from_params( cls, http_host: str, http_port: int, http_secure: bool, grpc_host: str, grpc_port: int, grpc_secure: bool, ) -> "ConnectionParams": return cls( http=ProtocolParams( host=http_host, port=http_port, secure=http_secure, ), grpc=ProtocolParams( host=grpc_host, port=grpc_port, secure=grpc_secure, ), )
@model_validator(mode="after") def _check_port_collision(self: T) -> T: if == and self.http.port == self.grpc.port: raise ValueError("http.port and grpc.port must be different if using the same host") return self @property def _grpc_address(self) -> Tuple[str, int]: return (, self.grpc.port) @property def _grpc_target(self) -> str: return f"{}:{self.grpc.port}" @overload def _grpc_channel(self, async_channel: Literal[False], proxies: Dict[str, str]) -> Channel: ... @overload def _grpc_channel(self, async_channel: Literal[True], proxies: Dict[str, str]) -> AsyncChannel: ... def _grpc_channel( self, async_channel: bool, proxies: Dict[str, str] ) -> Union[Channel, AsyncChannel]: if async_channel: import_path = grpc.aio else: import_path = grpc if (p := proxies.get("grpc")) is not None: options: list = [*GRPC_DEFAULT_OPTIONS, ("grpc.http_proxy", p)] else: options = GRPC_DEFAULT_OPTIONS if return import_path.secure_channel( target=self._grpc_target, credentials=ssl_channel_credentials(), options=options, ) else: return import_path.insecure_channel( target=self._grpc_target, options=options, ) @property def _http_scheme(self) -> str: return "https" if else "http" @property def _http_url(self) -> str: return f"{self._http_scheme}://{}:{self.http.port}"
class _ConnectionBase(ABC): @abstractmethod def get_current_bearer_token(self) -> str: raise NotImplementedError @abstractmethod def get_proxies(self) -> dict: raise NotImplementedError def _get_proxies(proxies: Union[dict, str, Proxies, None], trust_env: bool) -> dict: """ Get proxies as dict, compatible with 'requests' library. NOTE: 'proxies' has priority over 'trust_env', i.e. if 'proxies' is NOT None, 'trust_env' is ignored. Parameters ---------- proxies : dict, str or None The proxies to use for requests. If it is a dict it should follow 'requests' library format ( If it is a URL (str), a dict will be constructed with both 'http' and 'https' pointing to that URL. If None, no proxies will be used. trust_env : bool If True, the proxies will be read from ENV VARs (case insensitive): HTTP_PROXY/HTTPS_PROXY. NOTE: It is ignored if 'proxies' is NOT None. Returns ------- dict A dictionary with proxies, either set from 'proxies' or read from ENV VARs. """ if proxies is not None: if isinstance(proxies, str): return { "http": proxies, "https": proxies, "grpc": proxies, } if isinstance(proxies, dict): return proxies if isinstance(proxies, Proxies): return proxies.model_dump(exclude_none=True) raise TypeError( "If 'proxies' is not None, it must be of type dict, str, or wvc.init.Proxies. " f"Given type: {type(proxies)}." ) if not trust_env: return {} http_proxy = (os.environ.get("HTTP_PROXY"), os.environ.get("http_proxy")) https_proxy = (os.environ.get("HTTPS_PROXY"), os.environ.get("https_proxy")) grpc_proxy = (os.environ.get("GRPC_PROXY"), os.environ.get("grpc_proxy")) if not any(http_proxy + https_proxy + grpc_proxy): return {} proxies = {} if any(http_proxy): proxies["http"] = http_proxy[0] if http_proxy[0] else http_proxy[1] if any(https_proxy): proxies["https"] = https_proxy[0] if https_proxy[0] else https_proxy[1] if any(grpc_proxy): proxies["grpc"] = grpc_proxy[0] if grpc_proxy[0] else grpc_proxy[1] return proxies def _get_epoch_time() -> int: """ Get the current epoch time as an integer. Returns ------- int Current epoch time. """ dts = datetime.datetime.utcnow() return round(time.mktime(dts.timetuple()) + dts.microsecond / 1e6)