Viewing File: /home/ubuntu/.local/lib/python3.10/site-packages/cached_path/_cached_path.py

import logging
import os
import shutil
import tarfile
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Tuple
from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile

from .cache_file import CacheFile
from .common import PathOrStr, get_cache_dir
from .file_lock import FileLock
from .meta import Meta
from .schemes import (
    SchemeClient,
    get_scheme_client,
    get_supported_schemes,
    hf_get_from_cache,
)
from .util import (
    _lock_file_path,
    _meta_file_path,
    check_tarfile,
    find_latest_cached,
    resource_to_filename,
)

if TYPE_CHECKING:
    from rich.progress import Progress

logger = logging.getLogger("cached_path")


def _is_rarfile(path: PathOrStr, url_or_filename: PathOrStr) -> bool:
    try:
        import rarfile

        return rarfile.is_rarfile(path)
    except ImportError:
        if str(url_or_filename).lower().endswith(".rar"):
            import warnings

            warnings.warn(
                f"The file has a 'rar' extension but rarfile is unavailable: {url_or_filename}"
                "\nSee https://rarfile.readthedocs.io/"
            )

        return False


def _is_archive(file_path: PathOrStr, url_or_filename) -> bool:
    return (
        is_zipfile(file_path)
        or tarfile.is_tarfile(file_path)
        or _is_rarfile(file_path, url_or_filename)
    )


