Source code for weaviate_agents.query.query_agent

from abc import ABC, abstractmethod
from json import JSONDecodeError
from typing import (
    Any,
    AsyncGenerator,
    Coroutine,
    Generator,
    Generic,
    Literal,
    Optional,
    Union,
    overload,
)

import httpx
from httpx_sse import ServerSentEvent, aconnect_sse, connect_sse
from typing_extensions import deprecated
from weaviate.client import WeaviateAsyncClient, WeaviateClient

from weaviate_agents.base import ClientType, _BaseAgent
from weaviate_agents.query.classes import (
    AskModeResponse,
    ProgressMessage,
    QueryAgentCollectionConfig,
    QueryAgentResponse,
    StreamedTokens,
)
from weaviate_agents.query.classes.request import ChatMessage, ConversationContext
from weaviate_agents.query.search import (
    AsyncQueryAgentSearcher,
    AsyncSearchModeResponse,
    QueryAgentSearcher,
    SearchModeResponse,
)


class _BaseQueryAgent(Generic[ClientType], _BaseAgent[ClientType], ABC):
    """An agent for executing agentic queries against Weaviate.

    For more information, see the [Weaviate Agents - Query Agent Docs](https://weaviate.io/developers/agents/query)
    """

    def __init__(
        self,
        client: ClientType,
        collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None,
        agents_host: Union[str, None] = None,
        system_prompt: Union[str, None] = None,
        timeout: Union[int, None] = None,
    ):
        """Initialize the Query Agent.

        Args:
            client: The Weaviate client connected to a Weaviate Cloud cluster.
            collections: The collections to query. Will be overriden if passed in the `run` method.
            agents_host: Optional host of the agents service.
            system_prompt: Optional system prompt for the agent.
            timeout: The timeout for the request. Defaults to 60 seconds.
        """
        super().__init__(client, agents_host)

        self._collections = collections
        self._system_prompt = system_prompt

        self._timeout = 60 if timeout is None else timeout
        self.agent_url = f"{self._agents_host}/agent"
        self.query_url = f"{self._agents_host}/query"

    def _prepare_request_body(
        self,
        query: Union[str, list[ChatMessage]],
        collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None,
        context: Optional[QueryAgentResponse] = None,
        **kwargs,
    ) -> dict:
        """Prepare the request body for the query.

        Args:
            query: The natural language query string for the agent.
            collections: The collections to query. Will override any collections if passed in the constructor.
            context: Optional previous response from the agent.
            **kwargs: Additional keyword arguments to pass to the request body.
        """
        collections = collections or self._collections
        if not collections:
            raise ValueError("No collections provided to the query agent.")

        query_request = (
            query
            if isinstance(query, str)
            else ConversationContext(messages=query).model_dump(mode="json")
        )
        output = {
            "query": query_request,
            "collections": [
                collection
                if isinstance(collection, str)
                else collection.model_dump(mode="json")
                for collection in collections
            ],
            "headers": self._connection.additional_headers,
            "limit": 20,
            "system_prompt": self._system_prompt,
            **kwargs,
        }
        if context is not None:
            output["previous_response"] = context.model_dump(mode="json")
        return output

    @deprecated(
        "QueryAgent.run() is deprecated and will be removed in a future release. "
        "Use QueryAgent.ask() instead."
    )
    @abstractmethod
    def run(
        self,
        query: str,
        collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None,
        context: Optional[QueryAgentResponse] = None,
    ) -> Union[QueryAgentResponse, Coroutine[Any, Any, QueryAgentResponse]]:
        """Run the query agent.

        Deprecated:
            The `run` method is deprecated; use `ask()` instead.
        """
        pass

    @abstractmethod
    def ask(
        self,
        query: Union[str, list[ChatMessage]],
        collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None,
    ) -> Union[AskModeResponse, Coroutine[Any, Any, AskModeResponse]]:
        """Run the Query Agent ask mode."""
        pass

    @overload
    def stream(
        self,
        query: str,
        collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None,
        context: Optional[QueryAgentResponse] = None,
        include_progress: Literal[True] = True,
        include_final_state: Literal[True] = True,
    ) -> Union[
        Generator[
            Union[ProgressMessage, StreamedTokens, QueryAgentResponse], None, None
        ],
        AsyncGenerator[
            Union[ProgressMessage, StreamedTokens, QueryAgentResponse], None
        ],
    ]:
        pass

    @overload
    def stream(
        self,
        query: str,
        collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None,
        context: Optional[QueryAgentResponse] = None,
        include_progress: Literal[True] = True,
        include_final_state: Literal[False] = False,
    ) -> Union[
        Generator[Union[ProgressMessage, StreamedTokens], None, None],
        AsyncGenerator[Union[ProgressMessage, StreamedTokens], None],
    ]:
        pass

    @overload
    def stream(
        self,
        query: str,
        collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None,
        context: Optional[QueryAgentResponse] = None,
        include_progress: Literal[False] = False,
        include_final_state: Literal[True] = True,
    ) -> Union[
        Generator[Union[StreamedTokens, QueryAgentResponse], None, None],
        AsyncGenerator[Union[StreamedTokens, QueryAgentResponse], None],
    ]:
        pass

    @overload
    def stream(
        self,
        query: str,
        collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None,
        context: Optional[QueryAgentResponse] = None,
        include_progress: Literal[False] = False,
        include_final_state: Literal[False] = False,
    ) -> Union[
        Generator[StreamedTokens, None, None],
        AsyncGenerator[StreamedTokens, None],
    ]:
        pass

    @deprecated(
        "QueryAgent.stream() is deprecated and will be removed in a future release. "
        "Use QueryAgent.ask_stream() instead."
    )
    @abstractmethod
    def stream(
        self,
        query: str,
        collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None,
        context: Optional[QueryAgentResponse] = None,
        include_progress: bool = True,
        include_final_state: bool = True,
    ) -> Union[
        Generator[
            Union[ProgressMessage, StreamedTokens, QueryAgentResponse], None, None
        ],
        AsyncGenerator[
            Union[ProgressMessage, StreamedTokens, QueryAgentResponse], None
        ],
    ]:
        """Stream from the query agent.

        Deprecated:
            The `stream` method is deprecated; use `ask_stream()` instead.
        """
        pass

    @overload
    def ask_stream(
        self,
        query: Union[str, list[ChatMessage]],
        collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None,
        include_progress: Literal[True] = True,
        include_final_state: Literal[True] = True,
    ) -> Union[
        Generator[Union[ProgressMessage, StreamedTokens, AskModeResponse], None, None],
        AsyncGenerator[Union[ProgressMessage, StreamedTokens, AskModeResponse], None],
    ]: ...

    @overload
    def ask_stream(
        self,
        query: Union[str, list[ChatMessage]],
        collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None,
        include_progress: Literal[True] = True,
        include_final_state: Literal[False] = False,
    ) -> Union[
        Generator[Union[ProgressMessage, StreamedTokens], None, None],
        AsyncGenerator[Union[ProgressMessage, StreamedTokens], None],
    ]: ...

    @overload
    def ask_stream(
        self,
        query: Union[str, list[ChatMessage]],
        collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None,
        include_progress: Literal[False] = False,
        include_final_state: Literal[True] = True,
    ) -> Union[
        Generator[Union[StreamedTokens, AskModeResponse], None, None],
        AsyncGenerator[Union[StreamedTokens, AskModeResponse], None],
    ]: ...

    @overload
    def ask_stream(
        self,
        query: Union[str, list[ChatMessage]],
        collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None,
        include_progress: Literal[False] = False,
        include_final_state: Literal[False] = False,
    ) -> Union[
        Generator[StreamedTokens, None, None],
        AsyncGenerator[StreamedTokens, None],
    ]: ...

    @abstractmethod
    def ask_stream(
        self,
        query: Union[str, list[ChatMessage]],
        collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None,
        include_progress: bool = True,
        include_final_state: bool = True,
    ) -> Union[
        Generator[Union[ProgressMessage, StreamedTokens, AskModeResponse], None, None],
        AsyncGenerator[Union[ProgressMessage, StreamedTokens, AskModeResponse], None],
    ]:
        """Run the Query Agent ask mode and stream the response."""
        pass

    @abstractmethod
    def search(
        self,
        query: Union[str, list[ChatMessage]],
        limit: int = 20,
        collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None,
    ) -> Union[SearchModeResponse, Coroutine[Any, Any, AsyncSearchModeResponse]]:
        pass


