Viewing File: /home/ubuntu/.local/lib/python3.10/site-packages/fastcore/xtras.py

# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_xtras.ipynb.

# %% ../nbs/03_xtras.ipynb 1
from __future__ import annotations

# %% auto 0
__all__ = ['spark_chars', 'walk', 'globtastic', 'maybe_open', 'mkdir', 'image_size', 'bunzip', 'loads', 'loads_multi', 'dumps',
           'untar_dir', 'repo_details', 'run', 'open_file', 'save_pickle', 'load_pickle', 'parse_env',
           'expand_wildcards', 'dict2obj', 'obj2dict', 'repr_dict', 'is_listy', 'mapped', 'IterLen',
           'ReindexCollection', 'get_source_link', 'truncstr', 'sparkline', 'modify_exception', 'round_multiple',
           'set_num_threads', 'join_path_file', 'autostart', 'EventTimer', 'stringfmt_names', 'PartialFormatter',
           'partial_format', 'utc2local', 'local2utc', 'trace', 'modified_env', 'ContextManagers', 'shufflish',
           'console_help', 'hl_md', 'type2str', 'dataclass_src', 'nullable_dc', 'make_nullable', 'mk_dataclass']

# %% ../nbs/03_xtras.ipynb 2
from .imports import *
from .foundation import *
from .basics import *
from importlib import import_module
from functools import wraps
import string,time
from contextlib import contextmanager,ExitStack
from datetime import datetime, timezone

# %% ../nbs/03_xtras.ipynb 7
def walk(
    path:Path|str, # path to start searching
    symlinks:bool=True, # follow symlinks?
    keep_file:callable=ret_true, # function that returns True for wanted files
    keep_folder:callable=ret_true, # function that returns True for folders to enter
    skip_folder:callable=ret_false, # function that returns True for folders to skip
    func:callable=os.path.join, # function to apply to each matched file
    ret_folders:bool=False  # return folders, not just files
):
    "Generator version of `os.walk`, using functions to filter files and folders"
    from copy import copy
    for root,dirs,files in os.walk(path, followlinks=symlinks):
        if keep_folder(root,''):
            if ret_folders: yield func(root, '')
            yield from (func(root, name) for name in files if keep_file(root,name))
        for name in copy(dirs):
            if skip_folder(root,name): dirs.remove(name)

# %% ../nbs/03_xtras.ipynb 8
def globtastic(
    path:Path|str, # path to start searching
    recursive:bool=True, # search subfolders
    symlinks:bool=True, # follow symlinks?
    file_glob:str=None, # Only include files matching glob
    file_re:str=None, # Only include files matching regex
    folder_re:str=None, # Only enter folders matching regex
    skip_file_glob:str=None, # Skip files matching glob
    skip_file_re:str=None, # Skip files matching regex
    skip_folder_re:str=None, # Skip folders matching regex,
    func:callable=os.path.join, # function to apply to each matched file
    ret_folders:bool=False  # return folders, not just files
)->L: # Paths to matched files
    "A more powerful `glob`, including regex matches, symlink handling, and skip parameters"
    from fnmatch import fnmatch
    path = Path(path)
    if path.is_file(): return L([path])
    if not recursive: skip_folder_re='.'
    file_re,folder_re = compile_re(file_re),compile_re(folder_re)
    skip_file_re,skip_folder_re = compile_re(skip_file_re),compile_re(skip_folder_re)
    def _keep_file(root, name):
        return (not file_glob or fnmatch(name, file_glob)) and (
                not file_re or file_re.search(name)) and (
                not skip_file_glob or not fnmatch(name, skip_file_glob)) and (
                not skip_file_re or not skip_file_re.search(name))
    def _keep_folder(root, name): return not folder_re or folder_re.search(os.path.join(root,name))
    def _skip_folder(root, name): return skip_folder_re and skip_folder_re.search(name)
    return L(walk(path, symlinks=symlinks, keep_file=_keep_file, keep_folder=_keep_folder, skip_folder=_skip_folder,
                  func=func, ret_folders=ret_folders))

# %% ../nbs/03_xtras.ipynb 10
@contextmanager
def maybe_open(f, mode='r', **kwargs):
    "Context manager: open `f` if it is a path (and close on exit)"
    if isinstance(f, (str,os.PathLike)):
        with open(f, mode, **kwargs) as f: yield f
    else: yield f

