from html.parser import HTMLParser
from typing import override
from cross.tokens import LinkToken, TextToken, Token
from util.splitter import canonical_label
class HTMLToTokensParser(HTMLParser):
def __init__(self) -> None:
super().__init__()
self.tokens: list[Token] = []
self._tag_stack: dict[str, tuple[str, dict[str, str | None]]] = {}
self.in_pre: bool = False
self.in_code: bool = False
self.invisible: bool = False
def handle_a_endtag(self):
label, _attr = self._tag_stack.pop("a")
href = _attr.get("href")
if href:
if canonical_label(label, href):
self.tokens.append(LinkToken(href=href))
else:
self.tokens.append(LinkToken(href=href, label=label))
def append_text(self, text: str):
self.tokens.append(TextToken(text=text))
def append_newline(self):
if self.tokens:
last_token = self.tokens[-1]
if isinstance(last_token, TextToken) and not last_token.text.endswith("\n"):
self.tokens.append(TextToken(text="\n"))
@override
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
_attr = dict(attrs)
if self.invisible:
return
match tag:
case "p":
cls = _attr.get("class", "")
if cls and "quote-inline" in cls:
self.invisible = True
case "a":
self._tag_stack["a"] = ("", _attr)
case "code":
if not self.in_pre:
self.append_text("`")
self.in_code = True
case "pre":
self.append_newline()
self.append_text("```\n")
self.in_pre = True
case "blockquote":
self.append_newline()
self.append_text("> ")
case "strong" | "b":
self.append_text("**")
case "em" | "i":
self.append_text("*")
case "del" | "s":
self.append_text("~~")
case "br":
self.append_text("\n")
case "h1" | "h2" | "h3" | "h4" | "h5" | "h6":
level = int(tag[1])
self.append_text("\n" + "#" * level + " ")
case _:
# self.builder.extend(f"<{tag}>".encode("utf-8"))
pass
@override
def handle_endtag(self, tag: str) -> None:
if self.invisible:
if tag == "p":
self.invisible = False
return
match tag:
case "a":
if "a" in self._tag_stack:
self.handle_a_endtag()
case "code":
if not self.in_pre and self.in_code:
self.append_text("`")
self.in_code = False
case "pre":
self.append_newline()
self.append_text("```\n")
self.in_pre = False
case "blockquote":
self.append_text("\n")
case "strong" | "b":
self.append_text("**")
case "em" | "i":
self.append_text("*")
case "del" | "s":
self.append_text("~~")
case "p":
self.append_text("\n\n")
case "h1" | "h2" | "h3" | "h4" | "h5" | "h6":
self.append_text("\n")
case _:
# self.builder.extend(f"{tag}>".encode("utf-8"))
pass
@override
def handle_data(self, data: str) -> None:
if self.invisible:
return
if self._tag_stack.get('a'):
label, _attr = self._tag_stack.pop("a")
self._tag_stack["a"] = (label + data, _attr)
return
def get_result(self) -> list[Token]:
if not self.tokens:
return []
combined: list[Token] = []
buffer: list[str] = []
def flush_buffer():
if buffer:
merged = "".join(buffer)
combined.append(TextToken(text=merged))
buffer.clear()
for token in self.tokens:
if isinstance(token, TextToken):
buffer.append(token.text)
else:
flush_buffer()
combined.append(token)
flush_buffer()
if combined and isinstance(combined[-1], TextToken):
if combined[-1].text.endswith("\n\n"):
combined[-1] = TextToken(text=combined[-1].text[:-2])
if combined[-1].text.endswith("\n"):
combined[-1] = TextToken(text=combined[-1].text[:-1])
return combined