Skip to content

Commit

Permalink
fix memory leak with iterable validation (#1271)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt authored Apr 22, 2024
1 parent f537a03 commit a99729a
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 2 deletions.
12 changes: 12 additions & 0 deletions src/input/return_enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use serde::{ser::Error, Serialize, Serializer};
use crate::errors::{
py_err_string, ErrorType, ErrorTypeDefaults, InputValue, ToErrorValue, ValError, ValLineError, ValResult,
};
use crate::py_gc::PyGcTraverse;
use crate::tools::{extract_i64, new_py_string, py_err};
use crate::validators::{CombinedValidator, Exactness, ValidationState, Validator};

Expand Down Expand Up @@ -327,6 +328,15 @@ pub enum GenericIterator<'data> {
JsonArray(GenericJsonIterator<'data>),
}

impl PyGcTraverse for GenericIterator<'_> {
fn py_gc_traverse(&self, visit: &pyo3::PyVisit<'_>) -> Result<(), pyo3::PyTraverseError> {
if let Self::PyIterator(iter) = self {
iter.py_gc_traverse(visit)?;
}
Ok(())
}
}

impl GenericIterator<'_> {
pub(crate) fn into_static(self) -> GenericIterator<'static> {
match self {
Expand Down Expand Up @@ -385,6 +395,8 @@ impl GenericPyIterator {
}
}

impl_py_gc_traverse!(GenericPyIterator { obj, iter });

#[derive(Debug, Clone)]
pub struct GenericJsonIterator<'data> {
array: JsonArray<'data>,
Expand Down
9 changes: 8 additions & 1 deletion src/validators/generator.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::fmt;
use std::sync::Arc;

use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::{prelude::*, PyTraverseError, PyVisit};

use crate::errors::{ErrorType, LocItem, ValError, ValResult};
use crate::input::{BorrowInput, GenericIterator, Input};
use crate::py_gc::PyGcTraverse;
use crate::recursion_guard::RecursionState;
use crate::tools::SchemaDict;
use crate::ValidationError;
Expand Down Expand Up @@ -201,6 +202,12 @@ impl ValidatorIterator {
fn __str__(&self) -> String {
self.__repr__()
}

fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
self.iterator.py_gc_traverse(&visit)?;
self.validator.py_gc_traverse(&visit)?;
Ok(())
}
}

/// Owned validator wrapper for use in generators in functions, this can be passed back to python
Expand Down
41 changes: 40 additions & 1 deletion tests/test_garbage_collection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gc
import platform
from typing import Any
from typing import Any, Iterable
from weakref import WeakValueDictionary

import pytest
Expand Down Expand Up @@ -79,3 +79,42 @@ class MyModel(BaseModel):
gc.collect(2)

assert len(cache) == 0


@pytest.mark.xfail(
condition=platform.python_implementation() == 'PyPy', reason='https://foss.heptapod.net/pypy/pypy/-/issues/3899'
)
def test_gc_validator_iterator() -> None:
# test for /~https://github.com/pydantic/pydantic/issues/9243
class MyModel:
iter: Iterable[int]

v = SchemaValidator(
core_schema.model_schema(
MyModel,
core_schema.model_fields_schema(
{'iter': core_schema.model_field(core_schema.generator_schema(core_schema.int_schema()))}
),
),
)

class MyIterable:
def __iter__(self):
return self

def __next__(self):
raise StopIteration()

cache: 'WeakValueDictionary[int, Any]' = WeakValueDictionary()

for _ in range(10_000):
iterable = MyIterable()
cache[id(iterable)] = iterable
v.validate_python({'iter': iterable})
del iterable

gc.collect(0)
gc.collect(1)
gc.collect(2)

assert len(cache) == 0

0 comments on commit a99729a

Please sign in to comment.