# %% ../nbs/03_xtras.ipynb 26
def mkdir(path, exist_ok=False, parents=False, overwrite=False, **kwargs):
    "Creates and returns a directory defined by `path`, optionally removing previous existing directory if `overwrite` is `True`"
    import shutil
    path = Path(path)
    if path.exists() and overwrite: shutil.rmtree(path)
    path.mkdir(exist_ok=exist_ok, parents=parents, **kwargs)
    return path

# %% ../nbs/03_xtras.ipynb 28
def image_size(fn):
    "Tuple of (w,h) for png, gif, or jpg; `None` otherwise"
    from fastcore import imghdr
    import struct
    def _jpg_size(f):
        size,ftype = 2,0
        while not 0xc0 <= ftype <= 0xcf:
            f.seek(size, 1)
            byte = f.read(1)
            while ord(byte) == 0xff: byte = f.read(1)
            ftype = ord(byte)
            size = struct.unpack('>H', f.read(2))[0] - 2
        f.seek(1, 1)  # `precision'
        h,w = struct.unpack('>HH', f.read(4))
        return w,h

    def _gif_size(f): return struct.unpack('<HH', head[6:10])

    def _png_size(f):
        assert struct.unpack('>i', head[4:8])[0]==0x0d0a1a0a
        return struct.unpack('>ii', head[16:24])
    d = dict(png=_png_size, gif=_gif_size, jpeg=_jpg_size)
    with maybe_open(fn, 'rb') as f: return d[imghdr.what(f)](f)

# %% ../nbs/03_xtras.ipynb 30
def bunzip(fn):
    "bunzip `fn`, raising exception if output already exists"
    fn = Path(fn)
    assert fn.exists(), f"{fn} doesn't exist"
    out_fn = fn.with_suffix('')
    assert not out_fn.exists(), f"{out_fn} already exists"
    import bz2
    with bz2.BZ2File(fn, 'rb') as src, out_fn.open('wb') as dst:
        for d in iter(lambda: src.read(1024*1024), b''): dst.write(d)

# %% ../nbs/03_xtras.ipynb 32
def loads(s, **kw):
    "Same as `json.loads`, but handles `None`"
    if not s: return {}
    try: import ujson as json
    except ModuleNotFoundError: import json
    return json.loads(s, **kw)

# %% ../nbs/03_xtras.ipynb 33
def loads_multi(s:str):
    "Generator of >=0 decoded json dicts, possibly with non-json ignored text at start and end"
    import json
    _dec = json.JSONDecoder()
    while s.find('{')>=0:
        s = s[s.find('{'):]
        obj,pos = _dec.raw_decode(s)
        if not pos: raise ValueError(f'no JSON object found at {pos}')
        yield obj
        s = s[pos:]

# %% ../nbs/03_xtras.ipynb 35
def dumps(obj, **kw):
    "Same as `json.dumps`, but uses `ujson` if available"
    try: import ujson as json
    except ModuleNotFoundError: import json
    else: kw['escape_forward_slashes']=False
    return json.dumps(obj, **kw)

# %% ../nbs/03_xtras.ipynb 36
def _unpack(fname, out):
    import shutil
    shutil.unpack_archive(str(fname), str(out))
    ls = out.ls()
    return ls[0] if len(ls) == 1 else out

# %% ../nbs/03_xtras.ipynb 37
def untar_dir(fname, dest, rename=False, overwrite=False):
    "untar `file` into `dest`, creating a directory if the root contains more than one item"
    import tempfile,shutil
    with tempfile.TemporaryDirectory() as d:
        out = Path(d)/remove_suffix(Path(fname).stem, '.tar')
        out.mkdir()
        if rename: dest = dest/out.name
        else:
            src = _unpack(fname, out)
            dest = dest/src.name
        if dest.exists():
            if overwrite: shutil.rmtree(dest) if dest.is_dir() else dest.unlink()
            else: return dest
        if rename: src = _unpack(fname, out)
        shutil.move(str(src), dest)
        return dest

# %% ../nbs/03_xtras.ipynb 45
def repo_details(url):
    "Tuple of `owner,name` from ssh or https git repo `url`"
    res = remove_suffix(url.strip(), '.git')
    res = res.split(':')[-1]
    return res.split('/')[-2:]

