from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
Any,
Coroutine,
Generic,
Literal,
Optional,
Protocol,
TypeVar,
Union,
)
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field, field_validator
from typing_extensions import TypedDict
from weaviate.outputs.query import QueryReturn
from weaviate_agents.classes.core import Usage
from weaviate_agents.utils import print_ask_mode_response, print_query_agent_response
[docs]
class ComparisonOperator(str, Enum):
EQUALS = "="
LESS_THAN = "<"
GREATER_THAN = ">"
LESS_EQUAL = "<="
GREATER_EQUAL = ">="
NOT_EQUALS = "!="
LIKE = "LIKE"
CONTAINS_ANY = "contains_any"
CONTAINS_ALL = "contains_all"
class KnownFilterType(str, Enum):
INTEGER = "integer"
INTEGER_ARRAY = "integer_array"
TEXT = "text"
TEXT_ARRAY = "text_array"
BOOLEAN = "boolean"
BOOLEAN_ARRAY = "boolean_array"
DATE = "date_range"
DATE_ARRAY = "date_array"
GEO = "geo"
IS_NULL = "is_null"
UUID = "uuid"
UUID_ARRAY = "uuid_array"
class KnownPropertyFilterBase(BaseModel):
filter_type: KnownFilterType
property_name: str
[docs]
class IntegerPropertyFilter(KnownPropertyFilterBase):
"""Filter numeric properties using comparison operators."""
filter_type: Literal[KnownFilterType.INTEGER] = Field(
repr=False, default=KnownFilterType.INTEGER
)
operator: ComparisonOperator
value: float
[docs]
class IntegerArrayPropertyFilter(KnownPropertyFilterBase):
"""Filter numeric-array properties using comparison operators."""
filter_type: Literal[KnownFilterType.INTEGER_ARRAY] = Field(
repr=False, default=KnownFilterType.INTEGER_ARRAY
)
operator: ComparisonOperator
value: list[float]
[docs]
class TextPropertyFilter(KnownPropertyFilterBase):
"""Filter text properties using equality or LIKE operators."""
filter_type: Literal[KnownFilterType.TEXT] = Field(
repr=False, default=KnownFilterType.TEXT
)
operator: ComparisonOperator
value: str
[docs]
class TextArrayPropertyFilter(KnownPropertyFilterBase):
"""Filter text-array properties using equality or LIKE operators."""
filter_type: Literal[KnownFilterType.TEXT_ARRAY] = Field(
repr=False, default=KnownFilterType.TEXT_ARRAY
)
operator: ComparisonOperator
value: list[str]
[docs]
class BooleanPropertyFilter(KnownPropertyFilterBase):
"""Filter boolean properties using equality operators."""
filter_type: Literal[KnownFilterType.BOOLEAN] = Field(
repr=False, default=KnownFilterType.BOOLEAN
)
operator: ComparisonOperator
value: bool
[docs]
class BooleanArrayPropertyFilter(KnownPropertyFilterBase):
"""Filter boolean-array properties using equality operators."""
filter_type: Literal[KnownFilterType.BOOLEAN_ARRAY] = Field(
repr=False, default=KnownFilterType.BOOLEAN_ARRAY
)
operator: ComparisonOperator
value: list[bool]
class DateExact(BaseModel):
exact_timestamp: str
operator: ComparisonOperator
class DateRangeFrom(BaseModel):
date_from: str
inclusive_from: bool
class DateRangeTo(BaseModel):
date_to: str
inclusive_to: bool
class DateRangeBetween(BaseModel):
date_from: str
date_to: str
inclusive_from: bool
inclusive_to: bool
[docs]
class DatePropertyFilter(KnownPropertyFilterBase):
"""Filter datetime properties using equality operators."""
filter_type: Literal[KnownFilterType.DATE] = Field(
repr=False, default=KnownFilterType.DATE
)
value: Union[DateExact, DateRangeFrom, DateRangeTo, DateRangeBetween]
[docs]
class DateArrayPropertyFilter(KnownPropertyFilterBase):
"""Filter datetime properties using equality operators."""
filter_type: Literal[KnownFilterType.DATE_ARRAY] = Field(
repr=False, default=KnownFilterType.DATE_ARRAY
)
operator: ComparisonOperator
value: list[str]
[docs]
class GeoPropertyFilter(KnownPropertyFilterBase):
"""Filter geo-coordinates properties."""
filter_type: Literal[KnownFilterType.GEO] = Field(
repr=False, default=KnownFilterType.GEO
)
latitude: float
longitude: float
max_distance_meters: float
[docs]
class UUIDPropertyFilter(KnownPropertyFilterBase):
"""Filter UUID properties."""
filter_type: Literal[KnownFilterType.UUID] = Field(
repr=False, default=KnownFilterType.UUID
)
property_name: str
operator: ComparisonOperator
value: UUID
[docs]
class UUIDArrayPropertyFilter(KnownPropertyFilterBase):
"""Filter UUID array properties."""
filter_type: Literal[KnownFilterType.UUID_ARRAY] = Field(
repr=False, default=KnownFilterType.UUID_ARRAY
)
property_name: str
operator: ComparisonOperator
value: list[UUID]
[docs]
class IsNullPropertyFilter(KnownPropertyFilterBase):
"""Filter by property null state."""
filter_type: Literal[KnownFilterType.IS_NULL] = Field(
repr=False, default=KnownFilterType.IS_NULL
)
is_null: bool
[docs]
class UnknownPropertyFilter(BaseModel):
"""Catch-all filter for unknown filter types, to preserve future back-compatibility."""
model_config = ConfigDict(extra="allow")
filter_type: None
[docs]
@field_validator("filter_type", mode="before")
@classmethod
def ensure_filter_type_unknown(cls, value: Any) -> None:
if value in set(KnownFilterType):
raise ValueError(
f"{value} is an known filter type, but validation failed, "
"so the response was not as expected. "
"Try upgrading the weaviate-agents package to a new version."
)
return None
[docs]
def model_post_init(self, context: Any) -> None:
warnings.warn(
f"The filter_type {self.filter_type} wasn't recognised. "
"Try upgrading the weaviate-agents package to a new version."
)
PropertyFilter = Union[
IntegerPropertyFilter,
IntegerArrayPropertyFilter,
TextPropertyFilter,
TextArrayPropertyFilter,
BooleanPropertyFilter,
BooleanArrayPropertyFilter,
DatePropertyFilter,
DateArrayPropertyFilter,
GeoPropertyFilter,
IsNullPropertyFilter,
UUIDPropertyFilter,
UUIDArrayPropertyFilter,
UnknownPropertyFilter,
]
[docs]
class QueryResult(BaseModel):
queries: list[Union[str, None]]
filters: list[list[PropertyFilter]] = []
filter_operators: Literal["AND", "OR"]
[docs]
class NumericMetrics(str, Enum):
COUNT = "COUNT"
MAX = "MAXIMUM"
MEAN = "MEAN"
MEDIAN = "MEDIAN"
MIN = "MINIMUM"
MODE = "MODE"
SUM = "SUM"
TYPE = "TYPE"
[docs]
class TextMetrics(str, Enum):
COUNT = "COUNT"
TYPE = "TYPE"
TOP_OCCURRENCES = "TOP_OCCURRENCES"
[docs]
class BooleanMetrics(str, Enum):
COUNT = "COUNT"
TYPE = "TYPE"
TOTAL_TRUE = "TOTAL_TRUE"
TOTAL_FALSE = "TOTAL_FALSE"
PERCENTAGE_TRUE = "PERCENTAGE_TRUE"
PERCENTAGE_FALSE = "PERCENTAGE_FALSE"
[docs]
class DateMetrics(str, Enum):
COUNT = "COUNT"
MAX = "MAXIMUM"
MEDIAN = "MEDIAN"
MIN = "MINIMUM"
MODE = "MODE"
class KnownAggregationType(str, Enum):
INTEGER = "integer"
TEXT = "text"
BOOLEAN = "boolean"
DATE = "date"
class KnownPropertyAggregationBase(BaseModel):
aggregation_type: KnownAggregationType
property_name: str
[docs]
class IntegerPropertyAggregation(KnownPropertyAggregationBase):
"""Aggregate numeric properties using statistical functions."""
aggregation_type: Literal[KnownAggregationType.INTEGER] = Field(
repr=False, default=KnownAggregationType.INTEGER
)
metrics: NumericMetrics
[docs]
class TextPropertyAggregation(KnownPropertyAggregationBase):
"""Aggregate text properties using frequency analysis."""
aggregation_type: Literal[KnownAggregationType.TEXT] = Field(
repr=False, default=KnownAggregationType.TEXT
)
metrics: TextMetrics
top_occurrences_limit: Optional[int] = None
[docs]
class BooleanPropertyAggregation(KnownPropertyAggregationBase):
"""Aggregate boolean properties using statistical functions."""
aggregation_type: Literal[KnownAggregationType.BOOLEAN] = Field(
repr=False, default=KnownAggregationType.BOOLEAN
)
metrics: BooleanMetrics
[docs]
class DatePropertyAggregation(KnownPropertyAggregationBase):
"""Aggregate datetime properties using statistical functions."""
aggregation_type: Literal[KnownAggregationType.DATE] = Field(
repr=False, default=KnownAggregationType.DATE
)
metrics: DateMetrics
[docs]
class UnknownPropertyAggregation(BaseModel):
"""Catch-all aggregation for unknown aggregation types, to preserve future back-compatibility."""
model_config = ConfigDict(extra="allow")
aggregation_type: None
[docs]
@field_validator("aggregation_type", mode="before")
@classmethod
def ensure_filter_type_unknown(cls, value: Any) -> None:
if value in set(KnownAggregationType):
raise ValueError(
f"{value} is an known aggregation type, but validation failed, "
"so the response was not as expected. "
"Try upgrading the weaviate-agents package to a new version."
)
return None
[docs]
def model_post_init(self, context: Any) -> None:
warnings.warn(
f"The aggregation_type {self.aggregation_type} wasn't recognised. "
"Try upgrading the weaviate-agents package to a new version."
)
PropertyAggregation = Union[
IntegerPropertyAggregation,
TextPropertyAggregation,
BooleanPropertyAggregation,
DatePropertyAggregation,
UnknownPropertyAggregation,
]
[docs]
class AggregationResult(BaseModel):
"""The aggregations to be performed on a collection in a vector database.
They should be based on the original user query and can include multiple
aggregations across different properties and metrics.
"""
search_query: Optional[str] = None
groupby_property: Optional[str] = None
aggregations: list[PropertyAggregation]
filters: list[PropertyFilter] = []
[docs]
class AggregationResultWithCollection(AggregationResult):
collection: str
[docs]
class QueryResultWithCollection(QueryResult):
collection: str
[docs]
class Source(BaseModel):
object_id: str
collection: str
[docs]
class QueryAgentResponse(BaseModel):
output_type: Literal["final_state"] = "final_state"
original_query: str
collection_names: list[str]
searches: list[list[QueryResultWithCollection]]
aggregations: list[list[AggregationResultWithCollection]]
usage: Usage
total_time: float
is_partial_answer: bool
missing_information: list[str]
final_answer: str
sources: list[Source]
[docs]
def display(self) -> None:
"""Display a pretty-printed summary of the QueryAgentResponse object."""
print_query_agent_response(self)
return None
[docs]
class ModelUnitUsage(BaseModel):
model_units: int
usage_in_plan: bool
remaining_plan_requests: int
[docs]
class FilterAndOr(BaseModel):
combine: Literal["AND", "OR"]
filters: list[Union[PropertyFilter, FilterAndOr]]
[docs]
class QuerySort(BaseModel):
property_name: str
order: Literal["ascending", "descending"]
tie_break: Union[QuerySort, None]
[docs]
class QueryResultWithCollectionNormalized(BaseModel):
query: Union[str, None]
filters: Union[PropertyFilter, FilterAndOr, None]
collection: str
sort_property: Union[QuerySort, None] = None
uuid_value: Union[UUID, None] = None
[docs]
class AggregationResultWithCollectionNormalized(BaseModel):
groupby_property: Union[str, None]
aggregation: PropertyAggregation
filters: Union[PropertyFilter, FilterAndOr, None]
collection: str
[docs]
class SuggestedQuery(BaseModel):
query: str
[docs]
class SuggestQueryResponse(BaseModel):
queries: list[SuggestedQuery]
collection_count: int
usage: ModelUnitUsage
total_time: float
[docs]
class AskModeResponse(BaseModel):
output_type: Literal["final_state"] = "final_state"
searches: list[QueryResultWithCollectionNormalized]
aggregations: list[AggregationResultWithCollectionNormalized]
usage: ModelUnitUsage
total_time: float
is_partial_answer: Union[bool, None]
missing_information: Union[list[str], None]
final_answer: str
sources: Union[list[Source], None]
[docs]
def display(self) -> None:
"""Display a pretty-printed summary of the AskModeResponse object."""
print_ask_mode_response(self)
return None
[docs]
class ResearchModeResponse(BaseModel):
output_type: Literal["final_state"] = "final_state"
final_answer: str
usage: ModelUnitUsage
queries: list[AskModeResponse]
total_time: float
[docs]
class QueryWithCollection(TypedDict):
query: str
collection: str
[docs]
class ProgressDetails(TypedDict, total=False):
queries: list[QueryWithCollection]
[docs]
class ProgressMessage(BaseModel):
output_type: Literal["progress_message"] = "progress_message"
stage: str
message: str
details: ProgressDetails = {}
[docs]
class StreamedTokens(BaseModel):
output_type: Literal["streamed_tokens"] = "streamed_tokens"
delta: str
[docs]
class StreamedThoughts(BaseModel):
output_type: Literal["streamed_thoughts"] = "streamed_thoughts"
delta: str
# This is used as a workaround for not being able to use Self to
# type hint the next() abstract method (so that types on subclasses
# are properly represented) due to supporting Python version < 3.11.
# Suggested in https://peps.python.org/pep-0673/
SearchModeResponseT = TypeVar("SearchModeResponseT", bound="SearchModeResponseBase")
SearcherT = TypeVar("SearcherT")
[docs]
class SearchModeResponseBase(BaseModel, ABC, Generic[SearcherT]):
searches: Optional[list[QueryResultWithCollectionNormalized]] = None
usage: ModelUnitUsage
total_time: float
search_results: QueryReturn
_searcher: SearcherT
[docs]
@abstractmethod
def next(
self: SearchModeResponseT, limit: int = 20, offset: int = 0
) -> Union[SearchModeResponseT, Coroutine[Any, Any, SearchModeResponseT]]:
pass
class _SyncSearcher(Protocol):
def run(self, limit: int = 20, offset: int = 0) -> SearchModeResponse: ...
class _AsyncSearcher(Protocol):
async def run(
self, limit: int = 20, offset: int = 0
) -> AsyncSearchModeResponse: ...
[docs]
class SearchModeResponse(SearchModeResponseBase[_SyncSearcher]):
"""Response for the Query Agent search-only mode.
This contains the results of the search, the usage, and the underlying
searches performed. You can paginate through the results set by calling
the `next` method on this response with different `limit` / `offset` values.
This will result in the same underlying searches being performed each time,
resulting in a consistent results set across pages.
"""
[docs]
def next(self, limit: int = 20, offset: int = 0) -> SearchModeResponse:
"""Paginate the search-only results with the given `limit` and `offset` values.
Args:
limit: The maximum number of results to return. If not specified, this defaults to 20.
offset: The offset to start from. If not specified, the retrieval begins from the first object in the results set.
Returns:
The next ``SearchModeResponse`` page.
"""
return self._searcher.run(limit=limit, offset=offset)
[docs]
class AsyncSearchModeResponse(SearchModeResponseBase[_AsyncSearcher]):
"""Response for the Query Agent search-only mode (async).
This contains the results of the search, the usage, and the underlying
searches performed. You can paginate through the results set by calling
the `next` method on this response with different `limit` / `offset` values.
This will result in the same underlying searches being performed each time,
resulting in a consistent results set across pages.
"""
[docs]
async def next(self, limit: int = 20, offset: int = 0) -> AsyncSearchModeResponse:
"""Paginate the search-only results with the given `limit` and `offset` values.
Args:
limit: The maximum number of results to return. If not specified, this defaults to 20.
offset: The offset to start from. If not specified, the retrieval begins from the first object in the results set.
Returns:
The next ``AsyncSearchModeResponse`` page.
"""
return await self._searcher.run(limit=limit, offset=offset)