Source code for weaviate_agents.query.classes.response

from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
    Any,
    Coroutine,
    Generic,
    Literal,
    Optional,
    Protocol,
    TypeVar,
    Union,
)
from uuid import UUID

from pydantic import BaseModel, ConfigDict, Field, field_validator
from typing_extensions import TypedDict
from weaviate.outputs.query import QueryReturn

from weaviate_agents.classes.core import Usage
from weaviate_agents.utils import print_ask_mode_response, print_query_agent_response


[docs] class ComparisonOperator(str, Enum): EQUALS = "=" LESS_THAN = "<" GREATER_THAN = ">" LESS_EQUAL = "<=" GREATER_EQUAL = ">=" NOT_EQUALS = "!=" LIKE = "LIKE" CONTAINS_ANY = "contains_any" CONTAINS_ALL = "contains_all"
class KnownFilterType(str, Enum): INTEGER = "integer" INTEGER_ARRAY = "integer_array" TEXT = "text" TEXT_ARRAY = "text_array" BOOLEAN = "boolean" BOOLEAN_ARRAY = "boolean_array" DATE = "date_range" DATE_ARRAY = "date_array" GEO = "geo" IS_NULL = "is_null" UUID = "uuid" UUID_ARRAY = "uuid_array" class KnownPropertyFilterBase(BaseModel): filter_type: KnownFilterType property_name: str
[docs] class IntegerPropertyFilter(KnownPropertyFilterBase): """Filter numeric properties using comparison operators.""" filter_type: Literal[KnownFilterType.INTEGER] = Field( repr=False, default=KnownFilterType.INTEGER ) operator: ComparisonOperator value: float
[docs] class IntegerArrayPropertyFilter(KnownPropertyFilterBase): """Filter numeric-array properties using comparison operators.""" filter_type: Literal[KnownFilterType.INTEGER_ARRAY] = Field( repr=False, default=KnownFilterType.INTEGER_ARRAY ) operator: ComparisonOperator value: list[float]
[docs] class TextPropertyFilter(KnownPropertyFilterBase): """Filter text properties using equality or LIKE operators.""" filter_type: Literal[KnownFilterType.TEXT] = Field( repr=False, default=KnownFilterType.TEXT ) operator: ComparisonOperator value: str
[docs] class TextArrayPropertyFilter(KnownPropertyFilterBase): """Filter text-array properties using equality or LIKE operators.""" filter_type: Literal[KnownFilterType.TEXT_ARRAY] = Field( repr=False, default=KnownFilterType.TEXT_ARRAY ) operator: ComparisonOperator value: list[str]
[docs] class BooleanPropertyFilter(KnownPropertyFilterBase): """Filter boolean properties using equality operators.""" filter_type: Literal[KnownFilterType.BOOLEAN] = Field( repr=False, default=KnownFilterType.BOOLEAN ) operator: ComparisonOperator value: bool
[docs] class BooleanArrayPropertyFilter(KnownPropertyFilterBase): """Filter boolean-array properties using equality operators.""" filter_type: Literal[KnownFilterType.BOOLEAN_ARRAY] = Field( repr=False, default=KnownFilterType.BOOLEAN_ARRAY ) operator: ComparisonOperator value: list[bool]
class DateExact(BaseModel): exact_timestamp: str operator: ComparisonOperator class DateRangeFrom(BaseModel): date_from: str inclusive_from: bool class DateRangeTo(BaseModel): date_to: str inclusive_to: bool class DateRangeBetween(BaseModel): date_from: str date_to: str inclusive_from: bool inclusive_to: bool
[docs] class DatePropertyFilter(KnownPropertyFilterBase): """Filter datetime properties using equality operators.""" filter_type: Literal[KnownFilterType.DATE] = Field( repr=False, default=KnownFilterType.DATE ) value: Union[DateExact, DateRangeFrom, DateRangeTo, DateRangeBetween]
[docs] class DateArrayPropertyFilter(KnownPropertyFilterBase): """Filter datetime properties using equality operators.""" filter_type: Literal[KnownFilterType.DATE_ARRAY] = Field( repr=False, default=KnownFilterType.DATE_ARRAY ) operator: ComparisonOperator value: list[str]
[docs] class GeoPropertyFilter(KnownPropertyFilterBase): """Filter geo-coordinates properties.""" filter_type: Literal[KnownFilterType.GEO] = Field( repr=False, default=KnownFilterType.GEO ) latitude: float longitude: float max_distance_meters: float
[docs] class UUIDPropertyFilter(KnownPropertyFilterBase): """Filter UUID properties.""" filter_type: Literal[KnownFilterType.UUID] = Field( repr=False, default=KnownFilterType.UUID ) property_name: str operator: ComparisonOperator value: UUID
[docs] class UUIDArrayPropertyFilter(KnownPropertyFilterBase): """Filter UUID array properties.""" filter_type: Literal[KnownFilterType.UUID_ARRAY] = Field( repr=False, default=KnownFilterType.UUID_ARRAY ) property_name: str operator: ComparisonOperator value: list[UUID]
[docs] class IsNullPropertyFilter(KnownPropertyFilterBase): """Filter by property null state.""" filter_type: Literal[KnownFilterType.IS_NULL] = Field( repr=False, default=KnownFilterType.IS_NULL ) is_null: bool
[docs] class UnknownPropertyFilter(BaseModel): """Catch-all filter for unknown filter types, to preserve future back-compatibility.""" model_config = ConfigDict(extra="allow") filter_type: None
[docs] @field_validator("filter_type", mode="before") @classmethod def ensure_filter_type_unknown(cls, value: Any) -> None: if value in set(KnownFilterType): raise ValueError( f"{value} is an known filter type, but validation failed, " "so the response was not as expected. " "Try upgrading the weaviate-agents package to a new version." ) return None
[docs] def model_post_init(self, context: Any) -> None: warnings.warn( f"The filter_type {self.filter_type} wasn't recognised. " "Try upgrading the weaviate-agents package to a new version." )
PropertyFilter = Union[ IntegerPropertyFilter, IntegerArrayPropertyFilter, TextPropertyFilter, TextArrayPropertyFilter, BooleanPropertyFilter, BooleanArrayPropertyFilter, DatePropertyFilter, DateArrayPropertyFilter, GeoPropertyFilter, IsNullPropertyFilter, UUIDPropertyFilter, UUIDArrayPropertyFilter, UnknownPropertyFilter, ]
[docs] class QueryResult(BaseModel): queries: list[Union[str, None]] filters: list[list[PropertyFilter]] = [] filter_operators: Literal["AND", "OR"]
[docs] class NumericMetrics(str, Enum): COUNT = "COUNT" MAX = "MAXIMUM" MEAN = "MEAN" MEDIAN = "MEDIAN" MIN = "MINIMUM" MODE = "MODE" SUM = "SUM" TYPE = "TYPE"
[docs] class TextMetrics(str, Enum): COUNT = "COUNT" TYPE = "TYPE" TOP_OCCURRENCES = "TOP_OCCURRENCES"
[docs] class BooleanMetrics(str, Enum): COUNT = "COUNT" TYPE = "TYPE" TOTAL_TRUE = "TOTAL_TRUE" TOTAL_FALSE = "TOTAL_FALSE" PERCENTAGE_TRUE = "PERCENTAGE_TRUE" PERCENTAGE_FALSE = "PERCENTAGE_FALSE"
[docs] class DateMetrics(str, Enum): COUNT = "COUNT" MAX = "MAXIMUM" MEDIAN = "MEDIAN" MIN = "MINIMUM" MODE = "MODE"
class KnownAggregationType(str, Enum): INTEGER = "integer" TEXT = "text" BOOLEAN = "boolean" DATE = "date" class KnownPropertyAggregationBase(BaseModel): aggregation_type: KnownAggregationType property_name: str
[docs] class IntegerPropertyAggregation(KnownPropertyAggregationBase): """Aggregate numeric properties using statistical functions.""" aggregation_type: Literal[KnownAggregationType.INTEGER] = Field( repr=False, default=KnownAggregationType.INTEGER ) metrics: NumericMetrics
[docs] class TextPropertyAggregation(KnownPropertyAggregationBase): """Aggregate text properties using frequency analysis.""" aggregation_type: Literal[KnownAggregationType.TEXT] = Field( repr=False, default=KnownAggregationType.TEXT ) metrics: TextMetrics top_occurrences_limit: Optional[int] = None
[docs] class BooleanPropertyAggregation(KnownPropertyAggregationBase): """Aggregate boolean properties using statistical functions.""" aggregation_type: Literal[KnownAggregationType.BOOLEAN] = Field( repr=False, default=KnownAggregationType.BOOLEAN ) metrics: BooleanMetrics
[docs] class DatePropertyAggregation(KnownPropertyAggregationBase): """Aggregate datetime properties using statistical functions.""" aggregation_type: Literal[KnownAggregationType.DATE] = Field( repr=False, default=KnownAggregationType.DATE ) metrics: DateMetrics
[docs] class UnknownPropertyAggregation(BaseModel): """Catch-all aggregation for unknown aggregation types, to preserve future back-compatibility.""" model_config = ConfigDict(extra="allow") aggregation_type: None
[docs] @field_validator("aggregation_type", mode="before") @classmethod def ensure_filter_type_unknown(cls, value: Any) -> None: if value in set(KnownAggregationType): raise ValueError( f"{value} is an known aggregation type, but validation failed, " "so the response was not as expected. " "Try upgrading the weaviate-agents package to a new version." ) return None
[docs] def model_post_init(self, context: Any) -> None: warnings.warn( f"The aggregation_type {self.aggregation_type} wasn't recognised. " "Try upgrading the weaviate-agents package to a new version." )
PropertyAggregation = Union[ IntegerPropertyAggregation, TextPropertyAggregation, BooleanPropertyAggregation, DatePropertyAggregation, UnknownPropertyAggregation, ]
[docs] class AggregationResult(BaseModel): """The aggregations to be performed on a collection in a vector database. They should be based on the original user query and can include multiple aggregations across different properties and metrics. """ search_query: Optional[str] = None groupby_property: Optional[str] = None aggregations: list[PropertyAggregation] filters: list[PropertyFilter] = []
[docs] class AggregationResultWithCollection(AggregationResult): collection: str
[docs] class QueryResultWithCollection(QueryResult): collection: str
[docs] class Source(BaseModel): object_id: str collection: str
[docs] class QueryAgentResponse(BaseModel): output_type: Literal["final_state"] = "final_state" original_query: str collection_names: list[str] searches: list[list[QueryResultWithCollection]] aggregations: list[list[AggregationResultWithCollection]] usage: Usage total_time: float is_partial_answer: bool missing_information: list[str] final_answer: str sources: list[Source]
[docs] def display(self) -> None: """Display a pretty-printed summary of the QueryAgentResponse object.""" print_query_agent_response(self) return None
[docs] class ModelUnitUsage(BaseModel): model_units: int usage_in_plan: bool remaining_plan_requests: int
[docs] class FilterAndOr(BaseModel): combine: Literal["AND", "OR"] filters: list[Union[PropertyFilter, FilterAndOr]]
[docs] class QuerySort(BaseModel): property_name: str order: Literal["ascending", "descending"] tie_break: Union[QuerySort, None]
[docs] class QueryResultWithCollectionNormalized(BaseModel): query: Union[str, None] filters: Union[PropertyFilter, FilterAndOr, None] collection: str sort_property: Union[QuerySort, None] = None uuid_value: Union[UUID, None] = None
[docs] class AggregationResultWithCollectionNormalized(BaseModel): groupby_property: Union[str, None] aggregation: PropertyAggregation filters: Union[PropertyFilter, FilterAndOr, None] collection: str
[docs] class SuggestedQuery(BaseModel): query: str
[docs] class SuggestQueryResponse(BaseModel): queries: list[SuggestedQuery] collection_count: int usage: ModelUnitUsage total_time: float
[docs] class AskModeResponse(BaseModel): output_type: Literal["final_state"] = "final_state" searches: list[QueryResultWithCollectionNormalized] aggregations: list[AggregationResultWithCollectionNormalized] usage: ModelUnitUsage total_time: float is_partial_answer: Union[bool, None] missing_information: Union[list[str], None] final_answer: str sources: Union[list[Source], None]
[docs] def display(self) -> None: """Display a pretty-printed summary of the AskModeResponse object.""" print_ask_mode_response(self) return None
[docs] class ResearchModeResponse(BaseModel): output_type: Literal["final_state"] = "final_state" final_answer: str usage: ModelUnitUsage queries: list[AskModeResponse] total_time: float
[docs] class QueryWithCollection(TypedDict): query: str collection: str
[docs] class ProgressDetails(TypedDict, total=False): queries: list[QueryWithCollection]
[docs] class ProgressMessage(BaseModel): output_type: Literal["progress_message"] = "progress_message" stage: str message: str details: ProgressDetails = {}
[docs] class StreamedTokens(BaseModel): output_type: Literal["streamed_tokens"] = "streamed_tokens" delta: str
[docs] class StreamedThoughts(BaseModel): output_type: Literal["streamed_thoughts"] = "streamed_thoughts" delta: str
# This is used as a workaround for not being able to use Self to # type hint the next() abstract method (so that types on subclasses # are properly represented) due to supporting Python version < 3.11. # Suggested in https://peps.python.org/pep-0673/ SearchModeResponseT = TypeVar("SearchModeResponseT", bound="SearchModeResponseBase") SearcherT = TypeVar("SearcherT")
[docs] class SearchModeResponseBase(BaseModel, ABC, Generic[SearcherT]): searches: Optional[list[QueryResultWithCollectionNormalized]] = None usage: ModelUnitUsage total_time: float search_results: QueryReturn _searcher: SearcherT
[docs] @abstractmethod def next( self: SearchModeResponseT, limit: int = 20, offset: int = 0 ) -> Union[SearchModeResponseT, Coroutine[Any, Any, SearchModeResponseT]]: pass
class _SyncSearcher(Protocol): def run(self, limit: int = 20, offset: int = 0) -> SearchModeResponse: ... class _AsyncSearcher(Protocol): async def run( self, limit: int = 20, offset: int = 0 ) -> AsyncSearchModeResponse: ...
[docs] class SearchModeResponse(SearchModeResponseBase[_SyncSearcher]): """Response for the Query Agent search-only mode. This contains the results of the search, the usage, and the underlying searches performed. You can paginate through the results set by calling the `next` method on this response with different `limit` / `offset` values. This will result in the same underlying searches being performed each time, resulting in a consistent results set across pages. """
[docs] def next(self, limit: int = 20, offset: int = 0) -> SearchModeResponse: """Paginate the search-only results with the given `limit` and `offset` values. Args: limit: The maximum number of results to return. If not specified, this defaults to 20. offset: The offset to start from. If not specified, the retrieval begins from the first object in the results set. Returns: The next ``SearchModeResponse`` page. """ return self._searcher.run(limit=limit, offset=offset)
[docs] class AsyncSearchModeResponse(SearchModeResponseBase[_AsyncSearcher]): """Response for the Query Agent search-only mode (async). This contains the results of the search, the usage, and the underlying searches performed. You can paginate through the results set by calling the `next` method on this response with different `limit` / `offset` values. This will result in the same underlying searches being performed each time, resulting in a consistent results set across pages. """
[docs] async def next(self, limit: int = 20, offset: int = 0) -> AsyncSearchModeResponse: """Paginate the search-only results with the given `limit` and `offset` values. Args: limit: The maximum number of results to return. If not specified, this defaults to 20. offset: The offset to start from. If not specified, the retrieval begins from the first object in the results set. Returns: The next ``AsyncSearchModeResponse`` page. """ return await self._searcher.run(limit=limit, offset=offset)