refactor: reduce duplication in network filters

This commit is contained in:
James Graham
2021-04-25 11:25:15 +01:00
parent 20812dfc40
commit cc25a154ac

View File

@@ -3,6 +3,7 @@ Views for displaying networks of :class:`People` and :class:`Relationship`s.
""" """
import logging import logging
import typing
from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.auth.mixins import LoginRequiredMixin
from django.db.models import Q from django.db.models import Q
@@ -15,8 +16,11 @@ from people import forms, models, serializers
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def filter_relationships(form, at_date): def filter_by_form_answers(model: typing.Type, answerset_model: typing.Type,
relationship_answerset_set = models.RelationshipAnswerSet.objects.filter( relationship_key: str):
"""Build a filter to select based on form responses."""
def inner(form, at_date):
answerset_set = answerset_model.objects.filter(
Q(replaced_timestamp__gte=at_date) Q(replaced_timestamp__gte=at_date)
| Q(replaced_timestamp__isnull=True), | Q(replaced_timestamp__isnull=True),
timestamp__lte=at_date) timestamp__lte=at_date)
@@ -24,42 +28,25 @@ def filter_relationships(form, at_date):
# Filter answers to relationship questions # Filter answers to relationship questions
for field, values in form.cleaned_data.items(): for field, values in form.cleaned_data.items():
if field.startswith(f'{form.question_prefix}question_') and values: if field.startswith(f'{form.question_prefix}question_') and values:
relationship_answerset_set = relationship_answerset_set.filter( answerset_set = answerset_set.filter(
question_answers__in=values) question_answers__in=values)
return models.Relationship.objects.filter( return model.objects.filter(
pk__in=relationship_answerset_set.values_list('relationship', pk__in=answerset_set.values_list(relationship_key, flat=True))
flat=True))
return inner
def filter_people(form, at_date): filter_relationships = filter_by_form_answers(models.Relationship,
answerset_set = models.PersonAnswerSet.objects.filter( models.RelationshipAnswerSet,
Q(replaced_timestamp__gte=at_date) 'relationship')
| Q(replaced_timestamp__isnull=True),
timestamp__lte=at_date)
# Filter answers to questions filter_organisations = filter_by_form_answers(models.Organisation,
for field, values in form.cleaned_data.items(): models.OrganisationAnswerSet,
if field.startswith(f'{form.question_prefix}question_') and values: 'organisation')
answerset_set = answerset_set.filter(question_answers__in=values)
return models.Person.objects.filter( filter_people = filter_by_form_answers(models.Person, models.PersonAnswerSet,
pk__in=answerset_set.values_list('person', flat=True)) 'person')
def filter_organisations(form, at_date):
answerset_set = models.OrganisationAnswerSet.objects.filter(
Q(replaced_timestamp__gte=at_date)
| Q(replaced_timestamp__isnull=True),
timestamp__lte=at_date)
# Filter answers to questions
for field, values in form.cleaned_data.items():
if field.startswith(f'{form.question_prefix}question_') and values:
answerset_set = answerset_set.filter(question_answers__in=values)
return models.Organisation.objects.filter(
pk__in=answerset_set.values_list('organisation', flat=True))
class NetworkView(LoginRequiredMixin, TemplateView): class NetworkView(LoginRequiredMixin, TemplateView):
@@ -67,11 +54,11 @@ class NetworkView(LoginRequiredMixin, TemplateView):
template_name = 'people/network.html' template_name = 'people/network.html'
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
forms = self.get_forms() all_forms = self.get_forms()
if all(map(lambda f: f.is_valid(), forms.values())): if all(map(lambda f: f.is_valid(), all_forms.values())):
return self.forms_valid(forms) return self.forms_valid(all_forms)
return self.forms_invalid(forms) return self.forms_invalid(all_forms)
def get_forms(self): def get_forms(self):
form_kwargs = self.get_form_kwargs() form_kwargs = self.get_form_kwargs()
@@ -100,23 +87,23 @@ class NetworkView(LoginRequiredMixin, TemplateView):
""" """
context = super().get_context_data(**kwargs) context = super().get_context_data(**kwargs)
forms = self.get_forms() all_forms = self.get_forms()
context['relationship_form'] = forms['relationship'] context['relationship_form'] = all_forms['relationship']
context['person_form'] = forms['person'] context['person_form'] = all_forms['person']
context['organisation_form'] = forms['organisation'] context['organisation_form'] = all_forms['organisation']
if not all(map(lambda f: f.is_valid(), forms.values())): if not all(map(lambda f: f.is_valid(), all_forms.values())):
return context return context
relationship_at_date = forms['relationship'].cleaned_data['date'] relationship_at_date = all_forms['relationship'].cleaned_data['date']
if not relationship_at_date: if not relationship_at_date:
relationship_at_date = timezone.now().date() relationship_at_date = timezone.now().date()
person_at_date = forms['person'].cleaned_data['date'] person_at_date = all_forms['person'].cleaned_data['date']
if not person_at_date: if not person_at_date:
person_at_date = timezone.now().date() person_at_date = timezone.now().date()
organisation_at_date = forms['organisation'].cleaned_data['date'] organisation_at_date = all_forms['organisation'].cleaned_data['date']
if not organisation_at_date: if not organisation_at_date:
organisation_at_date = timezone.now().date() organisation_at_date = timezone.now().date()
@@ -126,15 +113,15 @@ class NetworkView(LoginRequiredMixin, TemplateView):
relationship_at_date += timezone.timedelta(days=1) relationship_at_date += timezone.timedelta(days=1)
context['person_set'] = serializers.PersonSerializer( context['person_set'] = serializers.PersonSerializer(
filter_people(forms['person'], person_at_date), filter_people(all_forms['person'], person_at_date),
many=True).data many=True).data
context['organisation_set'] = serializers.OrganisationSerializer( context['organisation_set'] = serializers.OrganisationSerializer(
filter_organisations(forms['organisation'], organisation_at_date), filter_organisations(all_forms['organisation'], organisation_at_date),
many=True).data many=True).data
context['relationship_set'] = serializers.RelationshipSerializer( context['relationship_set'] = serializers.RelationshipSerializer(
filter_relationships(forms['relationship'], relationship_at_date), filter_relationships(all_forms['relationship'], relationship_at_date),
many=True).data many=True).data
logger.info('Found %d distinct relationships matching filters', logger.info('Found %d distinct relationships matching filters',
@@ -142,12 +129,12 @@ class NetworkView(LoginRequiredMixin, TemplateView):
return context return context
def forms_valid(self, forms): def forms_valid(self, all_forms):
try: try:
return self.render_to_response(self.get_context_data()) return self.render_to_response(self.get_context_data())
except ValidationError: except ValidationError:
return self.forms_invalid(forms) return self.forms_invalid(all_forms)
def forms_invalid(self, forms): def forms_invalid(self, all_forms):
return self.render_to_response(self.get_context_data()) return self.render_to_response(self.get_context_data())