diff --git a/CLAUDE.md b/CLAUDE.md index d7b175636..e6baee770 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -28,6 +28,7 @@ This document contains critical information about working with this codebase. Fo - Bug fixes require regression tests - IMPORTANT: The `tests/client/test_client.py` is the most well designed test file. Follow its patterns. - IMPORTANT: Be minimal, and focus on E2E tests: Use the `mcp.client.Client` whenever possible. + - IMPORTANT: Do NOT test private functions (prefixed with `_`). Test them indirectly through the public API. Test files mirror the source tree: `src/mcp/client/streamable_http.py` → `tests/client/test_streamable_http.py` Add tests to the existing file for that module. diff --git a/pyproject.toml b/pyproject.toml index 65bde6966..45268b530 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "pyjwt[crypto]>=2.10.1", "typing-extensions>=4.13.0", "typing-inspection>=0.4.1", + "opentelemetry-api>=1.28.0", ] [project.optional-dependencies] @@ -71,6 +72,7 @@ dev = [ "coverage[toml]>=7.10.7,<=7.13", "pillow>=12.0", "strict-no-cover", + "opentelemetry-sdk>=1.28.0", ] docs = [ "mkdocs>=1.6.1", diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 453e36274..7de0ad483 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -8,12 +8,14 @@ import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from opentelemetry import trace from pydantic import BaseModel, TypeAdapter from typing_extensions import Self from mcp.shared.exceptions import MCPError from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.response_router import ResponseRouter +from mcp.shared.tracing import end_span_error, end_span_ok, start_client_span, start_server_span from mcp.types import ( CONNECTION_CLOSED, INVALID_PARAMS, @@ -77,6 +79,7 @@ def __init__( session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT], on_complete: Callable[[RequestResponder[ReceiveRequestT, SendResultT]], Any], message_metadata: MessageMetadata = None, + span: trace.Span | None = None, ) -> None: self.request_id = request_id self.request_meta = request_meta @@ -87,6 +90,7 @@ def __init__( self._cancel_scope = anyio.CancelScope() self._on_complete = on_complete self._entered = False # Track if we're in a context manager + self._span = span def __enter__(self) -> RequestResponder[ReceiveRequestT, SendResultT]: """Enter the context manager, enabling request cancellation tracking.""" @@ -126,6 +130,12 @@ async def respond(self, response: SendResultT | ErrorData) -> None: if not self.cancelled: # pragma: no branch self._completed = True + if self._span is not None: + if isinstance(response, ErrorData): + end_span_error(self._span, MCPError(code=response.code, message=response.message)) + else: + end_span_ok(self._span) + await self._session._send_response( # type: ignore[reportPrivateUsage] request_id=self.request_id, response=response ) @@ -139,6 +149,10 @@ async def cancel(self) -> None: self._cancel_scope.cancel() self._completed = True # Mark as completed so it's removed from in_flight + + if self._span is not None: + end_span_error(self._span, MCPError(code=0, message="Request cancelled")) + # Send an error response to indicate cancellation await self._session._send_response( # type: ignore[reportPrivateUsage] request_id=self.request_id, @@ -260,6 +274,9 @@ async def send_request( # Store the callback for this request self._progress_callbacks[request_id] = progress_callback + method = request_data["method"] + span = start_client_span(method, request_data.get("params")) + try: jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data) await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata)) @@ -278,7 +295,15 @@ async def send_request( if isinstance(response_or_error, JSONRPCError): raise MCPError.from_jsonrpc_error(response_or_error) else: - return result_type.model_validate(response_or_error.result, by_name=False) + result = result_type.model_validate(response_or_error.result, by_name=False) + if span is not None: + end_span_ok(span) + return result + + except BaseException as exc: + if span is not None: + end_span_error(span, exc) + raise finally: self._response_streams.pop(request_id, None) @@ -339,6 +364,8 @@ async def _receive_loop(self) -> None: message.message.model_dump(by_alias=True, mode="json", exclude_none=True), by_name=False, ) + request_data = message.message.model_dump(by_alias=True, mode="json", exclude_none=True) + server_span = start_server_span(request_data["method"], request_data.get("params")) responder = RequestResponder( request_id=message.message.id, request_meta=validated_request.params.meta if validated_request.params else None, @@ -346,6 +373,7 @@ async def _receive_loop(self) -> None: session=self, on_complete=lambda r: self._in_flight.pop(r.request_id, None), message_metadata=message.metadata, + span=server_span, ) self._in_flight[responder.request_id] = responder await self._received_request(responder) diff --git a/src/mcp/shared/tracing.py b/src/mcp/shared/tracing.py new file mode 100644 index 000000000..f404c370d --- /dev/null +++ b/src/mcp/shared/tracing.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import Any + +from opentelemetry import trace +from opentelemetry.trace import StatusCode + +_tracer = trace.get_tracer("mcp") + +_EXCLUDED_METHODS: frozenset[str] = frozenset({"notifications/message"}) + +# Semantic convention attribute keys +ATTR_MCP_METHOD_NAME = "mcp.method.name" +ATTR_ERROR_TYPE = "error.type" + +# Methods that have a meaningful target name in params +_TARGET_PARAM_KEY: dict[str, str] = { + "tools/call": "name", + "prompts/get": "name", + "resources/read": "uri", +} + + +def _extract_target(method: str, params: dict[str, Any] | None) -> str | None: + """Extract the target (e.g. tool name, prompt name) from request params.""" + key = _TARGET_PARAM_KEY.get(method) + if key is None or params is None: + return None + value = params.get(key) + if isinstance(value, str): + return value + return None + + +def start_client_span(method: str, params: dict[str, Any] | None) -> trace.Span | None: + """Start a CLIENT span for an outgoing MCP request. + + Returns None if the method is excluded from tracing. + """ + if method in _EXCLUDED_METHODS: + return None + + target = _extract_target(method, params) + span_name = f"{method} {target}" if target else method + span = _tracer.start_span( + span_name, + kind=trace.SpanKind.CLIENT, + attributes={ATTR_MCP_METHOD_NAME: method}, + ) + return span + + +def start_server_span(method: str, params: dict[str, Any] | None) -> trace.Span | None: + """Start a SERVER span for an incoming MCP request. + + Returns None if the method is excluded from tracing. + """ + if method in _EXCLUDED_METHODS: + return None + + target = _extract_target(method, params) + span_name = f"{method} {target}" if target else method + span = _tracer.start_span( + span_name, + kind=trace.SpanKind.SERVER, + attributes={ATTR_MCP_METHOD_NAME: method}, + ) + return span + + +def end_span_ok(span: trace.Span) -> None: + """Mark a span as successful and end it.""" + span.set_status(StatusCode.OK) + span.end() + + +def end_span_error(span: trace.Span, error: BaseException) -> None: + """Mark a span as errored and end it.""" + span.set_status(StatusCode.ERROR, str(error)) + span.set_attribute(ATTR_ERROR_TYPE, type(error).__qualname__) + span.end() diff --git a/tests/shared/test_tracing.py b/tests/shared/test_tracing.py new file mode 100644 index 000000000..504600e15 --- /dev/null +++ b/tests/shared/test_tracing.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +from typing import Any + +import anyio +import pytest +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.trace import SpanKind, StatusCode + +from mcp import Client, types +from mcp.server.lowlevel.server import Server +from mcp.shared.exceptions import MCPError +from mcp.shared.tracing import ATTR_ERROR_TYPE, ATTR_MCP_METHOD_NAME + +# Module-level provider + exporter — avoids the "Overriding of current +# TracerProvider is not allowed" warning that happens if you call +# set_tracer_provider() more than once. +_provider = TracerProvider() +_exporter = InMemorySpanExporter() +_provider.add_span_processor(SimpleSpanProcessor(_exporter)) + + +@pytest.fixture(autouse=True) +def _otel_setup(monkeypatch: pytest.MonkeyPatch) -> InMemorySpanExporter: + """Patch the module-level tracer to use our test provider and clear spans between tests.""" + import mcp.shared.tracing as tracing_mod + + monkeypatch.setattr(tracing_mod, "_tracer", _provider.get_tracer("mcp")) + _exporter.clear() + return _exporter + + +@pytest.mark.anyio +async def test_span_created_on_send_request(_otel_setup: InMemorySpanExporter) -> None: + """Verify a CLIENT span is created when send_request() succeeds.""" + exporter = _otel_setup + + server = Server(name="test server") + async with Client(server) as client: + await client.send_ping() + + spans = exporter.get_finished_spans() + # Filter to only the CLIENT ping span (initialize also produces one, plus server spans) + ping_spans = [ + s + for s in spans + if s.kind == SpanKind.CLIENT and s.attributes and s.attributes.get(ATTR_MCP_METHOD_NAME) == "ping" + ] + assert len(ping_spans) == 1 + + span = ping_spans[0] + assert span.name == "ping" + assert span.kind == SpanKind.CLIENT + assert span.status.status_code == StatusCode.OK + + +@pytest.mark.anyio +async def test_span_attributes_for_tool_call(_otel_setup: InMemorySpanExporter) -> None: + """Verify span name includes tool name for tools/call requests.""" + exporter = _otel_setup + + server = Server(name="test server") + + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + return [types.Tool(name="echo", description="Echo tool", input_schema={"type": "object"})] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[types.TextContent]: + return [types.TextContent(type="text", text=str(arguments))] + + async with Client(server) as client: + await client.call_tool("echo", {"msg": "hi"}) + + spans = exporter.get_finished_spans() + tool_spans = [ + s + for s in spans + if s.kind == SpanKind.CLIENT and s.attributes and s.attributes.get(ATTR_MCP_METHOD_NAME) == "tools/call" + ] + assert len(tool_spans) == 1 + + span = tool_spans[0] + assert span.name == "tools/call echo" + assert span.status.status_code == StatusCode.OK + + +@pytest.mark.anyio +async def test_span_error_on_failure(_otel_setup: InMemorySpanExporter) -> None: + """Verify span records ERROR status when the request times out.""" + exporter = _otel_setup + + server = Server(name="test server") + + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + return [types.Tool(name="slow_tool", description="Slow", input_schema={"type": "object"})] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[types.TextContent]: + await anyio.sleep(10) + return [] # pragma: no cover + + async with Client(server) as client: + with pytest.raises(MCPError, match="Timed out"): + await client.session.send_request( + types.CallToolRequest(params=types.CallToolRequestParams(name="slow_tool", arguments={})), + types.CallToolResult, + request_read_timeout_seconds=0.01, + ) + + spans = exporter.get_finished_spans() + tool_spans = [ + s + for s in spans + if s.kind == SpanKind.CLIENT and s.attributes and s.attributes.get(ATTR_MCP_METHOD_NAME) == "tools/call" + ] + assert len(tool_spans) == 1 + + span = tool_spans[0] + assert span.status.status_code == StatusCode.ERROR + assert span.attributes is not None + assert span.attributes.get(ATTR_ERROR_TYPE) == "MCPError" + + +@pytest.mark.anyio +async def test_no_span_for_excluded_method(_otel_setup: InMemorySpanExporter) -> None: + """Verify no span is created for excluded methods (notifications/message).""" + exporter = _otel_setup + + server = Server(name="test server") + async with Client(server) as client: + await client.send_ping() + + spans = exporter.get_finished_spans() + excluded_spans = [ + s for s in spans if s.attributes and s.attributes.get(ATTR_MCP_METHOD_NAME) == "notifications/message" + ] + assert len(excluded_spans) == 0 + + +@pytest.mark.anyio +async def test_server_span_on_successful_request(_otel_setup: InMemorySpanExporter) -> None: + """Verify a SERVER span is created when the server handles a request.""" + exporter = _otel_setup + + server = Server(name="test server") + async with Client(server) as client: + await client.send_ping() + + spans = exporter.get_finished_spans() + server_ping_spans = [ + s + for s in spans + if s.kind == SpanKind.SERVER and s.attributes and s.attributes.get(ATTR_MCP_METHOD_NAME) == "ping" + ] + assert len(server_ping_spans) == 1 + + span = server_ping_spans[0] + assert span.name == "ping" + assert span.status.status_code == StatusCode.OK + + +@pytest.mark.anyio +async def test_server_span_includes_target(_otel_setup: InMemorySpanExporter) -> None: + """Verify server span name includes tool name for tools/call requests.""" + exporter = _otel_setup + + server = Server(name="test server") + + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + return [types.Tool(name="echo", description="Echo tool", input_schema={"type": "object"})] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[types.TextContent]: + return [types.TextContent(type="text", text=str(arguments))] + + async with Client(server) as client: + await client.call_tool("echo", {"msg": "hi"}) + + spans = exporter.get_finished_spans() + server_tool_spans = [ + s + for s in spans + if s.kind == SpanKind.SERVER and s.attributes and s.attributes.get(ATTR_MCP_METHOD_NAME) == "tools/call" + ] + assert len(server_tool_spans) == 1 + + span = server_tool_spans[0] + assert span.name == "tools/call echo" + assert span.status.status_code == StatusCode.OK + + +@pytest.mark.anyio +async def test_server_span_error_on_error_response(_otel_setup: InMemorySpanExporter) -> None: + """Verify server span records ERROR status when the server responds with ErrorData.""" + exporter = _otel_setup + + server = Server(name="test server") + + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + raise MCPError(code=-1, message="internal failure") + + async with Client(server) as client: + with pytest.raises(MCPError, match="internal failure"): + await client.list_tools() + + spans = exporter.get_finished_spans() + server_spans = [ + s + for s in spans + if s.kind == SpanKind.SERVER and s.attributes and s.attributes.get(ATTR_MCP_METHOD_NAME) == "tools/list" + ] + assert len(server_spans) == 1 + + span = server_spans[0] + assert span.status.status_code == StatusCode.ERROR + assert span.attributes is not None + assert span.attributes.get(ATTR_ERROR_TYPE) == "MCPError" diff --git a/uv.lock b/uv.lock index 364112ec8..83b605fed 100644 --- a/uv.lock +++ b/uv.lock @@ -573,6 +573,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, ] +[[package]] +name = "importlib-metadata" +version = "8.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "zipp" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f3/49/3b30cad09e7771a4982d9975a8cbf64f00d4a1ececb53297f1d9a7be1b10/importlib_metadata-8.7.1.tar.gz", hash = "sha256:49fef1ae6440c182052f407c8d34a68f72efc36db9ca90dc0113398f2fdde8bb", size = 57107, upload-time = "2025-12-21T10:00:19.278Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/5e/f8e9a1d23b9c20a551a8a02ea3637b4642e22c2626e3a13a9a29cdea99eb/importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151", size = 27865, upload-time = "2025-12-21T10:00:18.329Z" }, +] + [[package]] name = "iniconfig" version = "2.1.0" @@ -724,6 +736,7 @@ dependencies = [ { name = "httpx" }, { name = "httpx-sse" }, { name = "jsonschema" }, + { name = "opentelemetry-api" }, { name = "pydantic" }, { name = "pydantic-settings" }, { name = "pyjwt", extra = ["crypto"] }, @@ -754,6 +767,7 @@ dev = [ { name = "dirty-equals" }, { name = "inline-snapshot" }, { name = "mcp", extra = ["cli", "ws"] }, + { name = "opentelemetry-sdk" }, { name = "pillow" }, { name = "pyright" }, { name = "pytest" }, @@ -778,6 +792,7 @@ requires-dist = [ { name = "httpx", specifier = ">=0.27.1" }, { name = "httpx-sse", specifier = ">=0.4" }, { name = "jsonschema", specifier = ">=4.20.0" }, + { name = "opentelemetry-api", specifier = ">=1.28.0" }, { name = "pydantic", specifier = ">=2.12.0" }, { name = "pydantic-settings", specifier = ">=2.5.2" }, { name = "pyjwt", extras = ["crypto"], specifier = ">=2.10.1" }, @@ -802,6 +817,7 @@ dev = [ { name = "dirty-equals", specifier = ">=0.9.0" }, { name = "inline-snapshot", specifier = ">=0.23.0" }, { name = "mcp", extras = ["cli", "ws"], editable = "." }, + { name = "opentelemetry-sdk", specifier = ">=1.28.0" }, { name = "pillow", specifier = ">=12.0" }, { name = "pyright", specifier = ">=1.1.400" }, { name = "pytest", specifier = ">=8.3.4" }, @@ -1535,6 +1551,46 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314, upload-time = "2024-06-04T18:44:08.352Z" }, ] +[[package]] +name = "opentelemetry-api" +version = "1.39.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "importlib-metadata" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/97/b9/3161be15bb8e3ad01be8be5a968a9237c3027c5be504362ff800fca3e442/opentelemetry_api-1.39.1.tar.gz", hash = "sha256:fbde8c80e1b937a2c61f20347e91c0c18a1940cecf012d62e65a7caf08967c9c", size = 65767, upload-time = "2025-12-11T13:32:39.182Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cf/df/d3f1ddf4bb4cb50ed9b1139cc7b1c54c34a1e7ce8fd1b9a37c0d1551a6bd/opentelemetry_api-1.39.1-py3-none-any.whl", hash = "sha256:2edd8463432a7f8443edce90972169b195e7d6a05500cd29e6d13898187c9950", size = 66356, upload-time = "2025-12-11T13:32:17.304Z" }, +] + +[[package]] +name = "opentelemetry-sdk" +version = "1.39.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/eb/fb/c76080c9ba07e1e8235d24cdcc4d125ef7aa3edf23eb4e497c2e50889adc/opentelemetry_sdk-1.39.1.tar.gz", hash = "sha256:cf4d4563caf7bff906c9f7967e2be22d0d6b349b908be0d90fb21c8e9c995cc6", size = 171460, upload-time = "2025-12-11T13:32:49.369Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/98/e91cf858f203d86f4eccdf763dcf01cf03f1dae80c3750f7e635bfa206b6/opentelemetry_sdk-1.39.1-py3-none-any.whl", hash = "sha256:4d5482c478513ecb0a5d938dcc61394e647066e0cc2676bee9f3af3f3f45f01c", size = 132565, upload-time = "2025-12-11T13:32:35.069Z" }, +] + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.60b1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/91/df/553f93ed38bf22f4b999d9be9c185adb558982214f33eae539d3b5cd0858/opentelemetry_semantic_conventions-0.60b1.tar.gz", hash = "sha256:87c228b5a0669b748c76d76df6c364c369c28f1c465e50f661e39737e84bc953", size = 137935, upload-time = "2025-12-11T13:32:50.487Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/5e/5958555e09635d09b75de3c4f8b9cae7335ca545d77392ffe7331534c402/opentelemetry_semantic_conventions-0.60b1-py3-none-any.whl", hash = "sha256:9fa8c8b0c110da289809292b0591220d3a7b53c1526a23021e977d68597893fb", size = 219982, upload-time = "2025-12-11T13:32:36.955Z" }, +] + [[package]] name = "outcome" version = "1.3.0.post0" @@ -2599,3 +2655,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/68/a1/dcb68430b1d00b698ae7a7e0194433bce4f07ded185f0ee5fb21e2a2e91e/websockets-15.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:cad21560da69f4ce7658ca2cb83138fb4cf695a2ba3e475e0559e05991aa8122", size = 176884, upload-time = "2025-03-05T20:03:27.934Z" }, { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743, upload-time = "2025-03-05T20:03:39.41Z" }, ] + +[[package]] +name = "zipp" +version = "3.23.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50ede074e376733dca2ae7c6eb617489437771209d4180/zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166", size = 25547, upload-time = "2025-06-08T17:06:39.4Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" }, +]