-
-
Notifications
You must be signed in to change notification settings - Fork 186
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add new base viewset for audit logging (#5173)
- Loading branch information
Showing
2 changed files
with
196 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from rest_framework import mixins, viewsets | ||
|
||
|
||
def get_nested_field(obj, field: str): | ||
""" | ||
Retrieve a period-separated nested field from an object or dict | ||
Raises an exception if the field is not found | ||
""" | ||
split = field.split('.') | ||
attribute = getattr(obj, split[0]) | ||
if len(split) > 1: | ||
for inner_field in split[1:]: | ||
if isinstance(attribute, dict): | ||
attribute = attribute.get(inner_field) | ||
else: | ||
attribute = getattr(attribute, inner_field) | ||
return attribute | ||
|
||
|
||
class AuditLoggedViewSet(viewsets.GenericViewSet): | ||
""" | ||
A ViewSet for adding arbitrary object data to a request before and after request | ||
Useful for storing information for audit logs on create/update. Allows inheriting | ||
ViewSets to implement additional logic for get_object, perform_update, | ||
perform_create, and perform_destroy via the get_object_override, | ||
perform_update_override, perform_create_override, and perform_destroy_override | ||
methods. | ||
Sets the values on the inner HttpRequest object rather than the DRF Request | ||
so middleware can access them. | ||
""" | ||
|
||
logged_fields = [] | ||
|
||
def get_object(self): | ||
# actually fetch the object | ||
obj = self.get_object_override() | ||
if self.request.method in ['GET', 'HEAD']: | ||
# since this is for audit logs, don't worry about read-only requests | ||
return obj | ||
audit_log_data = {} | ||
for field in self.logged_fields: | ||
value = get_nested_field(obj, field) | ||
audit_log_data[field] = value | ||
self.request._request.initial_data = audit_log_data | ||
return obj | ||
|
||
def perform_update(self, serializer): | ||
self.perform_update_override(serializer) | ||
audit_log_data = {} | ||
for field in self.logged_fields: | ||
value = get_nested_field(serializer.instance, field) | ||
audit_log_data[field] = value | ||
self.request._request.updated_data = audit_log_data | ||
|
||
def perform_create(self, serializer): | ||
self.perform_create_override(serializer) | ||
audit_log_data = {} | ||
for field in self.logged_fields: | ||
value = get_nested_field(serializer.instance, field) | ||
audit_log_data[field] = value | ||
self.request._request.updated_data = audit_log_data | ||
|
||
def perform_destroy(self, instance): | ||
audit_log_data = {} | ||
for field in self.logged_fields: | ||
value = get_nested_field(instance, field) | ||
audit_log_data[field] = value | ||
self.request._request.initial_data = audit_log_data | ||
self.perform_destroy_override(instance) | ||
|
||
def perform_destroy_override(self, instance): | ||
super().perform_destroy(instance) | ||
|
||
def perform_create_override(self, serializer): | ||
super().perform_create(serializer) | ||
|
||
def perform_update_override(self, serializer): | ||
super().perform_update(serializer) | ||
|
||
def get_object_override(self): | ||
return super().get_object() | ||
|
||
|
||
class AuditLoggedModelViewSet( | ||
AuditLoggedViewSet, | ||
mixins.CreateModelMixin, | ||
mixins.RetrieveModelMixin, | ||
mixins.UpdateModelMixin, | ||
mixins.DestroyModelMixin, | ||
mixins.ListModelMixin, | ||
): | ||
pass | ||
|
||
|
||
class AuditLoggedNoUpdateModelViewSet( | ||
AuditLoggedModelViewSet, | ||
mixins.CreateModelMixin, | ||
mixins.RetrieveModelMixin, | ||
mixins.DestroyModelMixin, | ||
mixins.ListModelMixin, | ||
): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from allauth.account.models import EmailAddress | ||
from django.test import override_settings | ||
from django.urls import reverse | ||
from rest_framework import permissions, serializers | ||
from rest_framework.routers import DefaultRouter | ||
|
||
from kobo.apps.audit_log.base_views import AuditLoggedModelViewSet | ||
from kobo.apps.kobo_auth.shortcuts import User | ||
from kpi.tests.kpi_test_case import KpiTestCase | ||
|
||
|
||
class DummyEmailSerializer(serializers.ModelSerializer): | ||
""" | ||
Basic model serializer for EmailAddresses | ||
""" | ||
|
||
class Meta: | ||
model = EmailAddress | ||
fields = '__all__' | ||
|
||
|
||
class DummyViewSet(AuditLoggedModelViewSet): | ||
""" | ||
DummyViewSet for testing the functionality of the AuditLoggedModelViewSet | ||
Uses the email address model because it's simple | ||
""" | ||
|
||
permission_classes = (permissions.AllowAny,) | ||
queryset = EmailAddress.objects.all() | ||
serializer_class = DummyEmailSerializer | ||
logged_fields = ['email', 'verified'] | ||
|
||
|
||
class TestUrls: | ||
""" | ||
Register our DummyViewSet at a test-only url | ||
""" | ||
|
||
router = DefaultRouter() | ||
router.register(r'test', DummyViewSet, basename='test-vs') | ||
urlpatterns = router.urls | ||
|
||
|
||
@override_settings(ROOT_URLCONF=TestUrls) | ||
class TestAuditLoggedViewSet(KpiTestCase): | ||
fixtures = ['test_data'] | ||
|
||
def test_creating_model_records_fields(self): | ||
response = self.client.post( | ||
reverse('test-vs-list'), data={'user': 1, 'email': 'new_email@example.com'} | ||
) | ||
request = response.wsgi_request | ||
self.assertDictEqual( | ||
request.updated_data, {'email': 'new_email@example.com', 'verified': False} | ||
) | ||
|
||
def test_updating_model_records_fields(self): | ||
user = User.objects.get(pk=1) | ||
email_address, _ = EmailAddress.objects.get_or_create( | ||
user=user, email='initial_email@example.com' | ||
) | ||
email_address.save() | ||
response = self.client.patch( | ||
reverse('test-vs-detail', args=[email_address.pk]), | ||
data={'email': 'newer_email@example.com'}, | ||
) | ||
request = response.wsgi_request | ||
self.assertEqual( | ||
request.initial_data, | ||
{'email': 'initial_email@example.com', 'verified': False}, | ||
) | ||
self.assertEqual( | ||
request.updated_data, | ||
{'email': 'newer_email@example.com', 'verified': False}, | ||
) | ||
|
||
def test_destroying_model_records_fields(self): | ||
user = User.objects.get(pk=1) | ||
email_address, _ = EmailAddress.objects.get_or_create( | ||
user=user, email='initial_email@example.com' | ||
) | ||
email_address.save() | ||
response = self.client.delete( | ||
reverse('test-vs-detail', args=[email_address.pk]) | ||
) | ||
request = response.wsgi_request | ||
self.assertEqual( | ||
request.initial_data, | ||
{'email': 'initial_email@example.com', 'verified': False}, | ||
) |