"""
Helper functions!
"""
import base64
import datetime
import io
import json
import os
import re
from enum import Enum, EnumMeta
from pathlib import Path
from typing import Union, Sequence, Any, Optional, List, Dict, Generator, Tuple, cast
import requests
import httpx
import uuid as uuid_lib
import validators
from requests.exceptions import JSONDecodeError
from weaviate.exceptions import (
SchemaValidationError,
UnexpectedStatusCodeError,
ResponseCannotBeDecodedError,
WeaviateInvalidInputError,
)
from weaviate.warnings import _Warnings
from weaviate.types import NUMBER, UUIDS, TIME
PYPI_PACKAGE_URL = "https://pypi.org/pypi/weaviate-client/json"
MAXIMUM_MINOR_VERSION_DELTA = 3 # The maximum delta between minor versions of Weaviate Client that will not trigger an upgrade warning.
MINIMUM_NO_WARNING_VERSION = (
"v1.16.0" # The minimum version of Weaviate that will not trigger an upgrade warning.
)
BYTES_PER_CHUNK = 65535 # The number of bytes to read per chunk when encoding files ~ 64kb
# MetaEnum and BaseEnum are required to support `in` statements:
# 'ALL' in ConsistencyLevel == True
# 12345 in ConsistencyLevel == False
[docs]
class BaseEnum(Enum, metaclass=MetaEnum):
pass
[docs]
def image_encoder_b64(image_or_image_path: Union[str, io.BufferedReader]) -> str:
"""
Encode a image in a Weaviate understandable format from a binary read file or by providing
the image path.
Parameters
----------
image_or_image_path : str, io.BufferedReader
The binary read file or the path to the file.
Returns
-------
str
Encoded image.
Raises
------
ValueError
If the argument is str and does not point to an existing file.
TypeError
If the argument is of a wrong data type.
"""
if isinstance(image_or_image_path, str):
if not os.path.isfile(image_or_image_path):
raise ValueError("No file found at location " + image_or_image_path)
with open(image_or_image_path, "br") as file:
content = file.read()
elif isinstance(image_or_image_path, io.BufferedReader):
content = image_or_image_path.read()
else:
raise TypeError(
'"image_or_image_path" should be a image path or a binary read file'
" (io.BufferedReader)"
)
return base64.b64encode(content).decode("utf-8")
[docs]
def file_encoder_b64(file_or_file_path: Union[str, Path, io.BufferedReader]) -> str:
"""
Encode a file in a Weaviate understandable format from an io.BufferedReader binary read file or by providing
the file path as either a string of a pathlib.Path object
If you pass an io.BufferedReader object, it is your responsibility to close it after encoding.
Parameters
----------
file_or_file_path : str, pathlib.Path io.BufferedReader
The binary read file or the path to the file.
Returns
-------
str
Encoded file.
Raises
------
ValueError
If the argument is str and does not point to an existing file.
TypeError
If the argument is of a wrong data type.
"""
def _chunks(buffer: io.BufferedReader, chunk_size: int) -> Generator[bytes, Any, Any]:
while True:
data = buffer.read(chunk_size)
if not data:
break
yield data
should_close_file = False
use_buffering = True
if isinstance(file_or_file_path, str):
if not os.path.isfile(file_or_file_path):
raise ValueError("No file found at location " + file_or_file_path)
file = open(file_or_file_path, "br")
should_close_file = True
use_buffering = os.path.getsize(file_or_file_path) > BYTES_PER_CHUNK
elif isinstance(file_or_file_path, Path):
if not file_or_file_path.is_file():
raise ValueError("No file found at location " + str(file_or_file_path))
file = file_or_file_path.open("br")
should_close_file = True
use_buffering = file_or_file_path.stat().st_size > BYTES_PER_CHUNK
elif isinstance(file_or_file_path, io.BufferedReader):
file = file_or_file_path
else:
raise TypeError(
'"file_or_file_path" should be a file path or a binary read file' " (io.BufferedReader)"
)
if use_buffering:
encoded: str = ""
for chunk in _chunks(file, BYTES_PER_CHUNK):
encoded += base64.b64encode(chunk).decode("utf-8")
else:
encoded = base64.b64encode(file.read()).decode("utf-8")
if should_close_file:
file.close()
return encoded
[docs]
def image_decoder_b64(encoded_image: str) -> bytes:
"""
Decode image from a Weaviate format image.
Parameters
----------
encoded_image : str
The encoded image.
Returns
-------
bytes
Decoded image as a binary string.
"""
return base64.b64decode(encoded_image.encode("utf-8"))
[docs]
def file_decoder_b64(encoded_file: str) -> bytes:
"""
Decode file from a Weaviate format image.
Parameters
----------
encoded_file : str
The encoded file.
Returns
-------
bytes
Decoded file as a binary string. Use this in your file
handling code to convert it into a specific file type of choice.
E.g., PIL for images.
"""
return base64.b64decode(encoded_file.encode("utf-8"))
[docs]
def generate_local_beacon(
to_uuid: Union[str, uuid_lib.UUID],
class_name: Optional[str] = None,
) -> dict:
"""
Generates a beacon with the given uuid and class name (only for Weaviate >= 1.14.0).
Parameters
----------
to_uuid : str or uuid.UUID
The UUID for which to create a local beacon.
class_name : Optional[str], optional
The class name of the `to_uuid` object. Used with Weaviate >= 1.14.0.
For Weaviate < 1.14.0 use None value.
Returns
-------
dict
The local beacon.
Raises
------
TypeError
If 'to_uuid' is not of type str.
ValueError
If the 'to_uuid' is not valid.
"""
if isinstance(to_uuid, str):
try:
uuid = str(uuid_lib.UUID(to_uuid))
except ValueError:
raise ValueError("Uuid does not have the proper form") from None
elif isinstance(to_uuid, uuid_lib.UUID):
uuid = str(to_uuid)
else:
raise TypeError("Expected to_object_uuid of type str or uuid.UUID")
if class_name is None:
return {"beacon": f"weaviate://localhost/{uuid}"}
return {"beacon": f"weaviate://localhost/{_capitalize_first_letter(class_name)}/{uuid}"}
def _get_dict_from_object(object_: Union[str, dict]) -> dict:
"""
Takes an object that should describe a dict
e.g. a schema or an object and tries to retrieve the dict.
Parameters
----------
object_ : str or dict
The object from which to retrieve the dict.
Can be a python dict, or the path to a json file or a url of a json file.
Returns
-------
dict
The object as a dict.
Raises
------
TypeError
If 'object_' is neither a string nor a dict.
ValueError
If no dict can be retrieved from object.
"""
# check if objects files is url
if object_ is None:
raise TypeError("argument is None")
if isinstance(object_, dict):
# Object is already a dict
return object_
if isinstance(object_, str):
if validators.url(object_):
# Object is URL
response = requests.get(object_)
if response.status_code == 200:
return cast(dict, response.json())
raise ValueError("Could not download file " + object_)
if not os.path.isfile(object_):
# Object is neither file nor URL
raise ValueError("No file found at location " + object_)
# Object is file
with open(object_, "r") as file:
return cast(dict, json.load(file))
raise TypeError(
"Argument is not of the supported types. Supported types are "
"url or file path as string or schema as dict."
)
[docs]
def is_weaviate_object_url(url: str) -> bool:
"""
Checks if the input follows a normal Weaviate 'beacon' like this:
'weaviate://localhost/ClassName/28f3f61b-b524-45e0-9bbe-2c1550bf73d2'
Parameters
----------
url : str
The URL to be validated.
Returns
-------
bool
True if the 'url' is a Weaviate object URL.
False otherwise.
"""
if not isinstance(url, str):
return False
if not url.startswith("weaviate://"):
return False
url = url[11:]
split = url.split("/")
if len(split) not in (2, 3):
return False
if split[0] != "localhost":
if not validators.domain(split[0]):
return False
try:
uuid_lib.UUID(split[-1])
except ValueError:
return False
return True
[docs]
def is_object_url(url: str) -> bool:
"""
Validates an url like 'http://localhost:8080/v1/objects/1c9cd584-88fe-5010-83d0-017cb3fcb446'
or '/v1/objects/1c9cd584-88fe-5010-83d0-017cb3fcb446' references a object. It only validates
the path format and UUID, not the host or the protocol.
Parameters
----------
url : str
The URL to be validated.
Returns
-------
bool
True if the 'url' is a valid path to an object.
False otherwise.
"""
v1_split = url.split("/v1/")
if len(v1_split) != 2:
return False
split = v1_split[1].split("/")
if len(split) not in (2, 3):
return False
try:
uuid_lib.UUID(split[-1])
except ValueError:
return False
if not split[0] == "objects":
return False
return True
[docs]
def get_valid_uuid(uuid: Union[str, uuid_lib.UUID]) -> str:
"""
Validate and extract the UUID.
Parameters
----------
uuid : str or uuid.UUID
The UUID to be validated and extracted.
Should be in the form of an UUID or in form of an URL (weaviate 'beacon' or 'href').
E.g.
'http://localhost:8080/v1/objects/fc7eb129-f138-457f-b727-1b29db191a67'
or
'weaviate://localhost/28f3f61b-b524-45e0-9bbe-2c1550bf73d2'
or
'fc7eb129-f138-457f-b727-1b29db191a67'
Returns
-------
str
The extracted UUID.
Raises
------
TypeError
If 'uuid' is not of type str.
ValueError
If 'uuid' is not valid or cannot be extracted.
"""
if isinstance(uuid, uuid_lib.UUID):
return str(uuid)
if not isinstance(uuid, str):
raise TypeError("'uuid' must be of type str or uuid.UUID, but was: " + str(type(uuid)))
_is_weaviate_url = is_weaviate_object_url(uuid)
_is_object_url = is_object_url(uuid)
_uuid = uuid
if _is_weaviate_url or _is_object_url:
_uuid = uuid.split("/")[-1]
try:
_uuid = str(uuid_lib.UUID(_uuid))
except ValueError:
raise ValueError("Not valid 'uuid' or 'uuid' can not be extracted from value") from None
return _uuid
[docs]
def get_vector(vector: Sequence) -> List[float]:
"""
Get weaviate compatible format of the embedding vector.
Parameters
----------
vector: Sequence
The embedding of an object. Used only for class objects that do not have a vectorization
module. Supported types are `list`, `numpy.ndarray`, `torch.Tensor`, `tf.Tensor`, `pd.Series` and `pl.Series`.
Returns
-------
list
The embedding as a list.
Raises
------
TypeError
If 'vector' is not of a supported type.
"""
if isinstance(vector, list):
# if vector is already a list
return vector
try:
# if vector is numpy.ndarray or torch.Tensor
return vector.squeeze().tolist() # type: ignore
except AttributeError:
pass
try:
# if vector is tf.Tensor or torch.Tensor
return vector.numpy().squeeze().tolist() # type: ignore
except AttributeError:
pass
try:
# if vector is pd.Series or pl.Series
return vector.to_list() # type: ignore
except AttributeError:
pass
raise TypeError(
"The type of the 'vector' argument is not supported!\n"
"Supported types are `list`, 'numpy.ndarray`, `torch.Tensor`, `tf.Tensor`, `pd.Series`, and `pl.Series`"
) from None
def _get_vector_v4(vector: Sequence) -> List[float]:
try:
return get_vector(vector)
except TypeError as e:
raise WeaviateInvalidInputError(
f"The vector you supplied was malformatted! Vector: {vector}"
) from e
[docs]
def get_domain_from_weaviate_url(url: str) -> str:
"""
Get the domain from a weaviate URL.
Parameters
----------
url : str
The weaviate URL.
Of this form: 'weaviate://localhost/objects/28f3f61b-b524-45e0-9bbe-2c1550bf73d2'
Returns
-------
str
The domain.
"""
return url[11:].split("/")[0]
def _is_sub_schema(sub_schema: dict, schema: dict) -> bool:
"""
Check for a subset in a schema.
Parameters
----------
sub_schema : dict
The smaller schema that should be contained in the 'schema'.
schema : dict
The schema for which to check if 'sub_schema' is a part of. Must have the 'classes' key.
Returns
-------
bool
True is 'sub_schema' is a subset of the 'schema'.
False otherwise.
"""
schema_classes = schema.get("classes", [])
if "classes" in sub_schema:
sub_schema_classes = sub_schema["classes"]
else:
sub_schema_classes = [sub_schema]
return _compare_class_sets(sub_schema_classes, schema_classes)
def _compare_class_sets(sub_set: list, set_: list) -> bool:
"""
Check for a subset in a set of classes.
Parameters
----------
sub_set : list
The smaller set that should be contained in the 'set'.
set_ : list
The set for which to check if 'sub_set' is a part of.
Returns
-------
bool
True is 'sub_set' is a subset of the 'set'.
False otherwise.
"""
for sub_set_class in sub_set:
found = False
for set_class in set_:
if "class" not in sub_set_class:
raise SchemaValidationError(
"The sub schema class/es MUST have a 'class' keyword each!"
)
if _capitalize_first_letter(sub_set_class["class"]) == _capitalize_first_letter(
set_class["class"]
):
if _compare_properties(sub_set_class["properties"], set_class["properties"]):
found = True
break
if not found:
return False
return True
def _compare_properties(sub_set: list, set_: list) -> bool:
"""
Check for a subset in a set of properties.
Parameters
----------
sub_set : list
The smaller set that should be contained in the 'set'.
set_ : list
The set for which to check if 'sub_set' is a part of.
Returns
-------
bool
True is 'sub_set' is a subset of the 'set'.
False otherwise.
"""
for sub_set_property in sub_set:
found = False
for set_property in set_:
if sub_set_property["name"] == set_property["name"]:
found = True
break
if not found:
return False
return True
[docs]
def generate_uuid5(identifier: Any, namespace: Any = "") -> str:
"""
Generate an UUIDv5, may be used to consistently generate the same UUID for a specific
identifier and namespace.
Parameters
----------
identifier : Any
The identifier/object that should be used as basis for the UUID.
namespace : Any, optional
Allows to namespace the identifier, by default ""
Returns
-------
str
The UUID as a string.
"""
return str(uuid_lib.uuid5(uuid_lib.NAMESPACE_DNS, str(namespace) + str(identifier)))
def _capitalize_first_letter(string: str) -> str:
"""
Capitalize only the first letter of the `string`.
Parameters
----------
string : str
The string to be capitalized.
Returns
-------
str
The capitalized string.
"""
if len(string) == 1:
return string.capitalize()
return string[0].capitalize() + string[1:]
[docs]
def check_batch_result(
results: Optional[List[Dict[str, Any]]],
) -> None:
"""
Check batch results for errors.
Parameters
----------
results : dict
The Weaviate batch creation return value.
"""
if results is None:
return
for result in results:
if "result" in result and "errors" in result["result"]:
if "error" in result["result"]["errors"]:
print(result["result"]["errors"])
def _check_positive_num(
value: Any, arg_name: str, data_type: type, include_zero: bool = False
) -> None:
"""
Check if the `value` of the `arg_name` is a positive number.
Parameters
----------
value : Union[int, float]
The value to check.
arg_name : str
The name of the variable from the original function call. Used for error message.
data_type : type
The data type to check for.
include_zero : bool
Wether zero counts as positive or not. By default False.
Raises
------
TypeError
If the `value` is not of type `data_type`.
ValueError
If the `value` has a non positive value.
"""
if not isinstance(value, data_type) or isinstance(value, bool):
raise TypeError(f"'{arg_name}' must be of type {data_type}.")
if include_zero:
if value < 0: # type: ignore
raise ValueError(f"'{arg_name}' must be positive, i.e. greater or equal to zero (>=0).")
else:
if value <= 0: # type: ignore
raise ValueError(f"'{arg_name}' must be positive, i.e. greater that zero (>0).")
[docs]
def is_weaviate_domain(url: str) -> bool:
return (
"weaviate.io" in url.lower()
or "semi.technology" in url.lower()
or "weaviate.cloud" in url.lower()
)
[docs]
def strip_newlines(s: str) -> str:
return s.replace("\n", " ")
def _sanitize_str(value: str) -> str:
"""
Ensures string is sanitized for GraphQL.
Parameters
----------
value : str
The value to be converted.
Returns
-------
str
The sanitized string.
"""
value = strip_newlines(value)
value = re.sub(r'(?<!\\)"', '\\"', value) # only replaces unescaped double quotes
return f'"{value}"'
[docs]
def parse_version_string(ver_str: str) -> tuple:
"""
Parse a version string into a float.
Parameters
----------
ver_str : str
The version string to parse. (e.g. "v1.18.2" or "1.18.0")
Returns
-------
tuple :
The parsed version as a tuple with len(2). (e.g. (1, 18)) Note: Ignores the patch version.
"""
if ver_str.count(".") == 0:
ver_str = ver_str + ".0"
pattern = r"v?(\d+)\.(\d+)"
match = re.match(pattern, ver_str)
if match:
ver_tup = tuple(map(int, match.groups()))
return ver_tup
else:
raise ValueError(
f"Unable to parse a version from the input string: {ver_str}. Is it in the format '(v)x.y.z' (e.g. 'v1.18.2' or '1.18.0')?"
)
class _ServerVersion:
def __init__(self, major: int, minor: int, patch: int) -> None:
self.major = major
self.minor = minor
self.patch = patch
def __eq__(self, other: object) -> bool:
if not isinstance(other, _ServerVersion):
return NotImplemented
return self.major == other.major and self.minor == other.minor and self.patch == other.patch
def __neq__(self, other: object) -> bool:
return not self.__eq__(other)
def __gt__(self, other: "_ServerVersion") -> bool:
if self.major > other.major:
return True
elif self.major == other.major:
if self.minor > other.minor:
return True
elif self.minor == other.minor:
if self.patch > other.patch:
return True
return False
def __lt__(self, other: "_ServerVersion") -> bool:
return not self.__gt__(other) and not self.__eq__(other)
def __ge__(self, other: "_ServerVersion") -> bool:
return self.__gt__(other) or self.__eq__(other)
def __le__(self, other: "_ServerVersion") -> bool:
return self.__lt__(other) or self.__eq__(other)
def __repr__(self) -> str:
return f"{self.major}.{self.minor}.{self.patch}"
def __str__(self) -> str:
return f"{self.major}.{self.minor}.{self.patch}"
def is_at_least(self, major: int, minor: int, patch: int) -> bool:
return self >= _ServerVersion(major, minor, patch)
def is_lower_than(self, major: int, minor: int, patch: int) -> bool:
return self < _ServerVersion(major, minor, patch)
@classmethod
def from_string(cls, version: str) -> "_ServerVersion":
initial = version
if version == "":
version = "0"
if version.count(".") == 0:
version = version + ".0"
if version.count(".") == 1:
version = version + ".0"
pattern = r"v?(\d+)\.(\d+)\.(\d+)"
match = re.match(pattern, version)
if match:
ver_tup = tuple(map(int, match.groups()))
return cls(major=ver_tup[0], minor=ver_tup[1], patch=ver_tup[2])
else:
raise ValueError(
f"Unable to parse a version from the input string: {initial}. Is it in the format '(v)x.y.z' (e.g. 'v1.18.2' or '1.18.0')?"
)
[docs]
def is_weaviate_too_old(current_version_str: str) -> bool:
"""
Check if the user should be gently nudged to upgrade their Weaviate server version.
Parameters
----------
current_version_str : str
The version of the Weaviate server that the client is connected to. (e.g. "v1.18.2" or "1.18.0")
Returns
-------
bool :
True if the user should be nudged to upgrade.
"""
current_version = parse_version_string(current_version_str)
minimum_version = parse_version_string(MINIMUM_NO_WARNING_VERSION)
return minimum_version > current_version
[docs]
def is_weaviate_client_too_old(current_version_str: str, latest_version_str: str) -> bool:
"""
Check if the user should be gently nudged to upgrade their Weaviate client version.
Parameters
----------
current_version_str : str
The version of the Weaviate client that is being used (e.g. "v1.18.2" or "1.18.0")
latest_version_str : str
The latest version of the Weaviate client to compare against (e.g. "v1.18.2" or "1.18.0")
Returns
-------
bool :
True if the user should be nudged to upgrade.
False if the user is using a valid version or if the version could not be parsed.
"""
try:
current_version = parse_version_string(current_version_str)
latest_major, latest_minor = parse_version_string(latest_version_str)
minimum_minor = max(latest_minor - MAXIMUM_MINOR_VERSION_DELTA, 0)
minimum_version = (latest_major, minimum_minor)
return minimum_version > current_version
except ValueError:
return False
def _get_valid_timeout_config(
timeout_config: Union[Tuple[NUMBER, NUMBER], NUMBER, None]
) -> Tuple[NUMBER, NUMBER]:
"""
Validate and return TimeOut configuration.
Parameters
----------
timeout_config : tuple(NUMBERS, NUMBERS) or NUMBERS or None, optional
Set the timeout configuration for all requests to the Weaviate server. It can be a
number or, a tuple of two numbers: (connect timeout, read timeout).
If only one number is passed then both connect and read timeout will be set to
that value.
Raises
------
TypeError
If arguments are of a wrong data type.
ValueError
If 'timeout_config' is not a tuple of 2.
ValueError
If 'timeout_config' is/contains negative number/s.
"""
def check_number(num: Union[NUMBER, Tuple[NUMBER, NUMBER], None]) -> bool:
return isinstance(num, float) or isinstance(num, int)
if (isinstance(timeout_config, float) or isinstance(timeout_config, int)) and not isinstance(
timeout_config, bool
):
assert timeout_config is not None
if timeout_config <= 0.0:
raise ValueError("'timeout_config' cannot be non-positive number/s!")
return timeout_config, timeout_config
if not isinstance(timeout_config, tuple):
raise TypeError("'timeout_config' should be a (or tuple of) positive number/s!")
if len(timeout_config) != 2:
raise ValueError("'timeout_config' must be of length 2!")
if not (check_number(timeout_config[0]) and check_number(timeout_config[1])) or (
isinstance(timeout_config[0], bool) and isinstance(timeout_config[1], bool)
):
raise TypeError("'timeout_config' must be tuple of numbers")
if timeout_config[0] <= 0.0 or timeout_config[1] <= 0.0:
raise ValueError("'timeout_config' cannot be non-positive number/s!")
return timeout_config
def _type_request_response(json_response: Any) -> Optional[Dict[str, Any]]:
if json_response is None:
return None
assert isinstance(json_response, dict)
return json_response
def _to_beacons(uuids: UUIDS, to_class: str = "") -> List[Dict[str, str]]:
if isinstance(uuids, uuid_lib.UUID) or isinstance(
uuids, str
): # replace with isinstance(uuids, UUID) in 3.10
uuids = [uuids]
if len(to_class) > 0:
to_class = to_class + "/"
return [{"beacon": f"weaviate://localhost/{to_class}{uuid_to}"} for uuid_to in uuids]
def _decode_json_response_dict(
response: Union[httpx.Response, requests.Response], location: str
) -> Optional[Dict[str, Any]]:
if response is None:
return None
if 200 <= response.status_code < 300:
try:
json_response = cast(Dict[str, Any], response.json())
return json_response
except JSONDecodeError:
raise ResponseCannotBeDecodedError(location, response)
raise UnexpectedStatusCodeError(location, response)
def _decode_json_response_list(
response: Union[httpx.Response, requests.Response], location: str
) -> Optional[List[Dict[str, Any]]]:
if response is None:
return None
if 200 <= response.status_code < 300:
try:
json_response = response.json()
return cast(list, json_response)
except JSONDecodeError:
raise ResponseCannotBeDecodedError(location, response)
raise UnexpectedStatusCodeError(location, response)
def _datetime_to_string(value: TIME) -> str:
if value.tzinfo is None:
_Warnings.datetime_insertion_with_no_specified_timezone(value)
value = value.replace(tzinfo=datetime.timezone.utc)
return value.isoformat(sep="T", timespec="microseconds")
def _datetime_from_weaviate_str(string: str) -> datetime.datetime:
try:
return datetime.datetime.strptime(
"".join(string.rsplit(":", 1) if string[-1] != "Z" else string),
"%Y-%m-%dT%H:%M:%S.%f%z",
)
except ValueError: # if the string does not have microseconds
return datetime.datetime.strptime(
"".join(string.rsplit(":", 1) if string[-1] != "Z" else string),
"%Y-%m-%dT%H:%M:%S%z",
)