from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Coroutine, Generic, Literal, Optional, TypeVar, Union
from pydantic import BaseModel, ConfigDict, Field, field_validator
from typing_extensions import TypedDict
from weaviate.outputs.query import QueryReturn
from weaviate_agents.classes.core import Usage
from weaviate_agents.utils import print_ask_mode_response, print_query_agent_response
[docs]
class ComparisonOperator(str, Enum):
EQUALS = "="
LESS_THAN = "<"
GREATER_THAN = ">"
LESS_EQUAL = "<="
GREATER_EQUAL = ">="
NOT_EQUALS = "!="
LIKE = "LIKE"
CONTAINS_ANY = "contains_any"
CONTAINS_ALL = "contains_all"
class KnownFilterType(str, Enum):
INTEGER = "integer"
INTEGER_ARRAY = "integer_array"
TEXT = "text"
TEXT_ARRAY = "text_array"
BOOLEAN = "boolean"
BOOLEAN_ARRAY = "boolean_array"
DATE = "date_range"
DATE_ARRAY = "date_array"
GEO = "geo"
IS_NULL = "is_null"
class KnownPropertyFilterBase(BaseModel):
filter_type: KnownFilterType
property_name: str
[docs]
class IntegerPropertyFilter(KnownPropertyFilterBase):
"""Filter numeric properties using comparison operators."""
filter_type: Literal[KnownFilterType.INTEGER] = Field(
repr=False, default=KnownFilterType.INTEGER
)
operator: ComparisonOperator
value: float
[docs]
class IntegerArrayPropertyFilter(KnownPropertyFilterBase):
"""Filter numeric-array properties using comparison operators."""
filter_type: Literal[KnownFilterType.INTEGER_ARRAY] = Field(
repr=False, default=KnownFilterType.INTEGER_ARRAY
)
operator: ComparisonOperator
value: list[float]
[docs]
class TextPropertyFilter(KnownPropertyFilterBase):
"""Filter text properties using equality or LIKE operators."""
filter_type: Literal[KnownFilterType.TEXT] = Field(
repr=False, default=KnownFilterType.TEXT
)
operator: ComparisonOperator
value: str
[docs]
class TextArrayPropertyFilter(KnownPropertyFilterBase):
"""Filter text-array properties using equality or LIKE operators."""
filter_type: Literal[KnownFilterType.TEXT_ARRAY] = Field(
repr=False, default=KnownFilterType.TEXT_ARRAY
)
operator: ComparisonOperator
value: list[str]
[docs]
class BooleanPropertyFilter(KnownPropertyFilterBase):
"""Filter boolean properties using equality operators."""
filter_type: Literal[KnownFilterType.BOOLEAN] = Field(
repr=False, default=KnownFilterType.BOOLEAN
)
operator: ComparisonOperator
value: bool
[docs]
class BooleanArrayPropertyFilter(KnownPropertyFilterBase):
"""Filter boolean-array properties using equality operators."""
filter_type: Literal[KnownFilterType.BOOLEAN_ARRAY] = Field(
repr=False, default=KnownFilterType.BOOLEAN_ARRAY
)
operator: ComparisonOperator
value: list[bool]
class DateExact(BaseModel):
exact_timestamp: str
operator: ComparisonOperator
class DateRangeFrom(BaseModel):
date_from: str
inclusive_from: bool
class DateRangeTo(BaseModel):
date_to: str
inclusive_to: bool
class DateRangeBetween(BaseModel):
date_from: str
date_to: str
inclusive_from: bool
inclusive_to: bool
[docs]
class DatePropertyFilter(KnownPropertyFilterBase):
"""Filter datetime properties using equality operators."""
filter_type: Literal[KnownFilterType.DATE] = Field(
repr=False, default=KnownFilterType.DATE
)
value: Union[DateExact, DateRangeFrom, DateRangeTo, DateRangeBetween]
[docs]
class DateArrayPropertyFilter(KnownPropertyFilterBase):
"""Filter datetime properties using equality operators."""
filter_type: Literal[KnownFilterType.DATE_ARRAY] = Field(
repr=False, default=KnownFilterType.DATE_ARRAY
)
operator: ComparisonOperator
value: list[str]
[docs]
class GeoPropertyFilter(KnownPropertyFilterBase):
"""Filter geo-coordinates properties."""
filter_type: Literal[KnownFilterType.GEO] = Field(
repr=False, default=KnownFilterType.GEO
)
latitude: float
longitude: float
max_distance_meters: float
[docs]
class IsNullPropertyFilter(KnownPropertyFilterBase):
"""Filter by property null state."""
filter_type: Literal[KnownFilterType.IS_NULL] = Field(
repr=False, default=KnownFilterType.IS_NULL
)
is_null: bool
[docs]
class UnknownPropertyFilter(BaseModel):
"""Catch-all filter for unknown filter types, to preserve future back-compatibility."""
model_config = ConfigDict(extra="allow")
filter_type: None
@field_validator("filter_type", mode="before")
@classmethod
def ensure_filter_type_unknown(cls, value: Any) -> None:
if value in set(KnownFilterType):
raise ValueError(
f"{value} is an known filter type, but validation failed, "
"so the response was not as expected. "
"Try upgrading the weaviate-agents package to a new version."
)
return None
[docs]
def model_post_init(self, context: Any) -> None:
warnings.warn(
f"The filter_type {self.filter_type} wasn't recognised. "
"Try upgrading the weaviate-agents package to a new version."
)
PropertyFilter = Union[
IntegerPropertyFilter,
IntegerArrayPropertyFilter,
TextPropertyFilter,
TextArrayPropertyFilter,
BooleanPropertyFilter,
BooleanArrayPropertyFilter,
DatePropertyFilter,
DateArrayPropertyFilter,
GeoPropertyFilter,
IsNullPropertyFilter,
UnknownPropertyFilter,
]
[docs]
class QueryResult(BaseModel):
queries: list[Union[str, None]]
filters: list[list[PropertyFilter]] = []
filter_operators: Literal["AND", "OR"]
[docs]
class NumericMetrics(str, Enum):
COUNT = "COUNT"
MAX = "MAXIMUM"
MEAN = "MEAN"
MEDIAN = "MEDIAN"
MIN = "MINIMUM"
MODE = "MODE"
SUM = "SUM"
TYPE = "TYPE"
[docs]
class TextMetrics(str, Enum):
COUNT = "COUNT"
TYPE = "TYPE"
TOP_OCCURRENCES = "TOP_OCCURRENCES"
[docs]
class BooleanMetrics(str, Enum):
COUNT = "COUNT"
TYPE = "TYPE"
TOTAL_TRUE = "TOTAL_TRUE"
TOTAL_FALSE = "TOTAL_FALSE"
PERCENTAGE_TRUE = "PERCENTAGE_TRUE"
PERCENTAGE_FALSE = "PERCENTAGE_FALSE"
[docs]
class DateMetrics(str, Enum):
COUNT = "COUNT"
MAX = "MAXIMUM"
MEDIAN = "MEDIAN"
MIN = "MINIMUM"
MODE = "MODE"
class KnownAggregationType(str, Enum):
INTEGER = "integer"
TEXT = "text"
BOOLEAN = "boolean"
DATE = "date"
class KnownPropertyAggregationBase(BaseModel):
aggregation_type: KnownAggregationType
property_name: str
[docs]
class IntegerPropertyAggregation(KnownPropertyAggregationBase):
"""Aggregate numeric properties using statistical functions."""
aggregation_type: Literal[KnownAggregationType.INTEGER] = Field(
repr=False, default=KnownAggregationType.INTEGER
)
metrics: NumericMetrics
[docs]
class TextPropertyAggregation(KnownPropertyAggregationBase):
"""Aggregate text properties using frequency analysis."""
aggregation_type: Literal[KnownAggregationType.TEXT] = Field(
repr=False, default=KnownAggregationType.TEXT
)
metrics: TextMetrics
top_occurrences_limit: Optional[int] = None
[docs]
class BooleanPropertyAggregation(KnownPropertyAggregationBase):
"""Aggregate boolean properties using statistical functions."""
aggregation_type: Literal[KnownAggregationType.BOOLEAN] = Field(
repr=False, default=KnownAggregationType.BOOLEAN
)
metrics: BooleanMetrics
[docs]
class DatePropertyAggregation(KnownPropertyAggregationBase):
"""Aggregate datetime properties using statistical functions."""
aggregation_type: Literal[KnownAggregationType.DATE] = Field(
repr=False, default=KnownAggregationType.DATE
)
metrics: DateMetrics
[docs]
class UnknownPropertyAggregation(BaseModel):
"""Catch-all aggregation for unknown aggregation types, to preserve future back-compatibility."""
model_config = ConfigDict(extra="allow")
aggregation_type: None
@field_validator("aggregation_type", mode="before")
@classmethod
def ensure_filter_type_unknown(cls, value: Any) -> None:
if value in set(KnownAggregationType):
raise ValueError(
f"{value} is an known aggregation type, but validation failed, "
"so the response was not as expected. "
"Try upgrading the weaviate-agents package to a new version."
)
return None
[docs]
def model_post_init(self, context: Any) -> None:
warnings.warn(
f"The aggregation_type {self.aggregation_type} wasn't recognised. "
"Try upgrading the weaviate-agents package to a new version."
)
PropertyAggregation = Union[
IntegerPropertyAggregation,
TextPropertyAggregation,
BooleanPropertyAggregation,
DatePropertyAggregation,
UnknownPropertyAggregation,
]
[docs]
class AggregationResult(BaseModel):
"""The aggregations to be performed on a collection in a vector database.
They should be based on the original user query and can include multiple
aggregations across different properties and metrics.
"""
search_query: Optional[str] = None
groupby_property: Optional[str] = None
aggregations: list[PropertyAggregation]
filters: list[PropertyFilter] = []
[docs]
class AggregationResultWithCollection(AggregationResult):
collection: str
[docs]
class QueryResultWithCollection(QueryResult):
collection: str
[docs]
class Source(BaseModel):
object_id: str
collection: str
[docs]
class QueryAgentResponse(BaseModel):
output_type: Literal["final_state"] = "final_state"
original_query: str
collection_names: list[str]
searches: list[list[QueryResultWithCollection]]
aggregations: list[list[AggregationResultWithCollection]]
usage: Usage
total_time: float
is_partial_answer: bool
missing_information: list[str]
final_answer: str
sources: list[Source]
[docs]
def display(self) -> None:
"""Display a pretty-printed summary of the QueryAgentResponse object."""
print_query_agent_response(self)
return None
[docs]
class ModelUnitUsage(BaseModel):
model_units: int
usage_in_plan: bool
remaining_plan_requests: int
[docs]
class FilterAndOr(BaseModel):
combine: Literal["AND", "OR"]
filters: list[Union[PropertyFilter, FilterAndOr]]
[docs]
class QuerySort(BaseModel):
property_name: str
order: Literal["ascending", "descending"]
tie_break: Union[QuerySort, None]
[docs]
class QueryResultWithCollectionNormalized(BaseModel):
query: Union[str, None]
filters: Union[PropertyFilter, FilterAndOr, None]
collection: str
sort_property: Union[QuerySort, None] = None
[docs]
class AggregationResultWithCollectionNormalized(BaseModel):
groupby_property: Union[str, None]
aggregation: PropertyAggregation
filters: Union[PropertyFilter, FilterAndOr, None]
collection: str
[docs]
class AskModeResponse(BaseModel):
output_type: Literal["final_state"] = "final_state"
searches: list[QueryResultWithCollectionNormalized]
aggregations: list[AggregationResultWithCollectionNormalized]
usage: ModelUnitUsage
total_time: float
is_partial_answer: Union[bool, None]
missing_information: Union[list[str], None]
final_answer: str
sources: Union[list[Source], None]
[docs]
def display(self) -> None:
"""Display a pretty-printed summary of the AskModeResponse object."""
print_ask_mode_response(self)
return None
[docs]
class ResearchModeResponse(BaseModel):
output_type: Literal["final_state"] = "final_state"
final_answer: str
usage: ModelUnitUsage
queries: list[AskModeResponse]
total_time: float
[docs]
class QueryWithCollection(TypedDict):
query: str
collection: str
[docs]
class ProgressDetails(TypedDict, total=False):
queries: list[QueryWithCollection]
[docs]
class ProgressMessage(BaseModel):
output_type: Literal["progress_message"] = "progress_message"
stage: str
message: str
details: ProgressDetails = {}
[docs]
class StreamedTokens(BaseModel):
output_type: Literal["streamed_tokens"] = "streamed_tokens"
delta: str
[docs]
class StreamedThoughts(BaseModel):
output_type: Literal["streamed_thoughts"] = "streamed_thoughts"
delta: str
# This is used as a workaround for not being able to use Self to
# type hint the next() abstract method (so that types on subclasses
# are properly represented) due to supporting Python version < 3.11.
# Suggested in https://peps.python.org/pep-0673/
SearchModeResponseT = TypeVar("SearchModeResponseT", bound="SearchModeResponseBase")
SearcherT = TypeVar("SearcherT")
[docs]
class SearchModeResponseBase(BaseModel, ABC, Generic[SearcherT]):
searches: Optional[list[QueryResultWithCollectionNormalized]] = None
usage: ModelUnitUsage
total_time: float
search_results: QueryReturn
_searcher: SearcherT
[docs]
@abstractmethod
def next(
self: SearchModeResponseT, limit: int = 20, offset: int = 0
) -> Union[SearchModeResponseT, Coroutine[Any, Any, SearchModeResponseT]]:
pass