import asyncio
import contextvars
import functools
import math
import os
import threading
import time
import uuid as uuid_package
from abc import ABC
from collections import deque
from concurrent.futures import ThreadPoolExecutor
from copy import copy
from dataclasses import dataclass, field
from typing import Any, Dict, Generic, List, Optional, Set, TypeVar, Union, cast
from pydantic import ValidationError
from typing_extensions import TypeAlias
from weaviate.cluster.types import Node
from weaviate.collections.batch.grpc_batch import _BatchGRPC
from weaviate.collections.batch.rest import _BatchREST
from weaviate.collections.classes.batch import (
BatchObject,
BatchObjectReturn,
BatchReference,
BatchReferenceReturn,
BatchResult,
ErrorObject,
ErrorReference,
Shard,
)
from weaviate.collections.classes.config import ConsistencyLevel
from weaviate.collections.classes.internal import (
ReferenceInput,
ReferenceInputs,
ReferenceToMulti,
)
from weaviate.collections.classes.types import WeaviateProperties
from weaviate.connect import executor
from weaviate.connect.v4 import ConnectionAsync, ConnectionSync
from weaviate.exceptions import (
EmptyResponseException,
WeaviateBatchValidationError,
)
from weaviate.logger import logger
from weaviate.types import UUID, VECTORS
from weaviate.util import _decode_json_response_dict
from weaviate.warnings import _Warnings
BatchResponse = List[Dict[str, Any]]
TBatchInput = TypeVar("TBatchInput")
TBatchReturn = TypeVar("TBatchReturn")
MAX_CONCURRENT_REQUESTS = 10
DEFAULT_REQUEST_TIMEOUT = 180
CONCURRENT_REQUESTS_DYNAMIC_VECTORIZER = 2
BATCH_TIME_TARGET = 10
VECTORIZER_BATCHING_STEP_SIZE = 48 # cohere max batch size is 96
MAX_RETRIES = float(
os.getenv("WEAVIATE_BATCH_MAX_RETRIES", "9.299")
) # approximately 10m30s of waiting in worst case, e.g. server scale up event
GCP_STREAM_TIMEOUT = (
160 # GCP connections have a max lifetime of 180s, leave 20s of buffer as safety
)
[docs]
class BatchRequest(ABC, Generic[TBatchInput, TBatchReturn]):
"""`BatchRequest` abstract class used as a interface for batch requests."""
def __init__(self) -> None:
self._items: List[TBatchInput] = []
self._lock = threading.Lock()
self._alock = asyncio.Lock()
def __len__(self) -> int:
with self._lock:
return len(self._items)
[docs]
async def alen(self) -> int:
"""Asynchronously get the length of the BatchRequest."""
async with self._alock:
return len(self._items)
[docs]
def add(self, item: TBatchInput) -> None:
"""Add an item to the BatchRequest."""
with self._lock:
self._items.append(item)
[docs]
async def aadd(self, item: TBatchInput) -> None:
"""Asynchronously add an item to the BatchRequest."""
async with self._alock:
self._items.append(item)
[docs]
def prepend(self, item: List[TBatchInput]) -> None:
"""Add items to the front of the BatchRequest.
This is intended to be used when objects should be retries, eg. after a temporary error.
"""
with self._lock:
self._items = item + self._items
[docs]
async def aprepend(self, item: List[TBatchInput]) -> None:
"""Asynchronously add items to the front of the BatchRequest.
This is intended to be used when objects should be retries, eg. after a temporary error.
"""
async with self._alock:
self._items = item + self._items
Ref = TypeVar("Ref", bound=BatchReference)
[docs]
class ReferencesBatchRequest(BatchRequest[Ref, BatchReferenceReturn]):
"""Collect Weaviate-object references to add them in one request to Weaviate."""
def __pop_items(self, pop_amount: int, uuid_lookup: Set[str]) -> List[Ref]:
ret: List[Ref] = []
i = 0
while len(ret) < pop_amount and len(self._items) > 0 and i < len(self._items):
if self._items[i].from_object_uuid not in uuid_lookup and (
self._items[i].to_object_uuid is None
or self._items[i].to_object_uuid not in uuid_lookup
):
ret.append(self._items.pop(i))
else:
i += 1
return ret
[docs]
def pop_items(self, pop_amount: int, uuid_lookup: Set[str]) -> List[Ref]:
"""Pop the given number of items from the BatchRequest queue.
Returns:
A list of items from the BatchRequest.
"""
with self._lock:
return self.__pop_items(pop_amount, uuid_lookup)
[docs]
async def apop_items(self, pop_amount: int, uuid_lookup: Set[str]) -> List[Ref]:
"""Asynchronously pop the given number of items from the BatchRequest queue.
Returns:
A list of items from the BatchRequest.
"""
async with self._alock:
return self.__pop_items(pop_amount, uuid_lookup)
def __head(self) -> Optional[Ref]:
if len(self._items) > 0:
return self._items[0]
return None
[docs]
def head(self) -> Optional[Ref]:
"""Get the first item from the BatchRequest queue without removing it.
Returns:
The first item from the BatchRequest or None if the queue is empty.
"""
with self._lock:
return self.__head()
[docs]
async def ahead(self) -> Optional[Ref]:
"""Asynchronously get the first item from the BatchRequest queue without removing it.
Returns:
The first item from the BatchRequest or None if the queue is empty.
"""
async with self._alock:
return self.__head()
Obj = TypeVar("Obj", bound=BatchObject)
[docs]
class ObjectsBatchRequest(Generic[Obj], BatchRequest[Obj, BatchObjectReturn]):
"""Collect objects for one batch request to weaviate."""
def __pop_items(self, pop_amount: int) -> List[Obj]:
if pop_amount >= len(self._items):
ret = copy(self._items)
self._items.clear()
else:
ret = copy(self._items[:pop_amount])
self._items = self._items[pop_amount:]
return ret
[docs]
def pop_items(self, pop_amount: int) -> List[Obj]:
"""Pop the given number of items from the BatchRequest queue.
Returns:
A list of items from the BatchRequest.
"""
with self._lock:
return self.__pop_items(pop_amount)
[docs]
async def apop_items(self, pop_amount: int) -> List[Obj]:
"""Asynchronously pop the given number of items from the BatchRequest queue.
Returns:
A list of items from the BatchRequest.
"""
async with self._alock:
return self.__pop_items(pop_amount)
def __head(self) -> Optional[Obj]:
if len(self._items) > 0:
return self._items[0]
return None
[docs]
def head(self) -> Optional[Obj]:
"""Get the first item from the BatchRequest queue without removing it.
Returns:
The first item from the BatchRequest or None if the queue is empty.
"""
with self._lock:
return self.__head()
[docs]
async def ahead(self) -> Optional[Obj]:
"""Asynchronously get the first item from the BatchRequest queue without removing it.
Returns:
The first item from the BatchRequest or None if the queue is empty.
"""
async with self._alock:
return self.__head()
[docs]
@dataclass
class _BatchDataWrapper:
results: BatchResult = field(default_factory=BatchResult)
failed_objects: List[ErrorObject] = field(default_factory=list)
failed_references: List[ErrorReference] = field(default_factory=list)
imported_shards: Set[Shard] = field(default_factory=set)
[docs]
@dataclass
class _DynamicBatching:
pass
[docs]
@dataclass
class _FixedSizeBatching:
batch_size: int
concurrent_requests: int
[docs]
@dataclass
class _RateLimitedBatching:
requests_per_minute: int
[docs]
@dataclass
class _ServerSideBatching:
concurrency: int
_BatchMode: TypeAlias = Union[
_DynamicBatching, _FixedSizeBatching, _RateLimitedBatching, _ServerSideBatching
]
[docs]
class _BatchBase:
def __init__(
self,
connection: ConnectionSync,
consistency_level: Optional[ConsistencyLevel],
results: _BatchDataWrapper,
batch_mode: _BatchMode,
executor: ThreadPoolExecutor,
vectorizer_batching: bool,
objects: Optional[ObjectsBatchRequest[BatchObject]] = None,
references: Optional[ReferencesBatchRequest[BatchReference]] = None,
) -> None:
self.__batch_objects = objects or ObjectsBatchRequest[BatchObject]()
self.__batch_references = references or ReferencesBatchRequest[BatchReference]()
self.__connection = connection
self.__consistency_level: Optional[ConsistencyLevel] = consistency_level
self.__vectorizer_batching = vectorizer_batching
self.__batch_grpc = _BatchGRPC(
connection._weaviate_version, self.__consistency_level, connection._grpc_max_msg_size
)
self.__batch_rest = _BatchREST(self.__consistency_level)
# lookup table for objects that are currently being processed - is used to not send references from objects that have not been added yet
self.__uuid_lookup: Set[str] = set()
# we do not want that users can access the results directly as they are not thread-safe
self.__results_for_wrapper_backup = results
self.__results_for_wrapper = _BatchDataWrapper()
self.__cluster = _ClusterBatch(self.__connection)
self.__batching_mode: _BatchMode = batch_mode
self.__max_batch_size: int = 1000
self.__executor = executor
self.__objs_count = 0
self.__refs_count = 0
self.__objs_logs_count = 0
self.__refs_logs_count = 0
if isinstance(self.__batching_mode, _FixedSizeBatching):
self.__recommended_num_objects = self.__batching_mode.batch_size
self.__concurrent_requests = self.__batching_mode.concurrent_requests
elif isinstance(self.__batching_mode, _RateLimitedBatching):
# Batch with rate limiting should never send more than the given amount of objects per minute.
# We could send all objects in a single batch every 60 seconds but that could cause problems with too large requests. Therefore, we
# limit the size of a batch to self.__max_batch_size and send multiple batches of equal size and send them in equally space in time.
# Example:
# 3000 objects, 1000/min -> 3 batches of 1000 objects, send every 20 seconds
self.__concurrent_requests = (
self.__batching_mode.requests_per_minute + self.__max_batch_size
) // self.__max_batch_size
self.__recommended_num_objects = (
self.__batching_mode.requests_per_minute // self.__concurrent_requests
)
elif isinstance(self.__batching_mode, _DynamicBatching) and not self.__vectorizer_batching:
self.__recommended_num_objects = 10
self.__concurrent_requests = 2
else:
assert isinstance(self.__batching_mode, _DynamicBatching) and self.__vectorizer_batching
self.__recommended_num_objects = VECTORIZER_BATCHING_STEP_SIZE
self.__concurrent_requests = 2
self.__dynamic_batching_sleep_time: int = 0
self._batch_send: bool = False
self.__recommended_num_refs: int = 50
self.__active_requests = 0
# dynamic batching
self.__time_last_scale_up: float = 0
self.__rate_queue: deque = deque(maxlen=50) # 5s with 0.1s refresh rate
self.__took_queue: deque = deque(maxlen=CONCURRENT_REQUESTS_DYNAMIC_VECTORIZER)
# fixed rate batching
self.__time_stamp_last_request: float = 0
# do 62 secs to give us some buffer to the "per-minute" calculation
self.__fix_rate_batching_base_time = 62
self.__active_requests_lock = threading.Lock()
self.__uuid_lookup_lock = threading.Lock()
self.__results_lock = threading.Lock()
self.__bg_threads = self.__start_bg_threads()
self.__bg_thread_exception: Optional[Exception] = None
@property
def number_errors(self) -> int:
"""Return the number of errors in the batch."""
return len(self.__results_for_wrapper.failed_objects) + len(
self.__results_for_wrapper.failed_references
)
[docs]
def _shutdown(self) -> None:
"""Shutdown the current batch and wait for all requests to be finished."""
self.flush()
# we are done, shut bg threads down and end the event loop
self.__shut_background_thread_down.set()
while self.__bg_threads.is_alive():
time.sleep(0.01)
# copy the results to the public results
self.__results_for_wrapper_backup.results = self.__results_for_wrapper.results
self.__results_for_wrapper_backup.failed_objects = self.__results_for_wrapper.failed_objects
self.__results_for_wrapper_backup.failed_references = (
self.__results_for_wrapper.failed_references
)
self.__results_for_wrapper_backup.imported_shards = (
self.__results_for_wrapper.imported_shards
)
def __batch_send(self) -> None:
refresh_time: float = 0.01
while (
self.__shut_background_thread_down is not None
and not self.__shut_background_thread_down.is_set()
):
if isinstance(self.__batching_mode, _RateLimitedBatching):
if (
time.time() - self.__time_stamp_last_request
< self.__fix_rate_batching_base_time // self.__concurrent_requests
):
time.sleep(1)
continue
refresh_time = 0
elif isinstance(self.__batching_mode, _DynamicBatching) and self.__vectorizer_batching:
if self.__dynamic_batching_sleep_time > 0:
if (
time.time() - self.__time_stamp_last_request
< self.__dynamic_batching_sleep_time
):
time.sleep(1)
continue
if (
self.__active_requests < self.__concurrent_requests
and len(self.__batch_objects) + len(self.__batch_references) > 0
):
self.__time_stamp_last_request = time.time()
self._batch_send = True
with self.__active_requests_lock:
self.__active_requests += 1
start = time.time()
while (len_o := len(self.__batch_objects)) < self.__recommended_num_objects and (
len_r := len(self.__batch_references)
) < self.__recommended_num_refs:
# wait for more objects to be added up to the recommended number
time.sleep(0.01)
if (
self.__shut_background_thread_down is not None
and self.__shut_background_thread_down.is_set()
):
# shutdown was requested, exit the loop
break
if time.time() - start >= 1 and (
len_o == len(self.__batch_objects) or len_r == len(self.__batch_references)
):
# no new objects were added in the last second, exit the loop
break
objs = self.__batch_objects.pop_items(self.__recommended_num_objects)
refs = self.__batch_references.pop_items(
self.__recommended_num_refs,
uuid_lookup=self.__uuid_lookup,
)
# do not block the thread - the results are written to a central (locked) list and we want to have multiple concurrent batch-requests
ctx = contextvars.copy_context()
self.__executor.submit(
ctx.run,
functools.partial(
self.__send_batch,
objs,
refs,
readd_rate_limit=isinstance(self.__batching_mode, _RateLimitedBatching),
),
)
time.sleep(refresh_time)
def __dynamic_batch_rate_loop(self) -> None:
refresh_time = 1
while (
self.__shut_background_thread_down is not None
and not self.__shut_background_thread_down.is_set()
):
if not isinstance(self.__batching_mode, _DynamicBatching):
return
try:
self.__dynamic_batching()
except Exception as e:
logger.debug(repr(e))
time.sleep(refresh_time)
def __start_bg_threads(self) -> threading.Thread:
"""Create a background thread that periodically checks how congested the batch queue is."""
self.__shut_background_thread_down = threading.Event()
def dynamic_batch_rate_wrapper() -> None:
try:
self.__dynamic_batch_rate_loop()
except Exception as e:
self.__bg_thread_exception = e
demonDynamic = threading.Thread(
target=dynamic_batch_rate_wrapper,
daemon=True,
name="BgDynamicBatchRate",
)
demonDynamic.start()
def batch_send_wrapper() -> None:
try:
self.__batch_send()
except Exception as e:
logger.error(e)
self.__bg_thread_exception = e
demonBatchSend = threading.Thread(
target=batch_send_wrapper,
daemon=True,
name="BgBatchScheduler",
)
demonBatchSend.start()
return demonBatchSend
def __dynamic_batching(self) -> None:
status = self.__cluster.get_nodes_status()
if "batchStats" not in status[0] or "queueLength" not in status[0]["batchStats"]:
# async indexing - just send a lot
self.__batching_mode = _FixedSizeBatching(1000, 10)
self.__recommended_num_objects = 1000
self.__concurrent_requests = 10
return
rate: int = status[0]["batchStats"]["ratePerSecond"]
rate_per_worker = rate / self.__concurrent_requests
batch_length = status[0]["batchStats"]["queueLength"]
self.__rate_queue.append(rate)
if self.__vectorizer_batching:
# slow vectorizer, we want to send larger batches that can take a bit longer, but fewer of them. We might need to sleep
if len(self.__took_queue) > 0 and self._batch_send:
max_took = max(self.__took_queue)
self.__dynamic_batching_sleep_time = 0
if max_took > 2 * BATCH_TIME_TARGET:
self.__concurrent_requests = 1
self.__recommended_num_objects = VECTORIZER_BATCHING_STEP_SIZE
elif max_took > BATCH_TIME_TARGET:
current_step = self.__recommended_num_objects // VECTORIZER_BATCHING_STEP_SIZE
if self.__concurrent_requests > 1:
self.__concurrent_requests -= 1
elif current_step > 1:
self.__recommended_num_objects = VECTORIZER_BATCHING_STEP_SIZE * (
current_step - 1
)
else:
# cannot scale down, sleep a bit
self.__dynamic_batching_sleep_time = max_took - BATCH_TIME_TARGET
elif max_took < 3 * BATCH_TIME_TARGET // 4:
if self.__dynamic_batching_sleep_time > 0:
self.__dynamic_batching_sleep_time = 0
elif self.__concurrent_requests < 3:
self.__concurrent_requests += 1
else:
current_step = (
self.__recommended_num_objects // VECTORIZER_BATCHING_STEP_SIZE
)
self.__recommended_num_objects = VECTORIZER_BATCHING_STEP_SIZE * (
current_step + 1
)
self._batch_send = False
else:
if batch_length == 0: # scale up if queue is empty
self.__recommended_num_objects = min(
self.__recommended_num_objects + 50,
self.__max_batch_size,
)
if (
self.__max_batch_size == self.__recommended_num_objects
and len(self.__batch_objects) > self.__recommended_num_objects
and time.time() - self.__time_last_scale_up > 1
and self.__concurrent_requests < MAX_CONCURRENT_REQUESTS
):
self.__concurrent_requests += 1
self.__time_last_scale_up = time.time()
else:
ratio = batch_length / rate
if 2.1 > ratio > 1.9: # ideal, send exactly as many objects as weaviate can process
self.__recommended_num_objects = math.floor(rate_per_worker)
elif ratio <= 1.9: # we can send more
self.__recommended_num_objects = math.floor(
min(
self.__recommended_num_objects * 1.5,
rate_per_worker * 2 / ratio,
)
)
if self.__max_batch_size == self.__recommended_num_objects:
self.__concurrent_requests += 1
elif ratio < 10: # too high, scale down
self.__recommended_num_objects = math.floor(rate_per_worker * 2 / ratio)
if self.__recommended_num_objects < 100 and self.__concurrent_requests > 2:
self.__concurrent_requests -= 1
else: # way too high, stop sending new batches
self.__recommended_num_objects = 0
self.__concurrent_requests = 2
def __send_batch(
self,
objs: List[BatchObject],
refs: List[BatchReference],
readd_rate_limit: bool,
) -> None:
if (n_objs := len(objs)) > 0:
start = time.time()
try:
response_obj = executor.result(
self.__batch_grpc.objects(
connection=self.__connection,
objects=[obj._to_internal() for obj in objs],
timeout=DEFAULT_REQUEST_TIMEOUT,
max_retries=MAX_RETRIES,
)
)
if response_obj.has_errors:
logger.error(
{
"message": f"Failed to send {len(response_obj.errors)} in a batch of {len(objs)}",
"errors": {err.message for err in response_obj.errors.values()},
}
)
except Exception as e:
errors_obj = {
idx: ErrorObject(message=repr(e), object_=obj) for idx, obj in enumerate(objs)
}
logger.error(
{
"message": f"Failed to send all objects in a batch of {len(objs)}",
"error": repr(e),
}
)
response_obj = BatchObjectReturn(
_all_responses=list(errors_obj.values()),
elapsed_seconds=time.time() - start,
errors=errors_obj,
has_errors=True,
)
readded_uuids = set()
readded_objects = []
highest_retry_count = 0
for i, err in response_obj.errors.items():
if (
(
"support@cohere.com" in err.message
and (
"rate limit" in err.message
or "500 error: internal server error" in err.message
)
)
or (
"OpenAI" in err.message
and (
"Rate limit reached" in err.message
or "on tokens per min (TPM)" in err.message
or "503 error: Service Unavailable." in err.message
or "500 error: The server had an error while processing your request."
in err.message
)
)
or ("failed with status: 503 error" in err.message) # huggingface
):
if err.object_.retry_count > highest_retry_count:
highest_retry_count = err.object_.retry_count
if err.object_.retry_count > 5:
continue # too many retries, give up
err.object_.retry_count += 1
readded_objects.append(i)
if len(readded_objects) > 0:
_Warnings.batch_rate_limit_reached(
response_obj.errors[readded_objects[0]].message,
self.__fix_rate_batching_base_time * (highest_retry_count + 1),
)
readd_objects = [
err.object_ for i, err in response_obj.errors.items() if i in readded_objects
]
readded_uuids = {obj.uuid for obj in readd_objects}
self.__batch_objects.prepend(readd_objects)
new_errors = {
i: err for i, err in response_obj.errors.items() if i not in readded_objects
}
response_obj = BatchObjectReturn(
uuids={
i: uid for i, uid in response_obj.uuids.items() if i not in readded_objects
},
errors=new_errors,
has_errors=len(new_errors) > 0,
_all_responses=[
err
for i, err in enumerate(response_obj.all_responses)
if i not in readded_objects
],
elapsed_seconds=response_obj.elapsed_seconds,
)
if readd_rate_limit:
# for rate limited batching the timing is handled by the outer loop => no sleep here
self.__time_stamp_last_request = (
time.time() + self.__fix_rate_batching_base_time * (highest_retry_count + 1)
) # skip a full minute to recover from the rate limit
self.__fix_rate_batching_base_time += (
1 # increase the base time as the current one is too low
)
else:
# sleep a bit to recover from the rate limit in other cases
time.sleep(2**highest_retry_count)
with self.__uuid_lookup_lock:
self.__uuid_lookup.difference_update(
str(obj.uuid) for obj in objs if obj.uuid not in readded_uuids
)
if (n_obj_errs := len(response_obj.errors)) > 0 and self.__objs_logs_count < 30:
logger.error(
{
"message": f"Failed to send {n_obj_errs} objects in a batch of {n_objs}. Please inspect client.batch.failed_objects or collection.batch.failed_objects for the failed objects.",
}
)
self.__objs_logs_count += 1
if self.__objs_logs_count > 30:
logger.error(
{
"message": "There have been more than 30 failed object batches. Further errors will not be logged.",
}
)
with self.__results_lock:
self.__results_for_wrapper.results.objs += response_obj
self.__results_for_wrapper.failed_objects.extend(response_obj.errors.values())
self.__took_queue.append(time.time() - start)
if (n_refs := len(refs)) > 0:
start = time.time()
try:
response_ref = executor.result(
self.__batch_rest.references(
connection=self.__connection,
references=[ref._to_internal() for ref in refs],
)
)
except Exception as e:
errors_ref = {
idx: ErrorReference(message=repr(e), reference=ref)
for idx, ref in enumerate(refs)
}
response_ref = BatchReferenceReturn(
elapsed_seconds=time.time() - start,
errors=errors_ref,
has_errors=True,
)
if (n_ref_errs := len(response_ref.errors)) > 0 and self.__refs_logs_count < 30:
logger.error(
{
"message": f"Failed to send {n_ref_errs} references in a batch of {n_refs}. Please inspect client.batch.failed_references or collection.batch.failed_references for the failed references.",
"errors": response_ref.errors,
}
)
self.__refs_logs_count += 1
if self.__refs_logs_count > 30:
logger.error(
{
"message": "There have been more than 30 failed reference batches. Further errors will not be logged.",
}
)
with self.__results_lock:
self.__results_for_wrapper.results.refs += response_ref
self.__results_for_wrapper.failed_references.extend(response_ref.errors.values())
with self.__active_requests_lock:
self.__active_requests -= 1
[docs]
def flush(self) -> None:
"""Flush the batch queue and wait for all requests to be finished."""
# bg thread is sending objs+refs automatically, so simply wait for everything to be done
while (
self.__active_requests > 0
or len(self.__batch_objects) > 0
or len(self.__batch_references) > 0
):
time.sleep(0.01)
self.__check_bg_threads_alive()
[docs]
def _add_object(
self,
collection: str,
properties: Optional[WeaviateProperties] = None,
references: Optional[ReferenceInputs] = None,
uuid: Optional[UUID] = None,
vector: Optional[VECTORS] = None,
tenant: Optional[str] = None,
) -> UUID:
self.__check_bg_threads_alive()
try:
batch_object = BatchObject(
collection=collection,
properties=properties,
references=references,
uuid=uuid,
vector=vector,
tenant=tenant,
index=self.__objs_count,
)
self.__objs_count += 1
self.__results_for_wrapper.imported_shards.add(
Shard(collection=collection, tenant=tenant)
)
except ValidationError as e:
raise WeaviateBatchValidationError(repr(e))
self.__uuid_lookup.add(str(batch_object.uuid))
self.__batch_objects.add(batch_object)
# block if queue gets too long or weaviate is overloaded - reading files is faster them sending them so we do
# not need a long queue
while (
self.__recommended_num_objects == 0
or len(self.__batch_objects) >= self.__recommended_num_objects * 2
):
self.__check_bg_threads_alive()
time.sleep(0.01)
assert batch_object.uuid is not None
return batch_object.uuid
[docs]
def _add_reference(
self,
from_object_uuid: UUID,
from_object_collection: str,
from_property_name: str,
to: ReferenceInput,
tenant: Optional[str] = None,
) -> None:
self.__check_bg_threads_alive()
if isinstance(to, ReferenceToMulti):
to_strs: Union[List[str], List[UUID]] = to.uuids_str
elif isinstance(to, str) or isinstance(to, uuid_package.UUID):
to_strs = [to]
else:
to_strs = list(to)
for uid in to_strs:
try:
batch_reference = BatchReference(
from_object_collection=from_object_collection,
from_object_uuid=from_object_uuid,
from_property_name=from_property_name,
to_object_collection=(
to.target_collection if isinstance(to, ReferenceToMulti) else None
),
to_object_uuid=uid,
tenant=tenant,
index=self.__refs_count,
)
self.__refs_count += 1
except ValidationError as e:
raise WeaviateBatchValidationError(repr(e))
self.__batch_references.add(batch_reference)
# block if queue gets too long or weaviate is overloaded
while self.__recommended_num_objects == 0:
time.sleep(0.01) # block if weaviate is overloaded, also do not send any refs
self.__check_bg_threads_alive()
def __check_bg_threads_alive(self) -> None:
if self.__bg_threads.is_alive():
return
raise self.__bg_thread_exception or Exception("Batch thread died unexpectedly")
[docs]
class _BgThreads:
def __init__(self, loop: threading.Thread, recv: threading.Thread):
self.loop = loop
self.recv = recv
self.__started_recv = False
self.__started_loop = False
[docs]
def start_recv(self) -> None:
if not self.__started_recv:
self.recv.start()
self.__started_recv = True
[docs]
def start_loop(self) -> None:
if not self.__started_loop:
self.loop.start()
self.__started_loop = True
[docs]
def is_alive(self) -> bool:
"""Check if the background threads are still alive."""
return self.loop_alive() and self.recv_alive()
[docs]
def loop_alive(self) -> bool:
"""Check if the loop background thread is still alive."""
if self.__started_loop:
return self.loop.is_alive()
return True # not started yet so considered alive
[docs]
def recv_alive(self) -> bool:
"""Check if the recv background thread is still alive."""
if self.__started_recv:
return self.recv.is_alive()
return True # not started yet so considered alive
[docs]
def join(self) -> None:
"""Join the background threads."""
self.loop.join()
self.recv.join()
[docs]
class _ClusterBatch:
def __init__(self, connection: ConnectionSync):
self._connection = connection
[docs]
def get_nodes_status(
self,
) -> List[Node]:
try:
response = executor.result(self._connection.get(path="/nodes"))
except Exception:
return []
response_typed = _decode_json_response_dict(response, "Nodes status")
assert response_typed is not None
nodes = response_typed.get("nodes")
if nodes is None:
return []
return cast(List[Node], nodes)
[docs]
def get_number_of_nodes(self) -> int:
return len(self.get_nodes_status())
[docs]
class _ClusterBatchAsync:
def __init__(self, connection: ConnectionAsync):
self._connection = connection
[docs]
async def get_nodes_status(
self,
) -> List[Node]:
try:
response = await executor.aresult(self._connection.get(path="/nodes"))
except Exception:
return []
response_typed = _decode_json_response_dict(response, "Nodes status")
assert response_typed is not None
nodes = response_typed.get("nodes")
if nodes is None or nodes == []:
raise EmptyResponseException("Nodes status response returned empty")
return cast(List[Node], nodes)
[docs]
async def get_number_of_nodes(self) -> int:
return len(await self.get_nodes_status())