Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions src/google/adk/cli/cli_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import asyncio
import importlib.util
import logging
import os
Expand All @@ -24,7 +25,7 @@
import click
from google.genai import types as genai_types

from ..agents.llm_agent import Agent
from ..agents.base_agent import BaseAgent
from ..evaluation.base_eval_service import BaseEvalService
from ..evaluation.base_eval_service import EvaluateConfig
from ..evaluation.base_eval_service import EvaluateRequest
Expand Down Expand Up @@ -86,11 +87,28 @@ def get_default_metric_info(
)


def get_root_agent(agent_module_file_path: str) -> Agent:
"""Returns root agent given the agent module."""
def get_root_agent(agent_module_file_path: str) -> BaseAgent:
"""Returns root agent given the agent module.

Supports modules exporting either `root_agent` or `get_agent_async`.
"""
agent_module = _get_agent_module(agent_module_file_path)
root_agent = agent_module.agent.root_agent
return root_agent
agent_module_with_agent = (
agent_module.agent if hasattr(agent_module, "agent") else agent_module
)
if hasattr(agent_module_with_agent, "root_agent"):
return agent_module_with_agent.root_agent

if hasattr(agent_module_with_agent, "get_agent_async"):
result = asyncio.run(agent_module_with_agent.get_agent_async())
if isinstance(result, tuple):
root_agent, _ = result
return root_agent
return result

raise ValueError(
"Module does not have a root_agent or get_agent_async method."
)


