Skip to content

Commit

Permalink
fix: fix integration tests after adding autoconnect=False
Browse files Browse the repository at this point in the history
during DB connection
  • Loading branch information
raphael0202 committed Dec 9, 2022
1 parent 7e43b4d commit 0382e35
Show file tree
Hide file tree
Showing 9 changed files with 390 additions and 351 deletions.
10 changes: 6 additions & 4 deletions tests/integration/insights/test_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@

@pytest.fixture(autouse=True)
def _set_up_and_tear_down(peewee_db):
clean_db()
# Run the test case.
yield
clean_db()
with peewee_db:
# clean db
clean_db()
# Run the test case.
yield
clean_db()


def test_annotation_fails_is_rolledback(mocker):
Expand Down
43 changes: 22 additions & 21 deletions tests/integration/insights/test_category_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,28 @@

@pytest.fixture(autouse=True)
def _set_up_and_tear_down(peewee_db):
# clean db
clean_db()
# a category already exists
PredictionFactory(
barcode=barcode1,
type="category",
value_tag="en:salmons",
automatic_processing=False,
predictor="matcher",
)
ProductInsightFactory(
id=insight_id1,
barcode=barcode1,
type="category",
value_tag="en:salmons",
predictor="matcher",
)
# Run the test case.
yield
# Tear down.
clean_db()
with peewee_db:
# clean db
clean_db()
# a category already exists
PredictionFactory(
barcode=barcode1,
type="category",
value_tag="en:salmons",
automatic_processing=False,
predictor="matcher",
)
ProductInsightFactory(
id=insight_id1,
barcode=barcode1,
type="category",
value_tag="en:salmons",
predictor="matcher",
)
# Run the test case.
yield
# Tear down.
clean_db()


def matcher_prediction(category):
Expand Down
22 changes: 12 additions & 10 deletions tests/integration/insights/test_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@

@pytest.fixture()
def image_model(peewee_db):
clean_db()
yield ImageModelFactory(source_image="/1/1.jpg")
clean_db()
with peewee_db:
clean_db()
yield ImageModelFactory(source_image="/1/1.jpg")
clean_db()


class FakeNutriscoreModel(RemoteModel):
Expand Down Expand Up @@ -75,11 +76,12 @@ def test_run_object_detection_model(mocker, image_model, model_name, label_names
assert image_prediction.max_confidence == 0.8


def test_run_object_detection_model_no_image_instance():
image_prediction = run_object_detection_model(
ObjectDetectionModel.nutriscore,
None,
source_image="/images/1/1.jpg",
threshold=0.1,
)
def test_run_object_detection_model_no_image_instance(peewee_db):
with peewee_db:
image_prediction = run_object_detection_model(
ObjectDetectionModel.nutriscore,
None,
source_image="/images/1/1.jpg",
threshold=0.1,
)
assert image_prediction is None
13 changes: 7 additions & 6 deletions tests/integration/insights/test_process_insights.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@

@pytest.fixture(autouse=True)
def _set_up_and_tear_down(peewee_db):
# clean db
clean_db()
# Run the test case.
yield
# Tear down.
clean_db()
with peewee_db:
# clean db
clean_db()
# Run the test case.
yield
# Tear down.
clean_db()


# global for generating items
Expand Down
58 changes: 39 additions & 19 deletions tests/integration/test_annotate_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ def client():

@pytest.fixture(autouse=True)
def _set_up_and_tear_down(peewee_db):
clean_db()
# Run the test case.
with peewee_db:
clean_db()
# Run the test case.
yield
clean_db()

with peewee_db:
clean_db()


def _fake_store(monkeypatch, barcode):
Expand Down Expand Up @@ -126,11 +129,14 @@ def test_logo_annotation_missing_value_when_required(logo_type, client):
}


def test_logo_annotation_incorrect_value_label_type(client):
def test_logo_annotation_incorrect_value_label_type(client, peewee_db):
"""A language-prefixed value is expected for label type."""
ann = LogoAnnotationFactory(
image_prediction__image__source_image="/images/2.jpg", annotation_type="label"
)

