Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion mlflow/openai/_agent_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any

import agents.tracing as oai
from agents import add_trace_processor
from agents import add_trace_processor, set_trace_processors
from agents.tracing.setup import GLOBAL_TRACE_PROVIDER

from mlflow.entities.span import SpanType
Expand Down Expand Up @@ -55,6 +55,14 @@ class OpenAISpanType:
}


def clear_trace_processors():
"""
Clear other trace processors to avoid warnings.
https://github.com/openai/openai-agents-python/issues/1387#issuecomment-3165660183
"""
set_trace_processors([])


def add_mlflow_trace_processor():
processors = GLOBAL_TRACE_PROVIDER._multi_processor._processors

Expand Down
8 changes: 8 additions & 0 deletions mlflow/openai/autolog.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def autolog(
disable_for_unsupported_versions=False,
silent=False,
log_traces=True,
disable_openai_agent_tracer=True,
):
"""
Enables (or disables) and configures autologging from OpenAI to MLflow.
Expand All @@ -58,6 +59,8 @@ def autolog(
autologging.
log_traces: If ``True``, traces are logged for OpenAI models. If ``False``, no traces are
collected during inference. Default to ``True``.
disable_openai_agent_tracer: If ``True``, disable the OpenAI Agent SDK tracer. If ``False``,
enable the OpenAI Agent SDK tracer. Default to ``True``.
Comment on lines +62 to +63
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
disable_openai_agent_tracer: If ``True``, disable the OpenAI Agent SDK tracer. If ``False``,
enable the OpenAI Agent SDK tracer. Default to ``True``.
disable_openai_agent_tracer: If ``True``, disable the OpenAI's native tracer that will send traces
to the OpenAI platform. Default to ``True``. Set this to ``False`` to export traces to both MLflow
and the OpenAI platform.

" OpenAI Agent SDK tracer" might be mistaken to our tracer if users don't know about the native tracer functionality.

"""
if Version(importlib.metadata.version("openai")).major < 1:
raise MlflowException("OpenAI autologging is only supported for openai >= 1.0.0")
Expand Down Expand Up @@ -86,9 +89,14 @@ def autolog(

from mlflow.openai._agent_tracer import (
add_mlflow_trace_processor,
clear_trace_processors,
remove_mlflow_trace_processor,
)

# if disable_openai_agent_tracer:
# print("Im clearing trace processors")
# clear_trace_processors()

if log_traces and not disable:
add_mlflow_trace_processor()
else:
Expand Down
191 changes: 182 additions & 9 deletions tests/openai/test_openai_agent_autolog.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
pytest.skip("OpenAI SDK is not installed. Skipping tests.", allow_module_level=True)

from agents import Agent, Runner, function_tool, set_default_openai_client, trace
from agents.tracing import set_trace_processors
from openai.types.responses.function_tool import FunctionTool
from openai.types.responses.response import Response
from openai.types.responses.response_output_item import (
Expand Down Expand Up @@ -40,12 +39,6 @@
set_default_openai_client(async_client)


@pytest.fixture(autouse=True)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleting since this will be the case by default now -- tests should continue passing

def disable_default_tracing():
# Disable default OpenAI tracer
set_trace_processors([])


@pytest.mark.asyncio
async def test_autolog_agent():
mlflow.openai.autolog()
Expand Down Expand Up @@ -126,6 +119,8 @@

assert response.final_output == "¡Hola! Estoy bien, gracias. ¿Y tú, cómo estás?"
traces = get_traces()
print([t.to_dict() for t in traces])
assert False
assert len(traces) == 1
trace = traces[0]
assert trace.info.status == "OK"
Expand All @@ -139,7 +134,7 @@
assert spans[0].outputs == response.final_output
assert spans[1].name == "Triage Agent"
assert spans[1].parent_id == spans[0].span_id
assert spans[2].name == "Response_1"
assert spans[2].name == "Response"
assert spans[2].parent_id == spans[1].span_id
assert spans[2].inputs == [{"role": "user", "content": "Hola. ¿Como estás?"}]
assert spans[2].outputs == [
Expand All @@ -158,7 +153,7 @@
assert spans[3].parent_id == spans[1].span_id
assert spans[4].name == "Spanish Agent"
assert spans[4].parent_id == spans[0].span_id
assert spans[5].name == "Response_2"
assert spans[5].name == "Response"
assert spans[5].parent_id == spans[4].span_id

# Validate chat attributes
Expand Down Expand Up @@ -389,3 +384,181 @@
await Runner.run(agent, messages)

assert get_traces() == []


@pytest.mark.asyncio
async def test_autolog_agent_with_enabled_openai_agent_tracer():
import logging

Check failure on line 391 in tests/openai/test_openai_agent_autolog.py

View workflow job for this annotation

GitHub Actions / lint

MLF0016: Builtin modules must be imported at the top level. See dev/clint/README.md for instructions on ignoring this rule (lazy-builtin-import).

# Set up logging capture for the openai.agents logger
class LogCapture(logging.Handler):
def __init__(self):
super().__init__()
self.records = []

def emit(self, record):
self.records.append(record)

log_capture = LogCapture()
openai_agents_logger = logging.getLogger("openai.agents")
openai_agents_logger.addHandler(log_capture)
openai_agents_logger.setLevel(logging.DEBUG)

mlflow.openai.autolog(disable_openai_agent_tracer=False)

# NB: We have to mock the OpenAI SDK responses to make agent works
DUMMY_RESPONSES = [
Response(
id="123",
created_at=12345678.0,
error=None,
model="gpt-4o-mini",
object="response",
instructions="Handoff to the appropriate agent based on the language of the request.",
output=[
ResponseFunctionToolCall(
id="123",
arguments="{}",
call_id="123",
name="transfer_to_spanish_agent",
type="function_call",
status="completed",
)
],
tools=[
FunctionTool(
name="transfer_to_spanish_agent",
parameters={"type": "object", "properties": {}, "required": []},
type="function",
description="Handoff to the Spanish_Agent agent to handle the request.",
strict=False,
),
],
tool_choice="auto",
temperature=1,
parallel_tool_calls=True,
),
Response(
id="123",
created_at=12345678.0,
error=None,
model="gpt-4o-mini",
object="response",
instructions="You only speak Spanish",
output=[
ResponseOutputMessage(
id="123",
content=[
ResponseOutputText(
annotations=[],
text="¡Hola! Estoy bien, gracias. ¿Y tú, cómo estás?",
type="output_text",
)
],
role="assistant",
status="completed",
type="message",
)
],
tools=[],
tool_choice="auto",
temperature=1,
parallel_tool_calls=True,
),
]

set_dummy_client(DUMMY_RESPONSES)

english_agent = Agent(name="English Agent", instructions="You only speak English")
spanish_agent = Agent(name="Spanish Agent", instructions="You only speak Spanish")
triage_agent = Agent(
name="Triage Agent",
instructions="Handoff to the appropriate agent based on the language of the request.",
handoffs=[spanish_agent, english_agent],
)

messages = [{"role": "user", "content": "Hola. ¿Como estás?"}]
response = await Runner.run(starting_agent=triage_agent, input=messages)

assert response.final_output == "¡Hola! Estoy bien, gracias. ¿Y tú, cómo estás?"
traces = get_traces()
print([t.to_dict() for t in traces])
# assert False
assert len(traces) == 1
trace = traces[0]
assert trace.info.status == "OK"
assert json.loads(trace.info.request_preview) == messages
assert json.loads(trace.info.response_preview) == response.final_output
spans = trace.data.spans
assert len(spans) == 6 # 1 root + 2 agent + 1 handoff + 2 response
assert spans[0].name == "AgentRunner.run"
assert spans[0].span_type == SpanType.AGENT
assert spans[0].inputs == messages
assert spans[0].outputs == response.final_output
assert spans[1].name == "Triage Agent"
assert spans[1].parent_id == spans[0].span_id
assert spans[2].name == "Response"
assert spans[2].parent_id == spans[1].span_id
assert spans[2].inputs == [{"role": "user", "content": "Hola. ¿Como estás?"}]
assert spans[2].outputs == [
{
"id": "123",
"arguments": "{}",
"call_id": "123",
"name": "transfer_to_spanish_agent",
"type": "function_call",
"status": "completed",
}
]
assert spans[2].attributes["temperature"] == 1
assert spans[3].name == "Handoff"
assert spans[3].span_type == SpanType.CHAIN
assert spans[3].parent_id == spans[1].span_id
assert spans[4].name == "Spanish Agent"
assert spans[4].parent_id == spans[0].span_id
assert spans[5].name == "Response"
assert spans[5].parent_id == spans[4].span_id

# Validate chat attributes
assert spans[2].attributes[SpanAttributeKey.CHAT_TOOLS] == [
{
"function": {
"description": "Handoff to the Spanish_Agent agent to handle the request.",
"name": "transfer_to_spanish_agent",
"parameters": {
"additionalProperties": None,
"properties": {},
"required": [],
"type": "object",
},
"strict": False,
},
"type": "function",
},
]
assert SpanAttributeKey.CHAT_TOOLS not in spans[5].attributes

# Validate that the non-fatal API key error was logged
import time

Check failure on line 542 in tests/openai/test_openai_agent_autolog.py

View workflow job for this annotation

GitHub Actions / lint

MLF0016: Builtin modules must be imported at the top level. See dev/clint/README.md for instructions on ignoring this rule (lazy-builtin-import).

time.sleep(5.0) # Give background thread time to log the error

# Check captured logs from openai.agents logger
captured_messages = [record.getMessage() for record in log_capture.records]
api_key_errors = [msg for msg in captured_messages if "Incorrect API key provided" in msg]

# Print debug information
print(f"DEBUG: Captured {len(captured_messages)} log messages")
print(f"DEBUG: API key errors found: {len(api_key_errors)}")
if captured_messages:
print(f"DEBUG: Sample captured messages: {captured_messages[:5]}")

# Clean up first before assertions to avoid interference
openai_agents_logger.removeHandler(log_capture)

error_msg = api_key_errors[0]
print(f"SUCCESS: Captured expected API key error: {error_msg}")
assert "401" in error_msg
assert "Incorrect API key provided: test" in error_msg
assert "invalid_api_key" in error_msg
assert "[non-fatal]" in error_msg
Loading