From e80a2bb062a1c08404dab04388843a0cea2372a8 Mon Sep 17 00:00:00 2001 From: p1c2u Date: Tue, 24 Feb 2026 00:18:49 +0000 Subject: [PATCH] Fix malformed schema traversal to report validation errors instead of internal exceptions --- openapi_spec_validator/validation/keywords.py | 53 +++++++++-- tests/bench/runner.py | 91 +++++++++++++------ tests/integration/test_main.py | 28 ++++++ .../integration/validation/test_exceptions.py | 64 +++++++++++++ 4 files changed, 198 insertions(+), 38 deletions(-) diff --git a/openapi_spec_validator/validation/keywords.py b/openapi_spec_validator/validation/keywords.py index 2d1561e..6c4b751 100644 --- a/openapi_spec_validator/validation/keywords.py +++ b/openapi_spec_validator/validation/keywords.py @@ -1,12 +1,14 @@ import string -from collections.abc import Iterator from collections.abc import Callable +from collections.abc import Iterator +from collections.abc import Mapping from collections.abc import Sequence from typing import TYPE_CHECKING from typing import Any from typing import cast from jsonschema._format import FormatChecker +from jsonschema.exceptions import SchemaError from jsonschema.exceptions import ValidationError from jsonschema.protocols import Validator from jsonschema_path.paths import SchemaPath @@ -19,6 +21,7 @@ DuplicateOperationIDError, ) from openapi_spec_validator.validation.exceptions import ExtraParametersError +from openapi_spec_validator.validation.exceptions import OpenAPIValidationError from openapi_spec_validator.validation.exceptions import ( ParameterDuplicateError, ) @@ -67,7 +70,11 @@ class SchemaValidator(KeywordValidator): def __init__(self, registry: "KeywordValidatorRegistry"): super().__init__(registry) - self.schema_ids_registry: list[int] | None = [] + # recursion/visit dedupe registry + self.visited_schema_ids: list[int] | None = [] + # meta-schema-check dedupe registry + # to avoid validating the same schema multiple times + self.meta_checked_schema_ids: list[int] | None = [] @property def default_validator(self) -> ValueValidator: @@ -95,23 +102,48 @@ def _collect_properties(self, schema: SchemaPath) -> set[str]: return props def __call__( - self, schema: SchemaPath, require_properties: bool = True + self, + schema: SchemaPath, + require_properties: bool = True, + meta_checked: bool = False, ) -> Iterator[ValidationError]: schema_value = schema.read_value() - if not hasattr(schema_value, "__getitem__"): + if not isinstance(schema_value, (Mapping, bool)): + yield OpenAPIValidationError( + f"{schema_value!r} is not of type 'object', 'boolean'" + ) return - assert self.schema_ids_registry is not None + if not meta_checked: + assert self.meta_checked_schema_ids is not None + schema_id = id(schema_value) + if schema_id not in self.meta_checked_schema_ids: + try: + schema_check = getattr( + self.default_validator.value_validator_cls, + "check_schema", + ) + schema_check(schema_value) + except (SchemaError, ValidationError) as err: + yield OpenAPIValidationError.create_from(err) + return + self.meta_checked_schema_ids.append(schema_id) + + assert self.visited_schema_ids is not None schema_id = id(schema_value) - if schema_id in self.schema_ids_registry: + if schema_id in self.visited_schema_ids: return - self.schema_ids_registry.append(schema_id) + self.visited_schema_ids.append(schema_id) nested_properties = [] if "allOf" in schema: all_of = schema / "allOf" for inner_schema in all_of: - yield from self(inner_schema, require_properties=False) + yield from self( + inner_schema, + require_properties=False, + meta_checked=True, + ) nested_properties += list( self._collect_properties(inner_schema) ) @@ -122,6 +154,7 @@ def __call__( yield from self( inner_schema, require_properties=False, + meta_checked=True, ) if "oneOf" in schema: @@ -130,6 +163,7 @@ def __call__( yield from self( inner_schema, require_properties=False, + meta_checked=True, ) if "not" in schema: @@ -137,6 +171,7 @@ def __call__( yield from self( not_schema, require_properties=False, + meta_checked=True, ) if "items" in schema: @@ -144,6 +179,7 @@ def __call__( yield from self( array_schema, require_properties=False, + meta_checked=True, ) if "properties" in schema: @@ -152,6 +188,7 @@ def __call__( yield from self( prop_schema, require_properties=False, + meta_checked=True, ) required = ( diff --git a/tests/bench/runner.py b/tests/bench/runner.py index d93f375..d059519 100644 --- a/tests/bench/runner.py +++ b/tests/bench/runner.py @@ -10,23 +10,23 @@ import argparse import cProfile import gc -from io import StringIO import json import pstats import statistics import time +from collections.abc import Iterator from dataclasses import dataclass from functools import cached_property +from io import StringIO from pathlib import Path from typing import Any -from collections.abc import Iterator from jsonschema_path import SchemaPath from jsonschema_path.typing import Schema +from openapi_spec_validator import schemas from openapi_spec_validator import validate from openapi_spec_validator.readers import read_from_filename -from openapi_spec_validator import schemas from openapi_spec_validator.shortcuts import get_validator_cls @@ -128,7 +128,10 @@ def benchmark_spec_file( spec_size_kb = spec_path.stat().st_size / 1024 spec, _ = read_from_filename(str(spec_path)) return benchmark_spec( - spec, repeats, warmup, no_gc, + spec, + repeats, + warmup, + no_gc, spec_name=spec_name, spec_size_kb=spec_size_kb, ) @@ -148,15 +151,17 @@ def benchmark_spec( spec_version = get_spec_version(spec) paths_count = count_paths(spec) schemas_count = count_schemas(spec) - print(f"⚔ Benchmarking {spec_name} spec (version {spec_version}, {paths_count} paths, {schemas_count} schemas)...") - + print( + f"⚔ Benchmarking {spec_name} spec (version {spec_version}, {paths_count} paths, {schemas_count} schemas)..." + ) + if no_gc: gc.disable() - + # Warmup for _ in range(warmup): run_once(spec) - + pr: cProfile.Profile | None = None if profile: print("\nšŸ”¬ Profiling mode enabled...") @@ -174,10 +179,10 @@ def benchmark_spec( # Print profile stats s = StringIO() - ps = pstats.Stats(pr, stream=s).sort_stats('cumulative') + ps = pstats.Stats(pr, stream=s).sort_stats("cumulative") ps.print_stats(30) print(s.getvalue()) - + # Save profile data pr.dump_stats(profile) print(f"šŸ’¾ Profile data saved to {profile}") @@ -185,7 +190,7 @@ def benchmark_spec( if no_gc: gc.enable() - + return BenchResult( spec_name=spec_name, spec_version=spec_version, @@ -197,7 +202,7 @@ def benchmark_spec( seconds=seconds, success=True, ) - + except Exception as e: return BenchResult( spec_name=spec_name, @@ -228,14 +233,16 @@ def generate_synthetic_spec( "description": "Success", "content": { "application/json": { - "schema": {"$ref": f"#/components/schemas/Schema{i % schemas}"} + "schema": { + "$ref": f"#/components/schemas/Schema{i % schemas}" + } } - } + }, } } } } - + schemas_obj = {} for i in range(schemas): schemas_obj[f"Schema{i}"] = { @@ -243,15 +250,20 @@ def generate_synthetic_spec( "properties": { "id": {"type": "integer"}, "name": {"type": "string"}, - "nested": {"$ref": f"#/components/schemas/Schema{(i + 1) % schemas}"} - } + "nested": { + "$ref": f"#/components/schemas/Schema{(i + 1) % schemas}" + }, + }, } - + return { "openapi": version, - "info": {"title": f"Synthetic API ({paths} paths, {schemas} schemas)", "version": "1.0.0"}, + "info": { + "title": f"Synthetic API ({paths} paths, {schemas} schemas)", + "version": "1.0.0", + }, "paths": paths_obj, - "components": {"schemas": schemas_obj} + "components": {"schemas": schemas_obj}, } @@ -274,13 +286,28 @@ def get_specs_iterator( def main(): - parser = argparse.ArgumentParser(description="Benchmark openapi-spec-validator") - parser.add_argument("specs", type=Path, nargs='*', help="File(s) with custom specs to benchmark, otherwise use synthetic specs.") - parser.add_argument("--repeats", type=int, default=1, help="Number of benchmark repeats") - parser.add_argument("--warmup", type=int, default=0, help="Number of warmup runs") - parser.add_argument("--no-gc", action="store_true", help="Disable GC during benchmark") + parser = argparse.ArgumentParser( + description="Benchmark openapi-spec-validator" + ) + parser.add_argument( + "specs", + type=Path, + nargs="*", + help="File(s) with custom specs to benchmark, otherwise use synthetic specs.", + ) + parser.add_argument( + "--repeats", type=int, default=1, help="Number of benchmark repeats" + ) + parser.add_argument( + "--warmup", type=int, default=0, help="Number of warmup runs" + ) + parser.add_argument( + "--no-gc", action="store_true", help="Disable GC during benchmark" + ) parser.add_argument("--output", type=str, help="Output JSON file path") - parser.add_argument("--profile", type=str, help="Profile file path (cProfile)") + parser.add_argument( + "--profile", type=str, help="Profile file path (cProfile)" + ) args = parser.parse_args() results: list[dict[str, Any]] = [] @@ -291,7 +318,9 @@ def main(): # Benchmark custom specs if args.specs: - print(f"\nšŸ” Testing with custom specs {[str(spec) for spec in args.specs]}") + print( + f"\nšŸ” Testing with custom specs {[str(spec) for spec in args.specs]}" + ) spec_iterator = get_specs_iterator(args.specs) # Synthetic specs for stress testing @@ -318,7 +347,9 @@ def main(): ) results.append(result.as_dict()) if result.success: - print(f" āœ… {result.median_s:.4f}s, {result.validations_per_sec:.2f} val/s") + print( + f" āœ… {result.median_s:.4f}s, {result.validations_per_sec:.2f} val/s" + ) else: print(f" āŒ Error: {result.error}") @@ -331,10 +362,10 @@ def main(): }, "results": results, } - + print(f"\nšŸ“Š Summary: {len(results)} specs benchmarked") print(json.dumps(output, indent=2)) - + if args.output: with open(args.output, "w") as f: json.dump(output, f, indent=2) diff --git a/tests/integration/test_main.py b/tests/integration/test_main.py index 07fb2db..9125cdd 100644 --- a/tests/integration/test_main.py +++ b/tests/integration/test_main.py @@ -193,6 +193,34 @@ def test_schema_stdin(capsys): assert "stdin: OK\n" in out +def test_malformed_schema_stdin(capsys): + """Malformed schema from STDIN reports validation error.""" + spec_io = StringIO( + """ +openapi: 3.1.0 +info: + version: "1" + title: "Title" +components: + schemas: + Component: + type: object + properties: + name: string +""" + ) + + testargs = ["--schema", "3.1.0", "-"] + with mock.patch("openapi_spec_validator.__main__.sys.stdin", spec_io): + with pytest.raises(SystemExit): + main(testargs) + + out, err = capsys.readouterr() + assert not err + assert "stdin: Validation Error:" in out + assert "stdin: OK" not in out + + def test_version(capsys): """Test --version flag outputs correct version.""" testargs = ["--version"] diff --git a/tests/integration/validation/test_exceptions.py b/tests/integration/validation/test_exceptions.py index efdcf03..2d2aba1 100644 --- a/tests/integration/validation/test_exceptions.py +++ b/tests/integration/validation/test_exceptions.py @@ -1,5 +1,8 @@ +import pytest + from openapi_spec_validator import OpenAPIV2SpecValidator from openapi_spec_validator import OpenAPIV30SpecValidator +from openapi_spec_validator import OpenAPIV31SpecValidator from openapi_spec_validator.validation.exceptions import ( DuplicateOperationIDError, ) @@ -495,3 +498,64 @@ def validate(to_validate) -> bool: assert len(errors_list) == 1 assert errors_list[0].__class__ == OpenAPIValidationError assert errors_list[0].message == ("'invalid' is not a 'custom'") + + def test_malformed_property_schema(self): + spec = { + "openapi": "3.1.0", + "info": { + "title": "Test Api", + "version": "0.0.1", + }, + "components": { + "schemas": { + "Component": { + "type": "object", + "properties": { + "name": "string", + }, + } + }, + }, + } + + errors = OpenAPIV31SpecValidator(spec).iter_errors() + + errors_list = list(errors) + assert len(errors_list) == 1 + assert errors_list[0].__class__ == OpenAPIValidationError + assert ( + "'string' is not of type 'object', 'boolean'" + in errors_list[0].message + ) + + @pytest.mark.parametrize( + "component_schema", + [ + {"allOf": {"type": "string"}}, + {"type": "array", "items": [{"type": "string"}]}, + {"type": 123}, + {"type": "object", "required": "name"}, + {"type": "string", "minLength": "1"}, + {"$ref": 42}, + ], + ) + def test_malformed_schema_examples(self, component_schema): + spec = { + "openapi": "3.1.0", + "info": { + "title": "Test Api", + "version": "0.0.1", + }, + "paths": {}, + "components": { + "schemas": { + "Component": component_schema, + }, + }, + } + + errors = OpenAPIV31SpecValidator(spec).iter_errors() + + errors_list = list(errors) + assert len(errors_list) > 0 + assert errors_list[0].__class__ == OpenAPIValidationError