Ajouter des traces à vos agents
Important
Cette fonctionnalité est disponible en préversion publique.
Cet article explique comment ajouter des traces à vos agents à l’aide des API Fluent et MLflowClient mises à disposition avec MLflow Tracing.
Remarque
Pour obtenir des exemples de référence API et de code pour le suivi MLflow, consultez la documentation sur MLflow.
Spécifications
- MLflow 2.13.1
Utiliser la journalisation pour ajouter des traces à vos agents
Si vous utilisez une bibliothèque GenAI prenant en charge le suivi (par exemple, LangChain, LlamaIndex ou OpenAI), vous pouvez activer la journalisation automatique MLflow pour l’intégration de la bibliothèque pour activer le suivi.
Par exemple, utilisez mlflow.langchain.autolog()
pour ajouter automatiquement des traces à votre agent LangChain.
Remarque
À partir de Databricks Runtime 15.4 LTS ML, le suivi MLflow est activé par défaut dans les notebooks. Pour désactiver le suivi, par exemple avec LangChain, vous pouvez exécuter mlflow.langchain.autolog(log_traces=False)
dans votre notebook.
mlflow.langchain.autolog()
MLflow prend en charge des bibliothèques supplémentaires pour la mise en file d’attente automatique de traces. Consultez la documentation MLflow Tracing pour obtenir la liste complète des bibliothèques intégrées.
Utiliser les API Fluent pour ajouter manuellement des traces à votre agent
Voici un exemple rapide qui utilise les API Fluentmlflow.trace
et mlflow.start_span
pour ajouter des traces au quickstart-agent
. Il est recommandé pour les modèles PyFunc.
import mlflow
from mlflow.deployments import get_deploy_client
class QAChain(mlflow.pyfunc.PythonModel):
def __init__(self):
self.client = get_deploy_client("databricks")
@mlflow.trace(name="quickstart-agent")
def predict(self, model_input, system_prompt, params):
messages = [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": model_input[0]["query"]
}
]
traced_predict = mlflow.trace(self.client.predict)
output = traced_predict(
endpoint=params["model_name"],
inputs={
"temperature": params["temperature"],
"max_tokens": params["max_tokens"],
"messages": messages,
},
)
with mlflow.start_span(name="_final_answer") as span:
# Initiate another span generation
span.set_inputs({"query": model_input[0]["query"]})
answer = output["choices"][0]["message"]["content"]
span.set_outputs({"generated_text": answer})
# Attributes computed at runtime can be set using the set_attributes() method.
span.set_attributes({
"model_name": params["model_name"],
"prompt_tokens": output["usage"]["prompt_tokens"],
"completion_tokens": output["usage"]["completion_tokens"],
"total_tokens": output["usage"]["total_tokens"]
})
return answer
Effectuer une inférence
Une fois que vous avez instrumenté votre code, vous pouvez exécuter votre fonction comme vous le feriez normalement. L’exemple suivant poursuit l’exemple avec la fonction predict()
dans la section précédente. Les traces s’affichent automatiquement lorsque vous exécutez la méthode d’appel. predict()
SYSTEM_PROMPT = """
You are an assistant for Databricks users. You are answering python, coding, SQL, data engineering, spark, data science, DW and platform, API or infrastructure administration question related to Databricks. If the question is not related to one of these topics, kindly decline to answer. If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible. Use the following pieces of context to answer the question at the end:
"""
model = QAChain()
prediction = model.predict(
[
{"query": "What is in MLflow 5.0"},
],
SYSTEM_PROMPT,
{
# Using Databricks Foundation Model for easier testing, feel free to replace it.
"model_name": "databricks-dbrx-instruct",
"temperature": 0.1,
"max_tokens": 1000,
}
)
API Fluent
Les API Fluent dans MLflow construisent automatiquement la hiérarchie de trace en fonction de l’emplacement et de l’exécution du code. Les sections suivantes décrivent les tâches prises en charge à l’aide des API MLflow Tracing Fluent.
Décorer votre fonction
Vous pouvez décorer votre fonction avec le décorateur @mlflow.trace
pour créer une portée pour l’étendue de la fonction décorée. La portée démarre lorsque la fonction est appelée et se termine lorsqu’elle est retournée. MLflow enregistre automatiquement l’entrée et la sortie de la fonction, ainsi que les exceptions levées à partir de la fonction. Par exemple, l’exécution du code suivant crée une portée portant le nom « my_function », en capturant les arguments d’entrée x et y, ainsi que la sortie de la fonction.
@mlflow.trace(name="agent", span_type="TYPE", attributes={"key": "value"})
def my_function(x, y):
return x + y
Utiliser le gestionnaire de contexte de suivi
Si vous souhaitez créer une portée pour un bloc de code arbitraire, pas seulement une fonction, vous pouvez utiliser mlflow.start_span()
comme gestionnaire de contexte qui encapsule le bloc de code. La portée commence lorsque le contexte est entré et se termine lorsque le contexte est arrêté. Les entrées et sorties de portée doivent être fournies manuellement via des méthodes setter de l’objet de portée qui sont générées à partir du gestionnaire de contexte.
with mlflow.start_span("my_span") as span:
span.set_inputs({"x": x, "y": y})
result = x + y
span.set_outputs(result)
span.set_attribute("key", "value")
Envelopper une fonction externe
La fonction mlflow.trace
peut être utilisée comme wrapper pour tracer une fonction de votre choix. Elle est utile lorsque vous souhaitez suivre les fonctions importées à partir de bibliothèques externes. Elle génère la même portée que celle que vous obtiendriez en décorant cette fonction.
from sklearn.metrics import accuracy_score
y_pred = [0, 2, 1, 3]
y_true = [0, 1, 2, 3]
traced_accuracy_score = mlflow.trace(accuracy_score)
traced_accuracy_score(y_true, y_pred)
API du client MLflow
MlflowClient
expose des API granulaires et thread safe pour démarrer et terminer les traces, gérer les portées et définir des champs de portée. Il fournit un contrôle total du cycle de vie et de la structure de trace. Ces API sont utiles lorsque les API Fluent ne sont pas suffisantes pour vos besoins, comme les applications multithreads et les rappels.
Les étapes suivantes permettent de créer une trace complète à l’aide du client MLflow.
Créez une instance de MLflowClient par
client = MlflowClient()
.Démarrez une trace à l’aide de la méthode
client.start_trace()
. Cette opération lance le contexte de trace et démarre une étendue racine absolue et retourne un objet d’étendue racine. Cette méthode doit être exécutée avant l’APIstart_span()
.- Définissez vos attributs, entrées et sorties pour la trace dans
client.start_trace()
.
Remarque
Il n’existe pas d’équivalent à la méthode
start_trace()
dans les API Fluent. Cela est dû au fait que les API Fluent initialisent automatiquement le contexte de trace et déterminent s’il s’agit de la portée racine en fonction de l’état managé.- Définissez vos attributs, entrées et sorties pour la trace dans
L’API start_trace() retourne une portée. Obtenez l’ID de requête, un identificateur unique de la trace également appelé
trace_id
, et l’ID de la portée retournée à l’aide despan.request_id
etspan.span_id
.Démarrez une portée enfant en utilisant
client.start_span(request_id, parent_id=span_id)
pour définir vos attributs, entrées et sorties pour la portée.- Avec cette méthode,
request_id
etparent_id
sont nécessaires pour associer la portée à la bonne position dans la hiérarchie des traces. Elle retourne un autre objet de portée.
- Avec cette méthode,
Terminez la portée enfant en appelant
client.end_span(request_id, span_id)
.Répétez 3 à 5 pour tous les enfants que vous souhaitez créer.
Une fois toutes les portées enfants terminées, appelez
client.end_trace(request_id)
pour fermer toute la trace et l’enregistrer.
from mlflow.client import MlflowClient
mlflow_client = MlflowClient()
root_span = mlflow_client.start_trace(
name="simple-rag-agent",
inputs={
"query": "Demo",
"model_name": "DBRX",
"temperature": 0,
"max_tokens": 200
}
)
request_id = root_span.request_id
# Retrieve documents that are similar to the query
similarity_search_input = dict(query_text="demo", num_results=3)
span_ss = mlflow_client.start_span(
"search",
# Specify request_id and parent_id to create the span at the right position in the trace
request_id=request_id,
parent_id=root_span.span_id,
inputs=similarity_search_input
)
retrieved = ["Test Result"]
# Span has to be ended explicitly
mlflow_client.end_span(request_id, span_id=span_ss.span_id, outputs=retrieved)
root_span.end_trace(request_id, outputs={"output": retrieved})