Correctly get lexer for language

`get_lexer_by_name` doesn't actually get the lexer by its name...
This commit is contained in:
Jake Howard 2022-09-07 14:30:52 +01:00
parent c7e56ab038
commit c56cc2f995
Signed by: jake
GPG key ID: 57AFB45680EDD477
2 changed files with 16 additions and 5 deletions

View file

@ -3,7 +3,7 @@ from typing import Iterator
from django.utils.safestring import mark_safe from django.utils.safestring import mark_safe
from pygments import highlight from pygments import highlight
from pygments.formatters.html import HtmlFormatter from pygments.formatters.html import HtmlFormatter
from pygments.lexers import get_all_lexers, get_lexer_by_name from pygments.lexers import find_lexer_class, get_all_lexers
from wagtail.blocks import ( from wagtail.blocks import (
BooleanBlock, BooleanBlock,
CharBlock, CharBlock,
@ -21,11 +21,11 @@ def get_language_choices() -> Iterator[tuple[str, str]]:
class CodeStructValue(StructValue): class CodeStructValue(StructValue):
def code(self) -> str: def code(self) -> str:
lexer = get_lexer_by_name(self.get("language")) lexer = find_lexer_class(self["language"])()
formatter = HtmlFormatter( formatter = HtmlFormatter(
linenos=None, linenos=None,
) )
return mark_safe(highlight(self.get("source"), lexer, formatter)) return mark_safe(highlight(self["source"], lexer, formatter))
class CodeBlock(StructBlock): class CodeBlock(StructBlock):

View file

@ -1,10 +1,11 @@
from django.test import TestCase from django.test import SimpleTestCase
from django.urls import reverse from django.urls import reverse
from .blocks import CodeStructValue, get_language_choices
from .utils import PYGMENTS_VERSION_SLUG from .utils import PYGMENTS_VERSION_SLUG
class PygmentsStylesTestCase(TestCase): class PygmentsStylesTestCase(SimpleTestCase):
url = reverse("code-block:styles") url = reverse("code-block:styles")
def test_accessible(self) -> None: def test_accessible(self) -> None:
@ -15,3 +16,13 @@ class PygmentsStylesTestCase(TestCase):
def test_url_contains_version(self) -> None: def test_url_contains_version(self) -> None:
self.assertIn(PYGMENTS_VERSION_SLUG, self.url) self.assertIn(PYGMENTS_VERSION_SLUG, self.url)
class CodeStructValueTestCase(SimpleTestCase):
def test_highlights(self) -> None:
for language, _ in get_language_choices():
with self.subTest(language):
block = CodeStructValue(
None, [("source", "test"), ("language", language)]
)
self.assertIsInstance(block.code(), str)