Add an SDK for distillation from conversation to goal/persona#20289
Add an SDK for distillation from conversation to goal/persona#20289smoorjani merged 13 commits intomlflow:masterfrom
Conversation
🛠 DevTools 🛠
Install mlflow from this PRFor Databricks, use the following command: |
|
Documentation preview for 8cb0e20 is available at: More info
|
029326d to
23f5016
Compare
23f5016 to
824199b
Compare
| from mlflow.genai.simulators import ConversationSimulator | ||
|
|
||
| # Get existing sessions | ||
| sessions = mlflow.search_sessions(...) # clint: disable=unknown-mlflow-function |
There was a problem hiding this comment.
Are we planning to introduce search_sessions() before merging this PR? (Just want to make sure the example uses real functions; I think search_sessions() is the best function to use for fetching sessions)
There was a problem hiding this comment.
yes! @B-Step62 is reviewing this PR: #20288 (review)
but feel free to take a look as well!
| def generate_seed_conversations( | ||
| sessions: list[Session], | ||
| *, | ||
| model: str | None = None, |
There was a problem hiding this comment.
@TomeHirata If I want to use a Gateway endpoint, what syntax should I use? I tried endpoints:/ and gateway:/, but they both fail as follows:
$ python distill.py
Provider List: https://docs.litellm.ai/docs/providers
Provider List: https://docs.litellm.ai/docs/providers
Traceback (most recent call last):
File "/Users/corey.zumar/mlflowrepos/mlflow3/distill.py", line 14, in <module>
df = generate_seed_conversations(
File "/Users/corey.zumar/mlflowrepos/mlflow3/mlflow/genai/simulators/distillation.py", line 94, in generate_seed_conversations
results = [_distill_goal_and_persona(session, model=model) for session in sessions]
File "/Users/corey.zumar/mlflowrepos/mlflow3/mlflow/genai/simulators/distillation.py", line 94, in <listcomp>
results = [_distill_goal_and_persona(session, model=model) for session in sessions]
File "/Users/corey.zumar/mlflowrepos/mlflow3/mlflow/genai/simulators/distillation.py", line 36, in _distill_goal_and_persona
response = invoke_model_without_tracing(
File "/Users/corey.zumar/mlflowrepos/mlflow3/mlflow/genai/simulators/utils.py", line 90, in invoke_model_without_tracing
response = litellm.completion(**kwargs)
File "/Users/corey.zumar/miniconda3/envs/mlflow/lib/python3.10/site-packages/litellm/utils.py", line 1739, in wrapper
raise e
File "/Users/corey.zumar/miniconda3/envs/mlflow/lib/python3.10/site-packages/litellm/utils.py", line 1560, in wrapper
result = original_function(*args, **kwargs)
File "/Users/corey.zumar/miniconda3/envs/mlflow/lib/python3.10/site-packages/litellm/main.py", line 4205, in completion
raise exception_type(
File "/Users/corey.zumar/miniconda3/envs/mlflow/lib/python3.10/site-packages/litellm/main.py", line 1289, in completion
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
File "/Users/corey.zumar/miniconda3/envs/mlflow/lib/python3.10/site-packages/litellm/litellm_core_utils/get_llm_provider_logic.py", line 476, in get_llm_provider
raise e
File "/Users/corey.zumar/miniconda3/envs/mlflow/lib/python3.10/site-packages/litellm/litellm_core_utils/get_llm_provider_logic.py", line 453, in get_llm_provider
raise litellm.exceptions.BadRequestError( # type: ignore
litellm.exceptions.BadRequestError: litellm.BadRequestError: LLM Provider NOT provided. Pass in the LLM provider you are trying to call. You passed model=gateway/corey-gemini-flash
Pass model as E.g. For 'Huggingface' inference endpoints pass in `completion(model='huggingface/starcoder',..)` Learn more: https://docs.litellm.ai/docs/providers
Can we find a way to make AI Gateway endpoints compatible with all of our OSS APIs that have a model= parameter?
There was a problem hiding this comment.
We have the support for gateway:/ in judge interfaces using litellm adapter. Can we use litellm adapter or add this handling logic in this API too?
mlflow/mlflow/genai/judges/adapters/litellm_adapter.py
Lines 237 to 262 in 5824de3
There was a problem hiding this comment.
Will do in a follow-up PR!
mlflow/genai/simulators/prompts.py
Outdated
| }}""" | ||
| # NB: We include "rationale" to invoke chain-of-thought reasoning for better results. | ||
|
|
||
| DISTILL_GOAL_AND_PERSONA_PROMPT = """Analyze the following conversation and extract the user's \ |
There was a problem hiding this comment.
It seems common for a user to try to achieve multiple goals in the same chat (e.g. I started by asking about the stock market and then reused the same chat window to ask about a loosely related or completely different topic).
I noticed this quite a bit when distilling logs from my own agent.
Can we try to support goal splitting? We should still bias the agent towards continuity (e.g. only split conversations if there's a very obvious change in topic / goal)
There was a problem hiding this comment.
Glad you called this out! Do we have customer evidence of switching goals in chats? I definitely sympathize with this for personal assistants, but I'm not sure I see this being as common in, for instance, a food delivery company's agent where the user wants to know where their food is.
Biggest concern with goal splitting is the accuracy of it. I think it's relatively easy to introduce (e.g., just make the output type of the LLM call a list of test cases), but I'd like to test this beforehand. Let's leave as a follow-up pending customer evidence?
| test_cases = generate_seed_conversations(sessions) | ||
|
|
||
| # Use the generated test cases with ConversationSimulator | ||
| simulator = ConversationSimulator(test_cases=test_cases) |
There was a problem hiding this comment.
@smoorjani I was testing this with my voice assistant. This assistant uses a "thread" concept / abstraction internally to keep track of message histories. For tracing purposes, I associate each thread with a session ID.
During simulation, I don't know whether to create a new session ID in my thread or reuse an existing session, since this information doesn't seem to be provided by the simulator when it calls my predict function. As a result, I naively end up with a new session for each turn. The following screenshot should all be one session:
Can we solve this problem? One idea is to expose the session ID in the predict function so that the caller can pass it through to their trace metadata (e.g. with my voice assistant, I could pass it through to my thread abstraction, which then propagates it into MLflow trace metadata). We should also have docs around this.
# Define jarvis agent's predict function
def jarvis_agent(input: list[dict], **kwargs) -> dict:
"""
Jarvis agent that uses process_query to handle conversations.
Args:
input: List of conversation messages in OpenAI format
Returns:
Dict with the assistant's response
"""
# Create a thread for this conversation
# (I can't really do anything else, since I'm not sure when a new conversation
# should actually start / end <- THIS IS THE PROBLEM)
thread = create_thread(include_briefing=True)
# Get the last user message
last_message = input[-1]
query = "Hey Jarvis, " + last_message["content"]
# Process the query using jarvis
response = process_query(query=query, thread=thread)
# Return in the format expected by the simulator
return {
"choices": [{
"message": {
"role": "assistant",
"content": response.text or "I apologize, but I couldn't process that request."
}
}]
}
test_cases = [{'goal': 'Provide the current major stock market index levels and their percentage change, using the most recent trading session when the market is closed.', 'persona': 'You are a frustrated, impatient user seeking quick market performance information; you communicate in short, emphatic all-caps demands, have basic financial context (e.g., knowing markets are closed on weekends), and focus on percent moves rather than raw index values.'}]
custom_scorer = mlflow.genai.make_judge(
name="correct_terminology",
instructions="Read the {{ conversation }} and evaluate whether the assistant used correct financial terminology in its responses.",
feedback_value_type=bool,
model="openai:/gpt-5-mini",
)
# Create simulator
simulator = ConversationSimulator(
test_cases=test_cases,
max_turns=8, # Maximum conversation length
user_model="openai:/gpt-5-mini",
)
# Evaluate - the simulator generates conversations and evaluates them
results = mlflow.genai.evaluate(
data=simulator,
predict_fn=jarvis_agent,
scorers=[
custom_scorer,
ConversationCompleteness(),
UserFrustration(),
],
)
There was a problem hiding this comment.
Ah ok, we had originally marked stateful agents out of scope, but this is a super neat way to fix that. Let me file a new PR!
mlflow/genai/simulators/prompts.py
Outdated
|
|
||
| Based on this conversation, identify: | ||
|
|
||
| 1. **Goal**: What is the user trying to accomplish? Describe their objective in one clear \ |
There was a problem hiding this comment.
I tried this on a conversation with my voice assistant, where the assistant hallucinated that the stock market is currently open on a Saturday in the first conversation turn. In the second conversation turn, I then corrected the agent by informing it that the market is not currently open (i.e. as a user, I caught the hallucination and reported it).
The resulting distilled conversation's goal field says that the user wants to use the most recent trading session when the market is closed. This is reasonable, but it results in the user explicitly asking to use the most recent trading session in the simulation, which means that the original issue isn't reproduced.
[{'goal': 'Provide the current major stock market index levels and their percentage change, using the most recent trading session when the market is closed.', 'persona': 'You are a frustrated, impatient user seeking quick market performance information; you communicate in short, emphatic all-caps demands, have basic financial context (e.g., knowing markets are closed on weekends), and focus on percent moves rather than raw index values.'}]
I think this pattern of augmenting the user's initial message(s) with additional information that proactively corrects the agent's mistakes may hide a lot of agent bugs.
I then tried modifying the goal myself as follows:
test_cases = [{'goal': 'Provide the current major stock market index levels. If markets are closed, expect the agent to provide information for the most recent trading day without you needing to ask (DO *NOT* ASK EXPLICITLY)', 'persona': 'You are a frustrated, impatient user seeking quick market performance information; you communicate in short, emphatic all-caps demands, have basic financial context (e.g., knowing markets are closed on weekends), and focus on percent moves rather than raw index values.'}]
This reproduces the issue, but it doesn't really seem 100% appropriate to put this in goal. It's almost like a set of conversational guidelines (expectations). Since expectations are a part of the dataset format, should we distill some expectations and then, during simulation, ensure that the user help the assistant "cheat"?
For example, from the original conversation, the following might be reasonable expectations:
{
"expectations": [
"If markets are closed, the agent should indicate this and provide information for the most recent trading day"
]
}
Distillation of expectations may be difficult (and we can / should be conservative here initially), but ensuring that the conversation simulator doesn't accidentally help the agent meet a set of developer-defined expectations seems important.
There was a problem hiding this comment.
Hmm this is quite interesting - I think it's valid that this should not go in the goal. But I'm not sure we should keep it in "expectations" as I see that as a field for what is expected of the agent's interactions with the user, not of the simulator. Maybe another field like guidelines for the conversation! But I would prefer to wait on more feedback like this before pre-emptively adding the field. WDYT?
There was a problem hiding this comment.
Discussed offline - we will move forward with this approach.
915b67d to
5c999e2
Compare
| from mlflow.genai.simulators import ConversationSimulator | ||
|
|
||
| # Get existing sessions | ||
| sessions = mlflow.search_sessions(...) # clint: disable=unknown-mlflow-function |
There was a problem hiding this comment.
Could you remove the clint comment after that PR merges?
mlflow/genai/simulators/utils.py
Outdated
|
|
||
| @contextmanager | ||
| def _delete_trace_if_created(): | ||
| """Delete any trace created within this context to avoid polluting user traces.""" |
There was a problem hiding this comment.
I've seen this pattern elsewhere, so this comment doesn't block this PR. But is there a way we can temporarily disable trace logging, rather than logging traces then deleting them? With the current approach, the user may see these traces, and additionally trace logging + deletion slows down simulation.
There was a problem hiding this comment.
Ah good catch, this should have been moved from simulation (I don't see deleted lines). I had checked this with Yuki as well since that was my original approach, but the reason we cannot disable trace logging is because the method to do it is not thread-safe; either we can parallelize these operations or we can disable thread logging.
AveshCSingh
left a comment
There was a problem hiding this comment.
LGTM! Conversation distillation is going to make it way easier to fix multi-turn issues
Signed-off-by: Samraj Moorjani <samraj.moorjani@databricks.com>
Signed-off-by: Samraj Moorjani <samraj.moorjani@databricks.com>
Signed-off-by: Samraj Moorjani <samraj.moorjani@databricks.com>
Signed-off-by: Samraj Moorjani <samraj.moorjani@databricks.com>
Signed-off-by: Samraj Moorjani <samraj.moorjani@databricks.com>
Signed-off-by: Samraj Moorjani <samraj.moorjani@databricks.com>
Signed-off-by: Samraj Moorjani <samraj.moorjani@databricks.com>
Signed-off-by: Samraj Moorjani <samraj.moorjani@databricks.com>
Signed-off-by: Samraj Moorjani <samraj.moorjani@databricks.com>
c261aac to
4bb0053
Compare
Co-Authored-By: Claude <noreply@anthropic.com> Signed-off-by: Samraj Moorjani <samraj.moorjani@databricks.com>
|
Merging in for the release candidate - will follow-up with PRs for simulation guidelines: |
Related Issues/PRs
#xxxWhat changes are proposed in this pull request?
This PR introduces
How is this PR tested?
Results:
Does this PR require documentation update?
Release Notes
Is this a user-facing change?
Add an SDK to distill conversations into goal/persona to use in the conversation simulator.
What component(s), interfaces, languages, and integrations does this PR affect?
Components
area/tracking: Tracking Service, tracking client APIs, autologgingarea/models: MLmodel format, model serialization/deserialization, flavorsarea/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registryarea/scoring: MLflow Model server, model deployment tools, Spark UDFsarea/evaluation: MLflow model evaluation features, evaluation metrics, and evaluation workflowsarea/gateway: MLflow AI Gateway client APIs, server, and third-party integrationsarea/prompts: MLflow prompt engineering features, prompt templates, and prompt managementarea/tracing: MLflow Tracing features, tracing APIs, and LLM tracing functionalityarea/projects: MLproject format, project running backendsarea/uiux: Front-end, user experience, plotting, JavaScript, JavaScript dev serverarea/build: Build and test infrastructure for MLflowarea/docs: MLflow documentation pagesHow should the PR be classified in the release notes? Choose one:
rn/none- No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" sectionrn/breaking-change- The PR will be mentioned in the "Breaking Changes" sectionrn/feature- A new user-facing feature worth mentioning in the release notesrn/bug-fix- A user-facing bug fix worth mentioning in the release notesrn/documentation- A user-facing documentation change worth mentioning in the release notesShould this PR be included in the next patch release?
Yesshould be selected for bug fixes, documentation updates, and other small changes.Noshould be selected for new features and larger changes. If you're unsure about the release classification of this PR, leave this unchecked to let the maintainers decide.What is a minor/patch release?
Bug fixes, doc updates and new features usually go into minor releases.
Bug fixes and doc updates usually go into patch releases.