from typing import Optional
from uuid import UUID
import httpx
from weaviate.classes.config import DataType
from weaviate.client import WeaviateClient
from weaviate.collections.classes.filters import _Filters
from weaviate_agents.base import _BaseAgent
from weaviate_agents.personalization.classes import (
GetObjectsRequest,
Persona,
PersonaInteraction,
PersonaInteractionResponse,
PersonalizationAgentGetObjectsResponse,
PersonalizationRequest,
)
from weaviate_agents.personalization.query import PersonalizedQuery
[docs]
class PersonalizationAgent(_BaseAgent[WeaviateClient]):
"""An agent for personalizing search results and queries based on persona interactions.
Warning:
Weaviate Agents - Personalization Agent is an early stage alpha product. The API is subject to
breaking changes. Please ensure you are using the latest version of the client.
For more information, see the [Weaviate Agents - Personalization Agent Docs](https://weaviate.io/developers/agents/personalization)
"""
def __init__(
self,
client: WeaviateClient,
reference_collection: str,
agents_host: Optional[str] = None,
vector_name: Optional[str] = None,
timeout: Optional[int] = None,
):
super().__init__(client, agents_host)
self._reference_collection = reference_collection
self._vector_name = vector_name
self._route = "/personalization"
self._timeout = timeout
[docs]
@classmethod
def create(
cls,
client: WeaviateClient,
reference_collection: str,
user_properties: Optional[dict[str, DataType]] = None,
agents_host: Optional[str] = None,
vector_name: Optional[str] = None,
timeout: Optional[int] = None,
) -> "PersonalizationAgent":
"""Create a new Personalization Agent for a collection.
Args:
client: The Weaviate client
reference_collection: The name of the collection to personalize
user_properties: Optional dictionary of user properties and their data types
agents_host: Optional host URL for the agents service
vector_name: Optional name of the vector field to use
timeout: Optional timeout for the request
Returns:
A new instance of the Personalization Agent
"""
agent = cls(
client=client,
reference_collection=reference_collection,
agents_host=agents_host,
vector_name=vector_name,
)
agent._initialize(
reference_collection,
create=True,
user_properties=user_properties,
vector_name=vector_name,
timeout=timeout,
)
return agent
[docs]
@classmethod
def connect(
cls,
client: WeaviateClient,
reference_collection: str,
agents_host: Optional[str] = None,
vector_name: Optional[str] = None,
timeout: Optional[int] = None,
) -> "PersonalizationAgent":
"""Connect to an existing Personalization Agent for a collection.
Args:
client: The Weaviate client
reference_collection: The name of the collection to connect to
agents_host: Optional host URL for the agents service
vector_name: Optional name of the vector field to use
timeout: Optional timeout for the request
Returns:
An instance of the Personalization Agent
"""
agent = cls(
client=client,
reference_collection=reference_collection,
agents_host=agents_host,
vector_name=vector_name,
)
agent._initialize(
reference_collection, create=False, vector_name=vector_name, timeout=timeout
)
return agent
def _initialize(
self,
reference_collection: str,
create: bool = False,
user_properties: Optional[dict[str, DataType]] = None,
vector_name: Optional[str] = None,
timeout: Optional[int] = None,
):
"""Initialize the agent with the given reference collection and user properties.
Args:
reference_collection: The name of the collection to personalize
create: Whether to create a new personalization agent
user_properties: Optional dictionary of user properties and their data types
vector_name: Optional name of the vector field to use
timeout: Optional timeout for the request
"""
request_data = {
"collection_name": reference_collection,
"headers": self._connection.additional_headers,
"persona_properties": user_properties or {},
"item_collection_vector_name": vector_name,
"create": create,
}
response = httpx.post(
f"{self._agents_host}{self._route}/",
headers=self._headers,
json=request_data,
timeout=timeout,
)
if response.is_error:
raise Exception(
f"Failed to initialize personalization agent: {response.text}"
)
[docs]
def add_persona(self, persona: Persona) -> None:
"""Add a persona to the Personalization Agent's persona collection.
Args:
persona: The persona to add. The persona must have a persona_id and properties that match the user properties
defined when the Personalization Agent was created.
"""
request_data = {
"persona": persona.model_dump(mode="json"),
"personalization_request": {
"collection_name": self._reference_collection,
"headers": self._connection.additional_headers,
"item_collection_vector_name": self._vector_name,
"create": False,
},
}
response = httpx.post(
f"{self._agents_host}{self._route}/persona",
headers=self._headers,
json=request_data,
timeout=self._timeout,
)
if response.is_error:
raise Exception(f"Failed to add persona: {response.text}")
[docs]
def update_persona(self, persona: Persona) -> None:
"""Update an existing persona in the Personalization Agent's persona collection.
Args:
persona: The persona to update. The persona must have a persona_id and properties that match
the user properties defined when the Personalization Agent was created.
"""
request_data = {
"persona": persona.model_dump(mode="json"),
"personalization_request": {
"collection_name": self._reference_collection,
"headers": self._connection.additional_headers,
"item_collection_vector_name": self._vector_name,
"create": False,
},
}
response = httpx.put(
f"{self._agents_host}{self._route}/persona",
headers=self._headers,
json=request_data,
timeout=self._timeout,
)
if response.is_error:
raise Exception(f"Failed to update persona: {response.text}")
[docs]
def get_persona(self, persona_id: UUID) -> Persona:
"""Get a persona by persona_id from the Personalization Agent's persona collection.
Args:
persona_id: The ID of the persona to retrieve
Returns:
The retrieved persona
"""
request_data = {
"collection_name": self._reference_collection,
"headers": self._connection.additional_headers,
"item_collection_vector_name": self._vector_name,
"create": False,
}
response = httpx.post(
f"{self._agents_host}{self._route}/persona/{str(persona_id)}",
headers=self._headers,
json=request_data,
timeout=self._timeout,
)
if response.is_error:
raise Exception(f"Failed to get persona: {response.text}")
return Persona(**response.json())
[docs]
def delete_persona(self, persona_id: UUID) -> None:
"""Delete a persona by persona_id from the Personalization Agent's persona collection.
Args:
persona_id: The ID of the persona to delete
"""
request_data = {
"collection_name": self._reference_collection,
"headers": self._connection.additional_headers,
"item_collection_vector_name": self._vector_name,
"create": False,
}
response = httpx.post(
f"{self._agents_host}{self._route}/persona/delete/{str(persona_id)}",
headers=self._headers,
json=request_data,
timeout=self._timeout,
)
if response.is_error:
raise Exception(f"Failed to delete persona: {response.text}")
[docs]
def has_persona(self, persona_id: UUID) -> bool:
"""Check if a persona exists in the Personalization Agent's persona collection.
Args:
persona_id: The ID of the persona to check
Returns:
True if the persona exists, False otherwise
"""
request_data = {
"collection_name": self._reference_collection,
"headers": self._connection.additional_headers,
"item_collection_vector_name": self._vector_name,
"create": False,
}
response = httpx.post(
f"{self._agents_host}{self._route}/persona/{str(persona_id)}/exists",
headers=self._headers,
json=request_data,
timeout=self._timeout,
)
if response.is_error:
raise Exception(f"Failed to check persona existence: {response.text}")
return response.json()["exists"]
[docs]
def add_interactions(
self,
interactions: list[PersonaInteraction],
create_persona_if_not_exists: bool = True,
remove_previous_interactions: bool = False,
) -> None:
"""Add interactions for personas to the Personalization Agent.
Args:
interactions: List of interactions to add. Each interaction can specify
`replace_previous_interactions=True` to replace that specific
item's interaction history.
create_persona_if_not_exists: Whether to create personas that don't exist yet
remove_previous_interactions: Whether to remove previous interactions for all items
in the current batch. Setting this to True is equivalent
to setting `replace_previous_interactions=True` for every
interaction in the batch. Use with caution as it affects
all items in the current batch.
"""
request_data = {
"interactions_request": {
"interactions": [
interaction.model_dump(mode="json") for interaction in interactions
],
"create_persona_if_not_exists": create_persona_if_not_exists,
"remove_previous_interactions": remove_previous_interactions,
},
"personalization_request": {
"collection_name": self._reference_collection,
"headers": self._connection.additional_headers,
"item_collection_vector_name": self._vector_name,
"create": False,
},
}
response = httpx.post(
f"{self._agents_host}{self._route}/interactions",
headers=self._headers,
json=request_data,
timeout=self._timeout,
)
if response.is_error:
raise Exception(f"Failed to add interactions: {response.text}")
[docs]
def get_interactions(
self, persona_id: UUID, interaction_type: str
) -> list[PersonaInteractionResponse]:
"""Get interactions for a specific persona filtered by interaction type.
Args:
persona_id: The ID of the persona to get interactions for
interaction_type: The type of interaction to filter by (e.g. "positive", "negative")
Returns:
List of matching interactions for the persona
"""
request_data = {
"interaction_request": {
"persona_id": str(persona_id),
"interaction_type": interaction_type,
},
"personalization_request": {
"collection_name": self._reference_collection,
"headers": self._connection.additional_headers,
"item_collection_vector_name": self._vector_name,
"create": False,
},
}
response = httpx.post(
f"{self._agents_host}{self._route}/interactions/get",
headers=self._headers,
json=request_data,
timeout=self._timeout,
)
if response.is_error:
raise Exception(f"Failed to get interactions: {response.text}")
return [
PersonaInteractionResponse(**interaction) for interaction in response.json()
]
[docs]
def get_objects(
self,
persona_id: UUID,
limit: int = 10,
recent_interactions_count: int = 100,
exclude_interacted_items: bool = True,
decay_rate: float = 0.1,
exclude_items: list[str] = [],
use_agent_ranking: bool = True,
explain_results: bool = True,
instruction: Optional[str] = None,
filters: Optional[_Filters] = None,
) -> PersonalizationAgentGetObjectsResponse:
"""Get Personalized objects for a specific persona.
Args:
persona_id: The ID of the persona to get objects for
limit: The maximum number of objects to return
recent_interactions_count: The number of recent interactions to consider
exclude_interacted_items: Whether to exclude items that have been interacted with
decay_rate: The decay rate for the personalization algorithm
exclude_items: List of items to exclude from the results
use_agent_ranking: Whether to use agent ranking for the results
explain_results: Whether to explain the results
instruction: Optional instruction to guide the personalization process
filters: Optional filters to apply to the results
"""
objects_request = GetObjectsRequest(
persona_id=persona_id,
limit=limit,
recent_interactions_count=recent_interactions_count,
exclude_interacted_items=exclude_interacted_items,
decay_rate=decay_rate,
exclude_items=exclude_items,
use_agent_ranking=use_agent_ranking,
explain_results=explain_results,
instruction=instruction,
filters=filters,
)
request_data = {
"objects_request": objects_request.model_dump(mode="json"),
"personalization_request": {
"collection_name": self._reference_collection,
"headers": self._connection.additional_headers,
"item_collection_vector_name": self._vector_name,
"create": False,
},
}
response = httpx.post(
f"{self._agents_host}{self._route}/objects",
headers=self._headers,
json=request_data,
timeout=self._timeout,
)
if response.is_error:
raise Exception(f"Failed to get objects: {response.text}")
return PersonalizationAgentGetObjectsResponse(**response.json())
[docs]
@classmethod
def exists(
cls,
client: WeaviateClient,
reference_collection: str,
agents_host: Optional[str] = None,
timeout: Optional[int] = None,
) -> bool:
"""Check if a persona collection exists for a given reference collection.
Args:
client: The Weaviate client
reference_collection: The name of the collection to check
agents_host: Optional host URL for the agents service
timeout: Optional timeout for the request
Returns:
True if the persona collection exists, False otherwise
"""
# Initialize base values from client
base_agent = cls(client, reference_collection, agents_host=agents_host)
response = httpx.get(
f"{base_agent._agents_host}{base_agent._route}/exists/{reference_collection}",
headers=base_agent._headers,
timeout=timeout,
)
if response.is_error:
raise Exception(
f"Failed to check if persona collection exists: {response.text}"
)
return response.json()["persona_collection_exists"]
[docs]
def query(
self,
persona_id: UUID,
strength: float = 0.5,
overfetch_factor: float = 1.5,
recent_interactions_count: int = 100,
decay_rate: float = 0.1,
) -> PersonalizedQuery:
personalization_request = PersonalizationRequest(
collection_name=self._reference_collection,
headers=self._connection.additional_headers,
item_collection_vector_name=self._vector_name,
create=False,
)
return PersonalizedQuery(
agents_host=self._agents_host,
headers=self._headers,
persona_id=persona_id,
personalization_request=personalization_request,
timeout=self._timeout,
strength=strength,
overfetch_factor=overfetch_factor,
recent_interactions_count=recent_interactions_count,
decay_rate=decay_rate,
)