def cached_path(
    url_or_filename: PathOrStr,
    cache_dir: Optional[PathOrStr] = None,
    extract_archive: bool = False,
    force_extract: bool = False,
    quiet: bool = False,
    progress: Optional["Progress"] = None,
) -> Path:
    """
    Given something that might be a URL or local path, determine which.
    If it's a remote resource, download the file and cache it, and
    then return the path to the cached file. If it's already a local path,
    make sure the file exists and return the path.

    For URLs, the following schemes are all supported out-of-the-box:

    * ``http`` and ``https``,
    * ``s3`` for objects on `AWS S3`_,
    * ``gs`` for objects on `Google Cloud Storage (GCS)`_, and
    * ``hf`` for objects or repositories on `HuggingFace Hub`_.

    If you have `Beaker-py`_ installed you can also use URLs of the form:
    ``beaker://{user_name}/{dataset_name}/{file_path}``.

    You can also extend ``cached_path()`` to handle more schemes with :func:`add_scheme_client()`.

    .. _AWS S3: https://aws.amazon.com/s3/
    .. _Google Cloud Storage (GCS): https://cloud.google.com/storage
    .. _HuggingFace Hub: https://huggingface.co/
    .. _Beaker-py: https://github.com/allenai/beaker-py

    Examples
    --------

    To download a file over ``https``::

        cached_path("https://github.com/allenai/cached_path/blob/main/README.md")

    To download an object on GCS::

        cached_path("gs://allennlp-public-models/lerc-2020-11-18.tar.gz")

    To download the PyTorch weights for the model `epwalsh/bert-xsmall-dummy`_
    on HuggingFace, you could do::

        cached_path("hf://epwalsh/bert-xsmall-dummy/pytorch_model.bin")

    For paths or URLs that point to a tarfile or zipfile, you can append the path
    to a specific file within the archive to the ``url_or_filename``, preceeded by a "!".
    The archive will be automatically extracted (provided you set ``extract_archive`` to ``True``),
    returning the local path to the specific file. For example::

        cached_path("model.tar.gz!weights.th", extract_archive=True)

    .. _epwalsh/bert-xsmall-dummy: https://huggingface.co/epwalsh/bert-xsmall-dummy

    Parameters
    ----------

    url_or_filename :
        A URL or path to parse and possibly download.

    cache_dir :
        The directory to cache downloads. If not specified, the global default cache directory
        will be used (``~/.cache/cached_path``). This can be set to something else with
        :func:`set_cache_dir()`.

    extract_archive :
        If ``True``, then zip or tar.gz archives will be automatically extracted.
        In which case the directory is returned.

    force_extract :
        If ``True`` and the file is an archive file, it will be extracted regardless
        of whether or not the extracted directory already exists.

        .. caution::
            Use this flag with caution! This can lead to race conditions if used
            from multiple processes on the same file.

    quiet :
        If ``True``, progress displays won't be printed.

    progress :
        A custom progress display to use. If not set and ``quiet=False``, a default display
        from :func:`~cached_path.get_download_progress()` will be used.

    Returns
    -------
    :class:`pathlib.Path`
        The local path to the (potentially cached) resource.

    Raises
    ------
    ``FileNotFoundError``

        If the resource cannot be found locally or remotely.

    ``ValueError``
        When the URL is invalid.

    ``Other errors``
        Other error types are possible as well depending on the client used to fetch
        the resource.

    """
    if not isinstance(url_or_filename, str):
        url_or_filename = str(url_or_filename)

    file_path: Path
    extraction_path: Optional[Path] = None
    etag: Optional[str] = None

    # If we're using the /a/b/foo.zip!c/d/file.txt syntax, handle it here.
    exclamation_index = url_or_filename.find("!")
    if extract_archive and exclamation_index >= 0:
        archive_path = url_or_filename[:exclamation_index]
        file_name = url_or_filename[exclamation_index + 1 :]

        # Call 'cached_path' recursively now to get the local path to the archive itself.
        cached_archive_path = cached_path(
            archive_path,
            cache_dir=cache_dir,
            extract_archive=True,
            force_extract=force_extract,
            quiet=quiet,
            progress=progress,
        )
        if not cached_archive_path.is_dir():
            raise ValueError(
                f"{url_or_filename} uses the ! syntax, but does not specify an archive file."
            )

        # Now return the full path to the desired file within the extracted archive,
        # provided it exists.
        file_path = cached_archive_path / file_name
        if not file_path.exists():
            raise FileNotFoundError(f"'{file_name}' not found within '{archive_path}'")

        return file_path

    parsed = urlparse(url_or_filename)

    if parsed.scheme in get_supported_schemes():
        # URL, so get it from the cache (downloading if necessary)
        file_path, etag = get_from_cache(url_or_filename, cache_dir, quiet=quiet, progress=progress)

        if extract_archive and _is_archive(file_path, url_or_filename):
            # This is the path the file should be extracted to.
            # For example ~/.cached_path/cache/234234.21341 -> ~/.cached_path/cache/234234.21341-extracted
            extraction_path = file_path.parent / (file_path.name + "-extracted")
    elif parsed.scheme == "file":
        return cached_path(url_or_filename.replace("file://", "", 1))
    else:
        orig_url_or_filename = url_or_filename
        url_or_filename = Path(url_or_filename).expanduser()
        cache_dir = Path(cache_dir if cache_dir else get_cache_dir()).expanduser()
        cache_dir.mkdir(parents=True, exist_ok=True)

        if url_or_filename.exists():
            # File, and it exists.
            file_path = url_or_filename
            # Normalize the path.
            url_or_filename = url_or_filename.resolve()
            if extract_archive and file_path.is_file() and _is_archive(file_path, url_or_filename):
                # We'll use a unique directory within the cache to root to extract the archive to.
                # The name of the directory is a hash of the resource file path and it's modification
                # time. That way, if the file changes, we'll know when to extract it again.
                extraction_name = (
                    resource_to_filename(url_or_filename, str(os.path.getmtime(file_path)))
                    + "-extracted"
                )
                extraction_path = cache_dir / extraction_name
        elif parsed.scheme == "":
            # File, but it doesn't exist.
            raise FileNotFoundError(f"file {url_or_filename} not found")
        else:
            # Something unknown
            raise ValueError(f"unable to parse {orig_url_or_filename} as a URL or as a local path")

    if extraction_path is not None:
        # If the extracted directory already exists (and is non-empty), then no
        # need to create a lock file and extract again unless `force_extract=True`.
        if os.path.isdir(extraction_path) and os.listdir(extraction_path) and not force_extract:
            return extraction_path

        # Extract it.
        with FileLock(_lock_file_path(extraction_path)):
            # Check again if the directory exists now that we've acquired the lock.
            if os.path.isdir(extraction_path) and os.listdir(extraction_path):
                if force_extract:
                    logger.warning(
                        "Extraction directory for %s (%s) already exists, "
                        "overwriting it since 'force_extract' is 'True'",
                        url_or_filename,
                        extraction_path,
                    )
                else:
                    return extraction_path

            logger.info("Extracting %s to %s", url_or_filename, extraction_path)
            shutil.rmtree(extraction_path, ignore_errors=True)

            # We extract first to a temporary directory in case something goes wrong
            # during the extraction process so we don't end up with a corrupted cache.
            tmp_extraction_dir = tempfile.mkdtemp(dir=os.path.split(extraction_path)[0])
            try:
                if tarfile.is_tarfile(file_path):
                    with tarfile.open(file_path) as tar_file:
                        check_tarfile(tar_file)
                        tar_file.extractall(tmp_extraction_dir)
                elif _is_rarfile(file_path, url_or_filename):
                    import rarfile

                    with rarfile.RarFile(file_path) as rar_file:
                        rar_file.extractall(tmp_extraction_dir)
                else:
                    with ZipFile(file_path) as zip_file:
                        zip_file.extractall(tmp_extraction_dir)
                # Extraction was successful, rename temp directory to final
                # cache directory and dump the meta data.
                os.replace(tmp_extraction_dir, extraction_path)
                meta = Meta.new(
                    url_or_filename,
                    extraction_path,
                    etag=etag,
                    extraction_dir=True,
                )
                meta.to_file()
            finally:
                shutil.rmtree(tmp_extraction_dir, ignore_errors=True)

        return extraction_path

    return file_path


