"""
Helper functions!
"""
import base64
import datetime
import io
import os
import re
import uuid as uuid_lib
from pathlib import Path
from typing import Union, Sequence, Any, Optional, List, Dict, Generator, Tuple, cast
import httpx
import validators
from weaviate.exceptions import (
SchemaValidationError,
UnexpectedStatusCodeError,
ResponseCannotBeDecodedError,
WeaviateInvalidInputError,
WeaviateUnsupportedFeatureError,
)
from weaviate.types import NUMBER, UUIDS, TIME
from weaviate.validator import _is_valid, _ExtraTypes
from weaviate.warnings import _Warnings
PYPI_PACKAGE_URL = "https://pypi.org/pypi/weaviate-client/json"
MAXIMUM_MINOR_VERSION_DELTA = 3 # The maximum delta between minor versions of Weaviate Client that will not trigger an upgrade warning.
MINIMUM_NO_WARNING_VERSION = (
"v1.16.0" # The minimum version of Weaviate that will not trigger an upgrade warning.
)
BYTES_PER_CHUNK = 65535 # The number of bytes to read per chunk when encoding files ~ 64kb
[docs]
def image_encoder_b64(image_or_image_path: Union[str, io.BufferedReader]) -> str:
"""
Encode a image in a Weaviate understandable format from a binary read file or by providing
the image path.
Parameters
----------
image_or_image_path : str, io.BufferedReader
The binary read file or the path to the file.
Returns
-------
str
Encoded image.
Raises
------
ValueError
If the argument is str and does not point to an existing file.
TypeError
If the argument is of a wrong data type.
"""
if isinstance(image_or_image_path, str):
if not os.path.isfile(image_or_image_path):
raise ValueError("No file found at location " + image_or_image_path)
with open(image_or_image_path, "br") as file:
content = file.read()
elif isinstance(image_or_image_path, io.BufferedReader):
content = image_or_image_path.read()
else:
raise TypeError(
'"image_or_image_path" should be a image path or a binary read file'
" (io.BufferedReader)"
)
return base64.b64encode(content).decode("utf-8")
[docs]
def file_encoder_b64(file_or_file_path: Union[str, Path, io.BufferedReader]) -> str:
"""
Encode a file in a Weaviate understandable format from an io.BufferedReader binary read file or by providing
the file path as either a string of a pathlib.Path object
If you pass an io.BufferedReader object, it is your responsibility to close it after encoding.
Parameters
----------
file_or_file_path : str, pathlib.Path io.BufferedReader
The binary read file or the path to the file.
Returns
-------
str
Encoded file.
Raises
------
ValueError
If the argument is str and does not point to an existing file.
TypeError
If the argument is of a wrong data type.
"""
def _chunks(buffer: io.BufferedReader, chunk_size: int) -> Generator[bytes, Any, Any]:
while True:
data = buffer.read(chunk_size)
if not data:
break
yield data
should_close_file = False
use_buffering = True
file = None
try:
if isinstance(file_or_file_path, str):
if not os.path.isfile(file_or_file_path):
raise ValueError("No file found at location " + file_or_file_path)
file = open(file_or_file_path, "br")
should_close_file = True
use_buffering = os.path.getsize(file_or_file_path) > BYTES_PER_CHUNK
elif isinstance(file_or_file_path, Path):
if not file_or_file_path.is_file():
raise ValueError("No file found at location " + str(file_or_file_path))
file = file_or_file_path.open("br")
should_close_file = True
use_buffering = file_or_file_path.stat().st_size > BYTES_PER_CHUNK
elif isinstance(file_or_file_path, io.BufferedReader):
file = file_or_file_path
else:
raise TypeError(
'"file_or_file_path" should be a file path or a binary read file'
" (io.BufferedReader)"
)
if use_buffering:
encoded: str = ""
for chunk in _chunks(file, BYTES_PER_CHUNK):
encoded += base64.b64encode(chunk).decode("utf-8")
else:
encoded = base64.b64encode(file.read()).decode("utf-8")
finally:
if should_close_file and file is not None:
file.close()
return encoded
[docs]
def image_decoder_b64(encoded_image: str) -> bytes:
"""
Decode image from a Weaviate format image.
Parameters
----------
encoded_image : str
The encoded image.
Returns
-------
bytes
Decoded image as a binary string.
"""
return base64.b64decode(encoded_image.encode("utf-8"))
[docs]
def file_decoder_b64(encoded_file: str) -> bytes:
"""
Decode file from a Weaviate format image.
Parameters
----------
encoded_file : str
The encoded file.
Returns
-------
bytes
Decoded file as a binary string. Use this in your file
handling code to convert it into a specific file type of choice.
E.g., PIL for images.
"""
return base64.b64decode(encoded_file.encode("utf-8"))
[docs]
def is_weaviate_object_url(url: str) -> bool:
"""
Checks if the input follows a normal Weaviate 'beacon' like this:
'weaviate://localhost/ClassName/28f3f61b-b524-45e0-9bbe-2c1550bf73d2'
Parameters
----------
url : str
The URL to be validated.
Returns
-------
bool
True if the 'url' is a Weaviate object URL.
False otherwise.
"""
if not isinstance(url, str):
return False
if not url.startswith("weaviate://"):
return False
url = url[11:]
split = url.split("/")
if len(split) not in (2, 3):
return False
if split[0] != "localhost":
if not validators.domain(split[0]):
return False
try:
uuid_lib.UUID(split[-1])
except ValueError:
return False
return True
[docs]
def is_object_url(url: str) -> bool:
"""
Validates an url like 'http://localhost:8080/v1/objects/1c9cd584-88fe-5010-83d0-017cb3fcb446'
or '/v1/objects/1c9cd584-88fe-5010-83d0-017cb3fcb446' references a object. It only validates
the path format and UUID, not the host or the protocol.
Parameters
----------
url : str
The URL to be validated.
Returns
-------
bool
True if the 'url' is a valid path to an object.
False otherwise.
"""
v1_split = url.split("/v1/")
if len(v1_split) != 2:
return False
split = v1_split[1].split("/")
if len(split) not in (2, 3):
return False
try:
uuid_lib.UUID(split[-1])
except ValueError:
return False
if not split[0] == "objects":
return False
return True
[docs]
def get_valid_uuid(uuid: Union[str, uuid_lib.UUID]) -> str:
"""
Validate and extract the UUID.
Parameters
----------
uuid : str or uuid.UUID
The UUID to be validated and extracted.
Should be in the form of an UUID or in form of an URL (weaviate 'beacon' or 'href').
E.g.
'http://localhost:8080/v1/objects/fc7eb129-f138-457f-b727-1b29db191a67'
or
'weaviate://localhost/28f3f61b-b524-45e0-9bbe-2c1550bf73d2'
or
'fc7eb129-f138-457f-b727-1b29db191a67'
Returns
-------
str
The extracted UUID.
Raises
------
TypeError
If 'uuid' is not of type str.
ValueError
If 'uuid' is not valid or cannot be extracted.
"""
if isinstance(uuid, uuid_lib.UUID):
return str(uuid)
if not isinstance(uuid, str):
raise TypeError("'uuid' must be of type str or uuid.UUID, but was: " + str(type(uuid)))
_is_weaviate_url = is_weaviate_object_url(uuid)
_is_object_url = is_object_url(uuid)
_uuid = uuid
if _is_weaviate_url or _is_object_url:
_uuid = uuid.split("/")[-1]
try:
_uuid = str(uuid_lib.UUID(_uuid))
except ValueError:
raise ValueError("Not valid 'uuid' or 'uuid' can not be extracted from value") from None
return _uuid
[docs]
def get_vector(vector: Sequence) -> Sequence[float]:
"""
Get weaviate compatible format of the embedding vector.
Parameters
----------
vector: Sequence
The embedding of an object. Used only for class objects that do not have a vectorization
module. Supported types are `list`, `numpy.ndarray`, `torch.Tensor`, `tf.Tensor`, `pd.Series` and `pl.Series`.
Returns
-------
list
The embedding as a list.
Raises
------
TypeError
If 'vector' is not of a supported type.
"""
if isinstance(vector, list):
# if vector is already a list
return vector
try:
# if vector is numpy.ndarray or torch.Tensor
return vector.squeeze().tolist() # type: ignore
except AttributeError:
pass
try:
# if vector is tf.Tensor or torch.Tensor
return vector.numpy().squeeze().tolist() # type: ignore
except AttributeError:
pass
try:
# if vector is pd.Series or pl.Series
return vector.to_list() # type: ignore
except AttributeError:
pass
raise TypeError(
"The type of the 'vector' argument is not supported!\n"
"Supported types are `list`, 'numpy.ndarray`, `torch.Tensor`, `tf.Tensor`, `pd.Series`, and `pl.Series`"
) from None
def _get_vector_v4(vector: Any) -> Sequence[float]:
try:
return get_vector(vector)
except TypeError as e:
raise WeaviateInvalidInputError(
f"The vector you supplied was malformatted! Vector: {vector}"
) from e
[docs]
def get_domain_from_weaviate_url(url: str) -> str:
"""
Get the domain from a weaviate URL.
Parameters
----------
url : str
The weaviate URL.
Of this form: 'weaviate://localhost/objects/28f3f61b-b524-45e0-9bbe-2c1550bf73d2'
Returns
-------
str
The domain.
"""
return url[11:].split("/")[0]
def _is_sub_schema(sub_schema: dict, schema: dict) -> bool:
"""
Check for a subset in a schema.
Parameters
----------
sub_schema : dict
The smaller schema that should be contained in the 'schema'.
schema : dict
The schema for which to check if 'sub_schema' is a part of. Must have the 'classes' key.
Returns
-------
bool
True is 'sub_schema' is a subset of the 'schema'.
False otherwise.
"""
schema_classes = schema.get("classes", [])
if "classes" in sub_schema:
sub_schema_classes = sub_schema["classes"]
else:
sub_schema_classes = [sub_schema]
return _compare_class_sets(sub_schema_classes, schema_classes)
def _compare_class_sets(sub_set: list, set_: list) -> bool:
"""
Check for a subset in a set of classes.
Parameters
----------
sub_set : list
The smaller set that should be contained in the 'set'.
set_ : list
The set for which to check if 'sub_set' is a part of.
Returns
-------
bool
True is 'sub_set' is a subset of the 'set'.
False otherwise.
"""
for sub_set_class in sub_set:
found = False
for set_class in set_:
if "class" not in sub_set_class:
raise SchemaValidationError(
"The sub schema class/es MUST have a 'class' keyword each!"
)
if _capitalize_first_letter(sub_set_class["class"]) == _capitalize_first_letter(
set_class["class"]
):
if _compare_properties(sub_set_class["properties"], set_class["properties"]):
found = True
break
if not found:
return False
return True
def _compare_properties(sub_set: list, set_: list) -> bool:
"""
Check for a subset in a set of properties.
Parameters
----------
sub_set : list
The smaller set that should be contained in the 'set'.
set_ : list
The set for which to check if 'sub_set' is a part of.
Returns
-------
bool
True is 'sub_set' is a subset of the 'set'.
False otherwise.
"""
for sub_set_property in sub_set:
found = False
for set_property in set_:
if sub_set_property["name"] == set_property["name"]:
found = True
break
if not found:
return False
return True
[docs]
def generate_uuid5(identifier: Any, namespace: Any = "") -> str:
"""
Generate an UUIDv5, may be used to consistently generate the same UUID for a specific
identifier and namespace.
Parameters
----------
identifier : Any
The identifier/object that should be used as basis for the UUID.
namespace : Any, optional
Allows to namespace the identifier, by default ""
Returns
-------
str
The UUID as a string.
"""
return str(uuid_lib.uuid5(uuid_lib.NAMESPACE_DNS, str(namespace) + str(identifier)))
def _capitalize_first_letter(string: str) -> str:
"""
Capitalize only the first letter of the `string`.
Parameters
----------
string : str
The string to be capitalized.
Returns
-------
str
The capitalized string.
"""
if len(string) == 1:
return string.capitalize()
return string[0].capitalize() + string[1:]
[docs]
def check_batch_result(
results: Optional[List[Dict[str, Any]]],
) -> None:
"""
Check batch results for errors.
Parameters
----------
results : dict
The Weaviate batch creation return value.
"""
if results is None:
return
for result in results:
if "result" in result and "errors" in result["result"]:
if "error" in result["result"]["errors"]:
print(result["result"]["errors"])
def _check_positive_num(
value: Any, arg_name: str, data_type: type, include_zero: bool = False
) -> None:
"""
Check if the `value` of the `arg_name` is a positive number.
Parameters
----------
value : Union[int, float]
The value to check.
arg_name : str
The name of the variable from the original function call. Used for error message.
data_type : type
The data type to check for.
include_zero : bool
Wether zero counts as positive or not. By default False.
Raises
------
TypeError
If the `value` is not of type `data_type`.
ValueError
If the `value` has a non positive value.
"""
if not isinstance(value, data_type) or isinstance(value, bool):
raise TypeError(f"'{arg_name}' must be of type {data_type}.")
if include_zero:
if value < 0: # type: ignore
raise ValueError(f"'{arg_name}' must be positive, i.e. greater or equal to zero (>=0).")
else:
if value <= 0: # type: ignore
raise ValueError(f"'{arg_name}' must be positive, i.e. greater that zero (>0).")
[docs]
def is_weaviate_domain(url: str) -> bool:
return (
"weaviate.io" in url.lower()
or "semi.technology" in url.lower()
or "weaviate.cloud" in url.lower()
)
[docs]
def strip_newlines(s: str) -> str:
return s.replace("\n", " ")
def _sanitize_str(value: str) -> str:
"""
Ensures string is sanitized for GraphQL.
Parameters
----------
value : str
The value to be converted.
Returns
-------
str
The sanitized string.
"""
value = strip_newlines(value)
value = re.sub(
r'(?<!\\)((?:\\{2})*)"', r"\1\"", value
) # only replaces unescaped double quotes without permitting query injection
return f'"{value}"'
[docs]
def parse_version_string(ver_str: str) -> tuple:
"""
Parse a version string into a float.
Parameters
----------
ver_str : str
The version string to parse. (e.g. "v1.18.2" or "1.18.0")
Returns
-------
tuple :
The parsed version as a tuple with len(2). (e.g. (1, 18)) Note: Ignores the patch version.
"""
if ver_str.count(".") == 0:
ver_str = ver_str + ".0"
pattern = r"v?(\d+)\.(\d+)"
match = re.match(pattern, ver_str)
if match:
ver_tup = tuple(map(int, match.groups()))
return ver_tup
else:
raise ValueError(
f"Unable to parse a version from the input string: {ver_str}. Is it in the format '(v)x.y.z' (e.g. 'v1.18.2' or '1.18.0')?"
)
class _ServerVersion:
def __init__(self, major: int, minor: int, patch: int) -> None:
self.major = major
self.minor = minor
self.patch = patch
def __eq__(self, other: object) -> bool:
if not isinstance(other, _ServerVersion):
return NotImplemented
return self.major == other.major and self.minor == other.minor and self.patch == other.patch
def __neq__(self, other: object) -> bool:
return not self.__eq__(other)
def __gt__(self, other: "_ServerVersion") -> bool:
if self.major > other.major:
return True
elif self.major == other.major:
if self.minor > other.minor:
return True
elif self.minor == other.minor:
if self.patch > other.patch:
return True
return False
def __lt__(self, other: "_ServerVersion") -> bool:
return not self.__gt__(other) and not self.__eq__(other)
def __ge__(self, other: "_ServerVersion") -> bool:
return self.__gt__(other) or self.__eq__(other)
def __le__(self, other: "_ServerVersion") -> bool:
return self.__lt__(other) or self.__eq__(other)
def __repr__(self) -> str:
return f"{self.major}.{self.minor}.{self.patch}"
def __str__(self) -> str:
return f"{self.major}.{self.minor}.{self.patch}"
def is_at_least(self, major: int, minor: int, patch: int) -> bool:
return self >= _ServerVersion(major, minor, patch)
def is_lower_than(self, major: int, minor: int, patch: int) -> bool:
return self < _ServerVersion(major, minor, patch)
@classmethod
def from_string(cls, version: str) -> "_ServerVersion":
initial = version
if version == "":
version = "0"
if version.count(".") == 0:
version = version + ".0"
if version.count(".") == 1:
version = version + ".0"
pattern = r"v?(\d+)\.(\d+)\.(\d+)"
match = re.match(pattern, version)
if match:
ver_tup = tuple(map(int, match.groups()))
return cls(major=ver_tup[0], minor=ver_tup[1], patch=ver_tup[2])
else:
raise ValueError(
f"Unable to parse a version from the input string: {initial}. Is it in the format '(v)x.y.z' (e.g. 'v1.18.2' or '1.18.0')?"
)
def check_is_at_least_1_25_0(self, feature: str) -> None:
if not self >= _ServerVersion(1, 25, 0):
raise WeaviateUnsupportedFeatureError(feature, str(self), "1.25.0")
@property
def supports_tenants_get_grpc(self) -> bool:
return self >= _ServerVersion(1, 25, 0)
[docs]
def is_weaviate_too_old(current_version_str: str) -> bool:
"""
Check if the user should be gently nudged to upgrade their Weaviate server version.
Parameters
----------
current_version_str : str
The version of the Weaviate server that the client is connected to. (e.g. "v1.18.2" or "1.18.0")
Returns
-------
bool :
True if the user should be nudged to upgrade.
"""
current_version = parse_version_string(current_version_str)
minimum_version = parse_version_string(MINIMUM_NO_WARNING_VERSION)
return minimum_version > current_version
[docs]
def is_weaviate_client_too_old(current_version_str: str, latest_version_str: str) -> bool:
"""
Check if the user should be gently nudged to upgrade their Weaviate client version.
Parameters
----------
current_version_str : str
The version of the Weaviate client that is being used (e.g. "v1.18.2" or "1.18.0")
latest_version_str : str
The latest version of the Weaviate client to compare against (e.g. "v1.18.2" or "1.18.0")
Returns
-------
bool :
True if the user should be nudged to upgrade.
False if the user is using a valid version or if the version could not be parsed.
"""
try:
current_version = parse_version_string(current_version_str)
latest_major, latest_minor = parse_version_string(latest_version_str)
minimum_minor = max(latest_minor - MAXIMUM_MINOR_VERSION_DELTA, 0)
minimum_version = (latest_major, minimum_minor)
return minimum_version > current_version
except ValueError:
return False
def _get_valid_timeout_config(
timeout_config: Union[Tuple[NUMBER, NUMBER], NUMBER, None]
) -> Tuple[NUMBER, NUMBER]:
"""
Validate and return TimeOut configuration.
Parameters
----------
timeout_config : tuple(NUMBERS, NUMBERS) or NUMBERS or None, optional
Set the timeout configuration for all requests to the Weaviate server. It can be a
number or, a tuple of two numbers: (connect timeout, read timeout).
If only one number is passed then both connect and read timeout will be set to
that value.
Raises
------
TypeError
If arguments are of a wrong data type.
ValueError
If 'timeout_config' is not a tuple of 2.
ValueError
If 'timeout_config' is/contains negative number/s.
"""
def check_number(num: Union[NUMBER, Tuple[NUMBER, NUMBER], None]) -> bool:
return isinstance(num, float) or isinstance(num, int)
if (isinstance(timeout_config, float) or isinstance(timeout_config, int)) and not isinstance(
timeout_config, bool
):
assert timeout_config is not None
if timeout_config <= 0.0:
raise ValueError("'timeout_config' cannot be non-positive number/s!")
return timeout_config, timeout_config
if not isinstance(timeout_config, tuple):
raise TypeError("'timeout_config' should be a (or tuple of) positive number/s!")
if len(timeout_config) != 2:
raise ValueError("'timeout_config' must be of length 2!")
if not (check_number(timeout_config[0]) and check_number(timeout_config[1])) or (
isinstance(timeout_config[0], bool) and isinstance(timeout_config[1], bool)
):
raise TypeError("'timeout_config' must be tuple of numbers")
if timeout_config[0] <= 0.0 or timeout_config[1] <= 0.0:
raise ValueError("'timeout_config' cannot be non-positive number/s!")
return timeout_config
def _type_request_response(json_response: Any) -> Optional[Dict[str, Any]]:
if json_response is None:
return None
assert isinstance(json_response, dict)
return json_response
def _to_beacons(uuids: UUIDS, to_class: str = "") -> List[Dict[str, str]]:
if isinstance(uuids, uuid_lib.UUID) or isinstance(
uuids, str
): # replace with isinstance(uuids, UUID) in 3.10
uuids = [uuids]
if len(to_class) > 0:
to_class = to_class + "/"
return [{"beacon": f"weaviate://localhost/{to_class}{uuid_to}"} for uuid_to in uuids]
def _decode_json_response_dict(response: httpx.Response, location: str) -> Optional[Dict[str, Any]]:
if response is None:
return None
if 200 <= response.status_code < 300:
try:
json_response = cast(Dict[str, Any], response.json())
return json_response
except httpx.DecodingError:
raise ResponseCannotBeDecodedError(location, response)
raise UnexpectedStatusCodeError(location, response)
def _decode_json_response_list(
response: httpx.Response, location: str
) -> Optional[List[Dict[str, Any]]]:
if response is None:
return None
if 200 <= response.status_code < 300:
try:
json_response = response.json()
return cast(list, json_response)
except httpx.DecodingError:
raise ResponseCannotBeDecodedError(location, response)
raise UnexpectedStatusCodeError(location, response)
def _datetime_to_string(value: TIME) -> str:
if value.tzinfo is None:
_Warnings.datetime_insertion_with_no_specified_timezone(value)
value = value.replace(tzinfo=datetime.timezone.utc)
return value.isoformat(sep="T", timespec="microseconds")
def _datetime_from_weaviate_str(string: str) -> datetime.datetime:
try:
return datetime.datetime.strptime(
"".join(string.rsplit(":", 1) if string[-1] != "Z" else string),
"%Y-%m-%dT%H:%M:%S.%f%z",
)
except ValueError: # if the string does not have microseconds
return datetime.datetime.strptime(
"".join(string.rsplit(":", 1) if string[-1] != "Z" else string),
"%Y-%m-%dT%H:%M:%S%z",
)
def __is_list_type(inputs: Any) -> bool:
try:
if len(inputs) == 0:
return False
except TypeError:
return False
return any(
_is_valid(types, inputs)
for types in [
List,
_ExtraTypes.TF,
_ExtraTypes.PANDAS,
_ExtraTypes.NUMPY,
_ExtraTypes.POLARS,
]
)
def _is_1d_vector(inputs: Any) -> bool:
try:
if len(inputs) == 0:
return False
except TypeError:
return False
if __is_list_type(inputs):
return not __is_list_type(inputs[0]) # 2D vectors are not 1D vectors
return False