Source code for weaviate_agents.query.search

from __future__ import annotations

from typing import Any, Optional, Union

import httpx

from weaviate_agents.query.classes.collection import QueryAgentCollectionConfig
from weaviate_agents.query.classes.request import (
    ChatMessage,
    ConversationContext,
    SearchModeExecutionRequest,
    SearchModeGenerationRequest,
)
from weaviate_agents.query.classes.response import (
    AsyncSearchModeResponse,
    QueryResultWithCollectionNormalized,
    SearchModeResponse,
)

__all__ = [
    "AsyncQueryAgentSearcher",
    "AsyncSearchModeResponse",
    "QueryAgentSearcher",
    "SearchModeResponse",
]


class _BaseQueryAgentSearcher:
    def __init__(
        self,
        headers: dict[str, Any],
        connection_headers: dict[str, str],
        timeout: int,
        query_url: str,
        query: Union[str, list[ChatMessage]],
        collections: list[Union[str, QueryAgentCollectionConfig]],
        system_prompt: Optional[str],
        diversity_weight: Optional[float] = None,
    ):
        self.headers = headers
        self.connection_headers = connection_headers
        self.timeout = timeout
        self.query_url = query_url
        self.query = query
        self.collections = collections
        self.system_prompt = system_prompt
        self.diversity_weight = diversity_weight
        self._cached_searches: Optional[list[QueryResultWithCollectionNormalized]] = (
            None
        )

    def _get_request_body(self, limit: int, offset: int) -> dict[str, Any]:
        query_request = (
            self.query
            if isinstance(self.query, str)
            else ConversationContext(messages=self.query)
        )
        if self._cached_searches is None:
            return SearchModeGenerationRequest(
                headers=self.connection_headers,
                original_query=query_request,
                collections=self.collections,
                limit=limit,
                offset=offset,
                system_prompt=self.system_prompt,
                diversity_weight=self.diversity_weight,
            ).model_dump(mode="json")
        else:
            return SearchModeExecutionRequest(
                headers=self.connection_headers,
                original_query=query_request,
                collections=self.collections,
                limit=limit,
                offset=offset,
                searches=self._cached_searches,
                diversity_weight=self.diversity_weight,
            ).model_dump(mode="json")


[docs] class QueryAgentSearcher(_BaseQueryAgentSearcher): """A configured searcher for the Query Agent search-only mode. This configured search can be run using the `run` method. This allows you to paginate through the results set multiple times with different `limit` / `offset` values. This will result in the same underlying searches being performed each time, resulting in a consistent results set across pages. For more information, see the [Weaviate Agents - Query Agent Docs](https://weaviate.io/developers/agents/query) """ def _handle_response(self, response: httpx.Response) -> SearchModeResponse: if response.is_error: raise Exception(response.text) parsed_response = SearchModeResponse(**response.json()) if parsed_response.searches: self._cached_searches = parsed_response.searches parsed_response._searcher = self return parsed_response
[docs] def run(self, limit: int = 20, offset: int = 0) -> SearchModeResponse: """Run the search-only agent with the given `limit` and `offset` values. Calling this method multiple times on the same QueryAgentSearcher instance will result in the same underlying searches being performed each time, allowing you to paginate over a consistent results set. Args: limit: The maximum number of results to return. If not specified, this defaults to 20. offset: The offset to start from. If not specified, the retrieval begins from the first object in the results set. Returns: A `SearchModeResponse` object containing the results of the search, the usage, and the underlying searches performed. """ request_body = self._get_request_body(limit, offset) response = httpx.post( self.query_url + "/search_only", headers=self.headers, json=request_body, timeout=self.timeout, ) return self._handle_response(response)
[docs] class AsyncQueryAgentSearcher(_BaseQueryAgentSearcher): """A configured async searcher for the Query Agent search-only mode. This configured search can be run using the `run` method. This allows you to paginate through the results set multiple times with different `limit` / `offset` values. This will result in the same underlying searches being performed each time, resulting in a consistent results set across pages. For more information, see the [Weaviate Agents - Query Agent Docs](https://weaviate.io/developers/agents/query) """ def _handle_response(self, response: httpx.Response) -> AsyncSearchModeResponse: if response.is_error: raise Exception(response.text) parsed_response = AsyncSearchModeResponse(**response.json()) if parsed_response.searches: self._cached_searches = parsed_response.searches parsed_response._searcher = self return parsed_response
[docs] async def run(self, limit: int = 20, offset: int = 0) -> AsyncSearchModeResponse: """Run the search-only agent with the given `limit` and `offset` values. Calling this method multiple times on the same AsyncQueryAgentSearcher instance will result in the same underlying searches being performed each time, allowing you to paginate over a consistent results set. Args: limit: The maximum number of results to return. If not specified, this defaults to 20. offset: The offset to start from. If not specified, the retrieval begins from the first object in the results set. Returns: An `AsyncSearchModeResponse` object containing the results of the search, the usage, and the underlying searches performed. """ request_body = self._get_request_body(limit, offset) async with httpx.AsyncClient() as client: response = await client.post( self.query_url + "/search_only", headers=self.headers, json=request_body, timeout=self.timeout, ) return self._handle_response(response)