diff --git a/src/google/adk/telemetry/tracing.py b/src/google/adk/telemetry/tracing.py index fbb55ec914..a5c627e59c 100644 --- a/src/google/adk/telemetry/tracing.py +++ b/src/google/adk/telemetry/tracing.py @@ -26,6 +26,7 @@ from collections.abc import Iterator from collections.abc import Mapping from contextlib import contextmanager +import hashlib import json import logging import os @@ -116,6 +117,50 @@ def _safe_json_serialize(obj) -> str: return '' +def _stable_json_serialize(obj: Any) -> str: + """Serializes with stable key ordering for deterministic receipts.""" + return json.dumps( + obj, + ensure_ascii=False, + sort_keys=True, + separators=(',', ':'), + default=lambda _: '', + ) + + +def _build_tool_call_receipt( + tool: BaseTool, + args: dict[str, Any], + function_response_event: Event | None, +) -> dict[str, str]: + """Builds a deterministic receipt for tool call tracing.""" + tool_call_id = '' + outcome = 'unknown' + if ( + function_response_event is not None + and function_response_event.content is not None + and function_response_event.content.parts + ): + function_response = function_response_event.content.parts[0].function_response + if function_response is not None: + if function_response.id is not None: + tool_call_id = function_response.id + if function_response.response is not None: + outcome = 'success' + + args_hash = hashlib.sha256( + _stable_json_serialize(args).encode('utf-8') + ).hexdigest() + return { + 'schema_version': '1', + 'tool_name': tool.name, + 'tool_type': tool.__class__.__name__, + 'tool_call_id': tool_call_id, + 'args_sha256': args_hash, + 'outcome': outcome, + } + + def trace_agent_invocation( span: trace.Span, agent: BaseAgent, ctx: InvocationContext ) -> None: @@ -184,6 +229,17 @@ def trace_tool_call( else: span.set_attribute('gcp.vertex.agent.tool_call_args', '{}') + span.set_attribute( + 'gcp.vertex.agent.tool_call_receipt', + _stable_json_serialize( + _build_tool_call_receipt( + tool=tool, + args=args, + function_response_event=function_response_event, + ) + ), + ) + # Tracing tool response tool_call_id = '' tool_response = '' diff --git a/tests/unittests/telemetry/test_spans.py b/tests/unittests/telemetry/test_spans.py index bb0846765f..523a90eb40 100644 --- a/tests/unittests/telemetry/test_spans.py +++ b/tests/unittests/telemetry/test_spans.py @@ -398,7 +398,20 @@ def test_trace_tool_call_with_scalar_response( mock.call('gcp.vertex.agent.llm_response', '{}'), ] - assert mock_span_fixture.set_attribute.call_count == len(expected_calls) + receipt_calls = [ + call_obj + for call_obj in mock_span_fixture.set_attribute.call_args_list + if call_obj.args[0] == 'gcp.vertex.agent.tool_call_receipt' + ] + assert len(receipt_calls) == 1 + receipt = json.loads(receipt_calls[0].args[1]) + assert receipt['schema_version'] == '1' + assert receipt['tool_name'] == mock_tool_fixture.name + assert receipt['tool_call_id'] == test_tool_call_id + assert receipt['outcome'] == 'success' + assert len(receipt['args_sha256']) == 64 + + assert mock_span_fixture.set_attribute.call_count == len(expected_calls) + 1 mock_span_fixture.set_attribute.assert_has_calls( expected_calls, any_order=True ) @@ -457,12 +470,75 @@ def test_trace_tool_call_with_dict_response( mock.call('gcp.vertex.agent.llm_response', '{}'), ] - assert mock_span_fixture.set_attribute.call_count == len(expected_calls) + receipt_calls = [ + call_obj + for call_obj in mock_span_fixture.set_attribute.call_args_list + if call_obj.args[0] == 'gcp.vertex.agent.tool_call_receipt' + ] + assert len(receipt_calls) == 1 + receipt = json.loads(receipt_calls[0].args[1]) + assert receipt['schema_version'] == '1' + assert receipt['tool_name'] == mock_tool_fixture.name + assert receipt['tool_call_id'] == test_tool_call_id + assert receipt['outcome'] == 'success' + assert len(receipt['args_sha256']) == 64 + + assert mock_span_fixture.set_attribute.call_count == len(expected_calls) + 1 mock_span_fixture.set_attribute.assert_has_calls( expected_calls, any_order=True ) +def test_trace_tool_call_receipt_is_deterministic( + monkeypatch, + mock_span_fixture, + mock_tool_fixture, + mock_event_fixture, +): + monkeypatch.setattr( + 'opentelemetry.trace.get_current_span', lambda: mock_span_fixture + ) + + mock_event_fixture.content = types.Content( + role='user', + parts=[ + types.Part( + function_response=types.FunctionResponse( + id='receipt-call', + name='test_function_1', + response={'ok': True}, + ) + ), + ], + ) + mock_event_fixture.id = 'receipt-event' + + trace_tool_call( + tool=mock_tool_fixture, + args={'b': 2, 'a': 1}, + function_response_event=mock_event_fixture, + ) + first_receipt = next( + call_obj.args[1] + for call_obj in mock_span_fixture.set_attribute.call_args_list + if call_obj.args[0] == 'gcp.vertex.agent.tool_call_receipt' + ) + + mock_span_fixture.set_attribute.reset_mock() + trace_tool_call( + tool=mock_tool_fixture, + args={'a': 1, 'b': 2}, + function_response_event=mock_event_fixture, + ) + second_receipt = next( + call_obj.args[1] + for call_obj in mock_span_fixture.set_attribute.call_args_list + if call_obj.args[0] == 'gcp.vertex.agent.tool_call_receipt' + ) + + assert first_receipt == second_receipt + + def test_trace_merged_tool_calls_sets_correct_attributes( monkeypatch, mock_span_fixture, mock_event_fixture ): @@ -604,6 +680,12 @@ def test_trace_tool_call_disabling_request_response_content( call_obj.args for call_obj in mock_span_fixture.set_attribute.call_args_list ) + receipt_calls = [ + call_obj + for call_obj in mock_span_fixture.set_attribute.call_args_list + if call_obj.args[0] == 'gcp.vertex.agent.tool_call_receipt' + ] + assert len(receipt_calls) == 1 def test_trace_merged_tool_disabling_request_response_content(