diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index bc0251a81e..bc7b27764c 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -896,6 +896,19 @@ async def _append_new_message_to_session( new_message.parts[i] = types.Part( text=f'Uploaded file: {file_name}. It is saved into artifacts' ) + + if self._has_duplicate_user_event_for_invocation( + session=session, + invocation_id=invocation_context.invocation_id, + new_message=new_message, + state_delta=state_delta, + ): + logger.info( + 'Skipping duplicate user event append for invocation_id=%s', + invocation_context.invocation_id, + ) + return + # Appends only. We do not yield the event because it's not from the model. if state_delta: event = Event( @@ -918,6 +931,25 @@ async def _append_new_message_to_session( await self.session_service.append_event(session=session, event=event) + def _has_duplicate_user_event_for_invocation( + self, + *, + session: Session, + invocation_id: str, + new_message: types.Content, + state_delta: Optional[dict[str, Any]], + ) -> bool: + expected_state_delta = state_delta or {} + for event in session.events: + if event.invocation_id != invocation_id or event.author != 'user': + continue + if ( + event.content == new_message + and event.actions.state_delta == expected_state_delta + ): + return True + return False + async def run_live( self, *, diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index ca7eb37533..179739118c 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -289,6 +289,151 @@ def _infer_agent_origin( assert event.content.parts[0].text == "Test LLM response" +@pytest.mark.asyncio +async def test_append_new_message_to_session_skips_duplicate_retry_message(): + session_service = InMemorySessionService() + runner = Runner( + app_name="test_app", + agent=MockLlmAgent("root_agent"), + session_service=session_service, + artifact_service=InMemoryArtifactService(), + ) + session = await session_service.create_session( + app_name="test_app", + user_id="test_user", + ) + user_message = types.Content( + role="user", + parts=[types.Part(text="retry message")], + ) + invocation_context = runner._new_invocation_context( + session, + invocation_id="inv-retry", + new_message=user_message, + run_config=RunConfig(), + ) + + await runner._append_new_message_to_session( + session=session, + new_message=user_message, + invocation_context=invocation_context, + ) + await runner._append_new_message_to_session( + session=session, + new_message=user_message, + invocation_context=invocation_context, + ) + + matched_events = [ + event + for event in session.events + if event.author == "user" + and event.invocation_id == "inv-retry" + and event.content == user_message + ] + assert len(matched_events) == 1 + + +@pytest.mark.asyncio +async def test_append_new_message_to_session_keeps_non_duplicate_messages(): + session_service = InMemorySessionService() + runner = Runner( + app_name="test_app", + agent=MockLlmAgent("root_agent"), + session_service=session_service, + artifact_service=InMemoryArtifactService(), + ) + session = await session_service.create_session( + app_name="test_app", + user_id="test_user", + ) + invocation_context = runner._new_invocation_context( + session, + invocation_id="inv-retry", + new_message=types.Content(role="user", parts=[types.Part(text="first")]), + run_config=RunConfig(), + ) + first_message = types.Content(role="user", parts=[types.Part(text="first")]) + second_message = types.Content(role="user", parts=[types.Part(text="second")]) + + await runner._append_new_message_to_session( + session=session, + new_message=first_message, + invocation_context=invocation_context, + ) + await runner._append_new_message_to_session( + session=session, + new_message=second_message, + invocation_context=invocation_context, + ) + + matched_events = [ + event + for event in session.events + if event.author == "user" and event.invocation_id == "inv-retry" + ] + assert len(matched_events) == 2 + + +@pytest.mark.asyncio +async def test_append_new_message_to_session_state_delta_deduping(): + session_service = InMemorySessionService() + runner = Runner( + app_name="test_app", + agent=MockLlmAgent("root_agent"), + session_service=session_service, + artifact_service=InMemoryArtifactService(), + ) + session = await session_service.create_session( + app_name="test_app", + user_id="test_user", + ) + user_message = types.Content(role="user", parts=[types.Part(text="same message")]) + invocation_context = runner._new_invocation_context( + session, + invocation_id="inv-state-delta", + new_message=user_message, + run_config=RunConfig(), + ) + + await runner._append_new_message_to_session( + session=session, + new_message=user_message, + invocation_context=invocation_context, + state_delta={"attempt": 1}, + ) + await runner._append_new_message_to_session( + session=session, + new_message=user_message, + invocation_context=invocation_context, + state_delta={"attempt": 1}, + ) + await runner._append_new_message_to_session( + session=session, + new_message=user_message, + invocation_context=invocation_context, + state_delta={"attempt": 2}, + ) + await runner._append_new_message_to_session( + session=session, + new_message=user_message, + invocation_context=invocation_context, + state_delta=None, + ) + + matched_events = [ + event + for event in session.events + if event.author == "user" + and event.invocation_id == "inv-state-delta" + and event.content == user_message + ] + assert len(matched_events) == 3 + assert matched_events[0].actions.state_delta == {"attempt": 1} + assert matched_events[1].actions.state_delta == {"attempt": 2} + assert matched_events[2].actions.state_delta == {} + + @pytest.mark.asyncio async def test_rewind_auto_create_session_on_missing_session(): """When auto_create_session=True, rewind should create session if missing.