Skip to content

Add an SDK for distillation from conversation to goal/persona#20289

Merged
smoorjani merged 13 commits intomlflow:masterfrom
smoorjani:gwt-distillation
Feb 3, 2026
Merged

Add an SDK for distillation from conversation to goal/persona#20289
smoorjani merged 13 commits intomlflow:masterfrom
smoorjani:gwt-distillation

Conversation

@smoorjani
Copy link
Collaborator

@smoorjani smoorjani commented Jan 23, 2026

Related Issues/PRs

#xxx

What changes are proposed in this pull request?

This PR introduces

How is this PR tested?

  • Existing unit/integration tests
  • New unit/integration tests
  • Manual tests
import mlflow
from mlflow.genai.simulators import generate_seed_conversations_df

mlflow.set_experiment("distillation-test")

# Create two conversation sessions with traces
sessions = []
for i, (user_msg, assistant_msg) in enumerate([
    ("How do I track experiments?", "Use mlflow.start_run() and mlflow.log_metric()."),
    ("My model deployment is failing!", "Check your model signature matches the input format."),
]):
    session_id = f"session-{i}"
    with mlflow.start_span(name="turn") as span:
        mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id})
        span.set_inputs(user_msg)
        span.set_outputs(assistant_msg)

    traces = mlflow.search_traces(
        filter_string=f"metadata.`mlflow.trace.session` = '{session_id}'",
        return_type="list",
    )
    sessions.append(traces)

# Generate seed conversations
df = generate_test_cases(sessions)
print(df.to_string())

Results:

  goal: Learn how to track machine learning experiments using MLflow.
  persona: You are a technical user with some familiarity in machine learning...
  ────────────────────────────────────────
  goal: Debug a model deployment failure
  persona: You are a user with some technical knowledge who communicates concisely...

Does this PR require documentation update?

  • No. You can skip the rest of this section.
  • Yes. I've updated:
    • Examples
    • API references
    • Instructions

Release Notes

Is this a user-facing change?

  • No. You can skip the rest of this section.
  • Yes. Give a description of this change to be included in the release notes for MLflow users.

Add an SDK to distill conversations into goal/persona to use in the conversation simulator.

What component(s), interfaces, languages, and integrations does this PR affect?

Components

  • area/tracking: Tracking Service, tracking client APIs, autologging
  • area/models: MLmodel format, model serialization/deserialization, flavors
  • area/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registry
  • area/scoring: MLflow Model server, model deployment tools, Spark UDFs
  • area/evaluation: MLflow model evaluation features, evaluation metrics, and evaluation workflows
  • area/gateway: MLflow AI Gateway client APIs, server, and third-party integrations
  • area/prompts: MLflow prompt engineering features, prompt templates, and prompt management
  • area/tracing: MLflow Tracing features, tracing APIs, and LLM tracing functionality
  • area/projects: MLproject format, project running backends
  • area/uiux: Front-end, user experience, plotting, JavaScript, JavaScript dev server
  • area/build: Build and test infrastructure for MLflow
  • area/docs: MLflow documentation pages

How should the PR be classified in the release notes? Choose one:

  • rn/none - No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section
  • rn/breaking-change - The PR will be mentioned in the "Breaking Changes" section
  • rn/feature - A new user-facing feature worth mentioning in the release notes
  • rn/bug-fix - A user-facing bug fix worth mentioning in the release notes
  • rn/documentation - A user-facing documentation change worth mentioning in the release notes

Should this PR be included in the next patch release?

Yes should be selected for bug fixes, documentation updates, and other small changes. No should be selected for new features and larger changes. If you're unsure about the release classification of this PR, leave this unchecked to let the maintainers decide.

What is a minor/patch release?
  • Minor release: a release that increments the second part of the version number (e.g., 1.2.0 -> 1.3.0).
    Bug fixes, doc updates and new features usually go into minor releases.
  • Patch release: a release that increments the third part of the version number (e.g., 1.2.0 -> 1.2.1).
    Bug fixes and doc updates usually go into patch releases.
  • Yes (this PR will be cherry-picked and included in the next patch release)
  • No (this PR will be included in the next minor release)