# %% ../nbs/03_xtras.ipynb 47
def run(cmd, *rest, same_in_win=False, ignore_ex=False, as_bytes=False, stderr=False):
    "Pass `cmd` (splitting with `shlex` if string) to `subprocess.run`; return `stdout`; raise `IOError` if fails"
    # Even the command is same on Windows, we have to add `cmd /c `"
    import subprocess
    if rest:
        if sys.platform == 'win32' and same_in_win:
            cmd = ('cmd', '/c', cmd, *rest)
        else:
            cmd = (cmd,)+rest
    elif isinstance(cmd, str):
        if sys.platform == 'win32' and same_in_win: cmd = 'cmd /c ' + cmd
        import shlex
        cmd = shlex.split(cmd)
    elif isinstance(cmd, list):
        if sys.platform == 'win32' and same_in_win: cmd = ['cmd', '/c'] + cmd
    res = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    stdout = res.stdout
    if stderr and res.stderr: stdout += b' ;; ' + res.stderr
    if not as_bytes: stdout = stdout.decode().strip()
    if ignore_ex: return (res.returncode, stdout)
    if res.returncode: raise IOError(stdout)
    return stdout

# %% ../nbs/03_xtras.ipynb 55
def open_file(fn, mode='r', **kwargs):
    "Open a file, with optional compression if gz or bz2 suffix"
    if isinstance(fn, io.IOBase): return fn
    import bz2,gzip,zipfile
    fn = Path(fn)
    if   fn.suffix=='.bz2': return bz2.BZ2File(fn, mode, **kwargs)
    elif fn.suffix=='.gz' : return gzip.GzipFile(fn, mode, **kwargs)
    elif fn.suffix=='.zip': return zipfile.ZipFile(fn, mode, **kwargs)
    else: return open(fn,mode, **kwargs)

# %% ../nbs/03_xtras.ipynb 56
def save_pickle(fn, o):
    "Save a pickle file, to a file name or opened file"
    import pickle
    with open_file(fn, 'wb') as f: pickle.dump(o, f)

# %% ../nbs/03_xtras.ipynb 57
def load_pickle(fn):
    "Load a pickle file from a file name or opened file"
    import pickle
    with open_file(fn, 'rb') as f: return pickle.load(f)

# %% ../nbs/03_xtras.ipynb 59
def parse_env(s:str=None, fn:Union[str,Path]=None) -> dict:
    "Parse a shell-style environment string or file"
    assert bool(s)^bool(fn), "Must pass exactly one of `s` or `fn`"
    if fn: s = Path(fn).read_text()
    def _f(line):
        m = re.match(r'^\s*(?:export\s+)?(\w+)\s*=\s*(["\']?)(.*?)(\2)\s*(?:#.*)?$', line).groups()
        return m[0], m[2]

    return dict(_f(o.strip()) for o in s.splitlines() if o.strip() and not re.match(r'\s*#', o))

# %% ../nbs/03_xtras.ipynb 61
def expand_wildcards(code):
    "Expand all wildcard imports in the given code string."
    import ast,importlib
    tree = ast.parse(code)

    def _replace_node(code, old_node, new_node):
        "Replace `old_node` in the source `code` with `new_node`."
        lines = code.splitlines()
        lnum = old_node.lineno
        indent = ' ' * (len(lines[lnum-1]) - len(lines[lnum-1].lstrip()))
        new_lines = [indent+line for line in ast.unparse(new_node).splitlines()]
        lines[lnum-1 : old_node.end_lineno] = new_lines
        return '\n'.join(lines)

    def _expand_import(node, mod, existing):
        "Create expanded import `node` in `tree` from wildcard import of `mod`."
        mod_all = getattr(mod, '__all__', None)
        available_names = set(mod_all) if mod_all is not None else set(dir(mod))
        used_names = {n.id for n in ast.walk(tree) if isinstance(n, ast.Name) and n.id in available_names} - existing
        if not used_names: return node
        names = [ast.alias(name=name, asname=None) for name in sorted(used_names)]
        return ast.ImportFrom(module=node.module, names=names, level=node.level)

    existing = set()
    for node in ast.walk(tree):
        if isinstance(node, ast.ImportFrom) and node.names[0].name != '*': existing.update(n.name for n in node.names)
        elif isinstance(node, ast.Import): existing.update(n.name.split('.')[0] for n in node.names)
    for node in ast.walk(tree):
        if isinstance(node, ast.ImportFrom) and any(n.name == '*' for n in node.names):
            new_import = _expand_import(node, importlib.import_module(node.module), existing)
            code = _replace_node(code, node, new_import)
    return code

