Source code for weaviate.util

"""
Helper functions!
"""

import base64
import datetime
import io
import json
import os
import re
from enum import Enum, EnumMeta
from pathlib import Path
from typing import Union, Sequence, Any, Optional, List, Dict, Generator, Tuple, cast

import requests
import httpx
import uuid as uuid_lib
import validators
from requests.exceptions import JSONDecodeError

from weaviate.exceptions import (
    SchemaValidationError,
    UnexpectedStatusCodeError,
    ResponseCannotBeDecodedError,
    WeaviateInvalidInputError,
)
from weaviate.warnings import _Warnings
from weaviate.types import NUMBER, UUIDS, TIME

PYPI_PACKAGE_URL = "https://pypi.org/pypi/weaviate-client/json"
MAXIMUM_MINOR_VERSION_DELTA = 3  # The maximum delta between minor versions of Weaviate Client that will not trigger an upgrade warning.
MINIMUM_NO_WARNING_VERSION = (
    "v1.16.0"  # The minimum version of Weaviate that will not trigger an upgrade warning.
)
BYTES_PER_CHUNK = 65535  # The number of bytes to read per chunk when encoding files ~ 64kb


# MetaEnum and BaseEnum are required to support `in` statements:
#    'ALL' in ConsistencyLevel == True
#    12345 in ConsistencyLevel == False
[docs] class MetaEnum(EnumMeta): def __contains__(cls, item: Any) -> bool: try: # when item is type ConsistencyLevel return item.name in cls.__members__.keys() except AttributeError: # when item is type str return item in cls.__members__.keys()
[docs] class BaseEnum(Enum, metaclass=MetaEnum): pass
[docs] def image_encoder_b64(image_or_image_path: Union[str, io.BufferedReader]) -> str: """ Encode a image in a Weaviate understandable format from a binary read file or by providing the image path. Parameters ---------- image_or_image_path : str, io.BufferedReader The binary read file or the path to the file. Returns ------- str Encoded image. Raises ------ ValueError If the argument is str and does not point to an existing file. TypeError If the argument is of a wrong data type. """ if isinstance(image_or_image_path, str): if not os.path.isfile(image_or_image_path): raise ValueError("No file found at location " + image_or_image_path) with open(image_or_image_path, "br") as file: content = file.read() elif isinstance(image_or_image_path, io.BufferedReader): content = image_or_image_path.read() else: raise TypeError( '"image_or_image_path" should be a image path or a binary read file' " (io.BufferedReader)" ) return base64.b64encode(content).decode("utf-8")
[docs] def file_encoder_b64(file_or_file_path: Union[str, Path, io.BufferedReader]) -> str: """ Encode a file in a Weaviate understandable format from an io.BufferedReader binary read file or by providing the file path as either a string of a pathlib.Path object If you pass an io.BufferedReader object, it is your responsibility to close it after encoding. Parameters ---------- file_or_file_path : str, pathlib.Path io.BufferedReader The binary read file or the path to the file. Returns ------- str Encoded file. Raises ------ ValueError If the argument is str and does not point to an existing file. TypeError If the argument is of a wrong data type. """ def _chunks(buffer: io.BufferedReader, chunk_size: int) -> Generator[bytes, Any, Any]: while True: data = buffer.read(chunk_size) if not data: break yield data should_close_file = False use_buffering = True if isinstance(file_or_file_path, str): if not os.path.isfile(file_or_file_path): raise ValueError("No file found at location " + file_or_file_path) file = open(file_or_file_path, "br") should_close_file = True use_buffering = os.path.getsize(file_or_file_path) > BYTES_PER_CHUNK elif isinstance(file_or_file_path, Path): if not file_or_file_path.is_file(): raise ValueError("No file found at location " + str(file_or_file_path)) file = file_or_file_path.open("br") should_close_file = True use_buffering = file_or_file_path.stat().st_size > BYTES_PER_CHUNK elif isinstance(file_or_file_path, io.BufferedReader): file = file_or_file_path else: raise TypeError( '"file_or_file_path" should be a file path or a binary read file' " (io.BufferedReader)" ) if use_buffering: encoded: str = "" for chunk in _chunks(file, BYTES_PER_CHUNK): encoded += base64.b64encode(chunk).decode("utf-8") else: encoded = base64.b64encode(file.read()).decode("utf-8") if should_close_file: file.close() return encoded
[docs] def image_decoder_b64(encoded_image: str) -> bytes: """ Decode image from a Weaviate format image. Parameters ---------- encoded_image : str The encoded image. Returns ------- bytes Decoded image as a binary string. """ return base64.b64decode(encoded_image.encode("utf-8"))
[docs] def file_decoder_b64(encoded_file: str) -> bytes: """ Decode file from a Weaviate format image. Parameters ---------- encoded_file : str The encoded file. Returns ------- bytes Decoded file as a binary string. Use this in your file handling code to convert it into a specific file type of choice. E.g., PIL for images. """ return base64.b64decode(encoded_file.encode("utf-8"))
[docs] def generate_local_beacon( to_uuid: Union[str, uuid_lib.UUID], class_name: Optional[str] = None, ) -> dict: """ Generates a beacon with the given uuid and class name (only for Weaviate >= 1.14.0). Parameters ---------- to_uuid : str or uuid.UUID The UUID for which to create a local beacon. class_name : Optional[str], optional The class name of the `to_uuid` object. Used with Weaviate >= 1.14.0. For Weaviate < 1.14.0 use None value. Returns ------- dict The local beacon. Raises ------ TypeError If 'to_uuid' is not of type str. ValueError If the 'to_uuid' is not valid. """ if isinstance(to_uuid, str): try: uuid = str(uuid_lib.UUID(to_uuid)) except ValueError: raise ValueError("Uuid does not have the proper form") from None elif isinstance(to_uuid, uuid_lib.UUID): uuid = str(to_uuid) else: raise TypeError("Expected to_object_uuid of type str or uuid.UUID") if class_name is None: return {"beacon": f"weaviate://localhost/{uuid}"} return {"beacon": f"weaviate://localhost/{_capitalize_first_letter(class_name)}/{uuid}"}
def _get_dict_from_object(object_: Union[str, dict]) -> dict: """ Takes an object that should describe a dict e.g. a schema or an object and tries to retrieve the dict. Parameters ---------- object_ : str or dict The object from which to retrieve the dict. Can be a python dict, or the path to a json file or a url of a json file. Returns ------- dict The object as a dict. Raises ------ TypeError If 'object_' is neither a string nor a dict. ValueError If no dict can be retrieved from object. """ # check if objects files is url if object_ is None: raise TypeError("argument is None") if isinstance(object_, dict): # Object is already a dict return object_ if isinstance(object_, str): if validators.url(object_): # Object is URL response = requests.get(object_) if response.status_code == 200: return cast(dict, response.json()) raise ValueError("Could not download file " + object_) if not os.path.isfile(object_): # Object is neither file nor URL raise ValueError("No file found at location " + object_) # Object is file with open(object_, "r") as file: return cast(dict, json.load(file)) raise TypeError( "Argument is not of the supported types. Supported types are " "url or file path as string or schema as dict." )
[docs] def is_weaviate_object_url(url: str) -> bool: """ Checks if the input follows a normal Weaviate 'beacon' like this: 'weaviate://localhost/ClassName/28f3f61b-b524-45e0-9bbe-2c1550bf73d2' Parameters ---------- url : str The URL to be validated. Returns ------- bool True if the 'url' is a Weaviate object URL. False otherwise. """ if not isinstance(url, str): return False if not url.startswith("weaviate://"): return False url = url[11:] split = url.split("/") if len(split) not in (2, 3): return False if split[0] != "localhost": if not validators.domain(split[0]): return False try: uuid_lib.UUID(split[-1]) except ValueError: return False return True
[docs] def is_object_url(url: str) -> bool: """ Validates an url like 'http://localhost:8080/v1/objects/1c9cd584-88fe-5010-83d0-017cb3fcb446' or '/v1/objects/1c9cd584-88fe-5010-83d0-017cb3fcb446' references a object. It only validates the path format and UUID, not the host or the protocol. Parameters ---------- url : str The URL to be validated. Returns ------- bool True if the 'url' is a valid path to an object. False otherwise. """ v1_split = url.split("/v1/") if len(v1_split) != 2: return False split = v1_split[1].split("/") if len(split) not in (2, 3): return False try: uuid_lib.UUID(split[-1]) except ValueError: return False if not split[0] == "objects": return False return True
[docs] def get_valid_uuid(uuid: Union[str, uuid_lib.UUID]) -> str: """ Validate and extract the UUID. Parameters ---------- uuid : str or uuid.UUID The UUID to be validated and extracted. Should be in the form of an UUID or in form of an URL (weaviate 'beacon' or 'href'). E.g. 'http://localhost:8080/v1/objects/fc7eb129-f138-457f-b727-1b29db191a67' or 'weaviate://localhost/28f3f61b-b524-45e0-9bbe-2c1550bf73d2' or 'fc7eb129-f138-457f-b727-1b29db191a67' Returns ------- str The extracted UUID. Raises ------ TypeError If 'uuid' is not of type str. ValueError If 'uuid' is not valid or cannot be extracted. """ if isinstance(uuid, uuid_lib.UUID): return str(uuid) if not isinstance(uuid, str): raise TypeError("'uuid' must be of type str or uuid.UUID, but was: " + str(type(uuid))) _is_weaviate_url = is_weaviate_object_url(uuid) _is_object_url = is_object_url(uuid) _uuid = uuid if _is_weaviate_url or _is_object_url: _uuid = uuid.split("/")[-1] try: _uuid = str(uuid_lib.UUID(_uuid)) except ValueError: raise ValueError("Not valid 'uuid' or 'uuid' can not be extracted from value") from None return _uuid
[docs] def get_vector(vector: Sequence) -> List[float]: """ Get weaviate compatible format of the embedding vector. Parameters ---------- vector: Sequence The embedding of an object. Used only for class objects that do not have a vectorization module. Supported types are `list`, `numpy.ndarray`, `torch.Tensor`, `tf.Tensor`, `pd.Series` and `pl.Series`. Returns ------- list The embedding as a list. Raises ------ TypeError If 'vector' is not of a supported type. """ if isinstance(vector, list): # if vector is already a list return vector try: # if vector is numpy.ndarray or torch.Tensor return vector.squeeze().tolist() # type: ignore except AttributeError: pass try: # if vector is tf.Tensor or torch.Tensor return vector.numpy().squeeze().tolist() # type: ignore except AttributeError: pass try: # if vector is pd.Series or pl.Series return vector.to_list() # type: ignore except AttributeError: pass raise TypeError( "The type of the 'vector' argument is not supported!\n" "Supported types are `list`, 'numpy.ndarray`, `torch.Tensor`, `tf.Tensor`, `pd.Series`, and `pl.Series`" ) from None
def _get_vector_v4(vector: Sequence) -> List[float]: try: return get_vector(vector) except TypeError as e: raise WeaviateInvalidInputError( f"The vector you supplied was malformatted! Vector: {vector}" ) from e
[docs] def get_domain_from_weaviate_url(url: str) -> str: """ Get the domain from a weaviate URL. Parameters ---------- url : str The weaviate URL. Of this form: 'weaviate://localhost/objects/28f3f61b-b524-45e0-9bbe-2c1550bf73d2' Returns ------- str The domain. """ return url[11:].split("/")[0]
def _is_sub_schema(sub_schema: dict, schema: dict) -> bool: """ Check for a subset in a schema. Parameters ---------- sub_schema : dict The smaller schema that should be contained in the 'schema'. schema : dict The schema for which to check if 'sub_schema' is a part of. Must have the 'classes' key. Returns ------- bool True is 'sub_schema' is a subset of the 'schema'. False otherwise. """ schema_classes = schema.get("classes", []) if "classes" in sub_schema: sub_schema_classes = sub_schema["classes"] else: sub_schema_classes = [sub_schema] return _compare_class_sets(sub_schema_classes, schema_classes) def _compare_class_sets(sub_set: list, set_: list) -> bool: """ Check for a subset in a set of classes. Parameters ---------- sub_set : list The smaller set that should be contained in the 'set'. set_ : list The set for which to check if 'sub_set' is a part of. Returns ------- bool True is 'sub_set' is a subset of the 'set'. False otherwise. """ for sub_set_class in sub_set: found = False for set_class in set_: if "class" not in sub_set_class: raise SchemaValidationError( "The sub schema class/es MUST have a 'class' keyword each!" ) if _capitalize_first_letter(sub_set_class["class"]) == _capitalize_first_letter( set_class["class"] ): if _compare_properties(sub_set_class["properties"], set_class["properties"]): found = True break if not found: return False return True def _compare_properties(sub_set: list, set_: list) -> bool: """ Check for a subset in a set of properties. Parameters ---------- sub_set : list The smaller set that should be contained in the 'set'. set_ : list The set for which to check if 'sub_set' is a part of. Returns ------- bool True is 'sub_set' is a subset of the 'set'. False otherwise. """ for sub_set_property in sub_set: found = False for set_property in set_: if sub_set_property["name"] == set_property["name"]: found = True break if not found: return False return True
[docs] def generate_uuid5(identifier: Any, namespace: Any = "") -> str: """ Generate an UUIDv5, may be used to consistently generate the same UUID for a specific identifier and namespace. Parameters ---------- identifier : Any The identifier/object that should be used as basis for the UUID. namespace : Any, optional Allows to namespace the identifier, by default "" Returns ------- str The UUID as a string. """ return str(uuid_lib.uuid5(uuid_lib.NAMESPACE_DNS, str(namespace) + str(identifier)))
def _capitalize_first_letter(string: str) -> str: """ Capitalize only the first letter of the `string`. Parameters ---------- string : str The string to be capitalized. Returns ------- str The capitalized string. """ if len(string) == 1: return string.capitalize() return string[0].capitalize() + string[1:]
[docs] def check_batch_result( results: Optional[List[Dict[str, Any]]], ) -> None: """ Check batch results for errors. Parameters ---------- results : dict The Weaviate batch creation return value. """ if results is None: return for result in results: if "result" in result and "errors" in result["result"]: if "error" in result["result"]["errors"]: print(result["result"]["errors"])
def _check_positive_num( value: Any, arg_name: str, data_type: type, include_zero: bool = False ) -> None: """ Check if the `value` of the `arg_name` is a positive number. Parameters ---------- value : Union[int, float] The value to check. arg_name : str The name of the variable from the original function call. Used for error message. data_type : type The data type to check for. include_zero : bool Wether zero counts as positive or not. By default False. Raises ------ TypeError If the `value` is not of type `data_type`. ValueError If the `value` has a non positive value. """ if not isinstance(value, data_type) or isinstance(value, bool): raise TypeError(f"'{arg_name}' must be of type {data_type}.") if include_zero: if value < 0: # type: ignore raise ValueError(f"'{arg_name}' must be positive, i.e. greater or equal to zero (>=0).") else: if value <= 0: # type: ignore raise ValueError(f"'{arg_name}' must be positive, i.e. greater that zero (>0).")
[docs] def is_weaviate_domain(url: str) -> bool: return ( "weaviate.io" in url.lower() or "semi.technology" in url.lower() or "weaviate.cloud" in url.lower() )
[docs] def strip_newlines(s: str) -> str: return s.replace("\n", " ")
def _sanitize_str(value: str) -> str: """ Ensures string is sanitized for GraphQL. Parameters ---------- value : str The value to be converted. Returns ------- str The sanitized string. """ value = strip_newlines(value) value = re.sub(r'(?<!\\)"', '\\"', value) # only replaces unescaped double quotes return f'"{value}"'
[docs] def parse_version_string(ver_str: str) -> tuple: """ Parse a version string into a float. Parameters ---------- ver_str : str The version string to parse. (e.g. "v1.18.2" or "1.18.0") Returns ------- tuple : The parsed version as a tuple with len(2). (e.g. (1, 18)) Note: Ignores the patch version. """ if ver_str.count(".") == 0: ver_str = ver_str + ".0" pattern = r"v?(\d+)\.(\d+)" match = re.match(pattern, ver_str) if match: ver_tup = tuple(map(int, match.groups())) return ver_tup else: raise ValueError( f"Unable to parse a version from the input string: {ver_str}. Is it in the format '(v)x.y.z' (e.g. 'v1.18.2' or '1.18.0')?" )
class _ServerVersion: def __init__(self, major: int, minor: int, patch: int) -> None: self.major = major self.minor = minor self.patch = patch def __eq__(self, other: object) -> bool: if not isinstance(other, _ServerVersion): return NotImplemented return self.major == other.major and self.minor == other.minor and self.patch == other.patch def __neq__(self, other: object) -> bool: return not self.__eq__(other) def __gt__(self, other: "_ServerVersion") -> bool: if self.major > other.major: return True elif self.major == other.major: if self.minor > other.minor: return True elif self.minor == other.minor: if self.patch > other.patch: return True return False def __lt__(self, other: "_ServerVersion") -> bool: return not self.__gt__(other) and not self.__eq__(other) def __ge__(self, other: "_ServerVersion") -> bool: return self.__gt__(other) or self.__eq__(other) def __le__(self, other: "_ServerVersion") -> bool: return self.__lt__(other) or self.__eq__(other) def __repr__(self) -> str: return f"{self.major}.{self.minor}.{self.patch}" def __str__(self) -> str: return f"{self.major}.{self.minor}.{self.patch}" def is_at_least(self, major: int, minor: int, patch: int) -> bool: return self >= _ServerVersion(major, minor, patch) def is_lower_than(self, major: int, minor: int, patch: int) -> bool: return self < _ServerVersion(major, minor, patch) @classmethod def from_string(cls, version: str) -> "_ServerVersion": initial = version if version == "": version = "0" if version.count(".") == 0: version = version + ".0" if version.count(".") == 1: version = version + ".0" pattern = r"v?(\d+)\.(\d+)\.(\d+)" match = re.match(pattern, version) if match: ver_tup = tuple(map(int, match.groups())) return cls(major=ver_tup[0], minor=ver_tup[1], patch=ver_tup[2]) else: raise ValueError( f"Unable to parse a version from the input string: {initial}. Is it in the format '(v)x.y.z' (e.g. 'v1.18.2' or '1.18.0')?" )
[docs] def is_weaviate_too_old(current_version_str: str) -> bool: """ Check if the user should be gently nudged to upgrade their Weaviate server version. Parameters ---------- current_version_str : str The version of the Weaviate server that the client is connected to. (e.g. "v1.18.2" or "1.18.0") Returns ------- bool : True if the user should be nudged to upgrade. """ current_version = parse_version_string(current_version_str) minimum_version = parse_version_string(MINIMUM_NO_WARNING_VERSION) return minimum_version > current_version
[docs] def is_weaviate_client_too_old(current_version_str: str, latest_version_str: str) -> bool: """ Check if the user should be gently nudged to upgrade their Weaviate client version. Parameters ---------- current_version_str : str The version of the Weaviate client that is being used (e.g. "v1.18.2" or "1.18.0") latest_version_str : str The latest version of the Weaviate client to compare against (e.g. "v1.18.2" or "1.18.0") Returns ------- bool : True if the user should be nudged to upgrade. False if the user is using a valid version or if the version could not be parsed. """ try: current_version = parse_version_string(current_version_str) latest_major, latest_minor = parse_version_string(latest_version_str) minimum_minor = max(latest_minor - MAXIMUM_MINOR_VERSION_DELTA, 0) minimum_version = (latest_major, minimum_minor) return minimum_version > current_version except ValueError: return False
def _get_valid_timeout_config( timeout_config: Union[Tuple[NUMBER, NUMBER], NUMBER, None] ) -> Tuple[NUMBER, NUMBER]: """ Validate and return TimeOut configuration. Parameters ---------- timeout_config : tuple(NUMBERS, NUMBERS) or NUMBERS or None, optional Set the timeout configuration for all requests to the Weaviate server. It can be a number or, a tuple of two numbers: (connect timeout, read timeout). If only one number is passed then both connect and read timeout will be set to that value. Raises ------ TypeError If arguments are of a wrong data type. ValueError If 'timeout_config' is not a tuple of 2. ValueError If 'timeout_config' is/contains negative number/s. """ def check_number(num: Union[NUMBER, Tuple[NUMBER, NUMBER], None]) -> bool: return isinstance(num, float) or isinstance(num, int) if (isinstance(timeout_config, float) or isinstance(timeout_config, int)) and not isinstance( timeout_config, bool ): assert timeout_config is not None if timeout_config <= 0.0: raise ValueError("'timeout_config' cannot be non-positive number/s!") return timeout_config, timeout_config if not isinstance(timeout_config, tuple): raise TypeError("'timeout_config' should be a (or tuple of) positive number/s!") if len(timeout_config) != 2: raise ValueError("'timeout_config' must be of length 2!") if not (check_number(timeout_config[0]) and check_number(timeout_config[1])) or ( isinstance(timeout_config[0], bool) and isinstance(timeout_config[1], bool) ): raise TypeError("'timeout_config' must be tuple of numbers") if timeout_config[0] <= 0.0 or timeout_config[1] <= 0.0: raise ValueError("'timeout_config' cannot be non-positive number/s!") return timeout_config def _type_request_response(json_response: Any) -> Optional[Dict[str, Any]]: if json_response is None: return None assert isinstance(json_response, dict) return json_response def _to_beacons(uuids: UUIDS, to_class: str = "") -> List[Dict[str, str]]: if isinstance(uuids, uuid_lib.UUID) or isinstance( uuids, str ): # replace with isinstance(uuids, UUID) in 3.10 uuids = [uuids] if len(to_class) > 0: to_class = to_class + "/" return [{"beacon": f"weaviate://localhost/{to_class}{uuid_to}"} for uuid_to in uuids] def _decode_json_response_dict( response: Union[httpx.Response, requests.Response], location: str ) -> Optional[Dict[str, Any]]: if response is None: return None if 200 <= response.status_code < 300: try: json_response = cast(Dict[str, Any], response.json()) return json_response except JSONDecodeError: raise ResponseCannotBeDecodedError(location, response) raise UnexpectedStatusCodeError(location, response) def _decode_json_response_list( response: Union[httpx.Response, requests.Response], location: str ) -> Optional[List[Dict[str, Any]]]: if response is None: return None if 200 <= response.status_code < 300: try: json_response = response.json() return cast(list, json_response) except JSONDecodeError: raise ResponseCannotBeDecodedError(location, response) raise UnexpectedStatusCodeError(location, response) def _datetime_to_string(value: TIME) -> str: if value.tzinfo is None: _Warnings.datetime_insertion_with_no_specified_timezone(value) value = value.replace(tzinfo=datetime.timezone.utc) return value.isoformat(sep="T", timespec="microseconds") def _datetime_from_weaviate_str(string: str) -> datetime.datetime: try: return datetime.datetime.strptime( "".join(string.rsplit(":", 1) if string[-1] != "Z" else string), "%Y-%m-%dT%H:%M:%S.%f%z", ) except ValueError: # if the string does not have microseconds return datetime.datetime.strptime( "".join(string.rsplit(":", 1) if string[-1] != "Z" else string), "%Y-%m-%dT%H:%M:%S%z", )