From c56cc2f99507e9a2504b9225a3b39eaa664ec2ef Mon Sep 17 00:00:00 2001 From: Jake Howard Date: Wed, 7 Sep 2022 14:30:52 +0100 Subject: [PATCH] Correctly get lexer for language `get_lexer_by_name` doesn't actually get the lexer by its name... --- website/contrib/code_block/blocks.py | 6 +++--- website/contrib/code_block/tests.py | 15 +++++++++++++-- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/website/contrib/code_block/blocks.py b/website/contrib/code_block/blocks.py index ddadf4a..e73f5f6 100644 --- a/website/contrib/code_block/blocks.py +++ b/website/contrib/code_block/blocks.py @@ -3,7 +3,7 @@ from typing import Iterator from django.utils.safestring import mark_safe from pygments import highlight from pygments.formatters.html import HtmlFormatter -from pygments.lexers import get_all_lexers, get_lexer_by_name +from pygments.lexers import find_lexer_class, get_all_lexers from wagtail.blocks import ( BooleanBlock, CharBlock, @@ -21,11 +21,11 @@ def get_language_choices() -> Iterator[tuple[str, str]]: class CodeStructValue(StructValue): def code(self) -> str: - lexer = get_lexer_by_name(self.get("language")) + lexer = find_lexer_class(self["language"])() formatter = HtmlFormatter( linenos=None, ) - return mark_safe(highlight(self.get("source"), lexer, formatter)) + return mark_safe(highlight(self["source"], lexer, formatter)) class CodeBlock(StructBlock): diff --git a/website/contrib/code_block/tests.py b/website/contrib/code_block/tests.py index 60a7f5d..cb08576 100644 --- a/website/contrib/code_block/tests.py +++ b/website/contrib/code_block/tests.py @@ -1,10 +1,11 @@ -from django.test import TestCase +from django.test import SimpleTestCase from django.urls import reverse +from .blocks import CodeStructValue, get_language_choices from .utils import PYGMENTS_VERSION_SLUG -class PygmentsStylesTestCase(TestCase): +class PygmentsStylesTestCase(SimpleTestCase): url = reverse("code-block:styles") def test_accessible(self) -> None: @@ -15,3 +16,13 @@ class PygmentsStylesTestCase(TestCase): def test_url_contains_version(self) -> None: self.assertIn(PYGMENTS_VERSION_SLUG, self.url) + + +class CodeStructValueTestCase(SimpleTestCase): + def test_highlights(self) -> None: + for language, _ in get_language_choices(): + with self.subTest(language): + block = CodeStructValue( + None, [("source", "test"), ("language", language)] + ) + self.assertIsInstance(block.code(), str)