diff --git a/README.md b/README.md index 69935a6..36f6024 100644 --- a/README.md +++ b/README.md @@ -128,6 +128,15 @@ The needed GitHub app permissions are the following under `Repository permission | `LABELS` | False | "" | A comma separated list of labels that should be added to pull requests opened by dependabot. | | `DEPENDABOT_CONFIG_FILE` | False | "" | Location of the configuration file for `dependabot.yml` configurations. If the file is present locally it takes precedence over the one in the repository. | +#### Rate Limiting + +| field | required | default | description | +| ----------------------------------- | -------- | ------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `RATE_LIMIT_ENABLED` | False | true | If set to true, rate limiting will be enabled to prevent hitting GitHub API rate limits. It is recommended to keep this enabled to avoid workflow failures. | +| `RATE_LIMIT_REQUESTS_PER_SECOND` | False | 2.0 | Maximum number of requests per second to the GitHub API. Adjust this based on your API rate limits. Lower values are more conservative. | +| `RATE_LIMIT_BACKOFF_FACTOR` | False | 2.0 | Exponential backoff multiplier for retries when rate limits are hit. A value of 2.0 means wait times double with each retry (1s, 2s, 4s, etc.). | +| `RATE_LIMIT_MAX_RETRIES` | False | 3 | Maximum number of retry attempts when a rate limit error occurs. Set to 0 to disable retries. | + ### Private repositories configuration Dependabot allows the configuration of [private registries](https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#configuration-options-for-private-registries) for dependabot to use. diff --git a/env.py b/env.py index 4c6c0bc..a93682a 100644 --- a/env.py +++ b/env.py @@ -47,6 +47,25 @@ def get_int_env_var(env_var_name: str) -> int | None: return None +def get_float_env_var(env_var_name: str, default: float | None = None) -> float | None: + """Get a float environment variable. + + Args: + env_var_name: The name of the environment variable to retrieve. + default: The default value to return if the environment variable is not set. + + Returns: + The value of the environment variable as a float or the default value. + """ + env_var = os.environ.get(env_var_name) + if env_var is None or not env_var.strip(): + return default + try: + return float(env_var) + except ValueError: + return default + + def parse_repo_specific_exemptions(repo_specific_exemptions_str: str) -> dict: """Parse the REPO_SPECIFIC_EXEMPTIONS environment variable into a dictionary. @@ -126,6 +145,10 @@ def get_env_vars( str | None, list[str], str | None, + bool, + float, + float, + int, ]: """ Get the environment variables for use in the action. @@ -162,6 +185,10 @@ def get_env_vars( team_name (str): The team to search for repositories in labels (list[str]): A list of labels to be added to dependabot configuration dependabot_config_file (str): Dependabot extra configuration file location path + rate_limit_enabled (bool): Whether rate limiting is enabled + rate_limit_requests_per_second (float): Maximum requests per second + rate_limit_backoff_factor (float): Exponential backoff factor for retries + rate_limit_max_retries (int): Maximum number of retry attempts """ if not test: # pragma: no cover @@ -352,6 +379,39 @@ def get_env_vars( f"No dependabot extra configuration found. Please create one in {dependabot_config_file}" ) + # Rate limiting configuration + rate_limit_enabled = get_bool_env_var("RATE_LIMIT_ENABLED", default=True) + rate_limit_requests_per_second_value = get_float_env_var( + "RATE_LIMIT_REQUESTS_PER_SECOND", default=2.0 + ) + rate_limit_backoff_factor_value = get_float_env_var( + "RATE_LIMIT_BACKOFF_FACTOR", default=2.0 + ) + rate_limit_max_retries_value = get_int_env_var("RATE_LIMIT_MAX_RETRIES") + + # Ensure non-None values with defaults + rate_limit_requests_per_second = ( + rate_limit_requests_per_second_value + if rate_limit_requests_per_second_value is not None + else 2.0 + ) + rate_limit_backoff_factor = ( + rate_limit_backoff_factor_value + if rate_limit_backoff_factor_value is not None + else 2.0 + ) + rate_limit_max_retries = ( + rate_limit_max_retries_value if rate_limit_max_retries_value is not None else 3 + ) + + # Validate rate limiting parameters + if rate_limit_requests_per_second <= 0: + raise ValueError("RATE_LIMIT_REQUESTS_PER_SECOND must be greater than 0") + if rate_limit_backoff_factor <= 0: + raise ValueError("RATE_LIMIT_BACKOFF_FACTOR must be greater than 0") + if rate_limit_max_retries < 0: + raise ValueError("RATE_LIMIT_MAX_RETRIES must be 0 or greater") + return ( organization, repositories_list, @@ -382,4 +442,8 @@ def get_env_vars( team_name, labels_list, dependabot_config_file, + rate_limit_enabled, + rate_limit_requests_per_second, + rate_limit_backoff_factor, + rate_limit_max_retries, ) diff --git a/evergreen.py b/evergreen.py index accd553..4d8930c 100644 --- a/evergreen.py +++ b/evergreen.py @@ -12,6 +12,7 @@ import ruamel.yaml from dependabot_file import build_dependabot_file from exceptions import OptionalFileNotFoundError, check_optional_file +from rate_limiter import RateLimiter def main(): # pragma: no cover @@ -48,8 +49,20 @@ def main(): # pragma: no cover team_name, labels, dependabot_config_file, + rate_limit_enabled, + rate_limit_requests_per_second, + rate_limit_backoff_factor, + rate_limit_max_retries, ) = env.get_env_vars() + # Initialize rate limiter + rate_limiter = RateLimiter( + requests_per_second=rate_limit_requests_per_second, + enabled=rate_limit_enabled, + backoff_factor=rate_limit_backoff_factor, + max_retries=rate_limit_max_retries, + ) + # Auth to GitHub.com or GHE github_connection = auth.auth_to_github( token, @@ -75,7 +88,9 @@ def main(): # pragma: no cover raise ValueError( "ORGANIZATION environment variable was not set. Please set it" ) - project_global_id = get_global_project_id(ghe, token, organization, project_id) + project_global_id = get_global_project_id( + ghe, token, organization, project_id, rate_limiter + ) # Get the repositories from the organization, team name, or list of repositories repos = get_repos_iterator( @@ -211,9 +226,11 @@ def main(): # pragma: no cover # Get dependabot security updates enabled if possible if enable_security_updates: if not is_dependabot_security_updates_enabled( - ghe, repo.owner, repo.name, token + ghe, repo.owner, repo.name, token, rate_limiter ): - enable_dependabot_security_updates(ghe, repo.owner, repo.name, token) + enable_dependabot_security_updates( + ghe, repo.owner, repo.name, token, rate_limiter + ) if follow_up_type == "issue": skip = check_pending_issues_for_duplicates(title, repo) @@ -225,9 +242,11 @@ def main(): # pragma: no cover summary_content += f"| {repo.full_name} | {'✅' if enable_security_updates else '❌'} | {follow_up_type} | [Link]({issue.html_url}) |\n" if project_global_id: issue_id = get_global_issue_id( - ghe, token, organization, repo.name, issue.number + ghe, token, organization, repo.name, issue.number, rate_limiter + ) + link_item_to_project( + ghe, token, project_global_id, issue_id, rate_limiter ) - link_item_to_project(ghe, token, project_global_id, issue_id) print(f"\tLinked issue to project {project_global_id}") else: # Try to detect if the repo already has an open pull request for dependabot @@ -256,10 +275,15 @@ def main(): # pragma: no cover ) if project_global_id: pr_id = get_global_pr_id( - ghe, token, organization, repo.name, pull.number + ghe, + token, + organization, + repo.name, + pull.number, + rate_limiter, ) response = link_item_to_project( - ghe, token, project_global_id, pr_id + ghe, token, project_global_id, pr_id, rate_limiter ) if response: print( @@ -283,7 +307,9 @@ def is_repo_created_date_before(repo_created_at: str, created_after_date: str): ) -def is_dependabot_security_updates_enabled(ghe, owner, repo, access_token): +def is_dependabot_security_updates_enabled( + ghe, owner, repo, access_token, rate_limiter +): """ Check if Dependabot security updates are enabled at the /repos/:owner/:repo/automated-security-fixes endpoint using the requests library API: https://docs.github.com/en/rest/repos/repos?apiVersion=2022-11-28#check-if-automated-security-fixes-are-enabled-for-a-repository @@ -295,7 +321,9 @@ def is_dependabot_security_updates_enabled(ghe, owner, repo, access_token): "Accept": "application/vnd.github.london-preview+json", } - response = requests.get(url, headers=headers, timeout=20) + response = rate_limiter.execute_with_backoff( + requests.get, url, headers=headers, timeout=20 + ) if response.status_code == 200: return response.json()["enabled"] return False @@ -325,7 +353,7 @@ def check_existing_config(repo, filename): return None -def enable_dependabot_security_updates(ghe, owner, repo, access_token): +def enable_dependabot_security_updates(ghe, owner, repo, access_token, rate_limiter): """ Enable Dependabot security updates at the /repos/:owner/:repo/automated-security-fixes endpoint using the requests library API: https://docs.github.com/en/rest/repos/repos?apiVersion=2022-11-28#enable-automated-security-fixes @@ -337,7 +365,9 @@ def enable_dependabot_security_updates(ghe, owner, repo, access_token): "Accept": "application/vnd.github.london-preview+json", } - response = requests.put(url, headers=headers, timeout=20) + response = rate_limiter.execute_with_backoff( + requests.put, url, headers=headers, timeout=20 + ) if response.status_code == 204: print("\tDependabot security updates enabled successfully.") else: @@ -438,7 +468,7 @@ def commit_changes( return pull -def get_global_project_id(ghe, token, organization, number): +def get_global_project_id(ghe, token, organization, number, rate_limiter): """ Fetches the project ID from GitHub's GraphQL API. API: https://docs.github.com/en/graphql/guides/forming-calls-with-graphql @@ -451,7 +481,9 @@ def get_global_project_id(ghe, token, organization, number): } try: - response = requests.post(url, headers=headers, json=data, timeout=20) + response = rate_limiter.execute_with_backoff( + requests.post, url, headers=headers, json=data, timeout=20 + ) response.raise_for_status() except requests.exceptions.RequestException as e: print(f"Request failed: {e}") @@ -464,7 +496,9 @@ def get_global_project_id(ghe, token, organization, number): return None -def get_global_issue_id(ghe, token, organization, repository, issue_number): +def get_global_issue_id( + ghe, token, organization, repository, issue_number, rate_limiter +): """ Fetches the issue ID from GitHub's GraphQL API API: https://docs.github.com/en/graphql/guides/forming-calls-with-graphql @@ -473,7 +507,7 @@ def get_global_issue_id(ghe, token, organization, repository, issue_number): url = f"{api_endpoint}/graphql" headers = {"Authorization": f"Bearer {token}"} data = { - "query": f""" + "query": f""" query {{ repository(owner: "{organization}", name: "{repository}") {{ issue(number: {issue_number}) {{ @@ -485,7 +519,9 @@ def get_global_issue_id(ghe, token, organization, repository, issue_number): } try: - response = requests.post(url, headers=headers, json=data, timeout=20) + response = rate_limiter.execute_with_backoff( + requests.post, url, headers=headers, json=data, timeout=20 + ) response.raise_for_status() except requests.exceptions.RequestException as e: print(f"Request failed: {e}") @@ -498,7 +534,7 @@ def get_global_issue_id(ghe, token, organization, repository, issue_number): return None -def get_global_pr_id(ghe, token, organization, repository, pr_number): +def get_global_pr_id(ghe, token, organization, repository, pr_number, rate_limiter): """ Fetches the pull request ID from GitHub's GraphQL API API: https://docs.github.com/en/graphql/guides/forming-calls-with-graphql @@ -507,7 +543,7 @@ def get_global_pr_id(ghe, token, organization, repository, pr_number): url = f"{api_endpoint}/graphql" headers = {"Authorization": f"Bearer {token}"} data = { - "query": f""" + "query": f""" query {{ repository(owner: "{organization}", name: "{repository}") {{ pullRequest(number: {pr_number}) {{ @@ -519,7 +555,9 @@ def get_global_pr_id(ghe, token, organization, repository, pr_number): } try: - response = requests.post(url, headers=headers, json=data, timeout=20) + response = rate_limiter.execute_with_backoff( + requests.post, url, headers=headers, json=data, timeout=20 + ) response.raise_for_status() except requests.exceptions.RequestException as e: print(f"Request failed: {e}") @@ -532,7 +570,7 @@ def get_global_pr_id(ghe, token, organization, repository, pr_number): return None -def link_item_to_project(ghe, token, project_global_id, item_id): +def link_item_to_project(ghe, token, project_global_id, item_id, rate_limiter): """ Links an item (issue or pull request) to a project in GitHub. API: https://docs.github.com/en/graphql/guides/forming-calls-with-graphql @@ -545,7 +583,9 @@ def link_item_to_project(ghe, token, project_global_id, item_id): } try: - response = requests.post(url, headers=headers, json=data, timeout=20) + response = rate_limiter.execute_with_backoff( + requests.post, url, headers=headers, json=data, timeout=20 + ) response.raise_for_status() return response except requests.exceptions.RequestException as e: diff --git a/rate_limiter.py b/rate_limiter.py new file mode 100644 index 0000000..545d93f --- /dev/null +++ b/rate_limiter.py @@ -0,0 +1,136 @@ +""" +Rate limiting module for GitHub API requests with exponential backoff. +""" + +import threading +import time +from typing import Any, Callable + + +class RateLimiter: + """ + Thread-safe rate limiter using token bucket algorithm with exponential backoff. + + Attributes: + requests_per_second (float): Maximum number of requests allowed per second + enabled (bool): Whether rate limiting is enabled + backoff_factor (float): Multiplier for exponential backoff (e.g., 2.0 means double wait time) + max_retries (int): Maximum number of retry attempts on rate limit errors + """ + + # pylint: disable=too-many-instance-attributes + def __init__( + self, + requests_per_second: float = 2.0, + enabled: bool = True, + backoff_factor: float = 2.0, + max_retries: int = 3, + ): + """ + Initialize the rate limiter. + + Args: + requests_per_second: Maximum requests per second (default: 2.0) + enabled: Whether rate limiting is enabled (default: True) + backoff_factor: Exponential backoff multiplier (default: 2.0) + max_retries: Maximum retry attempts (default: 3) + """ + self.requests_per_second = requests_per_second + self.enabled = enabled + self.backoff_factor = backoff_factor + self.max_retries = max_retries + + # Token bucket algorithm state + self._tokens = requests_per_second + self._max_tokens = requests_per_second + self._last_update = time.time() + self._lock = threading.Lock() + + # Minimum interval between requests + self._min_interval = 1.0 / requests_per_second if requests_per_second > 0 else 0 + + def wait_for_token(self) -> None: + """ + Wait until a token is available in the bucket (rate limit allows next request). + This implements the token bucket algorithm for smooth rate limiting. + """ + if not self.enabled: + return + + with self._lock: + now = time.time() + time_passed = now - self._last_update + + # Refill tokens based on time passed + self._tokens = min( + self._max_tokens, self._tokens + time_passed * self.requests_per_second + ) + + # Wait if no tokens available + if self._tokens < 1.0: + sleep_time = (1.0 - self._tokens) / self.requests_per_second + time.sleep(sleep_time) + self._tokens = 0.0 + else: + self._tokens -= 1.0 + + self._last_update = time.time() + + def execute_with_backoff(self, func: Callable, *args: Any, **kwargs: Any) -> Any: + """ + Execute a function with rate limiting and exponential backoff on errors. + + Args: + func: Function to execute + *args: Positional arguments to pass to func + **kwargs: Keyword arguments to pass to func + + Returns: + The return value of func + + Raises: + Exception: Re-raises the last exception if max_retries is exceeded + """ + if not self.enabled: + return func(*args, **kwargs) + + last_exception = None + initial_wait = 1.0 # Initial wait time in seconds + + for attempt in range(self.max_retries + 1): + try: + # Wait for rate limit token before making request + self.wait_for_token() + + # Execute the function + result = func(*args, **kwargs) + + # Check if response indicates rate limiting (if it's a requests.Response) + if hasattr(result, "status_code") and result.status_code == 429: + # Rate limit hit, trigger backoff + raise RateLimitExceeded("GitHub API rate limit exceeded (429)") + + return result + + except RateLimitExceeded as e: + last_exception = e + + if attempt < self.max_retries: + # Calculate exponential backoff wait time + wait_time = initial_wait * (self.backoff_factor**attempt) + print( + f"Rate limit exceeded, waiting {wait_time:.1f}s before retry {attempt + 1}/{self.max_retries}" + ) + time.sleep(wait_time) + else: + print(f"Max retries ({self.max_retries}) exceeded") + raise + + # Should not reach here, but raise last exception if we do + if last_exception: + raise last_exception + raise RuntimeError("Unexpected: no exception to raise") + + +class RateLimitExceeded(Exception): + """Exception raised when rate limit is exceeded.""" diff --git a/test_dependabot_file.py b/test_dependabot_file.py index 61c8901..a54f2ca 100644 --- a/test_dependabot_file.py +++ b/test_dependabot_file.py @@ -45,7 +45,7 @@ def test_build_dependabot_file_with_schedule_day(self): interval: 'weekly' day: 'tuesday' """ - ) +) result = build_dependabot_file( repo, False, [], {}, None, "weekly", "tuesday", [], None ) @@ -122,7 +122,7 @@ def test_build_dependabot_file_with_2_space_indent_existing_config_bundler_with_ schedule: interval: 'weekly' """ - ) +) existing_config = MagicMock() existing_config.content = base64.b64encode( b""" @@ -353,7 +353,7 @@ def test_build_dependabot_file_with_composer(self): schedule: interval: 'weekly' """ - ) +) result = build_dependabot_file( repo, False, [], {}, None, "weekly", "", [], None ) @@ -398,7 +398,7 @@ def test_build_dependabot_file_with_nuget(self): schedule: interval: 'weekly' """ - ) +) result = build_dependabot_file( repo, False, [], {}, None, "weekly", "", [], None ) @@ -483,7 +483,7 @@ def test_build_dependabot_file_with_terraform_with_files(self): schedule: interval: 'weekly' """ - ) +) result = build_dependabot_file( repo, False, [], {}, None, "weekly", "", [], None ) @@ -533,7 +533,7 @@ def test_build_dependabot_file_with_devcontainers(self): schedule: interval: 'weekly' """ - ) +) result = build_dependabot_file( repo, False, [], None, None, "weekly", "", [], None ) @@ -695,7 +695,7 @@ def test_build_dependabot_file_for_multiple_repos_with_few_existing_config(self) schedule: interval: 'weekly' """ - ) + ) result = build_dependabot_file( no_existing_config_repo, False, @@ -747,7 +747,7 @@ def test_check_multiple_repos_with_no_dependabot_config(self): schedule: interval: 'weekly' """ - ) +) result = build_dependabot_file( no_existing_config_repo, False, diff --git a/test_env.py b/test_env.py index 1ab77cd..1f07518 100644 --- a/test_env.py +++ b/test_env.py @@ -96,6 +96,10 @@ def test_get_env_vars_with_org(self): None, # team_name [], # labels None, + True, # rate_limit_enabled + 2.0, # rate_limit_requests_per_second + 2.0, # rate_limit_backoff_factor + 3, # rate_limit_max_retries ) result = get_env_vars(True) self.assertEqual(result, expected_result) @@ -152,6 +156,10 @@ def test_get_env_vars_with_org_and_repo_specific_exemptions(self): None, # team_name [], # labels None, + True, # rate_limit_enabled + 2.0, # rate_limit_requests_per_second + 2.0, # rate_limit_backoff_factor + 3, # rate_limit_max_retries ) result = get_env_vars(True) self.assertEqual(result, expected_result) @@ -265,6 +273,10 @@ def test_get_env_vars_with_repos(self): None, # team_name [], # labels None, + True, # rate_limit_enabled + 2.0, # rate_limit_requests_per_second + 2.0, # rate_limit_backoff_factor + 3, # rate_limit_max_retries ) result = get_env_vars(True) self.assertEqual(result, expected_result) @@ -323,6 +335,10 @@ def test_get_env_vars_with_team(self): "engineering", # team_name [], # labels None, + True, # rate_limit_enabled + 2.0, # rate_limit_requests_per_second + 2.0, # rate_limit_backoff_factor + 3, # rate_limit_max_retries ) result = get_env_vars(True) self.assertEqual(result, expected_result) @@ -398,6 +414,10 @@ def test_get_env_vars_optional_values(self): None, # team_name [], # labels None, + True, # rate_limit_enabled + 2.0, # rate_limit_requests_per_second + 2.0, # rate_limit_backoff_factor + 3, # rate_limit_max_retries ) result = get_env_vars(True) self.assertEqual(result, expected_result) @@ -445,6 +465,10 @@ def test_get_env_vars_with_update_existing(self): None, # team_name [], # labels None, + True, # rate_limit_enabled + 2.0, # rate_limit_requests_per_second + 2.0, # rate_limit_backoff_factor + 3, # rate_limit_max_retries ) result = get_env_vars(True) self.assertEqual(result, expected_result) @@ -506,6 +530,10 @@ def test_get_env_vars_auth_with_github_app_installation(self): None, # team_name [], # labels None, + True, # rate_limit_enabled + 2.0, # rate_limit_requests_per_second + 2.0, # rate_limit_backoff_factor + 3, # rate_limit_max_retries ) result = get_env_vars(True) self.assertEqual(result, expected_result) @@ -596,6 +624,10 @@ def test_get_env_vars_with_repos_no_dry_run(self): None, # team_name [], # labels None, + True, # rate_limit_enabled + 2.0, # rate_limit_requests_per_second + 2.0, # rate_limit_backoff_factor + 3, # rate_limit_max_retries ) result = get_env_vars(True) self.assertEqual(result, expected_result) @@ -643,6 +675,10 @@ def test_get_env_vars_with_repos_disabled_security_updates(self): None, # team_name [], # labels None, + True, # rate_limit_enabled + 2.0, # rate_limit_requests_per_second + 2.0, # rate_limit_backoff_factor + 3, # rate_limit_max_retries ) result = get_env_vars(True) self.assertEqual(result, expected_result) @@ -691,6 +727,10 @@ def test_get_env_vars_with_repos_filter_visibility_multiple_values(self): None, # team_name [], # labels None, + True, # rate_limit_enabled + 2.0, # rate_limit_requests_per_second + 2.0, # rate_limit_backoff_factor + 3, # rate_limit_max_retries ) result = get_env_vars(True) self.assertEqual(result, expected_result) @@ -739,6 +779,10 @@ def test_get_env_vars_with_repos_filter_visibility_single_value(self): None, # team_name [], # labels None, + True, # rate_limit_enabled + 2.0, # rate_limit_requests_per_second + 2.0, # rate_limit_backoff_factor + 3, # rate_limit_max_retries ) result = get_env_vars(True) self.assertEqual(result, expected_result) @@ -817,6 +861,10 @@ def test_get_env_vars_with_repos_filter_visibility_no_duplicates(self): None, # team_name [], # labels None, + True, # rate_limit_enabled + 2.0, # rate_limit_requests_per_second + 2.0, # rate_limit_backoff_factor + 3, # rate_limit_max_retries ) result = get_env_vars(True) self.assertEqual(result, expected_result) @@ -866,6 +914,10 @@ def test_get_env_vars_with_repos_exempt_ecosystems(self): None, # team_name [], # labels None, + True, # rate_limit_enabled + 2.0, # rate_limit_requests_per_second + 2.0, # rate_limit_backoff_factor + 3, # rate_limit_max_retries ) result = get_env_vars(True) self.assertEqual(result, expected_result) @@ -914,6 +966,10 @@ def test_get_env_vars_with_no_batch_size(self): None, # team_name [], # labels None, + True, # rate_limit_enabled + 2.0, # rate_limit_requests_per_second + 2.0, # rate_limit_backoff_factor + 3, # rate_limit_max_retries ) result = get_env_vars(True) self.assertEqual(result, expected_result) @@ -963,6 +1019,10 @@ def test_get_env_vars_with_batch_size(self): None, # team_name [], # labels None, + True, # rate_limit_enabled + 2.0, # rate_limit_requests_per_second + 2.0, # rate_limit_backoff_factor + 3, # rate_limit_max_retries ) result = get_env_vars(True) self.assertEqual(result, expected_result) @@ -1101,6 +1161,10 @@ def test_get_env_vars_with_valid_schedule_and_schedule_day(self): None, # team_name [], # labels None, + True, # rate_limit_enabled + 2.0, # rate_limit_requests_per_second + 2.0, # rate_limit_backoff_factor + 3, # rate_limit_max_retries ) result = get_env_vars(True) self.assertEqual(result, expected_result) @@ -1187,6 +1251,10 @@ def test_get_env_vars_with_a_valid_label(self): None, # team_name ["dependencies"], # labels None, + True, # rate_limit_enabled + 2.0, # rate_limit_requests_per_second + 2.0, # rate_limit_backoff_factor + 3, # rate_limit_max_retries ) result = get_env_vars(True) self.assertEqual(result, expected_result) @@ -1234,6 +1302,10 @@ def test_get_env_vars_with_valid_labels_containing_spaces(self): None, # team_name ["dependencies", "test", "test2"], # labels None, + True, # rate_limit_enabled + 2.0, # rate_limit_requests_per_second + 2.0, # rate_limit_backoff_factor + 3, # rate_limit_max_retries ) result = get_env_vars(True) self.assertEqual(result, expected_result) diff --git a/test_evergreen.py b/test_evergreen.py index b0a14a2..2bfdc76 100644 --- a/test_evergreen.py +++ b/test_evergreen.py @@ -21,6 +21,13 @@ is_repo_created_date_before, link_item_to_project, ) +from rate_limiter import RateLimiter + + +# Create a disabled rate limiter for tests to avoid adding delays +def get_mock_rate_limiter(): + """Get a disabled rate limiter for testing.""" + return RateLimiter(enabled=False) class TestDependabotSecurityUpdates(unittest.TestCase): @@ -54,7 +61,7 @@ def test_is_dependabot_security_updates_enabled(self): mock_get.return_value.json.return_value = expected_response result = is_dependabot_security_updates_enabled( - ghe, owner, repo, access_token + ghe, owner, repo, access_token, get_mock_rate_limiter() ) mock_get.assert_called_once_with( @@ -89,7 +96,7 @@ def test_is_dependabot_security_updates_disabled(self): mock_get.return_value.json.return_value = {"enabled": False} result = is_dependabot_security_updates_enabled( - ghe, owner, repo, access_token + ghe, owner, repo, access_token, get_mock_rate_limiter() ) mock_get.assert_called_once_with( @@ -123,7 +130,7 @@ def test_is_dependabot_security_updates_not_found(self): mock_get.return_value.status_code = 404 result = is_dependabot_security_updates_enabled( - ghe, owner, repo, access_token + ghe, owner, repo, access_token, get_mock_rate_limiter() ) mock_get.assert_called_once_with( @@ -157,7 +164,9 @@ def test_enable_dependabot_security_updates(self): mock_put.return_value.status_code = 204 with patch("builtins.print") as mock_print: - enable_dependabot_security_updates(ghe, owner, repo, access_token) + enable_dependabot_security_updates( + ghe, owner, repo, access_token, get_mock_rate_limiter() + ) mock_put.assert_called_once_with( expected_url, headers=expected_headers, timeout=20 @@ -192,7 +201,9 @@ def test_enable_dependabot_security_updates_failed(self): mock_put.return_value.status_code = 500 with patch("builtins.print") as mock_print: - enable_dependabot_security_updates(ghe, owner, repo, access_token) + enable_dependabot_security_updates( + ghe, owner, repo, access_token, get_mock_rate_limiter() + ) mock_put.assert_called_once_with( expected_url, headers=expected_headers, timeout=20 @@ -452,7 +463,9 @@ def test_get_global_project_id_success(self, mock_post): mock_post.return_value.status_code = 200 mock_post.return_value.json.return_value = expected_response - result = get_global_project_id(ghe, token, organization, number) + result = get_global_project_id( + ghe, token, organization, number, get_mock_rate_limiter() + ) mock_post.assert_called_once_with( expected_url, headers=expected_headers, json=expected_data, timeout=20 @@ -476,7 +489,9 @@ def test_get_global_project_id_request_failed(self, mock_post): mock_post.side_effect = requests.exceptions.RequestException("Request failed") with patch("builtins.print") as mock_print: - result = get_global_project_id(ghe, token, organization, number) + result = get_global_project_id( + ghe, token, organization, number, get_mock_rate_limiter() + ) mock_post.assert_called_once_with( expected_url, headers=expected_headers, json=expected_data, timeout=20 @@ -503,7 +518,9 @@ def test_get_global_project_id_parse_response_failed(self, mock_post): mock_post.return_value.json.return_value = expected_response with patch("builtins.print") as mock_print: - result = get_global_project_id(ghe, token, organization, number) + result = get_global_project_id( + ghe, token, organization, number, get_mock_rate_limiter() + ) mock_post.assert_called_once_with( expected_url, headers=expected_headers, json=expected_data, timeout=20 @@ -529,7 +546,9 @@ def test_get_global_issue_id_success(self, mock_post): mock_post.return_value.status_code = 200 mock_post.return_value.json.return_value = expected_response - result = get_global_issue_id(ghe, token, organization, repository, issue_number) + result = get_global_issue_id( + ghe, token, organization, repository, issue_number, get_mock_rate_limiter() + ) mock_post.assert_called_once() self.assertEqual(result, "1234567890") @@ -545,7 +564,9 @@ def test_get_global_issue_id_request_failed(self, mock_post): mock_post.side_effect = requests.exceptions.RequestException("Request failed") - result = get_global_issue_id(ghe, token, organization, repository, issue_number) + result = get_global_issue_id( + ghe, token, organization, repository, issue_number, get_mock_rate_limiter() + ) mock_post.assert_called_once() self.assertIsNone(result) @@ -564,7 +585,9 @@ def test_get_global_issue_id_parse_response_failed(self, mock_post): mock_post.return_value.status_code = 200 mock_post.return_value.json.return_value = expected_response - result = get_global_issue_id(ghe, token, organization, repository, issue_number) + result = get_global_issue_id( + ghe, token, organization, repository, issue_number, get_mock_rate_limiter() + ) mock_post.assert_called_once() self.assertIsNone(result) @@ -585,7 +608,9 @@ def test_get_global_pr_id_success(self, mock_post): mock_post.return_value = mock_response # Call the function with test data - result = get_global_pr_id("", "test_token", "test_org", "test_repo", 1) + result = get_global_pr_id( + "", "test_token", "test_org", "test_repo", 1, get_mock_rate_limiter() + ) # Check that the result is as expected self.assertEqual(result, "test_id") @@ -597,7 +622,9 @@ def test_get_global_pr_id_request_exception(self, mock_post): mock_post.side_effect = requests.exceptions.RequestException # Call the function with test data - result = get_global_pr_id("", "test_token", "test_org", "test_repo", 1) + result = get_global_pr_id( + "", "test_token", "test_org", "test_repo", 1, get_mock_rate_limiter() + ) # Check that the result is None self.assertIsNone(result) @@ -612,7 +639,9 @@ def test_get_global_pr_id_key_error(self, mock_post): mock_post.return_value = mock_response # Call the function with test data - result = get_global_pr_id("", "test_token", "test_org", "test_repo", 1) + result = get_global_pr_id( + "", "test_token", "test_org", "test_repo", 1, get_mock_rate_limiter() + ) # Check that the result is None self.assertIsNone(result) @@ -639,7 +668,9 @@ def test_link_item_to_project_success(self, mock_post): mock_response.status_code = 200 mock_post.return_value = mock_response - result = link_item_to_project(ghe, token, project_id, item_id) + result = link_item_to_project( + ghe, token, project_id, item_id, get_mock_rate_limiter() + ) mock_post.assert_called_once_with( expected_url, headers=expected_headers, json=expected_data, timeout=20 @@ -666,7 +697,9 @@ def test_link_item_to_project_request_exception(self, mock_post): mock_post.side_effect = requests.exceptions.RequestException("Request failed") with patch("builtins.print") as mock_print: - result = link_item_to_project(ghe, token, project_id, item_id) + result = link_item_to_project( + ghe, token, project_id, item_id, get_mock_rate_limiter() + ) mock_post.assert_called_once_with( expected_url, headers=expected_headers, json=expected_data, timeout=20 diff --git a/test_rate_limiter.py b/test_rate_limiter.py new file mode 100644 index 0000000..bc3bae5 --- /dev/null +++ b/test_rate_limiter.py @@ -0,0 +1,216 @@ +"""Tests for the rate_limiter module.""" + +import threading +import time +import unittest +from unittest.mock import MagicMock, patch + +from rate_limiter import RateLimiter, RateLimitExceeded + + +class TestRateLimiter(unittest.TestCase): + """Test the RateLimiter class.""" + + def test_rate_limiter_initialization_default_values(self): + """Test RateLimiter initialization with default values.""" + limiter = RateLimiter() + self.assertEqual(limiter.requests_per_second, 2.0) + self.assertTrue(limiter.enabled) + self.assertEqual(limiter.backoff_factor, 2.0) + self.assertEqual(limiter.max_retries, 3) + + def test_rate_limiter_initialization_custom_values(self): + """Test RateLimiter initialization with custom values.""" + limiter = RateLimiter( + requests_per_second=5.0, enabled=False, backoff_factor=3.0, max_retries=5 + ) + self.assertEqual(limiter.requests_per_second, 5.0) + self.assertFalse(limiter.enabled) + self.assertEqual(limiter.backoff_factor, 3.0) + self.assertEqual(limiter.max_retries, 5) + + def test_rate_limiter_disabled_no_delay(self): + """Test that disabled rate limiter doesn't add delay.""" + limiter = RateLimiter(requests_per_second=1.0, enabled=False) + + start_time = time.time() + for _ in range(5): + limiter.wait_for_token() + elapsed = time.time() - start_time + + # Should complete almost instantly (allowing small overhead) + self.assertLess(elapsed, 0.1) + + def test_rate_limiter_enabled_enforces_delay(self): + """Test that enabled rate limiter enforces delay between requests.""" + limiter = RateLimiter(requests_per_second=10.0, enabled=True) + + # Make enough requests to exhaust initial tokens and trigger delays + start_time = time.time() + for _ in range( + 15 + ): # 15 requests should take at least 1.4 seconds (needs 1.5s of tokens, starts with 1.0s worth) + limiter.wait_for_token() + elapsed = time.time() - start_time + + # Should take at least 0.4 seconds (15 requests - 10 initial tokens = 5 delayed tokens / 10 rps = 0.5s minimum) + self.assertGreaterEqual(elapsed, 0.4) + + def test_execute_with_backoff_success(self): + """Test execute_with_backoff with successful function execution.""" + limiter = RateLimiter(requests_per_second=10.0, enabled=True) + + mock_func = MagicMock(return_value="success") + result = limiter.execute_with_backoff(mock_func, "arg1", kwarg1="value1") + + self.assertEqual(result, "success") + mock_func.assert_called_once_with("arg1", kwarg1="value1") + + def test_execute_with_backoff_disabled(self): + """Test execute_with_backoff when rate limiting is disabled.""" + limiter = RateLimiter(enabled=False) + + mock_func = MagicMock(return_value="success") + result = limiter.execute_with_backoff(mock_func) + + self.assertEqual(result, "success") + mock_func.assert_called_once() + + def test_execute_with_backoff_rate_limit_then_success(self): + """Test execute_with_backoff with rate limit error then success.""" + limiter = RateLimiter( + requests_per_second=10.0, enabled=True, backoff_factor=2.0, max_retries=3 + ) + + # First call raises RateLimitExceeded, second call succeeds + mock_func = MagicMock( + side_effect=[RateLimitExceeded("Rate limit hit"), "success"] + ) + + with patch("time.sleep") as mock_sleep: + result = limiter.execute_with_backoff(mock_func) + + self.assertEqual(result, "success") + self.assertEqual(mock_func.call_count, 2) + # Should have slept once with initial wait time of 1.0 + mock_sleep.assert_called_once() + + def test_execute_with_backoff_429_response(self): + """Test execute_with_backoff with 429 status code response.""" + limiter = RateLimiter( + requests_per_second=10.0, enabled=True, backoff_factor=2.0, max_retries=2 + ) + + # Create mock response with 429 status + mock_response_429 = MagicMock() + mock_response_429.status_code = 429 + + mock_response_success = MagicMock() + mock_response_success.status_code = 200 + + mock_func = MagicMock(side_effect=[mock_response_429, mock_response_success]) + + with patch("time.sleep") as mock_sleep: + result = limiter.execute_with_backoff(mock_func) + + self.assertEqual(result, mock_response_success) + self.assertEqual(mock_func.call_count, 2) + mock_sleep.assert_called_once() + + def test_execute_with_backoff_max_retries_exceeded(self): + """Test execute_with_backoff when max retries is exceeded.""" + limiter = RateLimiter( + requests_per_second=10.0, enabled=True, backoff_factor=2.0, max_retries=2 + ) + + # Always raise RateLimitExceeded + mock_func = MagicMock(side_effect=RateLimitExceeded("Rate limit hit")) + + with patch("time.sleep") as mock_sleep: + with self.assertRaises(RateLimitExceeded): + limiter.execute_with_backoff(mock_func) + + # Should have tried max_retries + 1 times (initial + 2 retries = 3 total) + self.assertEqual(mock_func.call_count, 3) + # Should have slept 2 times (once for each retry, not for final failure) + self.assertEqual(mock_sleep.call_count, 2) + + def test_execute_with_backoff_exponential_backoff_timing(self): + """Test that exponential backoff timing increases correctly.""" + limiter = RateLimiter( + requests_per_second=10.0, enabled=True, backoff_factor=2.0, max_retries=3 + ) + + mock_func = MagicMock(side_effect=RateLimitExceeded("Rate limit hit")) + + with patch("time.sleep") as mock_sleep: + try: + limiter.execute_with_backoff(mock_func) + except RateLimitExceeded: + pass + + # Verify exponential backoff: 1.0, 2.0, 4.0 + calls = mock_sleep.call_args_list + self.assertEqual(len(calls), 3) + self.assertAlmostEqual(calls[0][0][0], 1.0, places=1) # First retry + self.assertAlmostEqual(calls[1][0][0], 2.0, places=1) # Second retry + self.assertAlmostEqual(calls[2][0][0], 4.0, places=1) # Third retry + + def test_execute_with_backoff_non_rate_limit_exception(self): + """Test that non-rate-limit exceptions are raised immediately.""" + limiter = RateLimiter(requests_per_second=10.0, enabled=True, max_retries=3) + + mock_func = MagicMock(side_effect=ValueError("Some other error")) + + with self.assertRaises(ValueError) as context: + limiter.execute_with_backoff(mock_func) + + self.assertEqual(str(context.exception), "Some other error") + # Should only be called once, no retries for non-rate-limit errors + self.assertEqual(mock_func.call_count, 1) + + def test_wait_for_token_refills_over_time(self): + """Test that tokens refill over time allowing burst requests.""" + limiter = RateLimiter(requests_per_second=5.0, enabled=True) + + # Wait for tokens to refill + time.sleep(0.5) + + # Should be able to make a couple requests quickly + start_time = time.time() + limiter.wait_for_token() + limiter.wait_for_token() + elapsed = time.time() - start_time + + # Should complete quickly due to token refill + self.assertLess(elapsed, 0.5) + + def test_rate_limiter_thread_safety(self): + """Test that rate limiter is thread-safe.""" + limiter = RateLimiter(requests_per_second=10.0, enabled=True) + results = [] + + def make_request(): + limiter.wait_for_token() + results.append(time.time()) + + # Use more threads to exhaust initial tokens + threads = [threading.Thread(target=make_request) for _ in range(15)] + + start_time = time.time() + for thread in threads: + thread.start() + for thread in threads: + thread.join() + elapsed = time.time() - start_time + + # All threads should complete + self.assertEqual(len(results), 15) + + # Should take at least 0.4 seconds for 15 requests at 10 rps + # (15 requests - 10 initial tokens = 5 delayed / 10 rps = 0.5s) + self.assertGreaterEqual(elapsed, 0.4) + + +if __name__ == "__main__": + unittest.main()