# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_basics.ipynb.
# %% auto 0
__all__ = ['defaults', 'null', 'num_methods', 'rnum_methods', 'inum_methods', 'arg0', 'arg1', 'arg2', 'arg3', 'arg4', 'Self',
'ifnone', 'maybe_attr', 'basic_repr', 'is_array', 'listify', 'tuplify', 'true', 'NullType', 'tonull',
'get_class', 'mk_class', 'wrap_class', 'ignore_exceptions', 'exec_local', 'risinstance', 'Inf', 'in_',
'ret_true', 'ret_false', 'stop', 'gen', 'chunked', 'otherwise', 'custom_dir', 'AttrDict', 'NS',
'get_annotations_ex', 'eval_type', 'type_hints', 'annotations', 'anno_ret', 'signature_ex', 'union2tuple',
'argnames', 'with_cast', 'store_attr', 'attrdict', 'properties', 'camel2words', 'camel2snake', 'snake2camel',
'class2attr', 'getcallable', 'getattrs', 'hasattrs', 'setattrs', 'try_attrs', 'GetAttrBase', 'GetAttr',
'delegate_attr', 'ShowPrint', 'Int', 'Str', 'Float', 'partition', 'flatten', 'concat', 'strcat', 'detuplify',
'replicate', 'setify', 'merge', 'range_of', 'groupby', 'last_index', 'filter_dict', 'filter_keys',
'filter_values', 'cycle', 'zip_cycle', 'sorted_ex', 'not_', 'argwhere', 'filter_ex', 'renumerate', 'first',
'only', 'nested_attr', 'nested_setdefault', 'nested_callable', 'nested_idx', 'set_nested_idx', 'val2idx',
'uniqueify', 'loop_first_last', 'loop_first', 'loop_last', 'first_match', 'last_match', 'fastuple', 'bind',
'mapt', 'map_ex', 'compose', 'maps', 'partialler', 'instantiate', 'using_attr', 'copy_func', 'patch_to',
'patch', 'patch_property', 'compile_re', 'ImportEnum', 'StrEnum', 'str_enum', 'Stateful', 'NotStr',
'PrettyString', 'even_mults', 'num_cpus', 'add_props', 'typed', 'exec_new', 'exec_import', 'str2bool', 'lt',
'gt', 'le', 'ge', 'eq', 'ne', 'add', 'sub', 'mul', 'truediv', 'is_', 'is_not', 'mod']
# %% ../nbs/01_basics.ipynb 1
from .imports import *
import builtins,types,typing
import pprint
try: from types import UnionType
except ImportError: UnionType = None
# %% ../nbs/01_basics.ipynb 5
defaults = SimpleNamespace()
# %% ../nbs/01_basics.ipynb 6
def ifnone(a, b):
"`b` if `a` is None else `a`"
return b if a is None else a
# %% ../nbs/01_basics.ipynb 9
def maybe_attr(o, attr):
"`getattr(o,attr,o)`"
return getattr(o,attr,o)
# %% ../nbs/01_basics.ipynb 12
def basic_repr(flds=None):
"Minimal `__repr__`"
if isinstance(flds, str): flds = re.split(', *', flds)
flds = list(flds or [])
def _f(self):
res = f'{type(self).__module__}.{type(self).__name__}'
if not flds: return f'<{res}>'
sig = ', '.join(f'{o}={getattr(self,o)!r}' for o in flds)
return f'{res}({sig})'
return _f
# %% ../nbs/01_basics.ipynb 18
def is_array(x):
"`True` if `x` supports `__array__` or `iloc`"
return hasattr(x,'__array__') or hasattr(x,'iloc')
# %% ../nbs/01_basics.ipynb 20
def listify(o=None, *rest, use_list=False, match=None):
"Convert `o` to a `list`"
if rest: o = (o,)+rest
if use_list: res = list(o)
elif o is None: res = []
elif isinstance(o, list): res = o
elif isinstance(o, str) or is_array(o): res = [o]
elif is_iter(o): res = list(o)
else: res = [o]
if match is not None:
if is_coll(match): match = len(match)
if len(res)==1: res = res*match
else: assert len(res)==match, 'Match length mismatch'
return res
# %% ../nbs/01_basics.ipynb 33
def tuplify(o, use_list=False, match=None):
"Make `o` a tuple"
return tuple(listify(o, use_list=use_list, match=match))
# %% ../nbs/01_basics.ipynb 35
def true(x):
"Test whether `x` is truthy; collections with >0 elements are considered `True`"
try: return bool(len(x))
except: return bool(x)
# %% ../nbs/01_basics.ipynb 37
class NullType:
"An object that is `False` and can be called, chained, and indexed"
def __getattr__(self,*args):return null
def __call__(self,*args, **kwargs):return null
def __getitem__(self, *args):return null
def __bool__(self): return False
null = NullType()
# %% ../nbs/01_basics.ipynb 39
def tonull(x):
"Convert `None` to `null`"
return null if x is None else x
# %% ../nbs/01_basics.ipynb 41
def get_class(nm, *fld_names, sup=None, doc=None, funcs=None, anno=None, **flds):
"Dynamically create a class, optionally inheriting from `sup`, containing `fld_names`"
attrs = {}
if not anno: anno = {}
for f in fld_names:
attrs[f] = None
if f not in anno: anno[f] = typing.Any
for f in listify(funcs): attrs[f.__name__] = f
for k,v in flds.items(): attrs[k] = v
sup = ifnone(sup, ())
if not isinstance(sup, tuple): sup=(sup,)
def _init(self, *args, **kwargs):
for i,v in enumerate(args): setattr(self, list(attrs.keys())[i], v)
for k,v in kwargs.items(): setattr(self,k,v)
attrs['_fields'] = [*fld_names,*flds.keys()]
def _eq(self,b):
return all([getattr(self,k)==getattr(b,k) for k in self._fields])
if not sup: attrs['__repr__'] = basic_repr(attrs['_fields'])
attrs['__init__'] = _init
attrs['__eq__'] = _eq
if anno: attrs['__annotations__'] = anno
res = type(nm, sup, attrs)
if doc is not None: res.__doc__ = doc
return res
# %% ../nbs/01_basics.ipynb 45
def mk_class(nm, *fld_names, sup=None, doc=None, funcs=None, mod=None, anno=None, **flds):
"Create a class using `get_class` and add to the caller's module"
if mod is None: mod = sys._getframe(1).f_locals
res = get_class(nm, *fld_names, sup=sup, doc=doc, funcs=funcs, anno=anno, **flds)
mod[nm] = res
# %% ../nbs/01_basics.ipynb 50
def wrap_class(nm, *fld_names, sup=None, doc=None, funcs=None, **flds):
"Decorator: makes function a method of a new class `nm` passing parameters to `mk_class`"
def _inner(f):
mk_class(nm, *fld_names, sup=sup, doc=doc, funcs=listify(funcs)+[f], mod=f.__globals__, **flds)
return f
return _inner
# %% ../nbs/01_basics.ipynb 52
class ignore_exceptions:
"Context manager to ignore exceptions"
def __enter__(self): pass
def __exit__(self, *args): return True
# %% ../nbs/01_basics.ipynb 55
def exec_local(code, var_name):
"Call `exec` on `code` and return the var `var_name`"
loc = {}
exec(code, globals(), loc)
return loc[var_name]
# %% ../nbs/01_basics.ipynb 57
def risinstance(types, obj=None):
"Curried `isinstance` but with args reversed"
types = tuplify(types)
if obj is None: return partial(risinstance,types)
if any(isinstance(t,str) for t in types):
return any(t.__name__ in types for t in type(obj).__mro__)
return isinstance(obj, types)
# %% ../nbs/01_basics.ipynb 69
class _InfMeta(type):
@property
def count(self): return itertools.count()
@property
def zeros(self): return itertools.cycle([0])
@property
def ones(self): return itertools.cycle([1])
@property
def nones(self): return itertools.cycle([None])
# %% ../nbs/01_basics.ipynb 70
class Inf(metaclass=_InfMeta):
"Infinite lists"
pass
# %% ../nbs/01_basics.ipynb 75
_dumobj = object()
def _oper(op,a,b=_dumobj): return (lambda o:op(o,a)) if b is _dumobj else op(a,b)
def _mk_op(nm, mod):
"Create an operator using `oper` and add to the caller's module"
op = getattr(operator,nm)
def _inner(a, b=_dumobj): return _oper(op, a,b)
_inner.__name__ = _inner.__qualname__ = nm
_inner.__doc__ = f'Same as `operator.{nm}`, or returns partial if 1 arg'
mod[nm] = _inner
# %% ../nbs/01_basics.ipynb 76
def in_(x, a):
"`True` if `x in a`"
return x in a
operator.in_ = in_
# %% ../nbs/01_basics.ipynb 77
_all_ = ['lt','gt','le','ge','eq','ne','add','sub','mul','truediv','is_','is_not','in_', 'mod']
# %% ../nbs/01_basics.ipynb 78
for op in _all_: _mk_op(op, globals())
# %% ../nbs/01_basics.ipynb 84
def ret_true(*args, **kwargs):
"Predicate: always `True`"
return True
# %% ../nbs/01_basics.ipynb 86
def ret_false(*args, **kwargs):
"Predicate: always `False`"
return False
# %% ../nbs/01_basics.ipynb 87
def stop(e=StopIteration):
"Raises exception `e` (by default `StopIteration`)"
raise e
# %% ../nbs/01_basics.ipynb 88
def gen(func, seq, cond=ret_true):
"Like `(func(o) for o in seq if cond(func(o)))` but handles `StopIteration`"
return itertools.takewhile(cond, map(func,seq))
# %% ../nbs/01_basics.ipynb 90
def chunked(it, chunk_sz=None, drop_last=False, n_chunks=None):
"Return batches from iterator `it` of size `chunk_sz` (or return `n_chunks` total)"
assert bool(chunk_sz) ^ bool(n_chunks)
if n_chunks: chunk_sz = max(math.ceil(len(it)/n_chunks), 1)
if not isinstance(it, Iterator): it = iter(it)
while True:
res = list(itertools.islice(it, chunk_sz))
if res and (len(res)==chunk_sz or not drop_last): yield res
if len(res)<chunk_sz: return
# %% ../nbs/01_basics.ipynb 93
def otherwise(x, tst, y):
"`y if tst(x) else x`"
return y if tst(x) else x
# %% ../nbs/01_basics.ipynb 97
def custom_dir(c, add):
"Implement custom `__dir__`, adding `add` to `cls`"
return object.__dir__(c) + listify(add)
# %% ../nbs/01_basics.ipynb 100
class AttrDict(dict):
"`dict` subclass that also provides access to keys as attrs"
def __getattr__(self,k): return self[k] if k in self else stop(AttributeError(k))
def __setattr__(self, k, v): (self.__setitem__,super().__setattr__)[k[0]=='_'](k,v)
def __dir__(self): return super().__dir__() + list(self.keys())
def _repr_markdown_(self): return f'```json\n{pprint.pformat(self, indent=2)}\n```'
def copy(self): return AttrDict(**self)
# %% ../nbs/01_basics.ipynb 104
class NS(SimpleNamespace):
"`SimpleNamespace` subclass that also adds `iter` and `dict` support"
def __iter__(self): return iter(self.__dict__)
def __getitem__(self,x): return self.__dict__[x]
def __setitem__(self,x,y): self.__dict__[x] = y
# %% ../nbs/01_basics.ipynb 111
def get_annotations_ex(obj, *, globals=None, locals=None):
"Backport of py3.10 `get_annotations` that returns globals/locals"
if isinstance(obj, type):
obj_dict = getattr(obj, '__dict__', None)
if obj_dict and hasattr(obj_dict, 'get'):
ann = obj_dict.get('__annotations__', None)
if isinstance(ann, types.GetSetDescriptorType): ann = None
else: ann = None
obj_globals = None
module_name = getattr(obj, '__module__', None)
if module_name:
module = sys.modules.get(module_name, None)
if module: obj_globals = getattr(module, '__dict__', None)
obj_locals = dict(vars(obj))
unwrap = obj
elif isinstance(obj, types.ModuleType):
ann = getattr(obj, '__annotations__', None)
obj_globals = getattr(obj, '__dict__')
obj_locals,unwrap = None,None
elif callable(obj):
ann = getattr(obj, '__annotations__', None)
obj_globals = getattr(obj, '__globals__', None)
obj_locals,unwrap = None,obj
else: raise TypeError(f"{obj!r} is not a module, class, or callable.")
if ann is None: ann = {}
if not isinstance(ann, dict): raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None")
if not ann: ann = {}
if unwrap is not None:
while True:
if hasattr(unwrap, '__wrapped__'):
unwrap = unwrap.__wrapped__
continue
if isinstance(unwrap, functools.partial):
unwrap = unwrap.func
continue
break
if hasattr(unwrap, "__globals__"): obj_globals = unwrap.__globals__
if globals is None: globals = obj_globals
if locals is None: locals = obj_locals
return dict(ann), globals, locals
# %% ../nbs/01_basics.ipynb 113
def eval_type(t, glb, loc):
"`eval` a type or collection of types, if needed, for annotations in py3.10+"
if isinstance(t,str):
if '|' in t: return Union[eval_type(tuple(t.split('|')), glb, loc)]
return eval(t, glb, loc)
if isinstance(t,(tuple,list)): return type(t)([eval_type(c, glb, loc) for c in t])
return t
# %% ../nbs/01_basics.ipynb 118
def _eval_type(t, glb, loc):
res = eval_type(t, glb, loc)
return NoneType if res is None else res
def type_hints(f):
"Like `typing.get_type_hints` but returns `{}` if not allowed type"
if not isinstance(f, typing._allowed_types): return {}
ann,glb,loc = get_annotations_ex(f)
return {k:_eval_type(v,glb,loc) for k,v in ann.items()}
# %% ../nbs/01_basics.ipynb 125
def annotations(o):
"Annotations for `o`, or `type(o)`"
res = {}
if not o: return res
res = type_hints(o)
if not res: res = type_hints(getattr(o,'__init__',None))
if not res: res = type_hints(type(o))
return res
# %% ../nbs/01_basics.ipynb 128
def anno_ret(func):
"Get the return annotation of `func`"
return annotations(func).get('return', None) if func else None
# %% ../nbs/01_basics.ipynb 134
def _ispy3_10(): return sys.version_info.major >=3 and sys.version_info.minor >=10
def signature_ex(obj, eval_str:bool=False):
"Backport of `inspect.signature(..., eval_str=True` to <py310"
from inspect import Signature, Parameter, signature
def _eval_param(ann, k, v):
if k not in ann: return v
return Parameter(v.name, v.kind, annotation=ann[k], default=v.default)
if not eval_str: return signature(obj)
if _ispy3_10(): return signature(obj, eval_str=eval_str)
sig = signature(obj)
if sig is None: return None
ann = type_hints(obj)
params = [_eval_param(ann,k,v) for k,v in sig.parameters.items()]
return Signature(params, return_annotation=sig.return_annotation)
# %% ../nbs/01_basics.ipynb 135
def union2tuple(t):
if (getattr(t, '__origin__', None) is Union
or (UnionType and isinstance(t, UnionType))): return t.__args__
return t
# %% ../nbs/01_basics.ipynb 137
def argnames(f, frame=False):
"Names of arguments to function or frame `f`"
code = getattr(f, 'f_code' if frame else '__code__')
return code.co_varnames[:code.co_argcount+code.co_kwonlyargcount]
# %% ../nbs/01_basics.ipynb 139
def with_cast(f):
"Decorator which uses any parameter annotations as preprocessing functions"
anno, out_anno, params = annotations(f), anno_ret(f), argnames(f)
c_out = ifnone(out_anno, noop)
defaults = dict(zip(reversed(params), reversed(f.__defaults__ or {})))
@functools.wraps(f)
def _inner(*args, **kwargs):
args = list(args)
for i,v in enumerate(params):
if v in anno:
c = anno[v]
if v in kwargs: kwargs[v] = c(kwargs[v])
elif i<len(args): args[i] = c(args[i])
elif v in defaults: kwargs[v] = c(defaults[v])
return c_out(f(*args, **kwargs))
return _inner
# %% ../nbs/01_basics.ipynb 141
def _store_attr(self, anno, **attrs):
stored = getattr(self, '__stored_args__', None)
for n,v in attrs.items():
if n in anno: v = anno[n](v)
setattr(self, n, v)
if stored is not None: stored[n] = v
# %% ../nbs/01_basics.ipynb 142
def store_attr(names=None, self=None, but='', cast=False, store_args=None, **attrs):
"Store params named in comma-separated `names` from calling context into attrs in `self`"
fr = sys._getframe(1)
args = argnames(fr, True)
if self: args = ('self', *args)
else: self = fr.f_locals[args[0]]
if store_args is None: store_args = not hasattr(self,'__slots__')
if store_args and not hasattr(self, '__stored_args__'): self.__stored_args__ = {}
anno = annotations(self) if cast else {}
if names and isinstance(names,str): names = re.split(', *', names)
ns = names if names is not None else getattr(self, '__slots__', args[1:])
added = {n:fr.f_locals[n] for n in ns}
attrs = {**attrs, **added}
if isinstance(but,str): but = re.split(', *', but)
attrs = {k:v for k,v in attrs.items() if k not in but}
return _store_attr(self, anno, **attrs)
# %% ../nbs/01_basics.ipynb 171
def attrdict(o, *ks, default=None):
"Dict from each `k` in `ks` to `getattr(o,k)`"
return {k:getattr(o, k, default) for k in ks}
# %% ../nbs/01_basics.ipynb 173
def properties(cls, *ps):
"Change attrs in `cls` with names in `ps` to properties"
for p in ps: setattr(cls,p,property(getattr(cls,p)))
# %% ../nbs/01_basics.ipynb 175
_c2w_re = re.compile(r'((?<=[a-z])[A-Z]|(?<!\A)[A-Z](?=[a-z]))')
_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')
# %% ../nbs/01_basics.ipynb 176
def camel2words(s, space=' '):
"Convert CamelCase to 'spaced words'"
return re.sub(_c2w_re, rf'{space}\1', s)
# %% ../nbs/01_basics.ipynb 178
def camel2snake(name):
"Convert CamelCase to snake_case"
s1 = re.sub(_camel_re1, r'\1_\2', name)
return re.sub(_camel_re2, r'\1_\2', s1).lower()
# %% ../nbs/01_basics.ipynb 180
def snake2camel(s):
"Convert snake_case to CamelCase"
return ''.join(s.title().split('_'))
# %% ../nbs/01_basics.ipynb 182
def class2attr(self, cls_name):
"Return the snake-cased name of the class; strip ending `cls_name` if it exists."
return camel2snake(re.sub(rf'{cls_name}$', '', self.__class__.__name__) or cls_name.lower())
# %% ../nbs/01_basics.ipynb 184
def getcallable(o, attr):
"Calls `getattr` with a default of `noop`"
return getattr(o, attr, noop)
# %% ../nbs/01_basics.ipynb 186
def getattrs(o, *attrs, default=None):
"List of all `attrs` in `o`"
return [getattr(o,attr,default) for attr in attrs]
# %% ../nbs/01_basics.ipynb 189
def hasattrs(o,attrs):
"Test whether `o` contains all `attrs`"
return all(hasattr(o,attr) for attr in attrs)
# %% ../nbs/01_basics.ipynb 191
def setattrs(dest, flds, src):
f = dict.get if isinstance(src, dict) else getattr
flds = re.split(r",\s*", flds)
for fld in flds: setattr(dest, fld, f(src, fld))
# %% ../nbs/01_basics.ipynb 194
def try_attrs(obj, *attrs):
"Return first attr that exists in `obj`"
for att in attrs:
try: return getattr(obj, att)
except: pass
raise AttributeError(attrs)
# %% ../nbs/01_basics.ipynb 197
class GetAttrBase:
"Basic delegation of `__getattr__` and `__dir__`"
_attr=noop
def __getattr__(self,k):
if k[0]=='_' or k==self._attr: return super().__getattr__(k)
return self._getattr(getattr(self, self._attr)[k])
def __dir__(self): return custom_dir(self, getattr(self, self._attr))
# %% ../nbs/01_basics.ipynb 198
class GetAttr:
"Inherit from this to have all attr accesses in `self._xtra` passed down to `self.default`"
_default='default'
def _component_attr_filter(self,k):
if k.startswith('__') or k in ('_xtra',self._default): return False
xtra = getattr(self,'_xtra',None)
return xtra is None or k in xtra
def _dir(self): return [k for k in dir(getattr(self,self._default)) if self._component_attr_filter(k)]
def __getattr__(self,k):
if self._component_attr_filter(k):
attr = getattr(self,self._default,None)
if attr is not None: return getattr(attr,k)
raise AttributeError(k)
def __dir__(self): return custom_dir(self,self._dir())
# def __getstate__(self): return self.__dict__
def __setstate__(self,data): self.__dict__.update(data)
# %% ../nbs/01_basics.ipynb 218
def delegate_attr(self, k, to):
"Use in `__getattr__` to delegate to attr `to` without inheriting from `GetAttr`"
if k.startswith('_') or k==to: raise AttributeError(k)
try: return getattr(getattr(self,to), k)
except AttributeError: raise AttributeError(k) from None
# %% ../nbs/01_basics.ipynb 224
class ShowPrint:
"Base class that prints for `show`"
def show(self, *args, **kwargs): print(str(self))
# %% ../nbs/01_basics.ipynb 226
class Int(int,ShowPrint):
"An extensible `int`"
pass
# %% ../nbs/01_basics.ipynb 227
class Str(str,ShowPrint):
"An extensible `str`"
pass
class Float(float,ShowPrint):
"An extensible `float`"
pass
# %% ../nbs/01_basics.ipynb 232
def partition(coll, f):
"Partition a collection by a predicate"
ts,fs = [],[]
for o in coll: (fs,ts)[f(o)].append(o)
if isinstance(coll,tuple):
typ = type(coll)
ts,fs = typ(ts),typ(fs)
return ts,fs
# %% ../nbs/01_basics.ipynb 234
def flatten(o):
"Concatenate all collections and items as a generator"
for item in o:
if isinstance(item, str): yield item; continue
try: yield from flatten(item)
except TypeError: yield item
# %% ../nbs/01_basics.ipynb 235
def concat(colls)->list:
"Concatenate all collections and items as a list"
return list(flatten(colls))
# %% ../nbs/01_basics.ipynb 238
def strcat(its, sep:str='')->str:
"Concatenate stringified items `its`"
return sep.join(map(str,its))
# %% ../nbs/01_basics.ipynb 240
def detuplify(x):
"If `x` is a tuple with one thing, extract it"
return None if len(x)==0 else x[0] if len(x)==1 and getattr(x, 'ndim', 1)==1 else x
# %% ../nbs/01_basics.ipynb 242
def replicate(item,match):
"Create tuple of `item` copied `len(match)` times"
return (item,)*len(match)
# %% ../nbs/01_basics.ipynb 244
def setify(o):
"Turn any list like-object into a set."
return o if isinstance(o,set) else set(listify(o))
# %% ../nbs/01_basics.ipynb 246
def merge(*ds):
"Merge all dictionaries in `ds`"
return {k:v for d in ds if d is not None for k,v in d.items()}
# %% ../nbs/01_basics.ipynb 248
def range_of(x):
"All indices of collection `x` (i.e. `list(range(len(x)))`)"
return list(range(len(x)))
# %% ../nbs/01_basics.ipynb 250
def groupby(x, key, val=noop):
"Like `itertools.groupby` but doesn't need to be sorted, and isn't lazy, plus some extensions"
if isinstance(key,int): key = itemgetter(key)
elif isinstance(key,str): key = attrgetter(key)
if isinstance(val,int): val = itemgetter(val)
elif isinstance(val,str): val = attrgetter(val)
res = {}
for o in x: res.setdefault(key(o), []).append(val(o))
return res
# %% ../nbs/01_basics.ipynb 254
def last_index(x, o):
"Finds the last index of occurence of `x` in `o` (returns -1 if no occurence)"
try: return next(i for i in reversed(range(len(o))) if o[i] == x)
except StopIteration: return -1
# %% ../nbs/01_basics.ipynb 256
def filter_dict(d, func):
"Filter a `dict` using `func`, applied to keys and values"
return {k:v for k,v in d.items() if func(k,v)}
# %% ../nbs/01_basics.ipynb 259
def filter_keys(d, func):
"Filter a `dict` using `func`, applied to keys"
return {k:v for k,v in d.items() if func(k)}
# %% ../nbs/01_basics.ipynb 261
def filter_values(d, func):
"Filter a `dict` using `func`, applied to values"
return {k:v for k,v in d.items() if func(v)}
# %% ../nbs/01_basics.ipynb 263
def cycle(o):
"Like `itertools.cycle` except creates list of `None`s if `o` is empty"
o = listify(o)
return itertools.cycle(o) if o is not None and len(o) > 0 else itertools.cycle([None])
# %% ../nbs/01_basics.ipynb 265
def zip_cycle(x, *args):
"Like `itertools.zip_longest` but `cycle`s through elements of all but first argument"
return zip(x, *map(cycle,args))
# %% ../nbs/01_basics.ipynb 267
def sorted_ex(iterable, key=None, reverse=False):
"Like `sorted`, but if key is str use `attrgetter`; if int use `itemgetter`"
if isinstance(key,str): k=lambda o:getattr(o,key,0)
elif isinstance(key,int): k=itemgetter(key)
else: k=key
return sorted(iterable, key=k, reverse=reverse)
# %% ../nbs/01_basics.ipynb 268
def not_(f):
"Create new function that negates result of `f`"
def _f(*args, **kwargs): return not f(*args, **kwargs)
return _f
# %% ../nbs/01_basics.ipynb 270
def argwhere(iterable, f, negate=False, **kwargs):
"Like `filter_ex`, but return indices for matching items"
if kwargs: f = partial(f,**kwargs)
if negate: f = not_(f)
return [i for i,o in enumerate(iterable) if f(o)]
# %% ../nbs/01_basics.ipynb 271
def filter_ex(iterable, f=noop, negate=False, gen=False, **kwargs):
"Like `filter`, but passing `kwargs` to `f`, defaulting `f` to `noop`, and adding `negate` and `gen`"
if f is None: f = lambda _: True
if kwargs: f = partial(f,**kwargs)
if negate: f = not_(f)
res = filter(f, iterable)
if gen: return res
return list(res)
# %% ../nbs/01_basics.ipynb 272
def range_of(a, b=None, step=None):
"All indices of collection `a`, if `a` is a collection, otherwise `range`"
if is_coll(a): a = len(a)
return list(range(a,b,step) if step is not None else range(a,b) if b is not None else range(a))
# %% ../nbs/01_basics.ipynb 274
def renumerate(iterable, start=0):
"Same as `enumerate`, but returns index as 2nd element instead of 1st"
return ((o,i) for i,o in enumerate(iterable, start=start))
# %% ../nbs/01_basics.ipynb 276
def first(x, f=None, negate=False, **kwargs):
"First element of `x`, optionally filtered by `f`, or None if missing"
x = iter(x)
if f: x = filter_ex(x, f=f, negate=negate, gen=True, **kwargs)
return next(x, None)
# %% ../nbs/01_basics.ipynb 278
def only(o):
"Return the only item of `o`, raise if `o` doesn't have exactly one item"
it = iter(o)
try: res = next(it)
except StopIteration: raise ValueError('iterable has 0 items') from None
try: next(it)
except StopIteration: return res
raise ValueError(f'iterable has more than 1 item')
# %% ../nbs/01_basics.ipynb 280
def nested_attr(o, attr, default=None):
"Same as `getattr`, but if `attr` includes a `.`, then looks inside nested objects"
try:
for a in attr.split("."): o = getattr(o, a)
except AttributeError: return default
return o
# %% ../nbs/01_basics.ipynb 282
def nested_setdefault(o, attr, default):
"Same as `setdefault`, but if `attr` includes a `.`, then looks inside nested objects"
attrs = attr.split('.')
for a in attrs[:-1]: o = o.setdefault(a, type(o)())
return o.setdefault(attrs[-1], default)
# %% ../nbs/01_basics.ipynb 286
def nested_callable(o, attr):
"Same as `nested_attr` but if not found will return `noop`"
return nested_attr(o, attr, noop)
# %% ../nbs/01_basics.ipynb 288
def _access(coll, idx):
if isinstance(idx,str) and hasattr(coll, idx): return getattr(coll, idx)
if hasattr(coll, 'get'): return coll.get(idx, None)
try: length = len(coll)
except TypeError: length = 0
if isinstance(idx,int) and idx<length: return coll[idx]
return None
def _nested_idx(coll, *idxs):
*idxs,last_idx = idxs
for idx in idxs:
if isinstance(idx,str) and hasattr(coll, idx): coll = getattr(coll, idx)
else:
if isinstance(coll,str) or not isinstance(coll, typing.Collection): return None,None
coll = coll.get(idx, None) if hasattr(coll, 'get') else coll[idx] if idx<len(coll) else None
return coll,last_idx
# %% ../nbs/01_basics.ipynb 289
def nested_idx(coll, *idxs):
"Index into nested collections, dicts, etc, with `idxs`"
if not coll or not idxs: return coll
coll,idx = _nested_idx(coll, *idxs)
if not coll or not idxs: return coll
return _access(coll, idx)
# %% ../nbs/01_basics.ipynb 292
def set_nested_idx(coll, value, *idxs):
"Set value indexed like `nested_idx"
coll,idx = _nested_idx(coll, *idxs)
coll[idx] = value
# %% ../nbs/01_basics.ipynb 294
def val2idx(x):
"Dict from value to index"
return {v:k for k,v in enumerate(x)}
# %% ../nbs/01_basics.ipynb 296
def uniqueify(x, sort=False, bidir=False, start=None):
"Unique elements in `x`, optional `sort`, optional return reverse correspondence, optional prepend with elements."
res = list(dict.fromkeys(x))
if start is not None: res = listify(start)+res
if sort: res.sort()
return (res,val2idx(res)) if bidir else res
# %% ../nbs/01_basics.ipynb 298
# looping functions from https://github.com/willmcgugan/rich/blob/master/rich/_loop.py
def loop_first_last(values):
"Iterate and generate a tuple with a flag for first and last value."
iter_values = iter(values)
try: previous_value = next(iter_values)
except StopIteration: return
first = True
for value in iter_values:
yield first,False,previous_value
first,previous_value = False,value
yield first,True,previous_value
# %% ../nbs/01_basics.ipynb 300
def loop_first(values):
"Iterate and generate a tuple with a flag for first value."
return ((b,o) for b,_,o in loop_first_last(values))
# %% ../nbs/01_basics.ipynb 302
def loop_last(values):
"Iterate and generate a tuple with a flag for last value."
return ((b,o) for _,b,o in loop_first_last(values))
# %% ../nbs/01_basics.ipynb 304
def first_match(lst, f, default=None):
"First element of `lst` matching predicate `f`, or `default` if none"
return next((i for i,o in enumerate(lst) if f(o)), default)
# %% ../nbs/01_basics.ipynb 306
def last_match(lst, f, default=None):
"Last element of `lst` matching predicate `f`, or `default` if none"
return next((i for i in range(len(lst)-1, -1, -1) if f(lst[i])), default)
# %% ../nbs/01_basics.ipynb 310
num_methods = """
__add__ __sub__ __mul__ __matmul__ __truediv__ __floordiv__ __mod__ __divmod__ __pow__
__lshift__ __rshift__ __and__ __xor__ __or__ __neg__ __pos__ __abs__
""".split()
rnum_methods = """
__radd__ __rsub__ __rmul__ __rmatmul__ __rtruediv__ __rfloordiv__ __rmod__ __rdivmod__
__rpow__ __rlshift__ __rrshift__ __rand__ __rxor__ __ror__
""".split()
inum_methods = """
__iadd__ __isub__ __imul__ __imatmul__ __itruediv__
__ifloordiv__ __imod__ __ipow__ __ilshift__ __irshift__ __iand__ __ixor__ __ior__
""".split()
# %% ../nbs/01_basics.ipynb 311
class fastuple(tuple):
"A `tuple` with elementwise ops and more friendly __init__ behavior"
def __new__(cls, x=None, *rest):
if x is None: x = ()
if not isinstance(x,tuple):
if len(rest): x = (x,)
else:
try: x = tuple(iter(x))
except TypeError: x = (x,)
return super().__new__(cls, x+rest if rest else x)
def _op(self,op,*args):
if not isinstance(self,fastuple): self = fastuple(self)
return type(self)(map(op,self,*map(cycle, args)))
def mul(self,*args):
"`*` is already defined in `tuple` for replicating, so use `mul` instead"
return fastuple._op(self, operator.mul,*args)
def add(self,*args):
"`+` is already defined in `tuple` for concat, so use `add` instead"
return fastuple._op(self, operator.add,*args)
def _get_op(op):
if isinstance(op,str): op = getattr(operator,op)
def _f(self,*args): return self._op(op,*args)
return _f
for n in num_methods:
if not hasattr(fastuple, n) and hasattr(operator,n): setattr(fastuple,n,_get_op(n))
for n in 'eq ne lt le gt ge'.split(): setattr(fastuple,n,_get_op(n))
setattr(fastuple,'__invert__',_get_op('__not__'))
setattr(fastuple,'max',_get_op(max))
setattr(fastuple,'min',_get_op(min))
# %% ../nbs/01_basics.ipynb 329
class _Arg:
def __init__(self,i): self.i = i
arg0 = _Arg(0)
arg1 = _Arg(1)
arg2 = _Arg(2)
arg3 = _Arg(3)
arg4 = _Arg(4)
# %% ../nbs/01_basics.ipynb 330
class bind:
"Same as `partial`, except you can use `arg0` `arg1` etc param placeholders"
def __init__(self, func, *pargs, **pkwargs):
self.func,self.pargs,self.pkwargs = func,pargs,pkwargs
self.maxi = max((x.i for x in pargs if isinstance(x, _Arg)), default=-1)
def __call__(self, *args, **kwargs):
args = list(args)
kwargs = {**self.pkwargs,**kwargs}
for k,v in kwargs.items():
if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
return self.func(*fargs, **kwargs)
# %% ../nbs/01_basics.ipynb 342
def mapt(func, *iterables):
"Tuplified `map`"
return tuple(map(func, *iterables))
# %% ../nbs/01_basics.ipynb 344
def map_ex(iterable, f, *args, gen=False, **kwargs):
"Like `map`, but use `bind`, and supports `str` and indexing"
g = (bind(f,*args,**kwargs) if callable(f)
else f.format if isinstance(f,str)
else f.__getitem__)
res = map(g, iterable)
if gen: return res
return list(res)
# %% ../nbs/01_basics.ipynb 352
def compose(*funcs, order=None):
"Create a function that composes all functions in `funcs`, passing along remaining `*args` and `**kwargs` to all"
funcs = listify(funcs)
if len(funcs)==0: return noop
if len(funcs)==1: return funcs[0]
if order is not None: funcs = sorted_ex(funcs, key=order)
def _inner(x, *args, **kwargs):
for f in funcs: x = f(x, *args, **kwargs)
return x
return _inner
# %% ../nbs/01_basics.ipynb 354
def maps(*args, retain=noop):
"Like `map`, except funcs are composed first"
f = compose(*args[:-1])
def _f(b): return retain(f(b), b)
return map(_f, args[-1])
# %% ../nbs/01_basics.ipynb 356
def partialler(f, *args, order=None, **kwargs):
"Like `functools.partial` but also copies over docstring"
fnew = partial(f,*args,**kwargs)
fnew.__doc__ = f.__doc__
if order is not None: fnew.order=order
elif hasattr(f,'order'): fnew.order=f.order
return fnew
# %% ../nbs/01_basics.ipynb 360
def instantiate(t):
"Instantiate `t` if it's a type, otherwise do nothing"
return t() if isinstance(t, type) else t
# %% ../nbs/01_basics.ipynb 362
def _using_attr(f, attr, x): return f(getattr(x,attr))
# %% ../nbs/01_basics.ipynb 363
def using_attr(f, attr):
"Construct a function which applies `f` to the argument's attribute `attr`"
return partial(_using_attr, f, attr)
# %% ../nbs/01_basics.ipynb 367
class _Self:
"An alternative to `lambda` for calling methods on passed object."
def __init__(self): self.nms,self.args,self.kwargs,self.ready = [],[],[],True
def __repr__(self): return f'self: {self.nms}({self.args}, {self.kwargs})'
def __call__(self, *args, **kwargs):
if self.ready:
x = args[0]
for n,a,k in zip(self.nms,self.args,self.kwargs):
x = getattr(x,n)
if callable(x) and a is not None: x = x(*a, **k)
return x
else:
self.args.append(args)
self.kwargs.append(kwargs)
self.ready = True
return self
def __getattr__(self,k):
if not self.ready:
self.args.append(None)
self.kwargs.append(None)
self.nms.append(k)
self.ready = False
return self
def _call(self, *args, **kwargs):
self.args,self.kwargs,self.nms = [args],[kwargs],['__call__']
self.ready = True
return self
# %% ../nbs/01_basics.ipynb 368
class _SelfCls:
def __getattr__(self,k): return getattr(_Self(),k)
def __getitem__(self,i): return self.__getattr__('__getitem__')(i)
def __call__(self,*args,**kwargs): return self.__getattr__('_call')(*args,**kwargs)
Self = _SelfCls()
# %% ../nbs/01_basics.ipynb 369
_all_ = ['Self']
# %% ../nbs/01_basics.ipynb 375
def copy_func(f):
"Copy a non-builtin function (NB `copy.copy` does not work for this)"
if not isinstance(f,FunctionType): return copy(f)
fn = FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__, f.__closure__)
fn.__kwdefaults__ = f.__kwdefaults__
fn.__dict__.update(f.__dict__)
fn.__annotations__.update(f.__annotations__)
fn.__qualname__ = f.__qualname__
return fn
# %% ../nbs/01_basics.ipynb 382
class _clsmethod:
def __init__(self, f): self.f = f
def __get__(self, _, f_cls): return MethodType(self.f, f_cls)
# %% ../nbs/01_basics.ipynb 383
def patch_to(cls, as_prop=False, cls_method=False):
"Decorator: add `f` to `cls`"
if not isinstance(cls, (tuple,list)): cls=(cls,)
def _inner(f):
for c_ in cls:
nf = copy_func(f)
nm = f.__name__
# `functools.update_wrapper` when passing patched function to `Pipeline`, so we do it manually
for o in functools.WRAPPER_ASSIGNMENTS: setattr(nf, o, getattr(f,o))
nf.__qualname__ = f"{c_.__name__}.{nm}"
if cls_method: setattr(c_, nm, _clsmethod(nf))
else:
if as_prop: setattr(c_, nm, property(nf))
else:
onm = '_orig_'+nm
if hasattr(c_, nm) and not hasattr(c_, onm): setattr(c_, onm, getattr(c_, nm))
setattr(c_, nm, nf)
# Avoid clobbering existing functions
return globals().get(nm, builtins.__dict__.get(nm, None))
return _inner
# %% ../nbs/01_basics.ipynb 394
def patch(f=None, *, as_prop=False, cls_method=False):
"Decorator: add `f` to the first parameter's class (based on f's type annotations)"
if f is None: return partial(patch, as_prop=as_prop, cls_method=cls_method)
ann,glb,loc = get_annotations_ex(f)
cls = union2tuple(eval_type(ann.pop('cls') if cls_method else next(iter(ann.values())), glb, loc))
return patch_to(cls, as_prop=as_prop, cls_method=cls_method)(f)
# %% ../nbs/01_basics.ipynb 402
def patch_property(f):
"Deprecated; use `patch(as_prop=True)` instead"
warnings.warn("`patch_property` is deprecated and will be removed; use `patch(as_prop=True)` instead")
cls = next(iter(f.__annotations__.values()))
return patch_to(cls, as_prop=True)(f)
# %% ../nbs/01_basics.ipynb 406
def compile_re(pat):
"Compile `pat` if it's not None"
return None if pat is None else re.compile(pat)
# %% ../nbs/01_basics.ipynb 408
class ImportEnum(enum.Enum):
"An `Enum` that can have its values imported"
@classmethod
def imports(cls):
g = sys._getframe(1).f_locals
for o in cls: g[o.name]=o
# %% ../nbs/01_basics.ipynb 411
class StrEnum(str,ImportEnum):
"An `ImportEnum` that behaves like a `str`"
def __str__(self): return self.name
# %% ../nbs/01_basics.ipynb 413
def str_enum(name, *vals):
"Simplified creation of `StrEnum` types"
return StrEnum(name, {o:o for o in vals})
# %% ../nbs/01_basics.ipynb 415
class Stateful:
"A base class/mixin for objects that should not serialize all their state"
_stateattrs=()
def __init__(self,*args,**kwargs):
self._init_state()
super().__init__(*args,**kwargs) # required for mixin usage
def __getstate__(self):
return {k:v for k,v in self.__dict__.items()
if k not in self._stateattrs+('_state',)}
def __setstate__(self, state):
self.__dict__.update(state)
self._init_state()
def _init_state(self):
"Override for custom init and deserialization logic"
self._state = {}
# %% ../nbs/01_basics.ipynb 421
class NotStr(GetAttr):
"Behaves like a `str`, but isn't an instance of one"
_default = 's'
def __init__(self, s): self.s = s.s if isinstance(s, NotStr) else s
def __repr__(self): return repr(self.s)
def __str__(self): return self.s
def __add__(self, b): return NotStr(self.s+str(b))
def __mul__(self, b): return NotStr(self.s*b)
def __len__(self): return len(self.s)
def __eq__(self, b): return self.s==b.s if isinstance(b, NotStr) else b
def __lt__(self, b): return self.s<b
def __hash__(self): return hash(self.s)
def __bool__(self): return bool(self.s)
def __contains__(self, b): return b in self.s
def __iter__(self): return iter(self.s)
# %% ../nbs/01_basics.ipynb 423
class PrettyString(str):
"Little hack to get strings to show properly in Jupyter."
def __repr__(self): return self
# %% ../nbs/01_basics.ipynb 429
def even_mults(start, stop, n):
"Build log-stepped array from `start` to `stop` in `n` steps."
if n==1: return stop
mult = stop/start
step = mult**(1/(n-1))
return [start*(step**i) for i in range(n)]
# %% ../nbs/01_basics.ipynb 431
def num_cpus():
"Get number of cpus"
try: return len(os.sched_getaffinity(0))
except AttributeError: return os.cpu_count()
defaults.cpus = num_cpus()
# %% ../nbs/01_basics.ipynb 433
def add_props(f, g=None, n=2):
"Create properties passing each of `range(n)` to f"
if g is None: return (property(partial(f,i)) for i in range(n))
return (property(partial(f,i), partial(g,i)) for i in range(n))
# %% ../nbs/01_basics.ipynb 436
def _typeerr(arg, val, typ): return TypeError(f"{arg}=={val} not {typ}")
# %% ../nbs/01_basics.ipynb 437
def typed(f):
"Decorator to check param and return types at runtime"
names = f.__code__.co_varnames
anno = annotations(f)
ret = anno.pop('return',None)
def _f(*args,**kwargs):
kw = {**kwargs}
if len(anno) > 0:
for i,arg in enumerate(args): kw[names[i]] = arg
for k,v in kw.items():
if k in anno and not isinstance(v,anno[k]): raise _typeerr(k, v, anno[k])
res = f(*args,**kwargs)
if ret is not None and not isinstance(res,ret): raise _typeerr("return", res, ret)
return res
return functools.update_wrapper(_f, f)
# %% ../nbs/01_basics.ipynb 445
def exec_new(code):
"Execute `code` in a new environment and return it"
pkg = None if __name__=='__main__' else Path().cwd().name
g = {'__name__': __name__, '__package__': pkg}
exec(code, g)
return g
# %% ../nbs/01_basics.ipynb 447
def exec_import(mod, sym):
"Import `sym` from `mod` in a new environment"
# pref = '' if __name__=='__main__' or mod[0]=='.' else '.'
return exec_new(f'from {mod} import {sym}')
# %% ../nbs/01_basics.ipynb 448
def str2bool(s):
"Case-insensitive convert string `s` too a bool (`y`,`yes`,`t`,`true`,`on`,`1`->`True`)"
if not isinstance(s,str): return bool(s)
if not s: return False
s = s.lower()
if s in ('y', 'yes', 't', 'true', 'on', '1'): return 1
elif s in ('n', 'no', 'f', 'false', 'off', '0'): return 0
else: raise ValueError()