with peewee_db:
ann = LogoAnnotationFactory(
image_prediction__image__source_image="/images/2.jpg",
annotation_type="label",
)
result = client.simulate_post(
"/api/v1/images/logos/annotate",
json={
Expand All @@ -148,10 +154,12 @@ def test_logo_annotation_incorrect_value_label_type(client):
}


def test_logo_annotation_brand(client, monkeypatch, fake_taxonomy):
ann = LogoAnnotationFactory(
image_prediction__image__source_image="/images/2.jpg", annotation_type="brand"
)
def test_logo_annotation_brand(client, peewee_db, monkeypatch, fake_taxonomy):
with peewee_db:
ann = LogoAnnotationFactory(
image_prediction__image__source_image="/images/2.jpg",
annotation_type="brand",
)
barcode = ann.image_prediction.image.barcode
_fake_store(monkeypatch, barcode)
monkeypatch.setattr(
Expand All @@ -169,15 +177,19 @@ def test_logo_annotation_brand(client, monkeypatch, fake_taxonomy):
end = datetime.utcnow()
assert result.status_code == 200
assert result.json == {"created insights": 1}
ann = LogoAnnotation.get(LogoAnnotation.id == ann.id)

with peewee_db:
ann = LogoAnnotation.get(LogoAnnotation.id == ann.id)
assert ann.annotation_type == "brand"
assert ann.annotation_value == "etorki"
assert ann.annotation_value_tag == "etorki"
assert ann.taxonomy_value == "Etorki"
assert ann.username == "a"
assert start <= ann.completed_at <= end
# we generate a prediction
predictions = list(Prediction.select().filter(barcode=barcode).execute())

with peewee_db:
predictions = list(Prediction.select().filter(barcode=barcode).execute())
assert len(predictions) == 1
(prediction,) = predictions
assert prediction.type == "brand"
Expand All @@ -195,7 +207,9 @@ def test_logo_annotation_brand(client, monkeypatch, fake_taxonomy):
assert start <= prediction.timestamp <= end
assert prediction.automatic_processing
# We check that this prediction in turn generates an insight
insights = list(ProductInsight.select().filter(barcode=barcode).execute())

with peewee_db:
insights = list(ProductInsight.select().filter(barcode=barcode).execute())
assert len(insights) == 1
(insight,) = insights
assert insight.type == "brand"
Expand All @@ -216,11 +230,14 @@ def test_logo_annotation_brand(client, monkeypatch, fake_taxonomy):
assert insight.completed_at is None # we did not run annotate yet


def test_logo_annotation_label(client, monkeypatch, fake_taxonomy):
def test_logo_annotation_label(client, peewee_db, monkeypatch, fake_taxonomy):
"""This test will check that, given an image with a logo above the confidence threshold,
that is then fed into the ANN logos and labels model, we annotate properly a product.
"""
ann = LogoAnnotationFactory(image_prediction__image__source_image="/images/2.jpg")
with peewee_db:
ann = LogoAnnotationFactory(
image_prediction__image__source_image="/images/2.jpg"
)
barcode = ann.image_prediction.image.barcode
_fake_store(monkeypatch, barcode)
start = datetime.utcnow()
Expand All @@ -237,15 +254,17 @@ def test_logo_annotation_label(client, monkeypatch, fake_taxonomy):
end = datetime.utcnow()
assert result.status_code == 200
assert result.json == {"created insights": 1}
ann = LogoAnnotation.get(LogoAnnotation.id == ann.id)
with peewee_db:
ann = LogoAnnotation.get(LogoAnnotation.id == ann.id)
assert ann.annotation_type == "label"
assert ann.annotation_value == "en:eu-organic"
assert ann.annotation_value_tag == "en:eu-organic"
assert ann.taxonomy_value == "en:eu-organic"
assert ann.username == "a"
assert start <= ann.completed_at <= end
# we generate a prediction
predictions = list(Prediction.select().filter(barcode=barcode).execute())
with peewee_db:
predictions = list(Prediction.select().filter(barcode=barcode).execute())
assert len(predictions) == 1
(prediction,) = predictions
assert prediction.type == "label"
Expand All @@ -263,7 +282,8 @@ def test_logo_annotation_label(client, monkeypatch, fake_taxonomy):
assert start <= prediction.timestamp <= end
assert prediction.automatic_processing
# We check that this prediction in turn generates an insight
insights = list(ProductInsight.select().filter(barcode=barcode).execute())
with peewee_db:
insights = list(ProductInsight.select().filter(barcode=barcode).execute())
assert len(insights) == 1
(insight,) = insights
assert insight.type == "label"
Expand Down
Loading

0 comments on commit 0382e35

Please sign in to comment.