diff --git a/pyproject.toml b/pyproject.toml index da05cfcee9..6f606708ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -164,6 +164,8 @@ extensions = [ "toolbox-adk>=0.5.7, <0.6.0", # For tools.toolbox_toolset.ToolboxToolset ] +firestore = ["google-cloud-firestore>=2.19.0, <3.0.0"] + otel-gcp = ["opentelemetry-instrumentation-google-genai>=0.6b0, <1.0.0"] toolbox = ["toolbox-adk>=0.5.7, <0.6.0"] diff --git a/src/google/adk/sessions/__init__.py b/src/google/adk/sessions/__init__.py index 7505eda346..5fd2751534 100644 --- a/src/google/adk/sessions/__init__.py +++ b/src/google/adk/sessions/__init__.py @@ -20,6 +20,7 @@ __all__ = [ 'BaseSessionService', 'DatabaseSessionService', + 'FirestoreSessionService', 'InMemorySessionService', 'Session', 'State', @@ -38,4 +39,14 @@ def __getattr__(name: str): 'DatabaseSessionService requires sqlalchemy>=2.0, please ensure it is' ' installed correctly.' ) from e + if name == 'FirestoreSessionService': + try: + from .firestore_session_service import FirestoreSessionService + + return FirestoreSessionService + except ImportError as e: + raise ImportError( + 'FirestoreSessionService requires google-cloud-firestore, please' + ' install it with: pip install google-cloud-firestore' + ) from e raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/src/google/adk/sessions/firestore_session_service.py b/src/google/adk/sessions/firestore_session_service.py new file mode 100644 index 0000000000..c5cee8332c --- /dev/null +++ b/src/google/adk/sessions/firestore_session_service.py @@ -0,0 +1,415 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Firestore-backed session service for Google ADK. + +Provides persistent, serverless session storage using Google Cloud Firestore. +This is well-suited for production deployments on Cloud Run, Cloud Functions, +or any GCP environment where managing a SQL database is undesirable. + +Firestore collection layout:: + + adk_app_states/{app_name} + adk_user_states/{app_name}_{user_id} + adk_sessions/{session_id} + -> subcollection: events/{event_id} + +Requires the ``google-cloud-firestore`` package:: + + pip install google-cloud-firestore +""" + +from __future__ import annotations + +import copy +import logging +import time +from typing import Any +from typing import Optional +import uuid + +from typing_extensions import override + +from . import _session_util +from ..errors.already_exists_error import AlreadyExistsError +from ..events.event import Event +from .base_session_service import BaseSessionService +from .base_session_service import GetSessionConfig +from .base_session_service import ListSessionsResponse +from .session import Session +from .state import State + +logger = logging.getLogger("google_adk." + __name__) + +# Firestore collection names +_APP_STATES_COLLECTION = "adk_app_states" +_USER_STATES_COLLECTION = "adk_user_states" +_SESSIONS_COLLECTION = "adk_sessions" +_EVENTS_SUBCOLLECTION = "events" + +# Firestore document field names +_FIELD_APP_NAME = "app_name" +_FIELD_USER_ID = "user_id" +_FIELD_STATE = "state" +_FIELD_CREATE_TIME = "create_time" +_FIELD_UPDATE_TIME = "update_time" +_FIELD_EVENT_DATA = "event_data" +_FIELD_TIMESTAMP = "timestamp" +_FIELD_INVOCATION_ID = "invocation_id" + + +def _user_state_doc_id(app_name: str, user_id: str) -> str: + """Builds a deterministic document ID for a user state entry.""" + return f"{app_name}_{user_id}" + + +class FirestoreSessionService(BaseSessionService): + """A session service backed by Google Cloud Firestore. + + This service stores sessions, events, and state in Firestore collections, + providing serverless, persistent storage suitable for production use. + + Args: + project: GCP project ID. If ``None``, uses Application Default Credentials. + database: Firestore database ID. Defaults to ``"(default)"``. + collection_prefix: Optional prefix for all collection names, useful for + multi-tenant setups or testing isolation. + """ + + def __init__( + self, + *, + project: Optional[str] = None, + database: str = "(default)", + collection_prefix: str = "", + ): + try: + from google.cloud.firestore_v1 import AsyncClient + except ImportError as e: + raise ImportError( + "FirestoreSessionService requires google-cloud-firestore. " + "Install it with: pip install google-cloud-firestore" + ) from e + + self._db = AsyncClient(project=project, database=database) + self._prefix = collection_prefix + + # -- Collection helpers -------------------------------------------------- + + def _col_app_states(self): + return self._db.collection(f"{self._prefix}{_APP_STATES_COLLECTION}") + + def _col_user_states(self): + return self._db.collection(f"{self._prefix}{_USER_STATES_COLLECTION}") + + def _col_sessions(self): + return self._db.collection(f"{self._prefix}{_SESSIONS_COLLECTION}") + + def _events_col(self, session_id: str): + """Returns the events subcollection for a given session.""" + return ( + self._col_sessions() + .document(session_id) + .collection(_EVENTS_SUBCOLLECTION) + ) + + # -- State helpers ------------------------------------------------------- + + async def _get_app_state(self, app_name: str) -> dict[str, Any]: + """Fetches the app-level state dict, returning empty dict if missing.""" + doc = await self._col_app_states().document(app_name).get() + if doc.exists: + return doc.to_dict().get(_FIELD_STATE, {}) + return {} + + async def _set_app_state(self, app_name: str, state: dict[str, Any]) -> None: + await self._col_app_states().document(app_name).set( + {_FIELD_STATE: state}, merge=True + ) + + async def _get_user_state( + self, app_name: str, user_id: str + ) -> dict[str, Any]: + doc_id = _user_state_doc_id(app_name, user_id) + doc = await self._col_user_states().document(doc_id).get() + if doc.exists: + return doc.to_dict().get(_FIELD_STATE, {}) + return {} + + async def _set_user_state( + self, app_name: str, user_id: str, state: dict[str, Any] + ) -> None: + doc_id = _user_state_doc_id(app_name, user_id) + await self._col_user_states().document(doc_id).set( + { + _FIELD_APP_NAME: app_name, + _FIELD_USER_ID: user_id, + _FIELD_STATE: state, + }, + merge=True, + ) + + def _merge_state( + self, + app_state: dict[str, Any], + user_state: dict[str, Any], + session_state: dict[str, Any], + ) -> dict[str, Any]: + """Merges app, user, and session state into a single dict.""" + merged = copy.deepcopy(session_state) + for key, value in app_state.items(): + merged[State.APP_PREFIX + key] = value + for key, value in user_state.items(): + merged[State.USER_PREFIX + key] = value + return merged + + # -- CRUD ---------------------------------------------------------------- + + @override + async def create_session( + self, + *, + app_name: str, + user_id: str, + state: Optional[dict[str, Any]] = None, + session_id: Optional[str] = None, + ) -> Session: + session_id = ( + session_id.strip() + if session_id and session_id.strip() + else str(uuid.uuid4()) + ) + + # Check for duplicate + existing = await self._col_sessions().document(session_id).get() + if existing.exists: + raise AlreadyExistsError(f"Session with id {session_id} already exists.") + + # Extract state deltas + state_deltas = _session_util.extract_state_delta(state) + app_state_delta = state_deltas["app"] + user_state_delta = state_deltas["user"] + session_state = state_deltas["session"] + + # Update app / user state + if app_state_delta: + current_app_state = await self._get_app_state(app_name) + current_app_state.update(app_state_delta) + await self._set_app_state(app_name, current_app_state) + + if user_state_delta: + current_user_state = await self._get_user_state(app_name, user_id) + current_user_state.update(user_state_delta) + await self._set_user_state(app_name, user_id, current_user_state) + + now = time.time() + # Store session document + await self._col_sessions().document(session_id).set({ + _FIELD_APP_NAME: app_name, + _FIELD_USER_ID: user_id, + _FIELD_STATE: session_state, + _FIELD_CREATE_TIME: now, + _FIELD_UPDATE_TIME: now, + }) + + # Build merged state for response + app_state = await self._get_app_state(app_name) + user_state = await self._get_user_state(app_name, user_id) + merged = self._merge_state(app_state, user_state, session_state) + + return Session( + app_name=app_name, + user_id=user_id, + id=session_id, + state=merged, + last_update_time=now, + ) + + @override + async def get_session( + self, + *, + app_name: str, + user_id: str, + session_id: str, + config: Optional[GetSessionConfig] = None, + ) -> Optional[Session]: + doc = await self._col_sessions().document(session_id).get() + if not doc.exists: + return None + + data = doc.to_dict() + if data.get(_FIELD_APP_NAME) != app_name: + return None + if data.get(_FIELD_USER_ID) != user_id: + return None + + session_state = data.get(_FIELD_STATE, {}) + + # Fetch events from subcollection + events_query = self._events_col(session_id).order_by(_FIELD_TIMESTAMP) + + if config and config.after_timestamp: + events_query = events_query.where( + filter=self._db.field_filter( + _FIELD_TIMESTAMP, ">=", config.after_timestamp + ) + ) + + event_docs = events_query.stream() + events: list[Event] = [] + async for event_doc in event_docs: + event_data = event_doc.to_dict() + raw = event_data.get(_FIELD_EVENT_DATA, {}) + if raw: + events.append(Event.model_validate(raw)) + + if config and config.num_recent_events: + events = events[-config.num_recent_events :] + + # Merge states + app_state = await self._get_app_state(app_name) + user_state = await self._get_user_state(app_name, user_id) + merged = self._merge_state(app_state, user_state, session_state) + + return Session( + app_name=app_name, + user_id=user_id, + id=session_id, + state=merged, + events=events, + last_update_time=data.get(_FIELD_UPDATE_TIME, 0.0), + ) + + @override + async def list_sessions( + self, *, app_name: str, user_id: Optional[str] = None + ) -> ListSessionsResponse: + query = self._col_sessions().where( + filter=self._db.field_filter(_FIELD_APP_NAME, "==", app_name) + ) + if user_id is not None: + query = query.where( + filter=self._db.field_filter(_FIELD_USER_ID, "==", user_id) + ) + + sessions: list[Session] = [] + async for doc in query.stream(): + data = doc.to_dict() + session_state = data.get(_FIELD_STATE, {}) + sid = doc.id + suid = data.get(_FIELD_USER_ID, "") + + app_state = await self._get_app_state(app_name) + user_state = await self._get_user_state(app_name, suid) + merged = self._merge_state(app_state, user_state, session_state) + + sessions.append( + Session( + app_name=app_name, + user_id=suid, + id=sid, + state=merged, + last_update_time=data.get(_FIELD_UPDATE_TIME, 0.0), + ) + ) + + return ListSessionsResponse(sessions=sessions) + + @override + async def delete_session( + self, *, app_name: str, user_id: str, session_id: str + ) -> None: + session_ref = self._col_sessions().document(session_id) + doc = await session_ref.get() + if not doc.exists: + return + + # Delete all events in the subcollection first + events_ref = session_ref.collection(_EVENTS_SUBCOLLECTION) + async for event_doc in events_ref.stream(): + await event_doc.reference.delete() + + # Delete the session document + await session_ref.delete() + + @override + async def append_event(self, session: Session, event: Event) -> Event: + if event.partial: + return event + + app_name = session.app_name + user_id = session.user_id + session_id = session.id + + session_ref = self._col_sessions().document(session_id) + doc = await session_ref.get() + if not doc.exists: + logger.warning( + "Cannot append event: session %s not found in Firestore.", + session_id, + ) + return event + + # Update in-memory session state via base class + await super().append_event(session=session, event=event) + session.last_update_time = event.timestamp + + # Extract and apply state deltas + if event.actions and event.actions.state_delta: + state_deltas = _session_util.extract_state_delta( + event.actions.state_delta + ) + app_state_delta = state_deltas["app"] + user_state_delta = state_deltas["user"] + session_state_delta = state_deltas["session"] + + if app_state_delta: + current_app_state = await self._get_app_state(app_name) + current_app_state.update(app_state_delta) + await self._set_app_state(app_name, current_app_state) + + if user_state_delta: + current_user_state = await self._get_user_state(app_name, user_id) + current_user_state.update(user_state_delta) + await self._set_user_state(app_name, user_id, current_user_state) + + if session_state_delta: + stored_data = doc.to_dict() + stored_state = stored_data.get(_FIELD_STATE, {}) + stored_state.update(session_state_delta) + await session_ref.update({_FIELD_STATE: stored_state}) + + # Store event in subcollection + event_data = event.model_dump(exclude_none=True, mode="json") + await self._events_col(session_id).document(event.id).set({ + _FIELD_EVENT_DATA: event_data, + _FIELD_TIMESTAMP: event.timestamp, + _FIELD_INVOCATION_ID: event.invocation_id, + }) + + # Update session timestamp + await session_ref.update({_FIELD_UPDATE_TIME: event.timestamp}) + + return event + + async def close(self) -> None: + """Closes the underlying Firestore client.""" + self._db.close() + + async def __aenter__(self) -> FirestoreSessionService: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self.close() diff --git a/tests/unittests/sessions/test_firestore_session_service.py b/tests/unittests/sessions/test_firestore_session_service.py new file mode 100644 index 0000000000..eec9fe9873 --- /dev/null +++ b/tests/unittests/sessions/test_firestore_session_service.py @@ -0,0 +1,463 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for FirestoreSessionService. + +All Firestore interactions are mocked so no real Firestore connection is +needed. +""" + +from __future__ import annotations + +import copy +from typing import Any +from unittest import mock + +from google.adk.errors.already_exists_error import AlreadyExistsError +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.sessions.base_session_service import GetSessionConfig +from google.genai import types +import pytest + +# --------------------------------------------------------------------------- +# Helpers – lightweight in-memory Firestore mock +# --------------------------------------------------------------------------- + + +class _FakeDocSnapshot: + """Mimics a Firestore DocumentSnapshot.""" + + def __init__(self, doc_id: str, data: dict[str, Any] | None): + self.id = doc_id + self._data = data + self.exists = data is not None + self.reference = mock.AsyncMock() + self.reference.delete = mock.AsyncMock() + + def to_dict(self) -> dict[str, Any]: + return copy.deepcopy(self._data) if self._data else {} + + +class _FakeDocRef: + """Mimics an async Firestore DocumentReference.""" + + def __init__(self, doc_id: str, store: dict[str, dict[str, Any]]): + self.id = doc_id + self._store = store + self._subcollections: dict[str, _FakeCollection] = {} + + async def get(self): + data = self._store.get(self.id) + return _FakeDocSnapshot(self.id, copy.deepcopy(data) if data else None) + + async def set(self, data, merge=False): + if merge and self.id in self._store: + self._store[self.id].update(data) + else: + self._store[self.id] = copy.deepcopy(data) + + async def update(self, data): + if self.id in self._store: + self._store[self.id].update(data) + + async def delete(self): + self._store.pop(self.id, None) + + def collection(self, name: str) -> _FakeCollection: + if name not in self._subcollections: + self._subcollections[name] = _FakeCollection({}) + return self._subcollections[name] + + +class _FakeCollection: + """Mimics an async Firestore CollectionReference backed by a dict.""" + + def __init__(self, store: dict[str, dict[str, Any]] | None = None): + self._store: dict[str, dict[str, Any]] = store if store is not None else {} + # Persistent doc refs so subcollections survive across calls + self._doc_refs: dict[str, _FakeDocRef] = {} + + def document(self, doc_id: str) -> _FakeDocRef: + if doc_id not in self._doc_refs: + self._doc_refs[doc_id] = _FakeDocRef(doc_id, self._store) + return self._doc_refs[doc_id] + + def where(self, **kwargs): + return _FakeQuery(self._store, kwargs.get("filter")) + + def order_by(self, field): + return _FakeQuery(self._store, None, order_field=field) + + async def stream(self): + for doc_id, data in list(self._store.items()): + snapshot = _FakeDocSnapshot(doc_id, copy.deepcopy(data)) + snapshot.reference = self.document(doc_id) + yield snapshot + + +class _FakeQuery: + """Mimics an async Firestore query.""" + + def __init__(self, store, filt=None, order_field=None): + self._store = store + self._filters: list = [] + if filt: + self._filters.append(filt) + self._order_field = order_field + + def where(self, **kwargs): + new_q = _FakeQuery(self._store, order_field=self._order_field) + new_q._filters = list(self._filters) + filt = kwargs.get("filter") + if filt: + new_q._filters.append(filt) + return new_q + + def order_by(self, field): + self._order_field = field + return self + + def _matches(self, data: dict) -> bool: + for f in self._filters: + field_path = f.field_path + op = f.op_string + val = f.value + actual = data.get(field_path) + if op == "==" and actual != val: + return False + if op == ">=" and (actual is None or actual < val): + return False + return True + + async def stream(self): + items = [ + (doc_id, copy.deepcopy(data)) + for doc_id, data in self._store.items() + if self._matches(data) + ] + if self._order_field: + items.sort(key=lambda x: x[1].get(self._order_field, 0)) + for doc_id, data in items: + yield _FakeDocSnapshot(doc_id, data) + + +class _FakeFieldFilter: + + def __init__(self, field_path, op_string, value): + self.field_path = field_path + self.op_string = op_string + self.value = value + + +class _FakeFirestoreClient: + """Mimics google.cloud.firestore_v1.AsyncClient.""" + + def __init__(self): + self._collections: dict[str, _FakeCollection] = {} + + def collection(self, name: str) -> _FakeCollection: + if name not in self._collections: + self._collections[name] = _FakeCollection({}) + return self._collections[name] + + @staticmethod + def field_filter(field_path, op_string, value): + return _FakeFieldFilter(field_path, op_string, value) + + def close(self): + pass + + +# --------------------------------------------------------------------------- +# Fixture +# --------------------------------------------------------------------------- + + +@pytest.fixture +def firestore_service(): + """Creates a FirestoreSessionService with a mocked Firestore client.""" + fake_client = _FakeFirestoreClient() + + with mock.patch.dict( + "sys.modules", + { + "google.cloud.firestore_v1": mock.MagicMock( + AsyncClient=lambda **kwargs: fake_client + ), + }, + ): + from google.adk.sessions.firestore_session_service import FirestoreSessionService + + service = FirestoreSessionService(project="test-project") + # Replace the client with our fake + service._db = fake_client + return service + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +APP = "test_app" +USER = "test_user" + + +@pytest.mark.asyncio +async def test_create_session(firestore_service): + session = await firestore_service.create_session(app_name=APP, user_id=USER) + assert session.app_name == APP + assert session.user_id == USER + assert session.id # auto-generated UUID + assert session.last_update_time > 0 + + +@pytest.mark.asyncio +async def test_create_session_with_custom_id(firestore_service): + session = await firestore_service.create_session( + app_name=APP, user_id=USER, session_id="custom-123" + ) + assert session.id == "custom-123" + + +@pytest.mark.asyncio +async def test_create_session_duplicate_id_raises(firestore_service): + await firestore_service.create_session( + app_name=APP, user_id=USER, session_id="dup-id" + ) + with pytest.raises(AlreadyExistsError): + await firestore_service.create_session( + app_name=APP, user_id=USER, session_id="dup-id" + ) + + +@pytest.mark.asyncio +async def test_create_session_with_state(firestore_service): + state = { + "app:theme": "dark", + "user:lang": "en", + "counter": 0, + } + session = await firestore_service.create_session( + app_name=APP, user_id=USER, state=state + ) + assert session.state.get("app:theme") == "dark" + assert session.state.get("user:lang") == "en" + assert session.state.get("counter") == 0 + + +@pytest.mark.asyncio +async def test_get_session(firestore_service): + created = await firestore_service.create_session( + app_name=APP, user_id=USER, session_id="s1" + ) + retrieved = await firestore_service.get_session( + app_name=APP, user_id=USER, session_id="s1" + ) + assert retrieved is not None + assert retrieved.id == "s1" + assert retrieved.app_name == APP + + +@pytest.mark.asyncio +async def test_get_session_nonexistent_returns_none(firestore_service): + result = await firestore_service.get_session( + app_name=APP, user_id=USER, session_id="nonexistent" + ) + assert result is None + + +@pytest.mark.asyncio +async def test_get_session_wrong_user_returns_none(firestore_service): + await firestore_service.create_session( + app_name=APP, user_id=USER, session_id="s1" + ) + result = await firestore_service.get_session( + app_name=APP, user_id="other_user", session_id="s1" + ) + assert result is None + + +@pytest.mark.asyncio +async def test_list_sessions(firestore_service): + await firestore_service.create_session( + app_name=APP, user_id=USER, session_id="s1" + ) + await firestore_service.create_session( + app_name=APP, user_id=USER, session_id="s2" + ) + await firestore_service.create_session( + app_name=APP, user_id="other", session_id="s3" + ) + + # List for specific user + response = await firestore_service.list_sessions(app_name=APP, user_id=USER) + assert len(response.sessions) == 2 + + # List all sessions for app + response_all = await firestore_service.list_sessions(app_name=APP) + assert len(response_all.sessions) == 3 + + +@pytest.mark.asyncio +async def test_delete_session(firestore_service): + await firestore_service.create_session( + app_name=APP, user_id=USER, session_id="s1" + ) + await firestore_service.delete_session( + app_name=APP, user_id=USER, session_id="s1" + ) + result = await firestore_service.get_session( + app_name=APP, user_id=USER, session_id="s1" + ) + assert result is None + + +@pytest.mark.asyncio +async def test_delete_nonexistent_session_is_noop(firestore_service): + # Should not raise + await firestore_service.delete_session( + app_name=APP, user_id=USER, session_id="nonexistent" + ) + + +@pytest.mark.asyncio +async def test_append_event(firestore_service): + session = await firestore_service.create_session( + app_name=APP, user_id=USER, session_id="s1" + ) + event = Event( + invocation_id="inv-1", + author="user", + content=types.Content(role="user", parts=[types.Part(text="Hello")]), + ) + result = await firestore_service.append_event(session, event) + assert result.id == event.id + + # Verify event is retrievable + retrieved = await firestore_service.get_session( + app_name=APP, user_id=USER, session_id="s1" + ) + assert retrieved is not None + assert len(retrieved.events) == 1 + + +@pytest.mark.asyncio +async def test_append_event_with_state_delta(firestore_service): + session = await firestore_service.create_session( + app_name=APP, user_id=USER, session_id="s1" + ) + event = Event( + invocation_id="inv-1", + author="agent", + actions=EventActions( + state_delta={"counter": 42, "app:global_key": "val"} + ), + ) + await firestore_service.append_event(session, event) + + retrieved = await firestore_service.get_session( + app_name=APP, user_id=USER, session_id="s1" + ) + assert retrieved is not None + assert retrieved.state.get("counter") == 42 + assert retrieved.state.get("app:global_key") == "val" + + +@pytest.mark.asyncio +async def test_append_partial_event_skipped(firestore_service): + session = await firestore_service.create_session( + app_name=APP, user_id=USER, session_id="s1" + ) + event = Event( + invocation_id="inv-1", + author="agent", + partial=True, + content=types.Content(role="model", parts=[types.Part(text="partial")]), + ) + result = await firestore_service.append_event(session, event) + assert result.partial is True + + retrieved = await firestore_service.get_session( + app_name=APP, user_id=USER, session_id="s1" + ) + assert retrieved is not None + assert len(retrieved.events) == 0 + + +@pytest.mark.asyncio +async def test_get_session_with_num_recent_events(firestore_service): + session = await firestore_service.create_session( + app_name=APP, user_id=USER, session_id="s1" + ) + for i in range(5): + event = Event( + invocation_id=f"inv-{i}", + author="user", + content=types.Content(role="user", parts=[types.Part(text=f"msg-{i}")]), + ) + await firestore_service.append_event(session, event) + + config = GetSessionConfig(num_recent_events=2) + retrieved = await firestore_service.get_session( + app_name=APP, user_id=USER, session_id="s1", config=config + ) + assert retrieved is not None + assert len(retrieved.events) == 2 + + +@pytest.mark.asyncio +async def test_app_state_shared_across_sessions(firestore_service): + state = {"app:shared": "value"} + await firestore_service.create_session( + app_name=APP, user_id=USER, session_id="s1", state=state + ) + s2 = await firestore_service.create_session( + app_name=APP, user_id=USER, session_id="s2" + ) + assert s2.state.get("app:shared") == "value" + + +@pytest.mark.asyncio +async def test_user_state_shared_across_sessions(firestore_service): + state = {"user:pref": "compact"} + await firestore_service.create_session( + app_name=APP, user_id=USER, session_id="s1", state=state + ) + s2 = await firestore_service.create_session( + app_name=APP, user_id=USER, session_id="s2" + ) + assert s2.state.get("user:pref") == "compact" + + +@pytest.mark.asyncio +async def test_temp_state_not_persisted(firestore_service): + session = await firestore_service.create_session( + app_name=APP, user_id=USER, session_id="s1" + ) + event = Event( + invocation_id="inv-1", + author="agent", + actions=EventActions( + state_delta={"temp:scratch": "tmp", "keep_me": "yes"} + ), + ) + await firestore_service.append_event(session, event) + + retrieved = await firestore_service.get_session( + app_name=APP, user_id=USER, session_id="s1" + ) + assert retrieved is not None + assert "temp:scratch" not in retrieved.state + assert retrieved.state.get("keep_me") == "yes"