from html.parser import HTMLParser
from typing import override
import cross.fragments as f
class HTMLToFragmentsParser(HTMLParser):
def __init__(self) -> None:
super().__init__()
self.builder: bytearray = bytearray()
self.fragments: list[f.Fragment] = []
self._tag_stack: dict[str, tuple[int, dict[str, str | None]]] = {}
self.in_pre: bool = False
self.in_code: bool = False
self.invisible: bool = False
def handle_a_endtag(self):
current_end = len(self.builder)
start, _attr = self._tag_stack.pop("a")
href = _attr.get('href')
if href and current_end > start:
self.fragments.append(
f.LinkFragment(start=start, end=current_end, url=href)
)
@override
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
_attr = dict(attrs)
def append_newline():
if self.builder and not self.builder.endswith(b"\n"):
self.builder.extend(b"\n")
if self.invisible:
return
match tag:
case "p":
cls = _attr.get('class', '')
if cls and 'quote-inline' in cls:
self.invisible = True
case "a":
self._tag_stack["a"] = (len(self.builder), _attr)
case "code":
if not self.in_pre:
self.builder.extend(b"`")
self.in_code = True
case "pre":
append_newline()
self.builder.extend(b"```\n")
self.in_pre = True
case "blockquote":
append_newline()
self.builder.extend(b"> ")
case "strong" | "b":
self.builder.extend(b"**")
case "em" | "i":
self.builder.extend(b"*")
case "del" | "s":
self.builder.extend(b"~~")
case "br":
self.builder.extend(b"\n")
case _:
if tag in {"h1", "h2", "h3", "h4", "h5", "h6"}:
level = int(tag[1])
self.builder.extend(("\n" + "#" * level + " ").encode('utf-8'))
@override
def handle_endtag(self, tag: str) -> None:
if self.invisible:
if tag == "p":
self.invisible = False
return
match tag:
case "a":
if "a" in self._tag_stack:
self.handle_a_endtag()
case "code":
if not self.in_pre and self.in_code:
self.builder.extend(b"`")
self.in_code = False
case "pre":
self.builder.extend(b"\n```\n")
self.in_pre = False
case "blockquote":
self.builder.extend(b"\n")
case "strong" | "b":
self.builder.extend(b"**")
case "em" | "i":
self.builder.extend(b"*")
case "del" | "s":
self.builder.extend(b"~~")
case "p":
self.builder.extend(b"\n\n")
case _:
if tag in ["h1", "h2", "h3", "h4", "h5", "h6"]:
self.builder.extend(b'\n')
@override
def handle_data(self, data: str) -> None:
if not self.invisible:
self.builder.extend(data.encode('utf-8'))
def get_result(self) -> tuple[str, list[f.Fragment]]:
if self.builder.endswith(b'\n\n'):
return self.builder[:-2].decode('utf-8'), self.fragments
return self.builder.decode('utf-8'), self.fragments