# %% ../nbs/03_xtras.ipynb 64
def dict2obj(d, list_func=L, dict_func=AttrDict):
    "Convert (possibly nested) dicts (or lists of dicts) to `AttrDict`"
    if isinstance(d, (L,list)): return list_func(d).map(dict2obj)
    if not isinstance(d, dict): return d
    return dict_func(**{k:dict2obj(v) for k,v in d.items()})

# %% ../nbs/03_xtras.ipynb 69
def obj2dict(d):
    "Convert (possibly nested) AttrDicts (or lists of AttrDicts) to `dict`"
    if isinstance(d, (L,list)): return list(L(d).map(obj2dict))
    if not isinstance(d, dict): return d
    return dict(**{k:obj2dict(v) for k,v in d.items()})

# %% ../nbs/03_xtras.ipynb 72
def _repr_dict(d, lvl):
    if isinstance(d,dict):
        its = [f"{k}: {_repr_dict(v,lvl+1)}" for k,v in d.items()]
    elif isinstance(d,(list,L)): its = [_repr_dict(o,lvl+1) for o in d]
    else: return str(d)
    return '\n' + '\n'.join([" "*(lvl*2) + "- " + o for o in its])

# %% ../nbs/03_xtras.ipynb 73
def repr_dict(d):
    "Print nested dicts and lists, such as returned by `dict2obj`"
    return _repr_dict(d,0).strip()

# %% ../nbs/03_xtras.ipynb 75
def is_listy(x):
    "`isinstance(x, (tuple,list,L,slice,Generator))`"
    return isinstance(x, (tuple,list,L,slice,Generator))

# %% ../nbs/03_xtras.ipynb 77
def mapped(f, it):
    "map `f` over `it`, unless it's not listy, in which case return `f(it)`"
    return L(it).map(f) if is_listy(it) else f(it)

# %% ../nbs/03_xtras.ipynb 81
@patch
def readlines(self:Path, hint=-1, encoding='utf8'):
    "Read the content of `self`"
    with self.open(encoding=encoding) as f: return f.readlines(hint)

# %% ../nbs/03_xtras.ipynb 82
@patch
def read_json(self:Path, encoding=None, errors=None):
    "Same as `read_text` followed by `loads`"
    return loads(self.read_text(encoding=encoding, errors=errors))

# %% ../nbs/03_xtras.ipynb 83
@patch
def mk_write(self:Path, data, encoding=None, errors=None, mode=511):
    "Make all parent dirs of `self`, and write `data`"
    self.parent.mkdir(exist_ok=True, parents=True, mode=mode)
    self.write_text(data, encoding=encoding, errors=errors)

# %% ../nbs/03_xtras.ipynb 84
@patch
def relpath(self:Path, start=None):
    "Same as `os.path.relpath`, but returns a `Path`, and resolves symlinks"
    return Path(os.path.relpath(self.resolve(), Path(start).resolve()))

# %% ../nbs/03_xtras.ipynb 87
@patch
def ls(self:Path, n_max=None, file_type=None, file_exts=None):
    "Contents of path as a list"
    import mimetypes
    extns=L(file_exts)
    if file_type: extns += L(k for k,v in mimetypes.types_map.items() if v.startswith(file_type+'/'))
    has_extns = len(extns)==0
    res = (o for o in self.iterdir() if has_extns or o.suffix in extns)
    if n_max is not None: res = itertools.islice(res, n_max)
    return L(res)

# %% ../nbs/03_xtras.ipynb 93
@patch
def __repr__(self:Path):
    b = getattr(Path, 'BASE_PATH', None)
    if b:
        try: self = self.relative_to(b)
        except: pass
    return f"Path({self.as_posix()!r})"

# %% ../nbs/03_xtras.ipynb 96
@patch
def delete(self:Path):
    "Delete a file, symlink, or directory tree"
    if not self.exists(): return
    if self.is_dir():
        import shutil
        shutil.rmtree(self)
    else: self.unlink()

