-
Notifications
You must be signed in to change notification settings - Fork 2.9k
runner: avoid duplicate user event append on invocation retry #4507
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -289,6 +289,151 @@ def _infer_agent_origin( | |
| assert event.content.parts[0].text == "Test LLM response" | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_append_new_message_to_session_skips_duplicate_retry_message(): | ||
| session_service = InMemorySessionService() | ||
| runner = Runner( | ||
| app_name="test_app", | ||
| agent=MockLlmAgent("root_agent"), | ||
| session_service=session_service, | ||
| artifact_service=InMemoryArtifactService(), | ||
| ) | ||
| session = await session_service.create_session( | ||
| app_name="test_app", | ||
| user_id="test_user", | ||
| ) | ||
| user_message = types.Content( | ||
| role="user", | ||
| parts=[types.Part(text="retry message")], | ||
| ) | ||
| invocation_context = runner._new_invocation_context( | ||
| session, | ||
| invocation_id="inv-retry", | ||
| new_message=user_message, | ||
| run_config=RunConfig(), | ||
| ) | ||
|
|
||
| await runner._append_new_message_to_session( | ||
| session=session, | ||
| new_message=user_message, | ||
| invocation_context=invocation_context, | ||
| ) | ||
| await runner._append_new_message_to_session( | ||
| session=session, | ||
| new_message=user_message, | ||
| invocation_context=invocation_context, | ||
| ) | ||
|
|
||
| matched_events = [ | ||
| event | ||
| for event in session.events | ||
| if event.author == "user" | ||
| and event.invocation_id == "inv-retry" | ||
| and event.content == user_message | ||
| ] | ||
| assert len(matched_events) == 1 | ||
|
Comment on lines
+293
to
+334
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new tests cover the duplicate detection logic well for messages. However, the check for duplicates also involves To improve test coverage, please consider adding a test case that verifies the behavior with different
This will ensure the |
||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_append_new_message_to_session_keeps_non_duplicate_messages(): | ||
| session_service = InMemorySessionService() | ||
| runner = Runner( | ||
| app_name="test_app", | ||
| agent=MockLlmAgent("root_agent"), | ||
| session_service=session_service, | ||
| artifact_service=InMemoryArtifactService(), | ||
| ) | ||
| session = await session_service.create_session( | ||
| app_name="test_app", | ||
| user_id="test_user", | ||
| ) | ||
| invocation_context = runner._new_invocation_context( | ||
| session, | ||
| invocation_id="inv-retry", | ||
| new_message=types.Content(role="user", parts=[types.Part(text="first")]), | ||
| run_config=RunConfig(), | ||
| ) | ||
| first_message = types.Content(role="user", parts=[types.Part(text="first")]) | ||
| second_message = types.Content(role="user", parts=[types.Part(text="second")]) | ||
|
|
||
| await runner._append_new_message_to_session( | ||
| session=session, | ||
| new_message=first_message, | ||
| invocation_context=invocation_context, | ||
| ) | ||
| await runner._append_new_message_to_session( | ||
| session=session, | ||
| new_message=second_message, | ||
| invocation_context=invocation_context, | ||
| ) | ||
|
|
||
| matched_events = [ | ||
| event | ||
| for event in session.events | ||
| if event.author == "user" and event.invocation_id == "inv-retry" | ||
| ] | ||
| assert len(matched_events) == 2 | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_append_new_message_to_session_state_delta_deduping(): | ||
| session_service = InMemorySessionService() | ||
| runner = Runner( | ||
| app_name="test_app", | ||
| agent=MockLlmAgent("root_agent"), | ||
| session_service=session_service, | ||
| artifact_service=InMemoryArtifactService(), | ||
| ) | ||
| session = await session_service.create_session( | ||
| app_name="test_app", | ||
| user_id="test_user", | ||
| ) | ||
| user_message = types.Content(role="user", parts=[types.Part(text="same message")]) | ||
| invocation_context = runner._new_invocation_context( | ||
| session, | ||
| invocation_id="inv-state-delta", | ||
| new_message=user_message, | ||
| run_config=RunConfig(), | ||
| ) | ||
|
|
||
| await runner._append_new_message_to_session( | ||
| session=session, | ||
| new_message=user_message, | ||
| invocation_context=invocation_context, | ||
| state_delta={"attempt": 1}, | ||
| ) | ||
| await runner._append_new_message_to_session( | ||
| session=session, | ||
| new_message=user_message, | ||
| invocation_context=invocation_context, | ||
| state_delta={"attempt": 1}, | ||
| ) | ||
| await runner._append_new_message_to_session( | ||
| session=session, | ||
| new_message=user_message, | ||
| invocation_context=invocation_context, | ||
| state_delta={"attempt": 2}, | ||
| ) | ||
| await runner._append_new_message_to_session( | ||
| session=session, | ||
| new_message=user_message, | ||
| invocation_context=invocation_context, | ||
| state_delta=None, | ||
| ) | ||
|
|
||
| matched_events = [ | ||
| event | ||
| for event in session.events | ||
| if event.author == "user" | ||
| and event.invocation_id == "inv-state-delta" | ||
| and event.content == user_message | ||
| ] | ||
| assert len(matched_events) == 3 | ||
| assert matched_events[0].actions.state_delta == {"attempt": 1} | ||
| assert matched_events[1].actions.state_delta == {"attempt": 2} | ||
| assert matched_events[2].actions.state_delta == {} | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_rewind_auto_create_session_on_missing_session(): | ||
| """When auto_create_session=True, rewind should create session if missing. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For better readability and conciseness, you can refactor this loop into a single statement using a generator expression with
any(). This is a more Pythonic way to check for the existence of an item in a sequence that matches a condition.