"""
ConfigBuilder class definition.
"""
import time
from typing import Dict, Any, cast, TYPE_CHECKING
from requests.exceptions import ConnectionError as RequestsConnectionError
from weaviate.connect import Connection
from weaviate.exceptions import UnexpectedStatusCodeException
from weaviate.util import _capitalize_first_letter, _decode_json_response_dict
if TYPE_CHECKING:
from .classification import Classification
[docs]
class ConfigBuilder:
"""
ConfigBuild class that is used to configure a classification process.
"""
def __init__(self, connection: Connection, classification: "Classification"):
"""
Initialize a ConfigBuilder class instance.
Parameters
----------
connection : weaviate.connect.Connection
Connection object to an active and running weaviate instance.
classification : weaviate.classification.Classification
Classification object to be configured using this ConfigBuilder
instance.
"""
self._connection = connection
self._classification = classification
self._config: Dict[str, Any] = {}
self._wait_for_completion = False
[docs]
def with_type(self, classification_type: str) -> "ConfigBuilder":
"""
Set classification type.
Parameters
----------
classification_type : str
Type of the desired classification.
Returns
-------
ConfigBuilder
Updated ConfigBuilder.
"""
self._config["type"] = classification_type
return self
[docs]
def with_k(self, k: int) -> "ConfigBuilder":
"""
Set k number for the kNN.
Parameters
----------
k : int
Number of objects to use to make a classification guess.
(For kNN)
Returns
-------
ConfigBuilder
Updated ConfigBuilder.
"""
if "settings" not in self._config:
self._config["settings"] = {"k": k}
else:
self._config["settings"]["k"] = k
return self
[docs]
def with_class_name(self, class_name: str) -> "ConfigBuilder":
"""
What Object type to classify.
Parameters
----------
class_name : str
Name of the class to be classified.
Returns
-------
ConfigBuilder
Updated ConfigBuilder.
"""
self._config["class"] = _capitalize_first_letter(class_name)
return self
[docs]
def with_classify_properties(self, classify_properties: list) -> "ConfigBuilder":
"""
Set the classify properties.
Parameters
----------
classify_properties: list
A list of properties to classify.
Returns
-------
ConfigBuilder
Updated ConfigBuilder.
"""
self._config["classifyProperties"] = classify_properties
return self
[docs]
def with_based_on_properties(self, based_on_properties: list) -> "ConfigBuilder":
"""
Set properties to build the classification on.
Parameters
----------
based_on_properties: list
A list of properties to classify on.
Returns
-------
ConfigBuilder
Updated ConfigBuilder.
"""
self._config["basedOnProperties"] = based_on_properties
return self
[docs]
def with_source_where_filter(self, where_filter: dict) -> "ConfigBuilder":
"""
Set Source 'where' Filter.
Parameters
----------
where_filter : dict
Filter to use, as a dict.
Returns
-------
ConfigBuilder
Updated ConfigBuilder.
"""
if "filters" not in self._config:
self._config["filters"] = {}
self._config["filters"]["sourceWhere"] = where_filter
return self
[docs]
def with_training_set_where_filter(self, where_filter: dict) -> "ConfigBuilder":
"""
Set Training set 'where' Filter.
Parameters
----------
where_filter : dict
Filter to use, as a dict.
Returns
-------
ConfigBuilder
Updated ConfigBuilder.
"""
if "filters" not in self._config:
self._config["filters"] = {}
self._config["filters"]["trainingSetWhere"] = where_filter
return self
[docs]
def with_target_where_filter(self, where_filter: dict) -> "ConfigBuilder":
"""
Set Target 'where' Filter.
Parameters
----------
where_filter : dict
Filter to use, as a dict.
Returns
-------
ConfigBuilder
Updated ConfigBuilder.
"""
if "filters" not in self._config:
self._config["filters"] = {}
self._config["filters"]["targetWhere"] = where_filter
return self
[docs]
def with_wait_for_completion(self) -> "ConfigBuilder":
"""
Wait for completion.
Returns
-------
ConfigBuilder
Updated ConfigBuilder.
"""
self._wait_for_completion = True
return self
[docs]
def with_settings(self, settings: dict) -> "ConfigBuilder":
"""
Set settings for the classification. NOTE if you are using 'kNN'
the value 'k' can be set by this method or by 'with_k'.
This method keeps previously set 'settings'.
Parameters
----------
settings: dict
Additional settings to be set/overwritten.
Returns
-------
ConfigBuilder
Updated ConfigBuilder.
"""
if "settings" not in self._config:
self._config["settings"] = settings
else:
for key in settings:
self._config["settings"][key] = settings[key]
return self
def _validate_config(self) -> None:
"""
Validate the current classification configuration.
Raises
------
ValueError
If a mandatory field is not set.
"""
required_fields = ["type", "class", "basedOnProperties", "classifyProperties"]
for field in required_fields:
if field not in self._config:
raise ValueError(f"{field} is not set for this classification")
if "settings" in self._config:
if not isinstance(self._config["settings"], dict):
raise TypeError('"settings" should be of type dict')
if self._config["type"] == "knn":
if "k" not in self._config.get("settings", []):
raise ValueError("k is not set for this classification")
def _start(self) -> dict:
"""
Start the classification based on the configuration set.
Returns
-------
dict
Classification result.
Raises
------
requests.ConnectionError
If the network connection to weaviate fails.
weaviate.UnexpectedStatusCodeException
Unexpected error.
"""
try:
response = self._connection.post(path="/classifications", weaviate_object=self._config)
except RequestsConnectionError as conn_err:
raise RequestsConnectionError("Classification may not started.") from conn_err
if response.status_code == 201:
res = _decode_json_response_dict(response, "Start classification")
assert res is not None
return res
raise UnexpectedStatusCodeException("Start classification", response)
[docs]
def do(self) -> dict:
"""
Start the classification.
Returns
-------
dict
Classification result.
"""
self._validate_config()
response = self._start()
if not self._wait_for_completion:
return response
# wait for completion
classification_uuid = response["id"]
# print(classification_uuid)
while self._classification.is_running(classification_uuid):
time.sleep(2.0)
return cast(dict, self._classification.get(classification_uuid))