Viewing File: /home/ubuntu/.local/lib/python3.10/site-packages/whisper_timestamped/make_subtitles.py

#!/usr/bin/env python3

import json
import string

_punctuation = "".join(c for c in string.punctuation if c not in ["-", "'"]) + "。,!?:”、…"

def split_long_segments(segments, max_length, use_space = True):
    new_segments = []
    for segment in segments:
        text = segment["text"]
        if len(text) <= max_length:
            new_segments.append(segment)
        else:
            meta_words = segment["words"]
            # Note: we do this in case punctuation were removed from words
            if use_space:
                # Split text around spaces and punctuations (keeping punctuations)
                words = text.split()
            else:
                words = [w["text"] for w in meta_words]
            if len(words) != len(meta_words):
                new_words = [w["text"] for w in meta_words]
                print(f"WARNING: {' '.join(words)} != {' '.join(new_words)}")
                words = new_words
            current_text = ""
            current_start = segment["start"]
            current_best_idx = None
            current_best_end = None
            current_best_next_start = None
            for i, (word, meta) in enumerate(zip(words, meta_words)):
                current_text_before = current_text
                if current_text and use_space:
                    current_text += " "
                current_text += word

                if len(current_text) > max_length and len(current_text_before):
                    start = current_start
                    if current_best_idx is not None:
                        text = current_text[:current_best_idx]
                        end = current_best_end
                        current_text = current_text[current_best_idx+1:]
                        current_start = current_best_next_start
                    else:
                        text = current_text_before
                        end = meta_words[i-1]["end"]
                        current_text = word
                        current_start = meta["start"]

                    current_best_idx = None
                    current_best_end = None
                    current_best_next_start = None                        

                    new_segments.append({"text": text, "start": start, "end": end})

                # Try to cut after punctuation
                if current_text and current_text[-1] in _punctuation:
                    current_best_idx = len(current_text)
                    current_best_end = meta["end"]
                    current_best_next_start = meta_words[i+1]["start"] if i+1 < len(meta_words) else None
            
            if len(current_text):
                new_segments.append({"text": current_text, "start": current_start, "end": segment["end"]})
            
    return new_segments

def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
    assert seconds >= 0, "non-negative timestamp expected"
    milliseconds = round(seconds * 1000.0)

    hours = milliseconds // 3_600_000
    milliseconds -= hours * 3_600_000

    minutes = milliseconds // 60_000
    milliseconds -= minutes * 60_000

    seconds = milliseconds // 1_000
    milliseconds -= seconds * 1_000

    hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
    return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"

def write_vtt(result, file):
    print("WEBVTT\n", file=file)
    for segment in result:
        print(
            f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
            f"{segment['text'].strip().replace('-->', '->')}\n",
            file=file,
            flush=True,
        )

def write_srt(result, file):
    for i, segment in enumerate(result, start=1):
        # write srt lines
        print(
            f"{i}\n"
            f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
            f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
            f"{segment['text'].strip().replace('-->', '->')}\n",
            file=file,
            flush=True,
        )

def cli():

    import os
    import argparse

    supported_formats = ["srt", "vtt"]

    parser = argparse.ArgumentParser(
        description='Convert .word.json transcription files (output of whisper_timestamped) to srt or vtt, being able to cut long segments',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument('input', type=str, help='Input json file, or input folder')
    parser.add_argument('output', type=str, help='Output srt or vtt file, or output folder')
    parser.add_argument('--max_length', default=200, help='Maximum length of a segment in characters', type=int)
    parser.add_argument('--format', type=str, default="all", help='Output format (if the output is a folder, i.e. not a file with an explicit extension)', choices= supported_formats + ["all"])
    args = parser.parse_args()

    if os.path.isdir(args.input) or not max([args.output.endswith(e) for e in supported_formats]):
        input_files = [f for f in os.listdir(args.input) if f.endswith(".words.json")] if os.path.isdir(args.input) else [os.path.basename(args.input)]
        extensions = [args.format] if args.format != "all" else ["srt", "vtt"]
        output_files = [[os.path.join(args.output, f[:-11] + "." + e) for e in extensions] for f in input_files]
        if os.path.isdir(args.input):
            input_files = [os.path.join(args.input, f) for f in input_files]
        else:
            input_files = [args.input]
        if not os.path.isdir(args.output):
            os.makedirs(args.output)
    else:
        input_files = [args.input]
        output_files = [[args.output]]
        if not os.path.isdir(os.path.dirname(args.output)):
            os.makedirs(os.path.dirname(args.output))

    for fn, outputs in zip(input_files, output_files):
        with open(fn, "r", encoding="utf-8") as f:
            transcript = json.load(f)
        segments = transcript["segments"]
        if args.max_length:
            language = transcript["language"]
            use_space = language not in ["zh", "ja", "th", "lo", "my"]
            segments = split_long_segments(segments, args.max_length, use_space=use_space)
        for output in outputs:
            if output.endswith(".srt"):
                with open(output, "w", encoding="utf-8") as f:
                    write_srt(segments, file=f)
            elif output.endswith(".vtt"):
                with open(output, "w", encoding="utf-8") as f:
                    write_vtt(segments, file=f)
            else:
                raise RuntimeError(f"Unknown output format for {output}")

if __name__ == "__main__":
    cli()
Back to Directory File Manager