# %% ../nbs/03_xtras.ipynb 98
class IterLen:
    "Base class to add iteration to anything supporting `__len__` and `__getitem__`"
    def __iter__(self): return (self[i] for i in range_of(self))

# %% ../nbs/03_xtras.ipynb 99
@docs
class ReindexCollection(GetAttr, IterLen):
    "Reindexes collection `coll` with indices `idxs` and optional LRU cache of size `cache`"
    _default='coll'
    def __init__(self, coll, idxs=None, cache=None, tfm=noop):
        if idxs is None: idxs = L.range(coll)
        store_attr()
        if cache is not None: self._get = functools.lru_cache(maxsize=cache)(self._get)

    def _get(self, i): return self.tfm(self.coll[i])
    def __getitem__(self, i): return self._get(self.idxs[i])
    def __len__(self): return len(self.coll)
    def reindex(self, idxs): self.idxs = idxs
    def shuffle(self):
        import random
        random.shuffle(self.idxs)
    def cache_clear(self): self._get.cache_clear()
    def __getstate__(self): return {'coll': self.coll, 'idxs': self.idxs, 'cache': self.cache, 'tfm': self.tfm}
    def __setstate__(self, s): self.coll,self.idxs,self.cache,self.tfm = s['coll'],s['idxs'],s['cache'],s['tfm']

    _docs = dict(reindex="Replace `self.idxs` with idxs",
                shuffle="Randomly shuffle indices",
                cache_clear="Clear LRU cache")

# %% ../nbs/03_xtras.ipynb 118
def _is_type_dispatch(x): return type(x).__name__ == "TypeDispatch"
def _unwrapped_type_dispatch_func(x): return x.first() if _is_type_dispatch(x) else x

def _is_property(x): return type(x)==property
def _has_property_getter(x): return _is_property(x) and hasattr(x, 'fget') and hasattr(x.fget, 'func')
def _property_getter(x): return x.fget.func if _has_property_getter(x) else x

def _unwrapped_func(x):
    x = _unwrapped_type_dispatch_func(x)
    x = _property_getter(x)
    return x


def get_source_link(func):
    "Return link to `func` in source code"
    import inspect
    func = _unwrapped_func(func)
    try: line = inspect.getsourcelines(func)[1]
    except Exception: return ''
    mod = inspect.getmodule(func)
    module = mod.__name__.replace('.', '/') + '.py'
    try:
        nbdev_mod = import_module(mod.__package__.split('.')[0] + '._nbdev')
        return f"{nbdev_mod.git_url}{module}#L{line}"
    except: return f"{module}#L{line}"

# %% ../nbs/03_xtras.ipynb 122
def truncstr(s:str, maxlen:int, suf:str='…', space='')->str:
    "Truncate `s` to length `maxlen`, adding suffix `suf` if truncated"
    return s[:maxlen-len(suf)]+suf if len(s)+len(space)>maxlen else s+space

# %% ../nbs/03_xtras.ipynb 124
spark_chars = '▁▂▃▅▆▇'

# %% ../nbs/03_xtras.ipynb 125
def _ceil(x, lim=None): return x if (not lim or x <= lim) else lim

def _sparkchar(x, mn, mx, incr, empty_zero):
    if x is None or (empty_zero and not x): return ' '
    if incr == 0: return spark_chars[0]
    res = int((_ceil(x,mx)-mn)/incr-0.5)
    return spark_chars[res]

# %% ../nbs/03_xtras.ipynb 126
def sparkline(data, mn=None, mx=None, empty_zero=False):
    "Sparkline for `data`, with `None`s (and zero, if `empty_zero`) shown as empty column"
    valid = [o for o in data if o is not None]
    if not valid: return ' '
    mn,mx,n = ifnone(mn,min(valid)),ifnone(mx,max(valid)),len(spark_chars)
    res = [_sparkchar(x=o, mn=mn, mx=mx, incr=(mx-mn)/n, empty_zero=empty_zero) for o in data]
    return ''.join(res)

