Create a vector search retriever tool
Important
This feature is in Public Preview.
Learn how to use Mosaic AI Agent Framework to create retrievers. A retriever is a type of agent tool that finds and returns relevant documents using a Vector Search index. Retrievers are a core component of RAG (Retrieval Augmented Generation) applications.
Requirements
- MLflow
Document
is only available on MLflow version 2.14.0 and above. - An existing Vector Search Index.
PyFunc retriever example
The following example uses databricks-vectorsearch to create a basic retriever that performs a Vector Search similarity search with filters. It uses MLflow decorators to enable agent tracing.
The retriever function should return a Document type. Use the metadata
field in the Document class to add additional attributes to the returned document, like like doc_uri
and similarity_score
.
Use the following code in the agent module or agent notebook.
import mlflow
import json
from mlflow.entities import Document
from typing import List, Dict, Any
from dataclasses import asdict
from databricks.vector_search.client import VectorSearchClient
class VectorSearchRetriever:
"""
Class using Databricks Vector Search to retrieve relevant documents.
"""
def __init__(self):
self.vector_search_client = VectorSearchClient(disable_notice=True)
# TODO: Replace this with the list of column names to return in the result when querying Vector Search
self.columns = ["chunk_id", "text_column", "doc_uri"]
self.vector_search_index = self.vector_search_client.get_index(
index_name="catalog.schema.chunked_docs_index"
)
mlflow.models.set_retriever_schema(
name="vector_search",
primary_key="chunk_id",
text_column="text_column",
doc_uri="doc_uri"
)
@mlflow.trace(span_type="RETRIEVER", name="vector_search")
def __call__(
self,
query: str,
filters: Dict[Any, Any] = None,
score_threshold = None
) -> List[Document]:
"""
Performs vector search to retrieve relevant chunks.
Args:
query: Search query.
filters: Optional filters to apply to the search. Filters must follow the Databricks Vector Search filter spec
score_threshold: Score threshold to use for the query.
Returns:
List of retrieved Documents.
"""
results = self.vector_search_index.similarity_search(
query_text=query,
columns=self.columns,
filters=filters,
num_results=5,
query_type="ann"
)
documents = self.convert_vector_search_to_documents(
results, score_threshold
)
return [asdict(doc) for doc in documents]
@mlflow.trace(span_type="PARSER")
def convert_vector_search_to_documents(
self, vs_results, score_threshold
) -> List[Document]:
docs = []
column_names = [column["name"] for column in vs_results.get("manifest", {}).get("columns", [])]
result_row_count = vs_results.get("result", {}).get("row_count", 0)
if result_row_count > 0:
for item in vs_results["result"]["data_array"]:
metadata = {}
score = item[-1]
if score >= score_threshold:
metadata["similarity_score"] = score
for i, field in enumerate(item[:-1]):
metadata[column_names[i]] = field
page_content = metadata.pop("text_column", None)
if page_content:
doc = Document(
page_content=page_content,
metadata=metadata
)
docs.append(doc)
return docs
To run the retriever, run the following Python code. You can optionally include Vector Search filters in the request to filter results.
retriever = VectorSearchRetriever()
query = "What is Databricks?"
filters={"text_column LIKE": "Databricks"},
results = retriever(query, filters=filters, score_threshold=0.1)
Set retriever schema
To ensure that retrievers are traced properly, call mlflow.models.set_retriever_schema when you define your agent in code. Use set_retriever_schema
to map the column names in the returned table to MLflow’s expected fields such as primary_key
, text_column
, and doc_uri
.
# Define the retriever's schema by providing your column names
mlflow.models.set_retriever_schema(
name="vector_search",
primary_key="chunk_id",
text_column="text_column",
doc_uri="doc_uri"
# other_columns=["column1", "column2"],
)
Note
The doc_uri
column is especially important when evaluating the retriever’s performance. doc_uri
is the main identifier for documents returned by the retriever, allowing you to compare them against ground truth evaluation sets. See Evaluation sets.
You can also specify additional columns in your retriever’s schema by providing a list of column names with the other_columns
field.
If you have multiple retrievers, you can define multiple schemas by using unique names for each retriever schema.
Trace the retriever
MLflow tracing adds observability by capturing detailed information about your agent’s execution. It provides a way to record the inputs, outputs, and metadata associated with each intermediate step of a request, enabling you to pinpoint the source of bugs and unexpected behaviors easily.
This example uses the @mlflow.trace decorator to create a trace for the retriever and parser. For other options for setting up trace methods, see MLflow Tracing for agents.
The decorator creates a span that starts when the function is invoked and ends when it returns. MLflow automatically records the function’s input and output and any exceptions raised from it.
Note
LangChain, LlamaIndex, and OpenAI library users can use MLflow autologging instead of manually defining traces with the decorator. See Use autologging to add traces to your agents.
...
@mlflow.trace(span_type="RETRIEVER", name="vector_search")
def __call__(self, query: str) -> List[Document]:
...
To ensure downstream applications such as Agent Evaluation and the AI Playground render the retriever trace correctly, make sure the decorator meets the following requirements:
- Use
span_type="RETRIEVER"
and ensure the function returnsList[Document]
object. See Retriever spans. - The trace name and the retriever_schema name must match to configure the trace correctly.
Filter Vector Search results
You can limit the search scope to a subset of data using a Vector Search filter.
The filters
parameter in VectorSearchRetriever
defines the filter conditions using the Databricks Vector Search filter specification.
filters = {"text_column LIKE": "Databricks"}
Inside the __call__
method, the filters dictionary is passed directly to the similarity_search
function:
results = self.vector_search_index.similarity_search(
query_text=query,
columns=self.columns,
filters=filters,
num_results=5,
query_type="ann"
)
After initial filtering, the score_threshold
parameter provides additional filtering by setting a minimum similarity score.
if score >= score_threshold:
metadata["similarity_score"] = score
The final result includes documents that meet the filters
and score_threshold
conditions.
Retriever example applications
See the genai-cookbook GitHub repository for AI agent examples that use retrievers:
- LangChain retriever example: Demo: Mosaic AI Agent Framework and Agent Evaluation
- PyFunc retriever example: Agent app sample code