def get_from_cache(
    url: str,
    cache_dir: Optional[PathOrStr] = None,
    quiet: bool = False,
    progress: Optional["Progress"] = None,
    no_downloads: bool = False,
    _client: Optional[SchemeClient] = None,
) -> Tuple[Path, Optional[str]]:
    """
    Given a URL, look for the corresponding dataset in the local cache.
    If it's not there, download it. Then return the path to the cached file and the ETag.
    """
    if url.startswith("hf://"):
        return hf_get_from_cache(url, cache_dir), None

    cache_dir = Path(cache_dir if cache_dir else get_cache_dir()).expanduser()
    cache_dir.mkdir(parents=True, exist_ok=True)
    client = _client or get_scheme_client(url)

    # Get eTag to add to filename, if it exists.
    try:
        etag = client.get_etag()
    except client.recoverable_errors:  # type: ignore
        # We might be offline, in which case we don't want to throw an error
        # just yet. Instead, we'll try to use the latest cached version of the
        # target resource, if it exists. We'll only throw an exception if we
        # haven't cached the resource at all yet.
        logger.warning(
            "Connection error occurred while trying to fetch ETag for %s. "
            "Will attempt to use latest cached version of resource",
            url,
        )
        latest_cached = find_latest_cached(url, cache_dir)
        if latest_cached:
            logger.info(
                "ETag request failed with recoverable error, using latest cached "
                "version of %s: %s",
                url,
                latest_cached,
            )
            meta = Meta.from_path(_meta_file_path(latest_cached))
            return latest_cached, meta.etag
        else:
            logger.error(
                "ETag request failed with recoverable error, "
                "but no cached version of %s could be found",
                url,
            )
            raise

    filename = resource_to_filename(url, etag)

    # Get cache path to put the file.
    cache_path = cache_dir / filename

    # Multiple processes may be trying to cache the same file at once, so we need
    # to be a little careful to avoid race conditions. We do this using a lock file.
    # Only one process can own this lock file at a time, and a process will block
    # on the call to `lock.acquire()` until the process currently holding the lock
    # releases it.
    logger.debug("waiting to acquire lock on %s", cache_path)
    with FileLock(_lock_file_path(cache_path), read_only_ok=True):
        if os.path.exists(cache_path):
            logger.info("cache of %s is up-to-date", url)
        elif no_downloads:
            raise FileNotFoundError(cache_path)
        else:
            size = client.get_size()
            with CacheFile(cache_path) as cache_file:
                logger.info("%s not found in cache, downloading to %s", url, cache_path)

                from .progress import BufferedWriterWithProgress, get_download_progress

                start_and_cleanup = progress is None
                progress = progress or get_download_progress(quiet=quiet)

                if start_and_cleanup:
                    progress.start()

                try:
                    display_url = url if len(url) <= 30 else f"\N{horizontal ellipsis}{url[-30:]}"
                    task_id = progress.add_task(f"Downloading [cyan i]{display_url}[/]", total=size)
                    writer_with_progress = BufferedWriterWithProgress(cache_file, progress, task_id)
                    client.get_resource(writer_with_progress)
                    progress.update(
                        task_id,
                        total=writer_with_progress.total_written,
                        completed=writer_with_progress.total_written,
                    )
                finally:
                    if start_and_cleanup:
                        progress.stop()

            logger.debug("creating metadata file for %s", cache_path)
            meta = Meta.new(
                url,
                cache_path,
                etag=etag,
            )
            meta.to_file()

    return cache_path, etag
Back to Directory File Manager