# %% ../nbs/03_xtras.ipynb 130
def modify_exception(
    e:Exception, # An exception
    msg:str=None, # A custom message
    replace:bool=False, # Whether to replace e.args with [msg]
) -> Exception:
    "Modifies `e` with a custom message attached"
    e.args = [f'{e.args[0]} {msg}'] if not replace and len(e.args) > 0 else [msg]
    return e

# %% ../nbs/03_xtras.ipynb 132
def round_multiple(x, mult, round_down=False):
    "Round `x` to nearest multiple of `mult`"
    def _f(x_): return (int if round_down else round)(x_/mult)*mult
    res = L(x).map(_f)
    return res if is_listy(x) else res[0]

# %% ../nbs/03_xtras.ipynb 134
def set_num_threads(nt):
    "Get numpy (and others) to use `nt` threads"
    try: import mkl; mkl.set_num_threads(nt)
    except: pass
    try: import torch; torch.set_num_threads(nt)
    except: pass
    os.environ['IPC_ENABLE']='1'
    for o in ['OPENBLAS_NUM_THREADS','NUMEXPR_NUM_THREADS','OMP_NUM_THREADS','MKL_NUM_THREADS']:
        os.environ[o] = str(nt)

# %% ../nbs/03_xtras.ipynb 136
def join_path_file(file, path, ext=''):
    "Return `path/file` if file is a string or a `Path`, file otherwise"
    if not isinstance(file, (str, Path)): return file
    path.mkdir(parents=True, exist_ok=True)
    return path/f'{file}{ext}'

# %% ../nbs/03_xtras.ipynb 138
def autostart(g):
    "Decorator that automatically starts a generator"
    @functools.wraps(g)
    def f():
        r = g()
        next(r)
        return r
    return f

# %% ../nbs/03_xtras.ipynb 139
class EventTimer:
    "An event timer with history of `store` items of time `span`"

    def __init__(self, store=5, span=60):
        import collections
        self.hist,self.span,self.last = collections.deque(maxlen=store),span,time.perf_counter()
        self._reset()

    def _reset(self): self.start,self.events = self.last,0

    def add(self, n=1):
        "Record `n` events"
        if self.duration>self.span:
            self.hist.append(self.freq)
            self._reset()
        self.events +=n
        self.last = time.perf_counter()

    @property
    def duration(self): return time.perf_counter()-self.start
    @property
    def freq(self): return self.events/self.duration

# %% ../nbs/03_xtras.ipynb 143
_fmt = string.Formatter()

# %% ../nbs/03_xtras.ipynb 144
def stringfmt_names(s:str)->list:
    "Unique brace-delimited names in `s`"
    return uniqueify(o[1] for o in _fmt.parse(s) if o[1])

# %% ../nbs/03_xtras.ipynb 146
class PartialFormatter(string.Formatter):
    "A `string.Formatter` that doesn't error on missing fields, and tracks missing fields and unused args"
    def __init__(self):
        self.missing = set()
        super().__init__()

    def get_field(self, nm, args, kwargs):
        try: return super().get_field(nm, args, kwargs)
        except KeyError:
            self.missing.add(nm)
            return '{'+nm+'}',nm

    def check_unused_args(self, used, args, kwargs):
        self.xtra = filter_keys(kwargs, lambda o: o not in used)

# %% ../nbs/03_xtras.ipynb 148
def partial_format(s:str, **kwargs):
    "string format `s`, ignoring missing field errors, returning missing and extra fields"
    fmt = PartialFormatter()
    res = fmt.format(s, **kwargs)
    return res,list(fmt.missing),fmt.xtra

# %% ../nbs/03_xtras.ipynb 151
def utc2local(dt:datetime)->datetime:
    "Convert `dt` from UTC to local time"
    return dt.replace(tzinfo=timezone.utc).astimezone(tz=None)

# %% ../nbs/03_xtras.ipynb 153
def local2utc(dt:datetime)->datetime:
    "Convert `dt` from local to UTC time"
    return dt.replace(tzinfo=None).astimezone(tz=timezone.utc)

# %% ../nbs/03_xtras.ipynb 155
def trace(f):
    "Add `set_trace` to an existing function `f`"
    from pdb import set_trace
    if getattr(f, '_traced', False): return f
    def _inner(*args,**kwargs):
        set_trace()
        return f(*args,**kwargs)
    _inner._traced = True
    return _inner

