from html.parser import HTMLParser
from typing import override
import cross.fragments as f
class HTMLToFragmentsParser(HTMLParser):
def __init__(self) -> None:
super().__init__()
self.text: str = ""
self.fragments: list[f.Fragment] = []
self._tag_stack: dict[str, tuple[int, dict[str, str | None]]] = {}
self.in_pre: bool = False
self.in_code: bool = False
self.invisible: bool = False
def handle_a_endtag(self):
current_end = len(self.text)
start, _attr = self._tag_stack.pop("a")
href = _attr.get('href')
if href and current_end > start:
self.fragments.append(
f.LinkFragment(start=start, end=current_end, url=href)
)
@override
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
_attr = dict(attrs)
def append_newline():
if self.text and not self.text.endswith("\n"):
self.text += "\n"
if self.invisible:
return
match tag:
case "p":
cls = _attr.get('class', '')
if cls and 'quote-inline' in cls:
self.invisible = True
case "a":
self._tag_stack["a"] = (len(self.text), _attr)
case "code":
if not self.in_pre:
self.text += "`"
self.in_code = True
case "pre":
append_newline()
self.text += "```\n"
self.in_pre = True
case "blockquote":
append_newline()
self.text += "> "
case "strong" | "b":
self.text += "**"
case "em" | "i":
self.text += "*"
case "del" | "s":
self.text += "~~"
case "br":
self.text += "\n"
case _:
if tag in {"h1", "h2", "h3", "h4", "h5", "h6"}:
level = int(tag[1])
self.text += "\n" + "#" * level + " "
@override
def handle_endtag(self, tag: str) -> None:
if self.invisible:
if tag == "p":
self.invisible = False
return
match tag:
case "a":
if "a" in self._tag_stack:
self.handle_a_endtag()
case "code":
if not self.in_pre and self.in_code:
self.text += "`"
self.in_code = False
case "pre":
self.text += "\n```\n"
self.in_pre = False
case "blockquote":
self.text += "\n"
case "strong" | "b":
self.text += "**"
case "em" | "i":
self.text += "*"
case "del" | "s":
self.text += "~~"
case "p":
self.text += "\n\n"
case _:
if tag in ["h1", "h2", "h3", "h4", "h5", "h6"]:
self.text += '\n'
@override
def handle_data(self, data: str) -> None:
if not self.invisible:
self.text += data
def get_result(self) -> tuple[str, list[f.Fragment]]:
if self.text.endswith('\n\n'):
return self.text[:-2], self.fragments
return self.text, self.fragments