def try_get_reset_func(agent_module_file_path: str) -> Any:
Expand Down
62 changes: 52 additions & 10 deletions src/google/adk/cli/cli_tools_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,13 @@ def wrapper(*args, **kwargs):
)
@click.argument("eval_set_file_path_or_id", nargs=-1)
@click.option("--config_file_path", help="Optional. The path to config file.")
@click.option(
"--num_runs",
type=click.IntRange(min=1),
default=1,
show_default=True,
help="Optional. Number of times to run each eval case.",
)
@click.option(
"--print_detailed_results",
is_flag=True,
Expand All @@ -721,6 +728,7 @@ def cli_eval(
agent_module_file_path: str,
eval_set_file_path_or_id: list[str],
config_file_path: str,
num_runs: int,
print_detailed_results: bool,
eval_storage_uri: Optional[str] = None,
log_level: str = "INFO",
Expand Down Expand Up @@ -789,6 +797,7 @@ def cli_eval(
from ..evaluation.base_eval_service import InferenceRequest
from ..evaluation.custom_metric_evaluator import _CustomMetricEvaluator
from ..evaluation.eval_config import get_eval_metrics_from_config
from ..evaluation.eval_config import discover_eval_config_for_test_file
from ..evaluation.eval_config import get_evaluation_criteria_or_default
from ..evaluation.eval_result import EvalCaseResult
from ..evaluation.evaluator import EvalStatus
Expand All @@ -808,9 +817,12 @@ def cli_eval(
except ModuleNotFoundError as mnf:
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE) from mnf

eval_config = get_evaluation_criteria_or_default(config_file_path)
print(f"Using evaluation criteria: {eval_config}")
eval_metrics = get_eval_metrics_from_config(eval_config)
eval_metrics_by_eval_set_id = {}
global_eval_metrics = None
if config_file_path:
eval_config = get_evaluation_criteria_or_default(config_file_path)
print(f"Using evaluation criteria: {eval_config}")
global_eval_metrics = get_eval_metrics_from_config(eval_config)
Comment on lines +820 to +825
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a potential UnboundLocalError for eval_config. It is used on line 922, but it's only defined within this if config_file_path: block, or later when handling eval set IDs. If config_file_path is not provided and the code proceeds to handle eval set file paths, eval_config will not be defined when it's needed for UserSimulatorProvider.

To fix this, you should initialize eval_config unconditionally at the beginning of the function. Applying this fix will also allow you to simplify the code in a couple of other places:

  1. On lines 869-873, you can reuse the eval_config variable instead of calling get_evaluation_criteria_or_default again.
  2. On lines 900-903, you can remove the redundant call to get_evaluation_criteria_or_default.
Suggested change
eval_metrics_by_eval_set_id = {}
global_eval_metrics = None
if config_file_path:
eval_config = get_evaluation_criteria_or_default(config_file_path)
print(f"Using evaluation criteria: {eval_config}")
global_eval_metrics = get_eval_metrics_from_config(eval_config)
eval_config = get_evaluation_criteria_or_default(config_file_path)
eval_metrics_by_eval_set_id = {}
global_eval_metrics = None
if config_file_path:
print(f"Using evaluation criteria: {eval_config}")
global_eval_metrics = get_eval_metrics_from_config(eval_config)


root_agent = get_root_agent(agent_module_file_path)
app_name = os.path.basename(agent_module_file_path)
Expand Down Expand Up @@ -854,6 +866,18 @@ def cli_eval(
f"`{eval_set_file_path}` should be a valid eval set file."
) from fne

eval_config_for_eval_set = (
get_evaluation_criteria_or_default(config_file_path)
if config_file_path
else discover_eval_config_for_test_file(eval_set_file_path)
)
print(
f"Using evaluation criteria for {eval_set_file_path}:"
f" {eval_config_for_eval_set}"
)
eval_metrics_by_eval_set_id[eval_set.eval_set_id] = (
get_eval_metrics_from_config(eval_config_for_eval_set)
)
eval_sets_manager.create_eval_set(
app_name=app_name, eval_set_id=eval_set.eval_set_id
)
Expand All @@ -873,6 +897,10 @@ def cli_eval(
)
else:
# We assume that what we have are eval set ids instead.
if global_eval_metrics is None:
eval_config = get_evaluation_criteria_or_default(config_file_path)
print(f"Using evaluation criteria: {eval_config}")
global_eval_metrics = get_eval_metrics_from_config(eval_config)
eval_sets_manager = (
eval_sets_manager
if eval_storage_uri
Expand All @@ -888,6 +916,7 @@ def cli_eval(
inference_config=InferenceConfig(),
)
)
eval_metrics_by_eval_set_id[eval_set_id_key] = global_eval_metrics

user_simulator_provider = UserSimulatorProvider(
user_simulator_config=eval_config.user_simulator_config
Expand Down Expand Up @@ -920,18 +949,31 @@ def cli_eval(
metric_evaluator_registry=metric_evaluator_registry,
)

repeated_inference_requests = inference_requests * num_runs
inference_results = asyncio.run(
_collect_inferences(
inference_requests=inference_requests, eval_service=eval_service
)
)
eval_results = asyncio.run(
_collect_eval_results(
inference_results=inference_results,
inference_requests=repeated_inference_requests,
eval_service=eval_service,
eval_metrics=eval_metrics,
)
)
eval_results = []
for eval_set_id, eval_metrics in eval_metrics_by_eval_set_id.items():
inference_results_for_eval_set = [
inference_result
for inference_result in inference_results
if inference_result.eval_set_id == eval_set_id
]
if not inference_results_for_eval_set:
continue
eval_results.extend(
asyncio.run(
_collect_eval_results(
inference_results=inference_results_for_eval_set,
eval_service=eval_service,
eval_metrics=eval_metrics,
)
)
)
Comment on lines +959 to +976
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation filters inference_results inside the loop for each eval_set_id. This can be inefficient if you have a large number of eval sets and inference results, as it iterates over all results for each set (O(num_eval_sets * num_inference_results)).

You can improve performance by grouping the inference results by eval_set_id once before the loop.

Suggested change
eval_results = []
for eval_set_id, eval_metrics in eval_metrics_by_eval_set_id.items():
inference_results_for_eval_set = [
inference_result
for inference_result in inference_results
if inference_result.eval_set_id == eval_set_id
]
if not inference_results_for_eval_set:
continue
eval_results.extend(
asyncio.run(
_collect_eval_results(
inference_results=inference_results_for_eval_set,
eval_service=eval_service,
eval_metrics=eval_metrics,
)
)
)
inference_results_by_eval_set_id = {}
for res in inference_results:
inference_results_by_eval_set_id.setdefault(res.eval_set_id, []).append(res)
eval_results = []
for eval_set_id, eval_metrics in eval_metrics_by_eval_set_id.items():
inference_results_for_eval_set = inference_results_by_eval_set_id.get(eval_set_id)
if not inference_results_for_eval_set:
continue
eval_results.extend(
asyncio.run(
_collect_eval_results(
inference_results=inference_results_for_eval_set,
eval_service=eval_service,
eval_metrics=eval_metrics,
)
)
)

except ModuleNotFoundError as mnf:
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE) from mnf

Expand Down
16 changes: 13 additions & 3 deletions src/google/adk/evaluation/agent_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .eval_case import IntermediateDataType
from .eval_case import Invocation
from .eval_config import EvalConfig
from .eval_config import discover_eval_config_for_test_file
from .eval_config import get_eval_metrics_from_config
from .eval_config import get_evaluation_criteria_or_default
from .eval_metrics import BaseCriterion
Expand All @@ -46,6 +47,7 @@
from .eval_metrics import PrebuiltMetrics
from .eval_result import EvalCaseResult
from .eval_set import EvalSet
from .eval_set_results_manager import EvalSetResultsManager
from .eval_sets_manager import EvalSetsManager
from .evaluator import EvalStatus
from .in_memory_eval_sets_manager import InMemoryEvalSetsManager
Expand Down Expand Up @@ -100,9 +102,7 @@ class AgentEvaluator:
@staticmethod
def find_config_for_test_file(test_file: str) -> EvalConfig:
"""Find the test_config.json file in the same folder as the test file."""
test_folder = os.path.dirname(test_file)
config_path = os.path.join(test_folder, "test_config.json")
return get_evaluation_criteria_or_default(config_path)
return discover_eval_config_for_test_file(test_file)

@staticmethod
async def evaluate_eval_set(
Expand All @@ -113,6 +113,7 @@ async def evaluate_eval_set(
num_runs: int = NUM_RUNS,
agent_name: Optional[str] = None,
print_detailed_results: bool = True,
eval_set_results_manager: Optional[EvalSetResultsManager] = None,
):
"""Evaluates an agent using the given EvalSet.

Expand All @@ -130,6 +131,8 @@ async def evaluate_eval_set(
than root agent. If left empty or none, then root agent is evaluated.
print_detailed_results: Whether to print detailed results for each metric
evaluation.
eval_set_results_manager: Optional results manager for persisting eval
outputs.
"""
if criteria:
logger.warning(
Expand Down Expand Up @@ -161,6 +164,7 @@ async def evaluate_eval_set(
eval_metrics=eval_metrics,
num_runs=num_runs,
user_simulator_provider=user_simulator_provider,
eval_set_results_manager=eval_set_results_manager,
)

# Step 2: Post-process the results!
Expand Down Expand Up @@ -200,6 +204,7 @@ async def evaluate(
agent_name: Optional[str] = None,
initial_session_file: Optional[str] = None,
print_detailed_results: bool = True,
eval_set_results_manager: Optional[EvalSetResultsManager] = None,
):
"""Evaluates an Agent given eval data.

Expand All @@ -218,6 +223,8 @@ async def evaluate(
needed by all the evals in the eval dataset.
print_detailed_results: Whether to print detailed results for each metric
evaluation.
eval_set_results_manager: Optional results manager for persisting eval
outputs.
"""
test_files = []
if isinstance(eval_dataset_file_path_or_dir, str) and os.path.isdir(
Expand Down Expand Up @@ -245,6 +252,7 @@ async def evaluate(
num_runs=num_runs,
agent_name=agent_name,
print_detailed_results=print_detailed_results,
eval_set_results_manager=eval_set_results_manager,
)

@staticmethod
Expand Down Expand Up @@ -536,6 +544,7 @@ async def _get_eval_results_by_eval_id(
eval_metrics: list[EvalMetric],
num_runs: int,
user_simulator_provider: UserSimulatorProvider,
eval_set_results_manager: Optional[EvalSetResultsManager] = None,
) -> dict[str, list[EvalCaseResult]]:
"""Returns EvalCaseResults grouped by eval case id.

Expand All @@ -560,6 +569,7 @@ async def _get_eval_results_by_eval_id(
app_name=app_name, eval_set=eval_set
),
user_simulator_provider=user_simulator_provider,
eval_set_results_manager=eval_set_results_manager,
)

inference_requests = [
Expand Down
11 changes: 11 additions & 0 deletions src/google/adk/evaluation/eval_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,17 @@ def get_evaluation_criteria_or_default(
return _DEFAULT_EVAL_CONFIG


def discover_eval_config_for_test_file(test_file_path: str) -> EvalConfig:
"""Returns EvalConfig for a test file via adjacent test_config.json lookup.

The lookup checks for a `test_config.json` in the same directory as the test
file, and falls back to the default criteria if not found.
"""
test_folder = os.path.dirname(test_file_path)
config_path = os.path.join(test_folder, "test_config.json")
return get_evaluation_criteria_or_default(config_path)


def get_eval_metrics_from_config(eval_config: EvalConfig) -> list[EvalMetric]:
"""Returns a list of EvalMetrics mapped from the EvalConfig."""
eval_metric_list = []
Expand Down