# %% ../nbs/03_xtras.ipynb 157
@contextmanager
def modified_env(*delete, **replace):
    "Context manager temporarily modifying `os.environ` by deleting `delete` and replacing `replace`"
    prev = dict(os.environ)
    try:
        os.environ.update(replace)
        for k in delete: os.environ.pop(k, None)
        yield
    finally:
        os.environ.clear()
        os.environ.update(prev)

# %% ../nbs/03_xtras.ipynb 159
class ContextManagers(GetAttr):
    "Wrapper for `contextlib.ExitStack` which enters a collection of context managers"
    def __init__(self, mgrs): self.default,self.stack = L(mgrs),ExitStack()
    def __enter__(self): self.default.map(self.stack.enter_context)
    def __exit__(self, *args, **kwargs): self.stack.__exit__(*args, **kwargs)

# %% ../nbs/03_xtras.ipynb 161
def shufflish(x, pct=0.04):
    "Randomly relocate items of `x` up to `pct` of `len(x)` from their starting location"
    n = len(x)
    import random
    return L(x[i] for i in sorted(range_of(x), key=lambda o: o+n*(1+random.random()*pct)))

# %% ../nbs/03_xtras.ipynb 162
def console_help(
    libname:str):  # name of library for console script listing
    "Show help for all console scripts from `libname`"
    from fastcore.style import S
    from pkg_resources import iter_entry_points as ep
    for e in ep('console_scripts'): 
        if e.module_name == libname or e.module_name.startswith(libname+'.'): 
            nm = S.bold.light_blue(e.name)
            print(f'{nm:45}{e.load().__doc__}')


# %% ../nbs/03_xtras.ipynb 163
def hl_md(s, lang='xml', show=True):
    "Syntax highlight `s` using `lang`."
    md = f'```{lang}\n{s}\n```'
    if not show: return md
    try:
        from IPython import display
        return display.Markdown(md)
    except ImportError: print(s)

# %% ../nbs/03_xtras.ipynb 166
def type2str(typ:type)->str:
    "Stringify `typ`"
    if typ is None or typ is NoneType: return 'None'
    if hasattr(typ, '__origin__'):
        args = ", ".join(type2str(arg) for arg in typ.__args__)
        if typ.__origin__ is Union: return f"Union[{args}]"
        return f"{typ.__origin__.__name__}[{args}]"
    elif isinstance(typ, type): return typ.__name__
    return str(typ)

# %% ../nbs/03_xtras.ipynb 168
def dataclass_src(cls):
    import dataclasses
    src = f"@dataclass\nclass {cls.__name__}:\n"
    for f in dataclasses.fields(cls):
        d = "" if f.default is dataclasses.MISSING else f" = {f.default!r}"
        src += f"    {f.name}: {type2str(f.type)}{d}\n"
    return src

# %% ../nbs/03_xtras.ipynb 171
def nullable_dc(cls):
    "Like `dataclass`, but default of `None` added to fields without defaults"
    from dataclasses import dataclass, field
    for k,v in get_annotations_ex(cls)[0].items():
        if not hasattr(cls,k): setattr(cls, k, field(default=None))
    return dataclass(cls)

# %% ../nbs/03_xtras.ipynb 173
def make_nullable(clas):
    from dataclasses import dataclass, fields, MISSING
    if hasattr(clas, '_nullable'): return
    clas._nullable = True

    original_init = clas.__init__
    def __init__(self, *args, **kwargs):
        flds = fields(clas)
        dargs = {k.name:v for k,v in zip(flds, args)}
        for f in flds:
            nm = f.name
            if nm not in dargs and nm not in kwargs and f.default is None and f.default_factory is MISSING:
                kwargs[nm] = None
        original_init(self, *args, **kwargs)
    
    clas.__init__ = __init__

    for f in fields(clas):
        if f.default is MISSING and f.default_factory is MISSING: f.default = None
    
    return clas

# %% ../nbs/03_xtras.ipynb 177
def mk_dataclass(cls):
    from dataclasses import dataclass, field, is_dataclass, MISSING
    if is_dataclass(cls): return make_nullable(cls)
    for k,v in get_annotations_ex(cls)[0].items():
        if not hasattr(cls,k) or getattr(cls,k) is MISSING:
            setattr(cls, k, field(default=None))
    dataclass(cls, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False)
Back to Directory File Manager