Use serializer to validate search params

This commit is contained in:
Jake Howard 2022-07-29 16:50:44 +01:00
parent 639f5885a4
commit c4109e42f1
Signed by: jake
GPG key ID: 57AFB45680EDD477
2 changed files with 43 additions and 25 deletions

View file

@ -1,6 +1,7 @@
from django.core.paginator import EmptyPage, PageNotAnInteger, Paginator from django.core.paginator import EmptyPage, Paginator
from django.http.request import HttpRequest from django.http.request import HttpRequest
from django.utils.functional import cached_property from django.utils.functional import cached_property
from rest_framework import serializers
from wagtail.models import Page from wagtail.models import Page
from wagtail.query import PageQuerySet from wagtail.query import PageQuerySet
from wagtail.search.models import Query from wagtail.search.models import Query
@ -19,6 +20,10 @@ class SearchPage(BaseContentMixin, BasePage): # type: ignore[misc]
search_fields = BasePage.search_fields + BaseContentMixin.search_fields search_fields = BasePage.search_fields + BaseContentMixin.search_fields
PAGE_SIZE = 15 PAGE_SIZE = 15
class SearchParamsSerializer(serializers.Serializer):
q = serializers.CharField()
page = serializers.IntegerField(min_value=1, default=1)
@cached_property @cached_property
def reading_time(self) -> int: def reading_time(self) -> int:
""" """
@ -35,15 +40,19 @@ class SearchPage(BaseContentMixin, BasePage): # type: ignore[misc]
def get_context(self, request: HttpRequest) -> dict: def get_context(self, request: HttpRequest) -> dict:
context = super().get_context(request) context = super().get_context(request)
if query_string := request.GET.get("q", ""):
filters, query = parse_query_string(query_string) serializer = self.SearchParamsSerializer(data=request.GET)
Query.get(query_string).add_hit()
if serializer.is_valid():
search_query = serializer.validated_data["q"]
filters, query = parse_query_string(search_query)
Query.get(search_query).add_hit()
pages = self.get_search_pages().search(query) pages = self.get_search_pages().search(query)
else:
pages = Page.objects.none()
paginator = Paginator(pages, self.PAGE_SIZE) paginator = Paginator(pages, self.PAGE_SIZE)
page_num = request.GET.get("page", "1") context["paginator"] = paginator
page_num = serializer.validated_data["page"]
context["page_num"] = page_num
try: try:
results = paginator.page(page_num) results = paginator.page(page_num)
@ -51,13 +60,16 @@ class SearchPage(BaseContentMixin, BasePage): # type: ignore[misc]
if not isinstance(results.object_list, PageQuerySet): if not isinstance(results.object_list, PageQuerySet):
results.object_list = Page.objects.filter( results.object_list = Page.objects.filter(
id__in=list( id__in=list(
results.object_list.get_queryset().values_list("id", flat=True) results.object_list.get_queryset().values_list(
"id", flat=True
)
) )
).specific() ).specific()
except (PageNotAnInteger, EmptyPage): except EmptyPage:
results = None results = []
context["invalid_page"] = True
context["results"] = results context["results"] = results
else:
context["invalid_search"] = True
return context return context

View file

@ -11,8 +11,14 @@
{% endif %} {% endif %}
<section class="container"> <section class="container">
{% if invalid_page %} {% if invalid_search %}
<p>Invalid page</p> <p>Invalid search</p>
{% elif results|length == 0 %}
{% if page_num > paginator.num_pages %}
<p>There aren't {{ page_num }} page - only {{ paginator.num_pages }}.</p>
{% else %}
<p>No results</p>
{% endif %}
{% else %} {% else %}
{% for page in results %} {% for page in results %}
{% include "common/listing-item.html" %} {% include "common/listing-item.html" %}