[docs] class QueryAgent(_BaseQueryAgent[WeaviateClient]): """An agent for executing agentic queries against Weaviate. For more information, see the [Weaviate Agents - Query Agent Docs](https://weaviate.io/developers/agents/query) """
[docs] @deprecated( "QueryAgent.run() is deprecated and will be removed in a future release. Use QueryAgent.ask() instead." ) def run( self, query: str, collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, context: Optional[QueryAgentResponse] = None, ) -> QueryAgentResponse: """Run the query agent. Args: query: The natural language query string for the agent. collections: The collections to query. Will override any collections if passed in the constructor. context: Optional previous response from the agent. """ request_body = self._prepare_request_body( query=query, collections=collections, context=context ) response = httpx.post( self.agent_url + "/query", headers=self._headers, json=request_body, timeout=self._timeout, ) if response.is_error: raise Exception(response.text) return QueryAgentResponse(**response.json())
[docs] def ask( self, query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, ) -> AskModeResponse: request_body = self._prepare_request_body(query=query, collections=collections) response = httpx.post( self.query_url + "/ask", headers=self._headers, json=request_body, timeout=self._timeout, ) if response.is_error: raise Exception(response.text) return AskModeResponse(**response.json())
@overload def stream( self, query: str, collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, context: Optional[QueryAgentResponse] = None, include_progress: Literal[True] = True, include_final_state: Literal[True] = True, ) -> Generator[ Union[ProgressMessage, StreamedTokens, QueryAgentResponse], None, None ]: ... @overload def stream( self, query: str, collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, context: Optional[QueryAgentResponse] = None, include_progress: Literal[True] = True, include_final_state: Literal[False] = False, ) -> Generator[Union[ProgressMessage, StreamedTokens], None, None]: ... @overload def stream( self, query: str, collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, context: Optional[QueryAgentResponse] = None, include_progress: Literal[False] = False, include_final_state: Literal[True] = True, ) -> Generator[Union[StreamedTokens, QueryAgentResponse], None, None]: ... @overload def stream( self, query: str, collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, context: Optional[QueryAgentResponse] = None, include_progress: Literal[False] = False, include_final_state: Literal[False] = False, ) -> Generator[StreamedTokens, None, None]: ...
[docs] @deprecated( "QueryAgent.stream() is deprecated and will be removed in a future release. Use QueryAgent.ask_stream() instead." ) def stream( self, query: str, collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, context: Optional[QueryAgentResponse] = None, include_progress: bool = True, include_final_state: bool = True, ): request_body = self._prepare_request_body( query=query, collections=collections, context=context, include_progress=include_progress, include_final_state=include_final_state, ) with httpx.Client() as client: with connect_sse( client=client, method="POST", url=self.agent_url + "/stream_query", json=request_body, headers=self._headers, timeout=self._timeout, ) as events: if events.response.is_error: events.response.read() raise Exception(events.response.text) for sse in events.iter_sse(): output = _parse_sse(sse, mode="query") if isinstance(output, ProgressMessage): if include_progress: yield output elif isinstance(output, QueryAgentResponse): if include_final_state: yield output else: yield output
@overload def ask_stream( self, query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, include_progress: Literal[True] = True, include_final_state: Literal[True] = True, ) -> Generator[ Union[ProgressMessage, StreamedTokens, AskModeResponse], None, None ]: ... @overload def ask_stream( self, query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, include_progress: Literal[True] = True, include_final_state: Literal[False] = False, ) -> Generator[Union[ProgressMessage, StreamedTokens], None, None]: ... @overload def ask_stream( self, query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, include_progress: Literal[False] = False, include_final_state: Literal[True] = True, ) -> Generator[Union[StreamedTokens, AskModeResponse], None, None]: ... @overload def ask_stream( self, query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, include_progress: Literal[False] = False, include_final_state: Literal[False] = False, ) -> Generator[StreamedTokens, None, None]: ...
[docs] def ask_stream( self, query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, include_progress: bool = True, include_final_state: bool = True, ): """Run the Query Agent ask mode and stream the response.""" request_body = self._prepare_request_body( query=query, collections=collections, include_progress=include_progress, include_final_state=include_final_state, ) with httpx.Client() as client: with connect_sse( client=client, method="POST", url=self.query_url + "/stream_ask", json=request_body, headers=self._headers, timeout=self._timeout, ) as events: if events.response.is_error: events.response.read() raise Exception(events.response.text) for sse in events.iter_sse(): output = _parse_sse(sse, mode="ask") if isinstance(output, ProgressMessage): if include_progress: yield output elif isinstance(output, AskModeResponse): if include_final_state: yield output else: yield output
[docs] def search( self, query: Union[str, list[ChatMessage]], limit: int = 20, collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, ) -> SearchModeResponse: """Run the Query Agent search-only mode. This method sends the initial search request and returns a `SearchModeResponse` containing the first page of results. To paginate, use the `SearchModeResponse.next()` method. This reuses the same underlying searches to ensure a consistent result set across pages. Args: query: The natural language query string for the agent. limit: The maximum number of results to return for the first page. collections: The collections to query. Overrides any collections provided in the constructor when set. Returns: A `SearchModeResponse` for the first page of results. Use `response.next(limit=..., offset=...)` to paginate. """ collections = collections or self._collections if not collections: raise ValueError("No collections provided to the query agent.") searcher = QueryAgentSearcher( headers=self._headers, connection_headers=self._connection.additional_headers, timeout=self._timeout, query_url=self.query_url, query=query, collections=collections, system_prompt=self._system_prompt, ) return searcher.run(limit=limit)
[docs] class AsyncQueryAgent(_BaseQueryAgent[WeaviateAsyncClient]): """An agent for executing agentic queries against Weaviate. For more information, see the [Weaviate Agents - Query Agent Docs](https://weaviate.io/developers/agents/query) """
[docs] @deprecated( "AsyncQueryAgent.run() is deprecated and will be removed in a future release. " "Use AsyncQueryAgent.ask() instead." ) async def run( self, query: str, collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, context: Optional[QueryAgentResponse] = None, ) -> QueryAgentResponse: """Run the query agent. Args: query: The natural language query string for the agent. collections: The collections to query. Will override any collections if passed in the constructor. context: Optional previous response from the agent. """ request_body = self._prepare_request_body( query=query, collections=collections, context=context ) async with httpx.AsyncClient() as client: response = await client.post( self.agent_url + "/query", headers=self._headers, json=request_body, timeout=self._timeout, ) if response.is_error: raise Exception(response.text) return QueryAgentResponse(**response.json())
[docs] async def ask( self, query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, ) -> AskModeResponse: request_body = self._prepare_request_body(query=query, collections=collections) async with httpx.AsyncClient() as client: response = await client.post( self.query_url + "/ask", headers=self._headers, json=request_body, timeout=self._timeout, ) if response.is_error: raise Exception(response.text) return AskModeResponse(**response.json())
@overload def stream( self, query: str, collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, context: Optional[QueryAgentResponse] = None, include_progress: Literal[True] = True, include_final_state: Literal[True] = True, ) -> AsyncGenerator[ Union[ProgressMessage, StreamedTokens, QueryAgentResponse], None ]: ... @overload def stream( self, query: str, collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, context: Optional[QueryAgentResponse] = None, include_progress: Literal[True] = True, include_final_state: Literal[False] = False, ) -> AsyncGenerator[Union[ProgressMessage, StreamedTokens], None]: ... @overload def stream( self, query: str, collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, context: Optional[QueryAgentResponse] = None, include_progress: Literal[False] = False, include_final_state: Literal[True] = True, ) -> AsyncGenerator[Union[StreamedTokens, QueryAgentResponse], None]: ... @overload def stream( self, query: str, collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, context: Optional[QueryAgentResponse] = None, include_progress: Literal[False] = False, include_final_state: Literal[False] = False, ) -> AsyncGenerator[StreamedTokens, None]: ...
[docs] @deprecated( "AsyncQueryAgent.stream() is deprecated and will be removed in a future release. " "Use AsyncQueryAgent.ask_stream() instead." ) async def stream( self, query: str, collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, context: Optional[QueryAgentResponse] = None, include_progress: bool = True, include_final_state: bool = True, ): request_body = self._prepare_request_body( query=query, collections=collections, context=context, include_progress=include_progress, include_final_state=include_final_state, ) async with httpx.AsyncClient() as client: async with aconnect_sse( client=client, method="POST", url=self.agent_url + "/stream_query", json=request_body, headers=self._headers, timeout=self._timeout, ) as events: if events.response.is_error: await events.response.aread() raise Exception(events.response.text) async for sse in events.aiter_sse(): output = _parse_sse(sse, mode="query") if isinstance(output, ProgressMessage): if include_progress: yield output elif isinstance(output, QueryAgentResponse): if include_final_state: yield output else: yield output
@overload def ask_stream( self, query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, include_progress: Literal[True] = True, include_final_state: Literal[True] = True, ) -> AsyncGenerator[ Union[ProgressMessage, StreamedTokens, AskModeResponse], None ]: ... @overload def ask_stream( self, query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, include_progress: Literal[True] = True, include_final_state: Literal[False] = False, ) -> AsyncGenerator[Union[ProgressMessage, StreamedTokens], None]: ... @overload def ask_stream( self, query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, include_progress: Literal[False] = False, include_final_state: Literal[True] = True, ) -> AsyncGenerator[Union[StreamedTokens, AskModeResponse], None]: ... @overload def ask_stream( self, query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, include_progress: Literal[False] = False, include_final_state: Literal[False] = False, ) -> AsyncGenerator[StreamedTokens, None]: ...
[docs] async def ask_stream( self, query: Union[str, list[ChatMessage]], collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, include_progress: bool = True, include_final_state: bool = True, ): """Run the Query Agent ask mode and stream the response.""" request_body = self._prepare_request_body( query=query, collections=collections, include_progress=include_progress, include_final_state=include_final_state, ) async with httpx.AsyncClient() as client: async with aconnect_sse( client=client, method="POST", url=self.query_url + "/stream_ask", json=request_body, headers=self._headers, timeout=self._timeout, ) as events: if events.response.is_error: await events.response.aread() raise Exception(events.response.text) async for sse in events.aiter_sse(): output = _parse_sse(sse, mode="ask") if isinstance(output, ProgressMessage): if include_progress: yield output elif isinstance(output, AskModeResponse): if include_final_state: yield output else: yield output
[docs] async def search( self, query: Union[str, list[ChatMessage]], limit: int = 20, collections: Union[list[Union[str, QueryAgentCollectionConfig]], None] = None, ) -> AsyncSearchModeResponse: """Run the Query Agent search-only mode. This method sends the initial search request and returns an `AsyncSearchModeResponse` containing the first page of results. To paginate, use the `AsyncSearchModeResponse.next()` method. This reuses the same underlying searches to ensure a consistent result set across pages. Args: query: The natural language query string for the agent. limit: The maximum number of results to return for the first page. collections: The collections to query. Overrides any collections provided in the constructor when set. Returns: An `AsyncSearchModeResponse` for the first page of results. Use `await response.next(limit=..., offset=...)` to paginate. """ collections = collections or self._collections if not collections: raise ValueError("No collections provided to the query agent.") searcher = AsyncQueryAgentSearcher( headers=self._headers, connection_headers=self._connection.additional_headers, timeout=self._timeout, query_url=self.query_url, query=query, collections=collections, system_prompt=self._system_prompt, ) return await searcher.run(limit=limit)
@overload def _parse_sse( sse: ServerSentEvent, mode: Literal["query"] ) -> Union[ProgressMessage, StreamedTokens, QueryAgentResponse]: ... @overload def _parse_sse( sse: ServerSentEvent, mode: Literal["ask"] ) -> Union[ProgressMessage, StreamedTokens, AskModeResponse]: ... def _parse_sse( sse: ServerSentEvent, mode: Literal["query", "ask"] ) -> Union[ProgressMessage, StreamedTokens, QueryAgentResponse, AskModeResponse]: try: data = sse.json() except JSONDecodeError: raise Exception(f"Unable to decode response: {sse.event=}, {sse.data=}") if sse.event == "error": raise Exception(str(data["error"])) elif sse.event == "progress_message": return ProgressMessage.model_validate(data) elif sse.event == "streamed_tokens": return StreamedTokens.model_validate(data) elif sse.event == "final_state": if mode == "query": return QueryAgentResponse.model_validate(data) elif mode == "ask": return AskModeResponse.model_validate(data) else: raise Exception( f"Unrecognised event type in response: {sse.event=}, {sse.data=}" )