@github-actions
Copy link
Contributor

🛠 DevTools 🛠

Install mlflow from this PR

# mlflow
pip install git+https://github.com/mlflow/mlflow.git@refs/pull/20289/merge
# mlflow-skinny
pip install git+https://github.com/mlflow/mlflow.git@refs/pull/20289/merge#subdirectory=libs/skinny

For Databricks, use the following command:

%sh curl -LsSf https://raw.githubusercontent.com/mlflow/mlflow/HEAD/dev/install-skinny.sh | sh -s pull/20289/merge

@github-actions github-actions bot added area/evaluation MLflow Evaluation rn/feature Mention under Features in Changelogs. labels Jan 23, 2026
@github-actions
Copy link
Contributor

github-actions bot commented Jan 23, 2026

Documentation preview for 8cb0e20 is available at:

More info
  • Ignore this comment if this PR does not change the documentation.
  • The preview is updated when a new commit is pushed to this PR.
  • This comment was created by this workflow run.
  • The documentation was built by this workflow run.

from mlflow.genai.simulators import ConversationSimulator

# Get existing sessions
sessions = mlflow.search_sessions(...) # clint: disable=unknown-mlflow-function
Copy link
Collaborator

@dbczumar dbczumar Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we planning to introduce search_sessions() before merging this PR? (Just want to make sure the example uses real functions; I think search_sessions() is the best function to use for fetching sessions)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! @B-Step62 is reviewing this PR: #20288 (review)

but feel free to take a look as well!

