Viewing File: /home/ubuntu/.local/lib/python3.10/site-packages/gradio/pipelines.py

"""This module should not be used directly as its API is subject to change. Instead,
please use the `gr.Interface.from_pipeline()` function."""

from __future__ import annotations

from typing import TYPE_CHECKING

from gradio import components

if TYPE_CHECKING:  # Only import for type checking (is False at runtime).
    from transformers import pipelines


def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
    """
    Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline.
    pipeline (transformers.Pipeline): the transformers.Pipeline from which to create an interface
    Returns:
    (dict): a dictionary of kwargs that can be used to construct an Interface object
    """
    try:
        import transformers
        from transformers import pipelines
    except ImportError as ie:
        raise ImportError(
            "transformers not installed. Please try `pip install transformers`"
        ) from ie
    if not isinstance(pipeline, pipelines.base.Pipeline):
        raise ValueError("pipeline must be a transformers.Pipeline")

    # Handle the different pipelines. The has_attr() checks to make sure the pipeline exists in the
    # version of the transformers library that the user has installed.
    if hasattr(transformers, "AudioClassificationPipeline") and isinstance(
        pipeline, pipelines.audio_classification.AudioClassificationPipeline
    ):
        pipeline_info = {
            "inputs": components.Audio(
                source="microphone",
                type="filepath",
                label="Input",
                render=False,
            ),
            "outputs": components.Label(label="Class", render=False),
            "preprocess": lambda i: {"inputs": i},
            "postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
        }
    elif hasattr(transformers, "AutomaticSpeechRecognitionPipeline") and isinstance(
        pipeline,
        pipelines.automatic_speech_recognition.AutomaticSpeechRecognitionPipeline,
    ):
        pipeline_info = {
            "inputs": components.Audio(
                source="microphone", type="filepath", label="Input", render=False
            ),
            "outputs": components.Textbox(label="Output", render=False),
            "preprocess": lambda i: {"inputs": i},
            "postprocess": lambda r: r["text"],
        }
    elif hasattr(transformers, "FeatureExtractionPipeline") and isinstance(
        pipeline, pipelines.feature_extraction.FeatureExtractionPipeline
    ):
        pipeline_info = {
            "inputs": components.Textbox(label="Input", render=False),
            "outputs": components.Dataframe(label="Output", render=False),
            "preprocess": lambda x: {"inputs": x},
            "postprocess": lambda r: r[0],
        }
    elif hasattr(transformers, "FillMaskPipeline") and isinstance(
        pipeline, pipelines.fill_mask.FillMaskPipeline
    ):
        pipeline_info = {
            "inputs": components.Textbox(label="Input", render=False),
            "outputs": components.Label(label="Classification", render=False),
            "preprocess": lambda x: {"inputs": x},
            "postprocess": lambda r: {i["token_str"]: i["score"] for i in r},
        }
    elif hasattr(transformers, "ImageClassificationPipeline") and isinstance(
        pipeline, pipelines.image_classification.ImageClassificationPipeline
    ):
        pipeline_info = {
            "inputs": components.Image(
                type="filepath", label="Input Image", render=False
            ),
            "outputs": components.Label(label="Classification", render=False),
            "preprocess": lambda i: {"images": i},
            "postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
        }
    elif hasattr(transformers, "QuestionAnsweringPipeline") and isinstance(
        pipeline, pipelines.question_answering.QuestionAnsweringPipeline
    ):
        pipeline_info = {
            "inputs": [
                components.Textbox(lines=7, label="Context", render=False),
                components.Textbox(label="Question", render=False),
            ],
            "outputs": [
                components.Textbox(label="Answer", render=False),
                components.Label(label="Score", render=False),
            ],
            "preprocess": lambda c, q: {"context": c, "question": q},
            "postprocess": lambda r: (r["answer"], r["score"]),
        }
    elif hasattr(transformers, "SummarizationPipeline") and isinstance(
        pipeline, pipelines.text2text_generation.SummarizationPipeline
    ):
        pipeline_info = {
            "inputs": components.Textbox(lines=7, label="Input", render=False),
            "outputs": components.Textbox(label="Summary", render=False),
            "preprocess": lambda x: {"inputs": x},
            "postprocess": lambda r: r[0]["summary_text"],
        }
    elif hasattr(transformers, "TextClassificationPipeline") and isinstance(
        pipeline, pipelines.text_classification.TextClassificationPipeline
    ):
        pipeline_info = {
            "inputs": components.Textbox(label="Input", render=False),
            "outputs": components.Label(label="Classification", render=False),
            "preprocess": lambda x: [x],
            "postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
        }
    elif hasattr(transformers, "TextGenerationPipeline") and isinstance(
        pipeline, pipelines.text_generation.TextGenerationPipeline
    ):
        pipeline_info = {
            "inputs": components.Textbox(label="Input", render=False),
            "outputs": components.Textbox(label="Output", render=False),
            "preprocess": lambda x: {"text_inputs": x},
            "postprocess": lambda r: r[0]["generated_text"],
        }
    elif hasattr(transformers, "TranslationPipeline") and isinstance(
        pipeline, pipelines.text2text_generation.TranslationPipeline
    ):
        pipeline_info = {
            "inputs": components.Textbox(label="Input", render=False),
            "outputs": components.Textbox(label="Translation", render=False),
            "preprocess": lambda x: [x],
            "postprocess": lambda r: r[0]["translation_text"],
        }
    elif hasattr(transformers, "Text2TextGenerationPipeline") and isinstance(
        pipeline, pipelines.text2text_generation.Text2TextGenerationPipeline
    ):
        pipeline_info = {
            "inputs": components.Textbox(label="Input", render=False),
            "outputs": components.Textbox(label="Generated Text", render=False),
            "preprocess": lambda x: [x],
            "postprocess": lambda r: r[0]["generated_text"],
        }
    elif hasattr(transformers, "ZeroShotClassificationPipeline") and isinstance(
        pipeline, pipelines.zero_shot_classification.ZeroShotClassificationPipeline
    ):
        pipeline_info = {
            "inputs": [
                components.Textbox(label="Input", render=False),
                components.Textbox(
                    label="Possible class names (" "comma-separated)", render=False
                ),
                components.Checkbox(label="Allow multiple true classes", render=False),
            ],
            "outputs": components.Label(label="Classification", render=False),
            "preprocess": lambda i, c, m: {
                "sequences": i,
                "candidate_labels": c,
                "multi_label": m,
            },
            "postprocess": lambda r: {
                r["labels"][i]: r["scores"][i] for i in range(len(r["labels"]))
            },
        }
    elif hasattr(transformers, "DocumentQuestionAnsweringPipeline") and isinstance(
        pipeline,
        pipelines.document_question_answering.DocumentQuestionAnsweringPipeline,  # type: ignore
    ):
        pipeline_info = {
            "inputs": [
                components.Image(type="filepath", label="Input Document", render=False),
                components.Textbox(label="Question", render=False),
            ],
            "outputs": components.Label(label="Label", render=False),
            "preprocess": lambda img, q: {"image": img, "question": q},
            "postprocess": lambda r: {i["answer"]: i["score"] for i in r},
        }
    elif hasattr(transformers, "VisualQuestionAnsweringPipeline") and isinstance(
        pipeline, pipelines.visual_question_answering.VisualQuestionAnsweringPipeline
    ):
        pipeline_info = {
            "inputs": [
                components.Image(type="filepath", label="Input Image", render=False),
                components.Textbox(label="Question", render=False),
            ],
            "outputs": components.Label(label="Score", render=False),
            "preprocess": lambda img, q: {"image": img, "question": q},
            "postprocess": lambda r: {i["answer"]: i["score"] for i in r},
        }
    elif hasattr(transformers, "ImageToTextPipeline") and isinstance(
        pipeline, pipelines.image_to_text.ImageToTextPipeline  # type: ignore
    ):
        pipeline_info = {
            "inputs": components.Image(
                type="filepath", label="Input Image", render=False
            ),
            "outputs": components.Textbox(label="Text", render=False),
            "preprocess": lambda i: {"images": i},
            "postprocess": lambda r: r[0]["generated_text"],
        }
    else:
        raise ValueError(f"Unsupported pipeline type: {type(pipeline)}")

    # define the function that will be called by the Interface
    def fn(*params):
        data = pipeline_info["preprocess"](*params)
        # special cases that needs to be handled differently
        if isinstance(
            pipeline,
            (
                pipelines.text_classification.TextClassificationPipeline,
                pipelines.text2text_generation.Text2TextGenerationPipeline,
                pipelines.text2text_generation.TranslationPipeline,
            ),
        ):
            data = pipeline(*data)
        else:
            data = pipeline(**data)
        output = pipeline_info["postprocess"](data)
        return output

    interface_info = pipeline_info.copy()
    interface_info["fn"] = fn
    del interface_info["preprocess"]
    del interface_info["postprocess"]

    # define the title/description of the Interface
    interface_info["title"] = pipeline.model.__class__.__name__

    return interface_info
Back to Directory File Manager