# -*- coding: utf-8 -*-
"""
A library of useful functions used throughout the *fyrd* package.
These include functions to handle data, format outputs, handle file opening,
run commands, check file extensions, get user input, and search and format
imports.
These functions are not intended to be accessed directly and so documentation
is limited.
"""
from __future__ import with_statement
import os as _os
import re as _re
import sys as _sys
import inspect as _inspect
import argparse as _argparse
from collections import OrderedDict as _OD
import bz2
import gzip
from subprocess import Popen
from subprocess import PIPE
from time import sleep
from glob import glob as _glob
from six import text_type as _txt
from six import string_types as _str
from six import integer_types as _int
from six.moves import input as _get_input
# Progress bar handling
from tqdm import tqdm, tqdm_notebook
try:
if str(type(get_ipython())) == "<class 'ipykernel.zmqshell.ZMQInteractiveShell'>":
_pb = tqdm_notebook
else:
_pb = tqdm
except NameError:
_pb = tqdm
from . import logme as _logme
STRPRSR = _re.compile(r'{(.*?)}')
[docs]def get_pbar(iterable, name=None, unit=None, **kwargs):
"""Return a tqdm progress bar iterable.
If progressbar is set to False in the config, will not be shown.
"""
from . import conf # Avoid reciprocal import issues
show_pb = bool(conf.get_option('queue', 'progressbar', True))
if 'desc' in kwargs:
dname = kwargs.pop('desc')
name = name if name else dname
if 'disable' in kwargs:
disable = kwargs['disable']
else:
disable = False if show_pb else True
return _pb(iterable, desc=name, unit=unit, disable=disable, **kwargs)
###############################################################################
# Useful Classes #
###############################################################################
[docs]class CommandError(Exception):
"""A custom exception."""
pass
###############################################################################
# Misc Functions #
###############################################################################
[docs]def string_getter(string):
"""Parse a string for `{}`, `{#}`, and `{string}`.
Parameters
----------
string : str
Returns
-------
ints : set
A set of ints containing all `{#}` values
vrs : set
A set of `{string}` values
Raises
------
ValueError
If both `{}` and `{#}` are passed
"""
locs = STRPRSR.findall(string)
ints = {int(i) for i in locs if i.isdigit()}
strs = {i for i in locs if not i.isdigit()}
if ints and '{}' in string:
raise ValueError('Cannot parse string with both numbered and '
'unnumbered braces')
return ints, strs
[docs]def parse_glob(string, get_vars=None):
"""Return a list of files that match a simple regex glob.
Parameters
----------
string : str
get_vars : list
A list of variable names to search for. The string must contain these
variables in the form `{variable}`. These variables will be temporarily
replaced with a `*` and then run through `glob.glob` to generate a list
of files. This list is then parsed to create the output.
Returns
-------
dict
Keys are all files that match the string, values are None if `get_vars`
is not passed. If `get_vars` is passed, the values are dictionaries
of `{'variable': 'result'}`. e.g. for `{name}.txt` and `hi.txt`::
{hi.txt: {name: 'hi'}}
Raises
------
ValueError
If blank or numeric variable names are used or if get_vars returns
multiple different names for a file.
"""
get_vars = listify(get_vars)
test_string = STRPRSR.sub('*', string)
files = _glob(test_string)
if not get_vars:
return _OD([(f, None) for f in files])
results = _OD([(f, {}) for f in files])
int_vars, str_vars = string_getter(string)
if '{}' in string or int_vars:
raise ValueError('Cannot have numeric placeholders in file strings ',
"i.e. no '{0}', '{1}', '{}', etc")
for var in get_vars:
if var not in str_vars:
_logme.log('Variable {0} not in search string: {1}'
.format(var, string), 'warn')
continue
# Turn search string into a regular expression
test_var = var if var.startswith('{') else '{' + var + '}'
test_string = STRPRSR.sub('.*?', string.replace(test_var, '(.*?)'))
test_string = _re.sub(r'([^.])\*', r'\1.*', test_string)
test_string = _re.sub(r'^\*', r'.*', test_string)
# Replace terminal non-greedy operators so we parse the whole string
if test_string.endswith('?'):
test_string = test_string[:-1]
if test_string.endswith('?)'):
test_string = test_string[:-2] + ')'
test_regex = _re.compile(test_string)
# Add to file dict
for fl in files:
vrs = test_regex.findall(fl)
ulen = len(set(vrs))
if ulen != 1:
_logme.log('File {0} has multiple values for {1}: {2}'
.format(fl, test_var, vrs), 'critical')
raise ValueError('Invalid file search string')
if ulen == 0:
_logme.log('File {0} has no results for {1}'
.format(fl, test_var), 'error')
continue
results[fl][var] = vrs[0]
return results
[docs]def file_getter(file_strings, variables, extra_vars=None, max_count=None):
"""Get a list of files and variable values using the search string.
The file strings can contain standard unix glob (like `*`) and variable
containing strings in the form `{name}`.
For example, a file_string of `{dir}/*.txt` will match every file that
ends in `.txt` in every directory relative to the current path.
The result for a directory name test with two files named 1.txt and 2.txt
is a list of::
[(('dir/1.txt'), {'dir': 'test'}),
(('dir/2.txt'), {'dir': 'test'})]
This is repeated for every file_string in file_strings, and the following
tests are done:
1. All file_strings must result in identical numbers of files
2. All variables must have only a single value in every file string
If there are multiple file_strings, they are added to the result x in
order, but the dictionary remains the same as variables must be shared. If
multiple file_strings are provided the results are combined by alphabetical
order.
Parameters
----------
file_strings : list of str
List of search strings, e.g. `*/*`, `*/*.txt`, `{dir}/*.txt` or
`{dir}/{file}.txt`
variables : list of str
List of variables to look for
extra_vars : list of str, optional
A list of additional variables specified in a very precise format::
new_var:orig_var:regex:sub_str
or
new_var:value
The orig_var must correspond to a variable in variables. var will be
generated by running re.sub(regex, sub_str, string) where string is
the result of orig_var for the given file set
max_count : int, optional
Max number of file_strings to parse, default is all.
Returns
-------
list
A list of files. Each list item will be a two-item tuple of
`(files, variables)`. Files will be a tuple with the same length as
max_count, or file_strings if max_count is None. Variables will be
a dictionary of all variables and extra_vars for this file set. e.g.::
[((file1, dir1, file2), {var1: val, var2: val})]
Raises
------
ValueError
Raised if any of the above tests are not met.
"""
# Make extra_var an empty list if None
extra_vars = listify(extra_vars) if extra_vars else []
# Get all file information
files = []
count = 0
for file_string in file_strings:
files.append(parse_glob(file_string, variables))
count += 1
if max_count and count == max_count:
break
# Make sure all files have values for variables
var_vals = {i: [] for i in variables}
empty = {}
for f in files:
for pvars in f.values():
# It's fine if the file string has no variables
if not pvars:
continue
for var, val in pvars.items():
if not val:
if var in empty:
empty[var].append(f)
else:
empty[var] = [f]
else:
var_vals[var].append(val)
fail = False
if empty:
_logme.log('The following variables had no '
'result in some files, cannot continue:\n'
'\n'.join(
['{0} files: {1}'.format(i, j) for i, j in empty.items()]
), 'critical')
fail = True
# Join files
bad = []
results = []
for file_info in zip(*[fl.items() for fl in files]):
good = True
# Make sure dicts are compatible and make combined dict
final_dict = {}
final_files = tuple([_os.path.abspath(fl[0]) for fl in file_info])
bad_dcts = []
for dct in [f[1] for f in file_info]:
if not dct:
continue
final_dict.update(dct)
bad_dcts.append(dct)
for var, val in dct.items():
for fl in file_info:
if var in fl[1] and val != fl[1][var]:
good = False
if not good:
bad.append((final_files, bad_dcts))
break
for extra_var in extra_vars:
try:
evari = extra_var.split(':')
if len(evari) == 2:
final_dict[evari[1]] = evari[2]
continue
var, orig_var, parse_str, sub_str = evari
except ValueError:
_logme.log(
'{} is malformatted should be: '.format(extra_var) +
'either new_var:orig_var:regex:sub '
'or variable:value', 'critical'
)
raise
if orig_var not in final_dict:
raise ValueError(
'Extra variable {0} sets {1} as '.format(var, orig_var) +
'the original variable, but it is not in the dict for '
'{0}'.format(final_files)
)
final_dict[var] = _re.sub(parse_str, sub_str, final_dict[orig_var])
results.append((final_files, final_dict))
# Take care of bad items
if empty:
_logme.log('The following file combinations had mismatched variables, '
'cannot continue:\n'
'\n'.join(
['{0} dicts: {1}'.format(i, j) for i, j in bad]
), 'critical')
fail = True
if fail:
raise ValueError('File parsing failure')
return results
[docs]def listify(iterable):
"""Try to force any iterable into a list sensibly."""
if isinstance(iterable, list):
return iterable
if isinstance(iterable, (_str, _txt, _int, float)):
return [iterable]
if not iterable:
return []
try:
iterable = list(iterable)
except TypeError:
iterable = [iterable]
return iterable
[docs]def merge_lists(lists):
"""Turn a list of lists into a single list."""
outlist = []
for lst in listify(lists):
outlist += lst
return outlist
[docs]def write_iterable(iterable, outfile):
"""Write all elements of iterable to outfile."""
with open_zipped(outfile, 'w') as fout:
fout.write('\n'.join(iterable))
[docs]def indent(string, prefix=' '):
"""Replicate python3's textwrap.indent for python2.
Parameters
----------
string : str
Any string.
prefix : str
What to indent with.
Returns
-------
str
Indented string
"""
out = ''
for i in string.split('\n'):
out += '{}{}\n'.format(prefix, i)
return out
[docs]def is_exc(x):
"""Check if x is the output of sys.exc_info().
Returns
-------
bool
True if matched the output of sys.exc_info().
"""
return bool(isinstance(x, tuple)
and len(x) == 3
and issubclass(BaseException, x[0]))
###############################################################################
# File Management #
###############################################################################
[docs]def open_zipped(infile, mode='r'):
"""Open a regular, gzipped, or bz2 file.
If infile is a file handle or text device, it is returned without
changes.
Returns
-------
text mode file handle.
"""
mode = mode[0] + 't'
if hasattr(infile, 'write'):
return infile
if isinstance(infile, _str):
if infile.endswith('.gz'):
return gzip.open(infile, mode)
if infile.endswith('.bz2'):
if hasattr(bz2, 'open'):
return bz2.open(infile, mode)
else:
return bz2.BZ2File(infile, mode)
return open(infile, mode)
[docs]def exp_file(infile):
"""Return an expanded path to a file."""
return _os.path.expandvars(
_re.sub(
'~', '$HOME', infile
)
)
[docs]def cmd_or_file(string):
"""If string is a file, return the contents, else return the string.
Parameters
----------
string : str
Path to a file or any other string
Returns
-------
script : str
Either the contents of the file if string is a file or just the
contents of string.
"""
if _os.path.isfile(string):
with open_zipped(string) as fin:
command = fin.read().strip()
else:
command = string.strip()
return command
[docs]def block_read(files, size=65536):
"""Iterate through a file by blocks."""
while True:
b = files.read(size)
if not b:
break
yield b
[docs]def count_lines(infile, force_blocks=False):
"""Return the line count of a file as quickly as possible.
Uses `wc` if avaialable, otherwise does a rapid read.
"""
if which('wc') and not force_blocks:
_logme.log('Using wc', 'debug')
if infile.endswith('.gz'):
cat = 'zcat'
elif infile.endswith('.bz2'):
cat = 'bzcat'
else:
cat = 'cat'
command = "{cat} {infile} | wc -l | awk '{{print $1}}'".format(
cat=cat, infile=infile
)
return int(cmd(command)[1])
else:
_logme.log('Using block read', 'debug')
with open_zipped(infile) as fin:
return sum(bl.count("\n") for bl in block_read(fin))
[docs]def split_file(infile, parts, outpath='', keep_header=False):
"""Split a file in parts and return a list of paths.
.. note:: Linux specific (uses wc).
If has_header is True, the top line is stripped off the infile prior to
splitting and assumed to be the header.
Parameters
----------
outpath : str, optional
The directory to save the split files.
keep_header : bool, optional
Add the header line to the top of every file.
Returns
-------
list
Paths to split files.
"""
# Determine how many reads will be in each split sam file.
_logme.log('Getting line count', 'debug')
num_lines = int(count_lines(infile)/int(parts)) + 1
# Subset the file into X number of jobs, maintain extension
cnt = 0
currjob = 1
suffix = '.split_' + str(currjob).zfill(4) + '.' + infile.split('.')[-1]
file_name = _os.path.basename(infile)
run_file = _os.path.join(outpath, file_name + suffix)
outfiles = [run_file]
# Actually split the file
_logme.log('Splitting file', 'debug')
with open_zipped(infile) as fin:
header = fin.readline() if keep_header else ''
sfile = open_zipped(run_file, 'w')
sfile.write(header)
for line in fin:
cnt += 1
if cnt < num_lines:
sfile.write(line)
elif cnt == num_lines:
sfile.write(line)
sfile.close()
currjob += 1
suffix = '.split_' + str(currjob).zfill(4) + '.' + \
infile.split('.')[-1]
run_file = _os.path.join(outpath, file_name + suffix)
sfile = open_zipped(run_file, 'w')
outfiles.append(run_file)
sfile.write(header)
cnt = 0
sfile.close()
_logme.log('Split files: {}'.format(outfiles), 'debug')
return tuple(outfiles)
[docs]def is_exe(fpath):
"""Return True is fpath is executable."""
return _os.path.isfile(fpath) and _os.access(fpath, _os.X_OK)
[docs]def file_type(infile):
"""Return file type after stripping gz or bz2."""
name_parts = infile.split('.')
if name_parts[-1] == 'gz' or name_parts[-1] == 'bz2':
name_parts.pop()
return name_parts[-1]
[docs]def is_file_type(infile, types):
"""Return True if infile is one of types.
Parameters
----------
infile : str
Any file name
types : list
String or list/tuple of strings (e.g `['bed', 'gtf']`)
Returns
-------
is_file_type : bool
"""
if hasattr(infile, 'write'):
infile = infile.name
types = listify(types)
for typ in types:
if file_type(infile) == typ:
return True
return False
###############################################################################
# Running Commands #
###############################################################################
[docs]def cmd(command, args=None, stdout=None, stderr=None, tries=1):
"""Run command and return status, output, stderr.
Parameters
----------
command : str
Path to executable.
args : tuple, optional
Tuple of arguments.
stdout : str, optional
File or open file like object to write STDOUT to.
stderr : str, optional
File or open file like object to write STDERR to.
tries : int, optional
Number of times to try to execute. 1+
Returns
-------
exit_code : int
STDOUT : str
STDERR : str
"""
tries = int(tries)
assert tries > 0
count = 1
if isinstance(command, (list, tuple)):
if args:
raise ValueError('Cannot submit list/tuple command as ' +
'well as args argument')
command = ' '.join(command)
assert isinstance(command, _str)
if args:
if isinstance(args, (list, tuple)):
args = ' '.join(args)
args = command + args
else:
args = command
_logme.log('Running {} as {}'.format(command, args), 'verbose')
while True:
try:
pp = Popen(args, shell=True, universal_newlines=True,
stdout=PIPE, stderr=PIPE)
except FileNotFoundError:
_logme.log('{} does not exist'.format(command), 'critical')
raise
out, err = pp.communicate()
code = pp.returncode
if code == 0 or count == tries:
break
_logme.log('Command {} failed with code {}, retrying.'
.format(command, code), 'warn')
sleep(1)
count += 1
_logme.log('{} completed with code {}'.format(command, code), 'debug')
if stdout:
with open_zipped(stdout, 'w') as fout:
fout.write(out)
if stderr:
with open_zipped(stderr, 'w') as fout:
fout.write(err)
return code, out.rstrip(), err.rstrip()
[docs]def export_run(function, args, kwargs):
"""Execute a function after first exporting all imports."""
kwargs['imports'] = export_imports(function, kwargs)
return function(*args, **kwargs)
[docs]def which(program):
"""Replicate the UNIX which command.
Taken verbatim from:
stackoverflow.com/questions/377017/test-if-executable-exists-in-python
Parameters
----------
program : str
Name of executable to test.
Returns
-------
str or None
Path to the program or None on failure.
"""
fpath, program = _os.path.split(program)
if fpath:
if is_exe(program):
return _os.path.abspath(program)
else:
for path in _os.environ["PATH"].split(_os.pathsep):
path = path.strip('"')
exe_file = _os.path.join(path, program)
if is_exe(exe_file):
return _os.path.abspath(exe_file)
return None
[docs]def check_pid(pid):
"""Check For the existence of a unix pid."""
try:
_os.kill(pid, 0)
except OSError:
return False
else:
return True
###############################################################################
# Option and Argument Management #
###############################################################################
[docs]def replace_argument(args, find_string, replace_string, error=True):
"""Replace find_string with replace string in a tuple or dict.
If dict, the values are replaced, not the keys.
Note: args can also be a list, in which case the first item is assumed
to be a tuple, and the second a dictionary
Parameters
----------
args : list/tuple/dict
Tuple or dict of args
find_string : str
A string to search for
replace_string : str
A string to replace with
error : bool
Raise ValueError if replacement fails
Returns
-------
The same object as was passed, with alterations made.
"""
double = False
if isinstance(args, list):
args, kwargs = args
double = True
elif isinstance(args, tuple):
kwargs = None
elif isinstance(args, dict):
kwargs = args.copy()
args = None
else:
raise ValueError('args must be list/tuple/dict, is {}\nval: {}'
.format(type(args), args))
if not args and not kwargs:
msg = 'No arguments or keyword arguments found'
if error:
raise ValueError(msg)
else:
_logme.log(msg, 'warn')
if double:
return None, None
else:
return None
found = False
newargs = tuple()
if args:
for arg in listify(args):
if isinstance(arg, _str) and find_string in arg:
arg = arg.format(**{find_string.strip('{}'): replace_string})
found = True
newargs += (arg,)
newkwds = {}
if kwargs:
for arg, value in kwargs.items():
if isinstance(value, _str) and find_string in value:
value = replace_string
found = True
newkwds[arg] = value
if found is not True:
msg = 'Could not find {}'.format(find_string)
if error:
raise ValueError(msg)
else:
_logme.log(msg, 'warn')
if double:
return None, None
else:
return None
if double:
return [newargs, newkwds]
else:
if newargs:
return newargs
else:
return newkwds
[docs]def opt_split(opt, split_on):
"""Split options by chars in split_on, merge all into single list.
Parameters
----------
opt : list
A list of strings, can be a single string.
split_on : list
A list of characters to use to split the options.
Returns
-------
list
A single merged list of split options, uniqueness guaranteed, order
not.
"""
opt = listify(opt)
split_on = listify(split_on)
final_list = []
for o in opt:
final_list += _re.split('[{}]'.format(''.join(split_on)), o)
return list(set(final_list)) # Return unique options only, order lost.
###############################################################################
# User Input #
###############################################################################
[docs]def get_yesno(message, default=None):
"""Get yes/no answer from user.
Parameters
----------
message : str
A message to print, an additional space will be added.
default : {'y', 'n'}, optional
One of `{'y', 'n'}`, the default if the user gives no answer. If None,
answer forced.
Returns
-------
bool
True on yes, False on no
"""
if default:
if default.lower().startswith('y'):
tailstr = '[Y/n] '
elif default.lower().startswith('n'):
tailstr = '[y/N] '
else:
raise ValueError('Invalid default')
else:
tailstr = '[y/n] '
message = message + tailstr if message.endswith(' ') \
else message + ' ' + tailstr
ans = get_input(message, 'yesno', default)
if ans.lower().startswith('y'):
return True
elif ans.lower().startswith('n'):
return False
else:
raise ValueError('Invalid response: {}'.format(ans))
###############################################################################
# Imports #
###############################################################################
[docs]def syspath_fmt(syspaths):
"""Take a list of paths and return a sys of sys.path.append strings."""
outlist = []
for pth in listify(syspaths):
if 'sys.path' in pth:
outlist.append(pth)
continue
if _os.path.exists(pth):
outlist.append("sys.path.append('{}')".format(
_os.path.abspath(pth)
))
else:
raise OSError('Paths must exist, {} does not.'
.format(pth))
return '\n'.join(outlist)
PROT_IMPT = """\
try:
{}
except ImportError:
pass
"""
[docs]def normalize_imports(imports, prot=True):
"""Take a heterogenous list of imports and normalize it.
Parameters
----------
imports : list
A list of strings, formatted differently.
prot : bool
Protect imports with try..except blocks
Returns
-------
list
A list of strings that can be used for imports
"""
out_impts = []
prot_impts = []
path_impts = []
imports = listify(imports)
if not imports:
return []
for imp in imports:
if not isinstance(imp, _str):
raise ValueError('All imports must be strings')
if imp.startswith('try:'):
prot_impts.append(imp.rstrip())
elif imp.startswith('import') or imp.startswith('from'):
out_impts.append(imp.rstrip())
elif imp.startswith('sys.path.append')\
or imp.startswith('sys.path.insert'):
path_impts.append(imp.rstrip())
else:
if imp.startswith('@'):
continue
out_impts.append('import {}'.format(imp))
if prot:
for imp in out_impts:
prot_impts.append(PROT_IMPT.format(imp))
out = prot_impts
else:
out = out_impts + prot_impts
# Remove duplicates
out = list(set(out))
# Add PATHS
if path_impts:
out = list(set(path_impts)) + out
return out
[docs]def get_function_path(function):
"""Return path to module defining a function if it exists."""
mod = _inspect.getmodule(function)
if mod and mod != '__main__':
return _os.path.dirname(_inspect.getabsfile(function))
else:
return None
[docs]def update_syspaths(function, kwds=None):
"""Add function path to 'syspaths' in kwds."""
if kwds:
syspaths = listify(kwds['syspaths']) if 'syspaths' in kwds else []
else:
syspaths = []
return [get_function_path(function)] + syspaths
[docs]def import_function(function, mode='string'):
"""Return an import string for the function.
Attempts to resolve the parent module also, if the parent module is a file,
ie it isn't __main__, the import string will include a call to
sys.path.append to ensure the module is importable.
If this function isn't defined by a module, returns an empty string.
Parameters
----------
mode : {'string', 'list'}, optional
string/list, return as a unified string or a list.
"""
if not callable(function):
raise ValueError('Function must be callable, {} is not'
.format(function))
if mode not in ['string', 'list']:
raise ValueError("Invalid mode {}, must be 'list' or 'string'"
.format(mode))
if _inspect.ismethod(function):
name = (dict(_inspect.getmembers(function.__self__))['__class__']
.__name__)
else:
name = function.__name__
# Attempt to resolve defining file
parent = _inspect.getmodule(function)
imports = []
if parent and parent.__name__ != '__main__':
path = _os.path.dirname(parent.__file__)
module = parent.__name__
# If module is the child of a package, change the directory up to the
# parent
if '.' in module:
path = _os.path.abspath(
_os.path.join(
path, *['..' for i in range(module.count('.'))]
)
)
imports.append("sys.path.append('{}')".format(path))
imports.append('import {}'.format(module))
imports.append('from {} import *'.format(module))
imports.append('from {} import {}'.format(module, name))
return imports if mode == 'list' else '\n'.join(imports)
[docs]def get_imports(function, mode='string'):
"""Build a list of potentially useful imports from a function handle.
Gets:
- All modules from globals()
- All modules from the function's globals()
- All functions from the function's globals()
Modes:
string:
Return a list of strings formatted as unprotected import calls
prot:
Similar to string, but with try..except blocks
list:
Return two lists: (import name, module name) for modules and (import
name, function name, module name) for functions
Parameters
----------
function : callable
A function handle
mode : str
A string corresponding to one of the above modes
Returns
-------
str or list
"""
if mode not in ['string', 'prot', 'list']:
raise ValueError('mode must be one of string/prot/list')
rootmod = _inspect.getmodule(function)
imports = []
func_imports = []
# For interactive sessions
members = dict(_inspect.getmembers(function))
locations = [members]
if '__globals__' in members:
locations.append(members['__globals__'])
for location in locations:
for name, item in location.items():
if name.startswith('__'):
continue
# Modules
if _inspect.ismodule(item):
imports.append((name, item.__name__))
# Functions
elif callable(item):
try:
func_imports.append((name, item.__name__, item.__module__))
except AttributeError:
pass
# Import all modules in the root module
imports += [(k,v.__name__)
for k,v in _inspect.getmembers(rootmod, _inspect.ismodule)
if not k.startswith('__')]
# Make unique
imports = sorted(list(set(imports)), key=_sort_imports)
func_imports = sorted(list(set(func_imports)), key=_sort_imports)
_logme.log('Imports: {}'.format(imports), 'debug')
_logme.log('Func imports: {}'.format(func_imports), 'debug')
# Create a sane set of imports
ignore_list = ['os', 'sys', 'dill', 'pickle', '__main__']
filtered_imports = []
filtered_func_imports = []
for iname, name in imports:
if iname in ignore_list:
continue
if name.startswith('@') or iname.startswith('@'):
continue
filtered_imports.append((iname, name))
for iname, name, mod in func_imports:
if iname in ignore_list:
continue
if name.startswith('@') or iname.startswith('@'):
continue
filtered_func_imports.append((iname, name, mod))
if mode == 'list':
return filtered_imports, filtered_func_imports
import_strings = []
for iname, name in filtered_imports:
names = name.split('.')
if names[0] == '__main__':
continue
if iname != name:
if len(names) > 1:
if '.'.join(names[1:]) != iname:
import_strings.append(
'from {} import {} as {}'
.format('.'.join(names[:-1]), names[-1], iname)
)
else:
import_strings.append(
'from {} import {}'
.format(names[0], '.'.join(names[1:]))
)
else:
import_strings.append(
('import {} as {}').format(name, iname)
)
else:
import_strings.append('import {}'.format(name))
# Function imports
for iname, name, mod in filtered_func_imports:
if mod == '__main__':
continue
if iname == name:
import_strings.append('from {} import {}'.format(mod, name))
else:
import_strings.append('from {} import {} as {}'
.format(mod, name, iname)
)
if mode == 'string':
return import_strings
elif mode == 'prot':
return normalize_imports(import_strings, prot=True)
else:
raise ValueError('Mode changed unexpectedly')
[docs]def export_globals(function):
"""Add a function's globals to the current globals."""
rootmod = _inspect.getmodule(function)
globals()[rootmod.__name__] = rootmod
for k, v in _inspect.getmembers(rootmod, _inspect.ismodule):
if not k.startswith('__'):
globals()[k] = v
[docs]def get_all_imports(function, kwds, prot=False):
"""Get all imports from a function and from kwds.
Parameters
----------
function : callable
A function handle
kwds : dict
A dictionary of keyword arguments
prot : bool
Wrap all import in try statement
Returns
-------
list
Imports
"""
imports = listify(kwds['imports'] if 'imports' in kwds else None)
imports = normalize_imports(imports, prot=False)
imports += get_imports(function, mode='string')
return normalize_imports(imports, prot=prot)
[docs]def export_imports(function, kwds):
"""Get imports from a function and from kwds.
Also sets globals and adds path to module to sys path.
Parameters
----------
function : callable
A function handle
kwds : dict
A dictionary of keyword arguments
Returns
-------
list
imports + sys.path.append for module path
"""
export_globals(function)
return import_function(function, 'list') + get_all_imports(function, kwds)
def _sort_imports(x):
"""Sort a list of tuples and strings, for use with sorted."""
if isinstance(x, tuple):
if x[1] == '__main__':
return 0
return x[1]
return x