def generate_seed_conversations(
sessions: list[Session],
*,
model: str | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TomeHirata If I want to use a Gateway endpoint, what syntax should I use? I tried endpoints:/ and gateway:/, but they both fail as follows:

$ python distill.py

Provider List: https://docs.litellm.ai/docs/providers


Provider List: https://docs.litellm.ai/docs/providers

Traceback (most recent call last):
  File "/Users/corey.zumar/mlflowrepos/mlflow3/distill.py", line 14, in <module>
    df = generate_seed_conversations(
  File "/Users/corey.zumar/mlflowrepos/mlflow3/mlflow/genai/simulators/distillation.py", line 94, in generate_seed_conversations
    results = [_distill_goal_and_persona(session, model=model) for session in sessions]
  File "/Users/corey.zumar/mlflowrepos/mlflow3/mlflow/genai/simulators/distillation.py", line 94, in <listcomp>
    results = [_distill_goal_and_persona(session, model=model) for session in sessions]
  File "/Users/corey.zumar/mlflowrepos/mlflow3/mlflow/genai/simulators/distillation.py", line 36, in _distill_goal_and_persona
    response = invoke_model_without_tracing(
  File "/Users/corey.zumar/mlflowrepos/mlflow3/mlflow/genai/simulators/utils.py", line 90, in invoke_model_without_tracing
    response = litellm.completion(**kwargs)
  File "/Users/corey.zumar/miniconda3/envs/mlflow/lib/python3.10/site-packages/litellm/utils.py", line 1739, in wrapper
    raise e
  File "/Users/corey.zumar/miniconda3/envs/mlflow/lib/python3.10/site-packages/litellm/utils.py", line 1560, in wrapper
    result = original_function(*args, **kwargs)
  File "/Users/corey.zumar/miniconda3/envs/mlflow/lib/python3.10/site-packages/litellm/main.py", line 4205, in completion
    raise exception_type(
  File "/Users/corey.zumar/miniconda3/envs/mlflow/lib/python3.10/site-packages/litellm/main.py", line 1289, in completion
    model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
  File "/Users/corey.zumar/miniconda3/envs/mlflow/lib/python3.10/site-packages/litellm/litellm_core_utils/get_llm_provider_logic.py", line 476, in get_llm_provider
    raise e
  File "/Users/corey.zumar/miniconda3/envs/mlflow/lib/python3.10/site-packages/litellm/litellm_core_utils/get_llm_provider_logic.py", line 453, in get_llm_provider
    raise litellm.exceptions.BadRequestError(  # type: ignore
litellm.exceptions.BadRequestError: litellm.BadRequestError: LLM Provider NOT provided. Pass in the LLM provider you are trying to call. You passed model=gateway/corey-gemini-flash
 Pass model as E.g. For 'Huggingface' inference endpoints pass in `completion(model='huggingface/starcoder',..)` Learn more: https://docs.litellm.ai/docs/providers

Can we find a way to make AI Gateway endpoints compatible with all of our OSS APIs that have a model= parameter?

Copy link
Collaborator

@TomeHirata TomeHirata Jan 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have the support for gateway:/ in judge interfaces using litellm adapter. Can we use litellm adapter or add this handling logic in this API too?

if provider == "gateway":
# MLFLOW_GATEWAY_URI takes precedence over tracking URI for gateway routing.
# This is needed for async job workers: the job infrastructure passes the HTTP
# tracking URI (e.g., http://127.0.0.1:5000) to workers, but _get_tracking_store()
# overwrites MLFLOW_TRACKING_URI with the backend store URI (e.g., sqlite://).
# Job workers set MLFLOW_GATEWAY_URI to preserve the HTTP URI for gateway calls.
tracking_uri = MLFLOW_GATEWAY_URI.get() or get_tracking_uri()
# Validate that tracking URI is a valid HTTP(S) URL for gateway
if not is_http_uri(tracking_uri):
raise MlflowException(
f"Gateway provider requires an HTTP(S) tracking URI, but got: '{tracking_uri}'. "
"The gateway provider routes requests through the MLflow tracking server. "
"Please set MLFLOW_TRACKING_URI to a valid HTTP(S) URL "
"(e.g., 'http://localhost:5000' or 'https://your-mlflow-server.com')."
)
api_base = append_to_uri_path(tracking_uri, "gateway/mlflow/v1/")
# Use openai/ prefix for LiteLLM to use OpenAI-compatible format.
# LiteLLM strips the prefix, so gateway receives model_name as the endpoint.
model = f"openai/{model_name}"
# LiteLLM requires api_key to be set when using custom api_base, otherwise it
# raises AuthenticationError looking for OPENAI_API_KEY env var. Gateway handles
# auth in the server layer, so we pass a dummy value to satisfy LiteLLM.
api_key = "mlflow-gateway-auth"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do in a follow-up PR!

}}"""
# NB: We include "rationale" to invoke chain-of-thought reasoning for better results.

DISTILL_GOAL_AND_PERSONA_PROMPT = """Analyze the following conversation and extract the user's \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems common for a user to try to achieve multiple goals in the same chat (e.g. I started by asking about the stock market and then reused the same chat window to ask about a loosely related or completely different topic).

I noticed this quite a bit when distilling logs from my own agent.

Can we try to support goal splitting? We should still bias the agent towards continuity (e.g. only split conversations if there's a very obvious change in topic / goal)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Glad you called this out! Do we have customer evidence of switching goals in chats? I definitely sympathize with this for personal assistants, but I'm not sure I see this being as common in, for instance, a food delivery company's agent where the user wants to know where their food is.

Biggest concern with goal splitting is the accuracy of it. I think it's relatively easy to introduce (e.g., just make the output type of the LLM call a list of test cases), but I'd like to test this beforehand. Let's leave as a follow-up pending customer evidence?

test_cases = generate_seed_conversations(sessions)

# Use the generated test cases with ConversationSimulator
simulator = ConversationSimulator(test_cases=test_cases)
Copy link
Collaborator

@dbczumar dbczumar Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@smoorjani I was testing this with my voice assistant. This assistant uses a "thread" concept / abstraction internally to keep track of message histories. For tracing purposes, I associate each thread with a session ID.

During simulation, I don't know whether to create a new session ID in my thread or reuse an existing session, since this information doesn't seem to be provided by the simulator when it calls my predict function. As a result, I naively end up with a new session for each turn. The following screenshot should all be one session:

Screenshot 2026-01-24 at 9 11 39 AM

Can we solve this problem? One idea is to expose the session ID in the predict function so that the caller can pass it through to their trace metadata (e.g. with my voice assistant, I could pass it through to my thread abstraction, which then propagates it into MLflow trace metadata). We should also have docs around this.



# Define jarvis agent's predict function
def jarvis_agent(input: list[dict], **kwargs) -> dict:
    """
    Jarvis agent that uses process_query to handle conversations.

    Args:
        input: List of conversation messages in OpenAI format

    Returns:
        Dict with the assistant's response
    """
    # Create a thread for this conversation
    # (I can't really do anything else, since I'm not sure when a new conversation
    # should actually start / end <- THIS IS THE PROBLEM)
    thread = create_thread(include_briefing=True)

    # Get the last user message
    last_message = input[-1]
    query = "Hey Jarvis, " + last_message["content"]

    # Process the query using jarvis
    response = process_query(query=query, thread=thread)

    # Return in the format expected by the simulator
    return {
        "choices": [{
            "message": {
                "role": "assistant",
                "content": response.text or "I apologize, but I couldn't process that request."
            }
        }]
    }

test_cases = [{'goal': 'Provide the current major stock market index levels and their percentage change, using the most recent trading session when the market is closed.', 'persona': 'You are a frustrated, impatient user seeking quick market performance information; you communicate in short, emphatic all-caps demands, have basic financial context (e.g., knowing markets are closed on weekends), and focus on percent moves rather than raw index values.'}]


custom_scorer = mlflow.genai.make_judge(
    name="correct_terminology",
    instructions="Read the {{ conversation }} and evaluate whether the assistant used correct financial terminology in its responses.",
    feedback_value_type=bool,
    model="openai:/gpt-5-mini",
)

# Create simulator
simulator = ConversationSimulator(
    test_cases=test_cases,
    max_turns=8,  # Maximum conversation length
    user_model="openai:/gpt-5-mini",
)

# Evaluate - the simulator generates conversations and evaluates them
results = mlflow.genai.evaluate(
    data=simulator,
    predict_fn=jarvis_agent,
    scorers=[
        custom_scorer,
        ConversationCompleteness(),
        UserFrustration(),
    ],
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, we had originally marked stateful agents out of scope, but this is a super neat way to fix that. Let me file a new PR!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in this PR: #20438


Based on this conversation, identify:

1. **Goal**: What is the user trying to accomplish? Describe their objective in one clear \
Copy link
Collaborator

@dbczumar dbczumar Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this on a conversation with my voice assistant, where the assistant hallucinated that the stock market is currently open on a Saturday in the first conversation turn. In the second conversation turn, I then corrected the agent by informing it that the market is not currently open (i.e. as a user, I caught the hallucination and reported it).

Screenshot 2026-01-24 at 9 43 24 AM

The resulting distilled conversation's goal field says that the user wants to use the most recent trading session when the market is closed. This is reasonable, but it results in the user explicitly asking to use the most recent trading session in the simulation, which means that the original issue isn't reproduced.

[{'goal': 'Provide the current major stock market index levels and their percentage change, using the most recent trading session when the market is closed.', 'persona': 'You are a frustrated, impatient user seeking quick market performance information; you communicate in short, emphatic all-caps demands, have basic financial context (e.g., knowing markets are closed on weekends), and focus on percent moves rather than raw index values.'}]
Screenshot 2026-01-24 at 9 27 28 AM

I think this pattern of augmenting the user's initial message(s) with additional information that proactively corrects the agent's mistakes may hide a lot of agent bugs.

I then tried modifying the goal myself as follows:

test_cases = [{'goal': 'Provide the current major stock market index levels. If markets are closed, expect the agent to provide information for the most recent trading day without you needing to ask (DO *NOT* ASK EXPLICITLY)', 'persona': 'You are a frustrated, impatient user seeking quick market performance information; you communicate in short, emphatic all-caps demands, have basic financial context (e.g., knowing markets are closed on weekends), and focus on percent moves rather than raw index values.'}]
Screenshot 2026-01-24 at 9 44 22 AM

This reproduces the issue, but it doesn't really seem 100% appropriate to put this in goal. It's almost like a set of conversational guidelines (expectations). Since expectations are a part of the dataset format, should we distill some expectations and then, during simulation, ensure that the user help the assistant "cheat"?

For example, from the original conversation, the following might be reasonable expectations:

{
    "expectations": [
        "If markets are closed, the agent should indicate this and provide information for the most recent trading day"
    ]
}

Distillation of expectations may be difficult (and we can / should be conservative here initially), but ensuring that the conversation simulator doesn't accidentally help the agent meet a set of developer-defined expectations seems important.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm this is quite interesting - I think it's valid that this should not go in the goal. But I'm not sure we should keep it in "expectations" as I see that as a field for what is expected of the agent's interactions with the user, not of the simulator. Maybe another field like guidelines for the conversation! But I would prefer to wait on more feedback like this before pre-emptively adding the field. WDYT?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline - we will move forward with this approach.

from mlflow.genai.simulators import ConversationSimulator

# Get existing sessions
sessions = mlflow.search_sessions(...) # clint: disable=unknown-mlflow-function
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove the clint comment after that PR merges?


@contextmanager
def _delete_trace_if_created():
"""Delete any trace created within this context to avoid polluting user traces."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've seen this pattern elsewhere, so this comment doesn't block this PR. But is there a way we can temporarily disable trace logging, rather than logging traces then deleting them? With the current approach, the user may see these traces, and additionally trace logging + deletion slows down simulation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good catch, this should have been moved from simulation (I don't see deleted lines). I had checked this with Yuki as well since that was my original approach, but the reason we cannot disable trace logging is because the method to do it is not thread-safe; either we can parallelize these operations or we can disable thread logging.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah got it!

Copy link
Collaborator

@AveshCSingh AveshCSingh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Conversation distillation is going to make it way easier to fix multi-turn issues

Signed-off-by: Samraj Moorjani <samraj.moorjani@databricks.com>
.
Signed-off-by: Samraj Moorjani <samraj.moorjani@databricks.com>
Signed-off-by: Samraj Moorjani <samraj.moorjani@databricks.com>
Signed-off-by: Samraj Moorjani <samraj.moorjani@databricks.com>
Signed-off-by: Samraj Moorjani <samraj.moorjani@databricks.com>
Signed-off-by: Samraj Moorjani <samraj.moorjani@databricks.com>
Signed-off-by: Samraj Moorjani <samraj.moorjani@databricks.com>
Signed-off-by: Samraj Moorjani <samraj.moorjani@databricks.com>
Signed-off-by: Samraj Moorjani <samraj.moorjani@databricks.com>
Signed-off-by: Samraj Moorjani <samraj.moorjani@databricks.com>
Signed-off-by: Samraj Moorjani <samraj.moorjani@databricks.com>
smoorjani and others added 2 commits February 2, 2026 16:31
Co-Authored-By: Claude <noreply@anthropic.com>
Signed-off-by: Samraj Moorjani <samraj.moorjani@databricks.com>
@smoorjani
Copy link
Collaborator Author

Merging in for the release candidate - will follow-up with PRs for simulation guidelines:

#20449

@smoorjani smoorjani enabled auto-merge February 3, 2026 01:33
@smoorjani smoorjani added this pull request to the merge queue Feb 3, 2026
Merged via the queue into mlflow:master with commit 037ee1f Feb 3, 2026
46 checks passed
@smoorjani smoorjani deleted the gwt-distillation branch February 3, 2026 01:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area/evaluation MLflow Evaluation rn/feature Mention under Features in Changelogs.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants