文件预览

prepare_youtube_transcript.py

查看 YouTube Chinese Subtitle Burn-in 技能包中的文件内容。

文件内容

scripts/prepare_youtube_transcript.py

#!/usr/bin/env python3
from __future__ import annotations

import argparse
import html
import json
import re
import sys
from pathlib import Path


CUE_RE = re.compile(
    r"(\d\d:\d\d:\d\d\.\d{3}) --> (\d\d:\d\d:\d\d\.\d{3}).*?\n(.*?)(?=\n\n|\Z)",
    re.S,
)
SRT_RE = re.compile(
    r"(\d\d:\d\d:\d\d,\d{3}) --> (\d\d:\d\d:\d\d,\d{3}).*?\n(.*?)(?=\n\n|\Z)",
    re.S,
)
HTML_TAG = re.compile(r"<[^>]+>")
WORD_RE = re.compile(r"[A-Za-z0-9]+(?:['.-][A-Za-z0-9]+)*|[^\sA-Za-z0-9]")


def parse_vtt_time(value: str) -> float:
    hh, mm, rest = value.split(":")
    ss, ms = rest.split(".")
    return int(hh) * 3600 + int(mm) * 60 + int(ss) + int(ms) / 1000


def parse_srt_time(value: str) -> float:
    hh, mm, rest = value.split(":")
    ss, ms = rest.split(",")
    return int(hh) * 3600 + int(mm) * 60 + int(ss) + int(ms) / 1000


def format_srt_time(seconds: float) -> str:
    total_ms = max(0, int(round(seconds * 1000)))
    hh = total_ms // 3_600_000
    total_ms %= 3_600_000
    mm = total_ms // 60_000
    total_ms %= 60_000
    ss = total_ms // 1000
    ms = total_ms % 1000
    return f"{hh:02d}:{mm:02d}:{ss:02d},{ms:03d}"


def clean_text(value: str) -> str:
    value = " ".join(value.splitlines())
    value = HTML_TAG.sub("", value)
    value = html.unescape(value)
    return re.sub(r"\s+", " ", value).strip()


def stitch_rolling_vtt(path: Path) -> list[tuple[float, str]]:
    words: list[str] = []
    timed: list[tuple[float, str]] = []
    data = path.read_text(encoding="utf-8-sig")
    for match in CUE_RE.finditer(data):
        start = parse_vtt_time(match.group(1))
        cue_words = clean_text(match.group(3)).split()
        if not cue_words:
            continue
        max_overlap = min(len(words), len(cue_words))
        overlap = 0
        lower_words = [word.lower() for word in words]
        lower_cue = [word.lower() for word in cue_words]
        for size in range(max_overlap, 0, -1):
            if lower_words[-size:] == lower_cue[:size]:
                overlap = size
                break
        for word in cue_words[overlap:]:
            words.append(word)
            timed.append((start, word))
    return timed


def tokenize(text: str) -> list[str]:
    return WORD_RE.findall(text)


def untokenize(parts: list[str]) -> str:
    out = ""
    for part in parts:
        if not out:
            out = part
        elif re.fullmatch(r"[,.!?;:%)]", part):
            out += part
        elif out.endswith(("(", "$")):
            out += part
        else:
            out += " " + part
    return out.strip()


def overlap_prefix(emitted: list[str], current: list[str], max_lookback: int = 80) -> int:
    limit = min(len(emitted), len(current), max_lookback)
    emitted_lower = [part.lower() for part in emitted]
    current_lower = [part.lower() for part in current]
    for size in range(limit, 0, -1):
        if emitted_lower[-size:] == current_lower[:size]:
            return size
    return 0


def stitch_rolling_srt(path: Path) -> list[tuple[float, str]]:
    emitted: list[str] = []
    timed: list[tuple[float, str]] = []
    data = path.read_text(encoding="utf-8-sig")
    for match in SRT_RE.finditer(data):
        start = parse_srt_time(match.group(1))
        cue_words = tokenize(clean_text(match.group(3)))
        if not cue_words:
            continue
        overlap = overlap_prefix(emitted, cue_words)
        for word in cue_words[overlap:]:
            emitted.append(word)
            timed.append((start, word))
    return timed


def build_cues(timed: list[tuple[float, str]]) -> list[dict[str, object]]:
    cues: list[dict[str, object]] = []
    current: list[str] = []
    start: float | None = None
    last = 0.0
    for index, (timestamp, word) in enumerate(timed):
        if start is None:
            start = timestamp
        current.append(word)
        last = timestamp
        next_gap = timed[index + 1][0] - timestamp if index + 1 < len(timed) else 999
        ends_sentence = bool(re.search(r"[.!?]$", word))
        duration = timestamp - start
        should_break = (
            (len(current) >= 8 and ends_sentence)
            or len(current) >= 13
            or duration >= 5.5
            or next_gap > 1.25
        )
        if should_break:
            end = timed[index + 1][0] - 0.05 if index + 1 < len(timed) else timestamp + 2.5
            if end <= start:
                end = last + 2
            cues.append({"id": len(cues) + 1, "start": start, "end": end, "en": " ".join(current)})
            current = []
            start = None
    if current and start is not None:
        cues.append({"id": len(cues) + 1, "start": start, "end": last + 2, "en": " ".join(current)})
    return cues


def write_srt(cues: list[dict[str, object]], path: Path, text_key: str) -> None:
    lines: list[str] = []
    for index, cue in enumerate(cues, 1):
        lines.append(str(cue.get("id", index)))
        lines.append(f"{format_srt_time(float(cue['start']))} --> {format_srt_time(float(cue['end']))}")
        lines.append(str(cue[text_key]))
        lines.append("")
    path.write_text("\n".join(lines), encoding="utf-8")


def main() -> int:
    parser = argparse.ArgumentParser(description="Clean YouTube rolling VTT/SRT captions into semantic cue JSON/SRT.")
    parser.add_argument("subtitle", type=Path)
    parser.add_argument("--json-out", type=Path, required=True)
    parser.add_argument("--srt-out", type=Path)
    args = parser.parse_args()

    if not args.subtitle.exists():
        print(f"FAIL: missing subtitle {args.subtitle}")
        return 1
    suffix = args.subtitle.suffix.lower()
    if suffix == ".vtt":
        timed = stitch_rolling_vtt(args.subtitle)
    elif suffix == ".srt":
        timed = stitch_rolling_srt(args.subtitle)
    else:
        print(f"FAIL: unsupported subtitle format {args.subtitle.suffix}")
        return 1
    if not timed:
        print(f"FAIL: no timed words found in {args.subtitle}")
        return 1
    cues = build_cues(timed)
    args.json_out.parent.mkdir(parents=True, exist_ok=True)
    args.json_out.write_text(json.dumps(cues, ensure_ascii=False, indent=2), encoding="utf-8")
    if args.srt_out:
        args.srt_out.parent.mkdir(parents=True, exist_ok=True)
        write_srt(cues, args.srt_out, "en")
    print(f"PASS: wrote {len(cues)} cleaned cues")
    return 0


if __name__ == "__main__":
    sys.exit(main())