Viewing File: /home/ubuntu/.local/lib/python3.10/site-packages/gradio/route_utils.py

from __future__ import annotations

import json
from typing import TYPE_CHECKING, Optional, Union

import fastapi
import httpx
from gradio_client.documentation import document, set_documentation_group

from gradio import utils
from gradio.data_classes import PredictBody
from gradio.exceptions import Error
from gradio.helpers import EventData
from gradio.state_holder import SessionState

if TYPE_CHECKING:
    from gradio.blocks import Blocks
    from gradio.routes import App

set_documentation_group("routes")


class Obj:
    """
    Using a class to convert dictionaries into objects. Used by the `Request` class.
    Credit: https://www.geeksforgeeks.org/convert-nested-python-dictionary-to-object/
    """

    def __init__(self, dict_):
        self.__dict__.update(dict_)
        for key, value in dict_.items():
            if isinstance(value, (dict, list)):
                value = Obj(value)
            setattr(self, key, value)

    def __getitem__(self, item):
        return self.__dict__[item]

    def __setitem__(self, item, value):
        self.__dict__[item] = value

    def __iter__(self):
        for key, value in self.__dict__.items():
            if isinstance(value, Obj):
                yield (key, dict(value))
            else:
                yield (key, value)

    def __contains__(self, item) -> bool:
        if item in self.__dict__:
            return True
        for value in self.__dict__.values():
            if isinstance(value, Obj) and item in value:
                return True
        return False

    def keys(self):
        return self.__dict__.keys()

    def values(self):
        return self.__dict__.values()

    def items(self):
        return self.__dict__.items()

    def __str__(self) -> str:
        return str(self.__dict__)

    def __repr__(self) -> str:
        return str(self.__dict__)


@document()
class Request:
    """
    A Gradio request object that can be used to access the request headers, cookies,
    query parameters and other information about the request from within the prediction
    function. The class is a thin wrapper around the fastapi.Request class. Attributes
    of this class include: `headers`, `client`, `query_params`, and `path_params`. If
    auth is enabled, the `username` attribute can be used to get the logged in user.
    Example:
        import gradio as gr
        def echo(text, request: gr.Request):
            if request:
                print("Request headers dictionary:", request.headers)
                print("IP address:", request.client.host)
                print("Query parameters:", dict(request.query_params))
            return text
        io = gr.Interface(echo, "textbox", "textbox").launch()
    Demos: request_ip_headers
    """

    def __init__(
        self,
        request: fastapi.Request | None = None,
        username: str | None = None,
        **kwargs,
    ):
        """
        Can be instantiated with either a fastapi.Request or by manually passing in
        attributes (needed for websocket-based queueing).
        Parameters:
            request: A fastapi.Request
        """
        self.request = request
        self.username = username
        self.kwargs: dict = kwargs

    def dict_to_obj(self, d):
        if isinstance(d, dict):
            return json.loads(json.dumps(d), object_hook=Obj)
        else:
            return d

    def __getattr__(self, name):
        if self.request:
            return self.dict_to_obj(getattr(self.request, name))
        else:
            try:
                obj = self.kwargs[name]
            except KeyError as ke:
                raise AttributeError(
                    f"'Request' object has no attribute '{name}'"
                ) from ke
            return self.dict_to_obj(obj)


class FnIndexInferError(Exception):
    pass


def infer_fn_index(app: App, api_name: str, body: PredictBody) -> int:
    if body.fn_index is None:
        for i, fn in enumerate(app.get_blocks().dependencies):
            if fn["api_name"] == api_name:
                return i

        raise FnIndexInferError(f"Could not infer fn_index for api_name {api_name}.")
    else:
        return body.fn_index


def compile_gr_request(
    app: App,
    body: PredictBody,
    fn_index_inferred: int,
    username: Optional[str],
    request: Optional[fastapi.Request],
):
    # If this fn_index cancels jobs, then the only input we need is the
    # current session hash
    if app.get_blocks().dependencies[fn_index_inferred]["cancels"]:
        body.data = [body.session_hash]
    if body.request:
        if body.batched:
            gr_request = [Request(username=username, **req) for req in body.request]
        else:
            assert isinstance(body.request, dict)
            gr_request = Request(username=username, **body.request)
    else:
        if request is None:
            raise ValueError("request must be provided if body.request is None")
        gr_request = Request(username=username, request=request)

    return gr_request


def restore_session_state(app: App, body: PredictBody):
    fn_index = body.fn_index
    session_hash = getattr(body, "session_hash", None)
    if session_hash is not None:
        session_state = app.state_holder[session_hash]
        # The should_reset set keeps track of the fn_indices
        # that have been cancelled. When a job is cancelled,
        # the /reset route will mark the jobs as having been reset.
        # That way if the cancel job finishes BEFORE the job being cancelled
        # the job being cancelled will not overwrite the state of the iterator.
        if fn_index in app.iterators_to_reset[session_hash]:
            iterators = {}
            app.iterators_to_reset[session_hash].remove(fn_index)
        else:
            iterators = app.iterators[session_hash]
    else:
        session_state = SessionState(app.get_blocks())
        iterators = {}

    return session_state, iterators


def prepare_event_data(
    blocks: Blocks,
    body: PredictBody,
    fn_index_inferred: int,
) -> EventData:
    dependency = blocks.dependencies[fn_index_inferred]
    target = dependency["targets"][0] if len(dependency["targets"]) else None
    event_data = EventData(
        blocks.blocks.get(target[0]) if target else None,
        body.event_data,
    )
    return event_data


async def call_process_api(
    app: App,
    body: PredictBody,
    gr_request: Union[Request, list[Request]],
    fn_index_inferred: int,
):
    session_state, iterators = restore_session_state(app=app, body=body)

    dependency = app.get_blocks().dependencies[fn_index_inferred]
    event_data = prepare_event_data(app.get_blocks(), body, fn_index_inferred)
    event_id = getattr(body, "event_id", None)

    fn_index = body.fn_index
    session_hash = getattr(body, "session_hash", None)
    inputs = body.data

    batch_in_single_out = not body.batched and dependency["batch"]
    if batch_in_single_out:
        inputs = [inputs]

    try:
        with utils.MatplotlibBackendMananger():
            output = await app.get_blocks().process_api(
                fn_index=fn_index_inferred,
                inputs=inputs,
                request=gr_request,
                state=session_state,
                iterators=iterators,
                session_hash=session_hash,
                event_id=event_id,
                event_data=event_data,
                in_event_listener=True,
            )
        iterator = output.pop("iterator", None)
        if hasattr(body, "session_hash"):
            app.iterators[body.session_hash][fn_index] = iterator
        if isinstance(output, Error):
            raise output
    except BaseException:
        iterator = iterators.get(fn_index, None)
        if iterator is not None:  # close off any streams that are still open
            run_id = id(iterator)
            pending_streams: dict[int, list] = (
                app.get_blocks().pending_streams[session_hash].get(run_id, {})
            )
            for stream in pending_streams.values():
                stream.append(None)
        raise

    if batch_in_single_out:
        output["data"] = output["data"][0]

    return output


def strip_url(orig_url: str) -> str:
    """
    Strips the query parameters and trailing slash from a URL.
    """
    parsed_url = httpx.URL(orig_url)
    stripped_url = parsed_url.copy_with(query=None)
    stripped_url = str(stripped_url)
    return stripped_url.rstrip("/")
Back to Directory File Manager