Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,19 @@ async def _append_new_message_to_session(
new_message.parts[i] = types.Part(
text=f'Uploaded file: {file_name}. It is saved into artifacts'
)

if self._has_duplicate_user_event_for_invocation(
session=session,
invocation_id=invocation_context.invocation_id,
new_message=new_message,
state_delta=state_delta,
):
logger.info(
'Skipping duplicate user event append for invocation_id=%s',
invocation_context.invocation_id,
)
return

# Appends only. We do not yield the event because it's not from the model.
if state_delta:
event = Event(
Expand All @@ -918,6 +931,25 @@ async def _append_new_message_to_session(

await self.session_service.append_event(session=session, event=event)

def _has_duplicate_user_event_for_invocation(
self,
*,
session: Session,
invocation_id: str,
new_message: types.Content,
state_delta: Optional[dict[str, Any]],
) -> bool:
expected_state_delta = state_delta or {}
for event in session.events:
if event.invocation_id != invocation_id or event.author != 'user':
continue
if (
event.content == new_message
and event.actions.state_delta == expected_state_delta
):
return True
return False
Comment on lines +943 to +951
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability and conciseness, you can refactor this loop into a single statement using a generator expression with any(). This is a more Pythonic way to check for the existence of an item in a sequence that matches a condition.

    return any(
        event.content == new_message
        and event.actions.state_delta == expected_state_delta
        for event in session.events
        if event.author == "user" and event.invocation_id == invocation_id
    )


async def run_live(
self,
*,
Expand Down
145 changes: 145 additions & 0 deletions tests/unittests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,151 @@ def _infer_agent_origin(
assert event.content.parts[0].text == "Test LLM response"


@pytest.mark.asyncio
async def test_append_new_message_to_session_skips_duplicate_retry_message():
session_service = InMemorySessionService()
runner = Runner(
app_name="test_app",
agent=MockLlmAgent("root_agent"),
session_service=session_service,
artifact_service=InMemoryArtifactService(),
)
session = await session_service.create_session(
app_name="test_app",
user_id="test_user",
)
user_message = types.Content(
role="user",
parts=[types.Part(text="retry message")],
)
invocation_context = runner._new_invocation_context(
session,
invocation_id="inv-retry",
new_message=user_message,
run_config=RunConfig(),
)

await runner._append_new_message_to_session(
session=session,
new_message=user_message,
invocation_context=invocation_context,
)
await runner._append_new_message_to_session(
session=session,
new_message=user_message,
invocation_context=invocation_context,
)

matched_events = [
event
for event in session.events
if event.author == "user"
and event.invocation_id == "inv-retry"
and event.content == user_message
]
assert len(matched_events) == 1
Comment on lines +293 to +334
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new tests cover the duplicate detection logic well for messages. However, the check for duplicates also involves state_delta, which is not covered in the tests.

To improve test coverage, please consider adding a test case that verifies the behavior with different state_delta values. For example:

  1. Append a message with a specific state_delta.
  2. Append the same message with the same state_delta (should be skipped).
  3. Append the same message with a different state_delta (should be appended).
  4. Append the same message with state_delta=None (should be appended).

This will ensure the state_delta comparison in _has_duplicate_user_event_for_invocation works as expected.



@pytest.mark.asyncio
async def test_append_new_message_to_session_keeps_non_duplicate_messages():
session_service = InMemorySessionService()
runner = Runner(
app_name="test_app",
agent=MockLlmAgent("root_agent"),
session_service=session_service,
artifact_service=InMemoryArtifactService(),
)
session = await session_service.create_session(
app_name="test_app",
user_id="test_user",
)
invocation_context = runner._new_invocation_context(
session,
invocation_id="inv-retry",
new_message=types.Content(role="user", parts=[types.Part(text="first")]),
run_config=RunConfig(),
)
first_message = types.Content(role="user", parts=[types.Part(text="first")])
second_message = types.Content(role="user", parts=[types.Part(text="second")])

await runner._append_new_message_to_session(
session=session,
new_message=first_message,
invocation_context=invocation_context,
)
await runner._append_new_message_to_session(
session=session,
new_message=second_message,
invocation_context=invocation_context,
)

matched_events = [
event
for event in session.events
if event.author == "user" and event.invocation_id == "inv-retry"
]
assert len(matched_events) == 2


@pytest.mark.asyncio
async def test_append_new_message_to_session_state_delta_deduping():
session_service = InMemorySessionService()
runner = Runner(
app_name="test_app",
agent=MockLlmAgent("root_agent"),
session_service=session_service,
artifact_service=InMemoryArtifactService(),
)
session = await session_service.create_session(
app_name="test_app",
user_id="test_user",
)
user_message = types.Content(role="user", parts=[types.Part(text="same message")])
invocation_context = runner._new_invocation_context(
session,
invocation_id="inv-state-delta",
new_message=user_message,
run_config=RunConfig(),
)

await runner._append_new_message_to_session(
session=session,
new_message=user_message,
invocation_context=invocation_context,
state_delta={"attempt": 1},
)
await runner._append_new_message_to_session(
session=session,
new_message=user_message,
invocation_context=invocation_context,
state_delta={"attempt": 1},
)
await runner._append_new_message_to_session(
session=session,
new_message=user_message,
invocation_context=invocation_context,
state_delta={"attempt": 2},
)
await runner._append_new_message_to_session(
session=session,
new_message=user_message,
invocation_context=invocation_context,
state_delta=None,
)

matched_events = [
event
for event in session.events
if event.author == "user"
and event.invocation_id == "inv-state-delta"
and event.content == user_message
]
assert len(matched_events) == 3
assert matched_events[0].actions.state_delta == {"attempt": 1}
assert matched_events[1].actions.state_delta == {"attempt": 2}
assert matched_events[2].actions.state_delta == {}


@pytest.mark.asyncio
async def test_rewind_auto_create_session_on_missing_session():
"""When auto_create_session=True, rewind should create session if missing.
Expand Down