Skip to content

Commit

Permalink
Add new base viewset for audit logging (#5173)
Browse files Browse the repository at this point in the history
  • Loading branch information
rgraber authored Oct 16, 2024
1 parent 16c9dd3 commit 799c0fb
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 0 deletions.
105 changes: 105 additions & 0 deletions kobo/apps/audit_log/base_views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from rest_framework import mixins, viewsets


def get_nested_field(obj, field: str):
"""
Retrieve a period-separated nested field from an object or dict
Raises an exception if the field is not found
"""
split = field.split('.')
attribute = getattr(obj, split[0])
if len(split) > 1:
for inner_field in split[1:]:
if isinstance(attribute, dict):
attribute = attribute.get(inner_field)
else:
attribute = getattr(attribute, inner_field)
return attribute


class AuditLoggedViewSet(viewsets.GenericViewSet):
"""
A ViewSet for adding arbitrary object data to a request before and after request
Useful for storing information for audit logs on create/update. Allows inheriting
ViewSets to implement additional logic for get_object, perform_update,
perform_create, and perform_destroy via the get_object_override,
perform_update_override, perform_create_override, and perform_destroy_override
methods.
Sets the values on the inner HttpRequest object rather than the DRF Request
so middleware can access them.
"""

logged_fields = []

def get_object(self):
# actually fetch the object
obj = self.get_object_override()
if self.request.method in ['GET', 'HEAD']:
# since this is for audit logs, don't worry about read-only requests
return obj
audit_log_data = {}
for field in self.logged_fields:
value = get_nested_field(obj, field)
audit_log_data[field] = value
self.request._request.initial_data = audit_log_data
return obj

def perform_update(self, serializer):
self.perform_update_override(serializer)
audit_log_data = {}
for field in self.logged_fields:
value = get_nested_field(serializer.instance, field)
audit_log_data[field] = value
self.request._request.updated_data = audit_log_data

def perform_create(self, serializer):
self.perform_create_override(serializer)
audit_log_data = {}
for field in self.logged_fields:
value = get_nested_field(serializer.instance, field)
audit_log_data[field] = value
self.request._request.updated_data = audit_log_data

def perform_destroy(self, instance):
audit_log_data = {}
for field in self.logged_fields:
value = get_nested_field(instance, field)
audit_log_data[field] = value
self.request._request.initial_data = audit_log_data
self.perform_destroy_override(instance)

def perform_destroy_override(self, instance):
super().perform_destroy(instance)

def perform_create_override(self, serializer):
super().perform_create(serializer)

def perform_update_override(self, serializer):
super().perform_update(serializer)

def get_object_override(self):
return super().get_object()


class AuditLoggedModelViewSet(
AuditLoggedViewSet,
mixins.CreateModelMixin,
mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
mixins.ListModelMixin,
):
pass


class AuditLoggedNoUpdateModelViewSet(
AuditLoggedModelViewSet,
mixins.CreateModelMixin,
mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
mixins.ListModelMixin,
):
pass
91 changes: 91 additions & 0 deletions kobo/apps/audit_log/tests/test_base_views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from allauth.account.models import EmailAddress
from django.test import override_settings
from django.urls import reverse
from rest_framework import permissions, serializers
from rest_framework.routers import DefaultRouter

from kobo.apps.audit_log.base_views import AuditLoggedModelViewSet
from kobo.apps.kobo_auth.shortcuts import User
from kpi.tests.kpi_test_case import KpiTestCase


class DummyEmailSerializer(serializers.ModelSerializer):
"""
Basic model serializer for EmailAddresses
"""

class Meta:
model = EmailAddress
fields = '__all__'


class DummyViewSet(AuditLoggedModelViewSet):
"""
DummyViewSet for testing the functionality of the AuditLoggedModelViewSet
Uses the email address model because it's simple
"""

permission_classes = (permissions.AllowAny,)
queryset = EmailAddress.objects.all()
serializer_class = DummyEmailSerializer
logged_fields = ['email', 'verified']


class TestUrls:
"""
Register our DummyViewSet at a test-only url
"""

router = DefaultRouter()
router.register(r'test', DummyViewSet, basename='test-vs')
urlpatterns = router.urls


@override_settings(ROOT_URLCONF=TestUrls)
class TestAuditLoggedViewSet(KpiTestCase):
fixtures = ['test_data']

def test_creating_model_records_fields(self):
response = self.client.post(
reverse('test-vs-list'), data={'user': 1, 'email': 'new_email@example.com'}
)
request = response.wsgi_request
self.assertDictEqual(
request.updated_data, {'email': 'new_email@example.com', 'verified': False}
)

def test_updating_model_records_fields(self):
user = User.objects.get(pk=1)
email_address, _ = EmailAddress.objects.get_or_create(
user=user, email='initial_email@example.com'
)
email_address.save()
response = self.client.patch(
reverse('test-vs-detail', args=[email_address.pk]),
data={'email': 'newer_email@example.com'},
)
request = response.wsgi_request
self.assertEqual(
request.initial_data,
{'email': 'initial_email@example.com', 'verified': False},
)
self.assertEqual(
request.updated_data,
{'email': 'newer_email@example.com', 'verified': False},
)

def test_destroying_model_records_fields(self):
user = User.objects.get(pk=1)
email_address, _ = EmailAddress.objects.get_or_create(
user=user, email='initial_email@example.com'
)
email_address.save()
response = self.client.delete(
reverse('test-vs-detail', args=[email_address.pk])
)
request = response.wsgi_request
self.assertEqual(
request.initial_data,
{'email': 'initial_email@example.com', 'verified': False},
)

0 comments on commit 799c0fb

Please sign in to comment.