From cc25a154acff0f4f74819704c869c0f12b948307 Mon Sep 17 00:00:00 2001 From: James Graham Date: Sun, 25 Apr 2021 11:25:15 +0100 Subject: [PATCH] refactor: reduce duplication in network filters --- people/views/network.py | 101 +++++++++++++++++----------------------- 1 file changed, 44 insertions(+), 57 deletions(-) diff --git a/people/views/network.py b/people/views/network.py index 7f2037d..cf81d61 100644 --- a/people/views/network.py +++ b/people/views/network.py @@ -3,6 +3,7 @@ Views for displaying networks of :class:`People` and :class:`Relationship`s. """ import logging +import typing from django.contrib.auth.mixins import LoginRequiredMixin from django.db.models import Q @@ -15,51 +16,37 @@ from people import forms, models, serializers logger = logging.getLogger(__name__) # pylint: disable=invalid-name -def filter_relationships(form, at_date): - relationship_answerset_set = models.RelationshipAnswerSet.objects.filter( - Q(replaced_timestamp__gte=at_date) - | Q(replaced_timestamp__isnull=True), - timestamp__lte=at_date) +def filter_by_form_answers(model: typing.Type, answerset_model: typing.Type, + relationship_key: str): + """Build a filter to select based on form responses.""" + def inner(form, at_date): + answerset_set = answerset_model.objects.filter( + Q(replaced_timestamp__gte=at_date) + | Q(replaced_timestamp__isnull=True), + timestamp__lte=at_date) - # Filter answers to relationship questions - for field, values in form.cleaned_data.items(): - if field.startswith(f'{form.question_prefix}question_') and values: - relationship_answerset_set = relationship_answerset_set.filter( - question_answers__in=values) + # Filter answers to relationship questions + for field, values in form.cleaned_data.items(): + if field.startswith(f'{form.question_prefix}question_') and values: + answerset_set = answerset_set.filter( + question_answers__in=values) - return models.Relationship.objects.filter( - pk__in=relationship_answerset_set.values_list('relationship', - flat=True)) + return model.objects.filter( + pk__in=answerset_set.values_list(relationship_key, flat=True)) + + return inner -def filter_people(form, at_date): - answerset_set = models.PersonAnswerSet.objects.filter( - Q(replaced_timestamp__gte=at_date) - | Q(replaced_timestamp__isnull=True), - timestamp__lte=at_date) +filter_relationships = filter_by_form_answers(models.Relationship, + models.RelationshipAnswerSet, + 'relationship') - # Filter answers to questions - for field, values in form.cleaned_data.items(): - if field.startswith(f'{form.question_prefix}question_') and values: - answerset_set = answerset_set.filter(question_answers__in=values) +filter_organisations = filter_by_form_answers(models.Organisation, + models.OrganisationAnswerSet, + 'organisation') - return models.Person.objects.filter( - pk__in=answerset_set.values_list('person', flat=True)) - - -def filter_organisations(form, at_date): - answerset_set = models.OrganisationAnswerSet.objects.filter( - Q(replaced_timestamp__gte=at_date) - | Q(replaced_timestamp__isnull=True), - timestamp__lte=at_date) - - # Filter answers to questions - for field, values in form.cleaned_data.items(): - if field.startswith(f'{form.question_prefix}question_') and values: - answerset_set = answerset_set.filter(question_answers__in=values) - - return models.Organisation.objects.filter( - pk__in=answerset_set.values_list('organisation', flat=True)) +filter_people = filter_by_form_answers(models.Person, models.PersonAnswerSet, + 'person') class NetworkView(LoginRequiredMixin, TemplateView): @@ -67,11 +54,11 @@ class NetworkView(LoginRequiredMixin, TemplateView): template_name = 'people/network.html' def post(self, request, *args, **kwargs): - forms = self.get_forms() - if all(map(lambda f: f.is_valid(), forms.values())): - return self.forms_valid(forms) + all_forms = self.get_forms() + if all(map(lambda f: f.is_valid(), all_forms.values())): + return self.forms_valid(all_forms) - return self.forms_invalid(forms) + return self.forms_invalid(all_forms) def get_forms(self): form_kwargs = self.get_form_kwargs() @@ -100,23 +87,23 @@ class NetworkView(LoginRequiredMixin, TemplateView): """ context = super().get_context_data(**kwargs) - forms = self.get_forms() - context['relationship_form'] = forms['relationship'] - context['person_form'] = forms['person'] - context['organisation_form'] = forms['organisation'] + all_forms = self.get_forms() + context['relationship_form'] = all_forms['relationship'] + context['person_form'] = all_forms['person'] + context['organisation_form'] = all_forms['organisation'] - if not all(map(lambda f: f.is_valid(), forms.values())): + if not all(map(lambda f: f.is_valid(), all_forms.values())): return context - relationship_at_date = forms['relationship'].cleaned_data['date'] + relationship_at_date = all_forms['relationship'].cleaned_data['date'] if not relationship_at_date: relationship_at_date = timezone.now().date() - person_at_date = forms['person'].cleaned_data['date'] + person_at_date = all_forms['person'].cleaned_data['date'] if not person_at_date: person_at_date = timezone.now().date() - organisation_at_date = forms['organisation'].cleaned_data['date'] + organisation_at_date = all_forms['organisation'].cleaned_data['date'] if not organisation_at_date: organisation_at_date = timezone.now().date() @@ -126,15 +113,15 @@ class NetworkView(LoginRequiredMixin, TemplateView): relationship_at_date += timezone.timedelta(days=1) context['person_set'] = serializers.PersonSerializer( - filter_people(forms['person'], person_at_date), + filter_people(all_forms['person'], person_at_date), many=True).data context['organisation_set'] = serializers.OrganisationSerializer( - filter_organisations(forms['organisation'], organisation_at_date), + filter_organisations(all_forms['organisation'], organisation_at_date), many=True).data context['relationship_set'] = serializers.RelationshipSerializer( - filter_relationships(forms['relationship'], relationship_at_date), + filter_relationships(all_forms['relationship'], relationship_at_date), many=True).data logger.info('Found %d distinct relationships matching filters', @@ -142,12 +129,12 @@ class NetworkView(LoginRequiredMixin, TemplateView): return context - def forms_valid(self, forms): + def forms_valid(self, all_forms): try: return self.render_to_response(self.get_context_data()) except ValidationError: - return self.forms_invalid(forms) + return self.forms_invalid(all_forms) - def forms_invalid(self, forms): + def forms_invalid(self, all_forms): return self.render_to_response(self.get_context_data())