diff --git a/Lib/collections/__init__.py b/Lib/collections/__init__.py index 55ffc36ea5b0c7..512a956ff31b64 100644 --- a/Lib/collections/__init__.py +++ b/Lib/collections/__init__.py @@ -328,14 +328,14 @@ def __ior__(self, other): return self def __or__(self, other): - if not isinstance(other, dict): + if not isinstance(other, (dict, frozendict)): return NotImplemented new = self.__class__(self) new.update(other) return new def __ror__(self, other): - if not isinstance(other, dict): + if not isinstance(other, (dict, frozendict)): return NotImplemented new = self.__class__(other) new.update(self) @@ -1221,14 +1221,14 @@ def __repr__(self): def __or__(self, other): if isinstance(other, UserDict): return self.__class__(self.data | other.data) - if isinstance(other, dict): + if isinstance(other, (dict, frozendict)): return self.__class__(self.data | other) return NotImplemented def __ror__(self, other): if isinstance(other, UserDict): return self.__class__(other.data | self.data) - if isinstance(other, dict): + if isinstance(other, (dict, frozendict)): return self.__class__(other | self.data) return NotImplemented diff --git a/Lib/test/test_ordered_dict.py b/Lib/test/test_ordered_dict.py index 4204a6a47d2a81..b5e6adf0db1660 100644 --- a/Lib/test/test_ordered_dict.py +++ b/Lib/test/test_ordered_dict.py @@ -698,6 +698,7 @@ def test_merge_operator(self): d |= list(b.items()) expected = OrderedDict({0: 0, 1: 1, 2: 2, 3: 3}) self.assertEqual(a | dict(b), expected) + self.assertEqual(a | frozendict(b), expected) self.assertEqual(a | b, expected) self.assertEqual(c, expected) self.assertEqual(d, expected) @@ -706,6 +707,7 @@ def test_merge_operator(self): c |= a expected = OrderedDict({1: 1, 2: 1, 3: 3, 0: 0}) self.assertEqual(dict(b) | a, expected) + self.assertEqual(frozendict(b) | a, expected) self.assertEqual(b | a, expected) self.assertEqual(c, expected) diff --git a/Objects/odictobject.c b/Objects/odictobject.c index 25928028919c9c..238a4845a8924e 100644 --- a/Objects/odictobject.c +++ b/Objects/odictobject.c @@ -2271,7 +2271,7 @@ static int mutablemapping_update_arg(PyObject *self, PyObject *arg) { int res = 0; - if (PyDict_CheckExact(arg)) { + if (PyAnyDict_CheckExact(arg)) { PyObject *items = PyDict_Items(arg); if (items == NULL) { return -1;