"""
Get API information encoded in C files.
See ``find_function`` for how functions should be formatted, and
``read_order`` for how the order of the functions should be
specified.
"""
from __future__ import division, absolute_import, print_function
import sys, os, re
import hashlib
import textwrap
from os.path import join
__docformat__ = 'restructuredtext'
# The files under src/ that are scanned for API functions
API_FILES = [join('multiarray', 'alloc.c'),
join('multiarray', 'array_assign_array.c'),
join('multiarray', 'array_assign_scalar.c'),
join('multiarray', 'arrayobject.c'),
join('multiarray', 'arraytypes.c.src'),
join('multiarray', 'buffer.c'),
join('multiarray', 'calculation.c'),
join('multiarray', 'conversion_utils.c'),
join('multiarray', 'convert.c'),
join('multiarray', 'convert_datatype.c'),
join('multiarray', 'ctors.c'),
join('multiarray', 'datetime.c'),
join('multiarray', 'datetime_busday.c'),
join('multiarray', 'datetime_busdaycal.c'),
join('multiarray', 'datetime_strings.c'),
join('multiarray', 'descriptor.c'),
join('multiarray', 'einsum.c.src'),
join('multiarray', 'flagsobject.c'),
join('multiarray', 'getset.c'),
join('multiarray', 'item_selection.c'),
join('multiarray', 'iterators.c'),
join('multiarray', 'mapping.c'),
join('multiarray', 'methods.c'),
join('multiarray', 'multiarraymodule.c'),
join('multiarray', 'nditer_api.c'),
join('multiarray', 'nditer_constr.c'),
join('multiarray', 'nditer_pywrap.c'),
join('multiarray', 'nditer_templ.c.src'),
join('multiarray', 'number.c'),
join('multiarray', 'refcount.c'),
join('multiarray', 'scalartypes.c.src'),
join('multiarray', 'scalarapi.c'),
join('multiarray', 'sequence.c'),
join('multiarray', 'shape.c'),
join('multiarray', 'strfuncs.c'),
join('multiarray', 'usertypes.c'),
join('umath', 'loops.c.src'),
join('umath', 'ufunc_object.c'),
join('umath', 'ufunc_type_resolution.c'),
join('umath', 'reduction.c'),
]
THIS_DIR = os.path.dirname(__file__)
API_FILES = [os.path.join(THIS_DIR, '..', 'src', a) for a in API_FILES]
def file_in_this_dir(filename):
return os.path.join(THIS_DIR, filename)
def remove_whitespace(s):
return ''.join(s.split())
def _repl(str):
return str.replace('Bool', 'npy_bool')
class StealRef(object):
def __init__(self, arg):
self.arg = arg # counting from 1
def __str__(self):
try:
return ' '.join('NPY_STEALS_REF_TO_ARG(%d)' % x for x in self.arg)
except TypeError:
return 'NPY_STEALS_REF_TO_ARG(%d)' % self.arg
class NonNull(object):
def __init__(self, arg):
self.arg = arg # counting from 1
def __str__(self):
try:
return ' '.join('NPY_GCC_NONNULL(%d)' % x for x in self.arg)
except TypeError:
return 'NPY_GCC_NONNULL(%d)' % self.arg
class Function(object):
def __init__(self, name, return_type, args, doc=''):
self.name = name
self.return_type = _repl(return_type)
self.args = args
self.doc = doc
def _format_arg(self, typename, name):
if typename.endswith('*'):
return typename + name
else:
return typename + ' ' + name
def __str__(self):
argstr = ', '.join([self._format_arg(*a) for a in self.args])
if self.doc:
doccomment = '/* %s */\n' % self.doc
else:
doccomment = ''
return '%s%s %s(%s)' % (doccomment, self.return_type, self.name, argstr)
def to_ReST(self):
lines = ['::', '', ' ' + self.return_type]
argstr = ',\000'.join([self._format_arg(*a) for a in self.args])
name = ' %s' % (self.name,)
s = textwrap.wrap('(%s)' % (argstr,), width=72,
initial_indent=name,
subsequent_indent=' ' * (len(name)+1),
break_long_words=False)
for l in s:
lines.append(l.replace('\000', ' ').rstrip())
lines.append('')
if self.doc:
lines.append(textwrap.dedent(self.doc))
return '\n'.join(lines)
def api_hash(self):
m = hashlib.md5()
m.update(remove_whitespace(self.return_type))
m.update('\000')
m.update(self.name)
m.update('\000')
for typename, name in self.args:
m.update(remove_whitespace(typename))
m.update('\000')
return m.hexdigest()[:8]
class ParseError(Exception):
def __init__(self, filename, lineno, msg):
self.filename = filename
self.lineno = lineno
self.msg = msg
def __str__(self):
return '%s:%s:%s' % (self.filename, self.lineno, self.msg)
def skip_brackets(s, lbrac, rbrac):
count = 0
for i, c in enumerate(s):
if c == lbrac:
count += 1
elif c == rbrac:
count -= 1
if count == 0:
return i
raise ValueError("no match '%s' for '%s' (%r)" % (lbrac, rbrac, s))
def split_arguments(argstr):
arguments = []
bracket_counts = {'(': 0, '[': 0}
current_argument = []
state = 0
i = 0
def finish_arg():
if current_argument:
argstr = ''.join(current_argument).strip()
m = re.match(r'(.*(\s+|[*]))(\w+)$', argstr)
if m:
typename = m.group(1).strip()
name = m.group(3)
else:
typename = argstr
name = ''
arguments.append((typename, name))
del current_argument[:]
while i < len(argstr):
c = argstr[i]
if c == ',':
finish_arg()
elif c == '(':
p = skip_brackets(argstr[i:], '(', ')')
current_argument += argstr[i:i+p]
i += p-1
else:
current_argument += c
i += 1
finish_arg()
return arguments
def find_functions(filename, tag='API'):
"""
Scan the file, looking for tagged functions.
Assuming ``tag=='API'``, a tagged function looks like::
/*API*/
static returntype*
function_name(argtype1 arg1, argtype2 arg2)
{
}
where the return type must be on a separate line, the function
name must start the line, and the opening ``{`` must start the line.
An optional documentation comment in ReST format may follow the tag,
as in::
/*API
This function does foo...
*/
"""
fo = open(filename, 'r')
functions = []
return_type = None
function_name = None
function_args = []
doclist = []
SCANNING, STATE_DOC, STATE_RETTYPE, STATE_NAME, STATE_ARGS = list(range(5))
state = SCANNING
tagcomment = '/*' + tag
for lineno, line in enumerate(fo):
try:
line = line.strip()
if state == SCANNING:
if line.startswith(tagcomment):
if line.endswith('*/'):
state = STATE_RETTYPE
else:
state = STATE_DOC
elif state == STATE_DOC:
if line.startswith('*/'):
state = STATE_RETTYPE
else:
line = line.lstrip(' *')
doclist.append(line)
elif state == STATE_RETTYPE:
# first line of declaration with return type
m = re.match(r'NPY_NO_EXPORT\s+(.*)$', line)
if m:
line = m.group(1)
return_type = line
state = STATE_NAME
elif state == STATE_NAME:
# second line, with function name
m = re.match(r'(\w+)\s*\(', line)
if m:
function_name = m.group(1)
else:
raise ParseError(filename, lineno+1,
'could not find function name')
function_args.append(line[m.end():])
state = STATE_ARGS
elif state == STATE_ARGS:
if line.startswith('{'):
# finished
fargs_str = ' '.join(function_args).rstrip(' )')
fargs = split_arguments(fargs_str)
f = Function(function_name, return_type, fargs,
'\n'.join(doclist))
functions.append(f)
return_type = None
function_name = None
function_args = []
doclist = []
state = SCANNING
else:
function_args.append(line)
except Exception:
print(filename, lineno + 1)
raise
fo.close()
return functions
def should_rebuild(targets, source_files):
from distutils.dep_util import newer_group
for t in targets:
if not os.path.exists(t):
return True
sources = API_FILES + list(source_files) + [__file__]
if newer_group(sources, targets[0], missing='newer'):
return True
return False
def write_file(filename, data):
"""
Write data to filename
Only write changed data to avoid updating timestamps unnecessarily
"""
if os.path.exists(filename):
with open(filename) as f:
if data == f.read():
return
with open(filename, 'w') as fid:
fid.write(data)
# Those *Api classes instances know how to output strings for the generated code
class TypeApi(object):
def __init__(self, name, index, ptr_cast, api_name):
self.index = index
self.name = name
self.ptr_cast = ptr_cast
self.api_name = api_name
def define_from_array_api_string(self):
return "#define %s (*(%s *)%s[%d])" % (self.name,
self.ptr_cast,
self.api_name,
self.index)
def array_api_define(self):
return " (void *) &%s" % self.name
def internal_define(self):
astr = """\
extern NPY_NO_EXPORT PyTypeObject %(type)s;
""" % {'type': self.name}
return astr
class GlobalVarApi(object):
def __init__(self, name, index, type, api_name):
self.name = name
self.index = index
self.type = type
self.api_name = api_name
def define_from_array_api_string(self):
return "#define %s (*(%s *)%s[%d])" % (self.name,
self.type,
self.api_name,
self.index)
def array_api_define(self):
return " (%s *) &%s" % (self.type, self.name)
def internal_define(self):
astr = """\
extern NPY_NO_EXPORT %(type)s %(name)s;
""" % {'type': self.type, 'name': self.name}
return astr
# Dummy to be able to consistently use *Api instances for all items in the
# array api
class BoolValuesApi(object):
def __init__(self, name, index, api_name):
self.name = name
self.index = index
self.type = 'PyBoolScalarObject'
self.api_name = api_name
def define_from_array_api_string(self):
return "#define %s ((%s *)%s[%d])" % (self.name,
self.type,
self.api_name,
self.index)
def array_api_define(self):
return " (void *) &%s" % self.name
def internal_define(self):
astr = """\
extern NPY_NO_EXPORT PyBoolScalarObject _PyArrayScalar_BoolValues[2];
"""
return astr
class FunctionApi(object):
def __init__(self, name, index, annotations, return_type, args, api_name):
self.name = name
self.index = index
self.annotations = annotations
self.return_type = return_type
self.args = args
self.api_name = api_name
def _argtypes_string(self):
if not self.args:
return 'void'
argstr = ', '.join([_repl(a[0]) for a in self.args])
return argstr
def define_from_array_api_string(self):
define = """\
#define %s \\\n (*(%s (*)(%s)) \\
%s[%d])""" % (self.name,
self.return_type,
self._argtypes_string(),
self.api_name,
self.index)
return define
def array_api_define(self):
return " (void *) %s" % self.name
def internal_define(self):
annstr = []
for a in self.annotations:
annstr.append(str(a))
annstr = ' '.join(annstr)
astr = """\
NPY_NO_EXPORT %s %s %s \\\n (%s);""" % (annstr, self.return_type,
self.name,
self._argtypes_string())
return astr
def order_dict(d):
"""Order dict by its values."""
o = list(d.items())
def _key(x):
return x[1] + (x[0],)
return sorted(o, key=_key)
def merge_api_dicts(dicts):
ret = {}
for d in dicts:
for k, v in d.items():
ret[k] = v
return ret
def check_api_dict(d):
"""Check that an api dict is valid (does not use the same index twice)."""
# remove the extra value fields that aren't the index
index_d = {k: v[0] for k, v in d.items()}
# We have if a same index is used twice: we 'revert' the dict so that index
# become keys. If the length is different, it means one index has been used
# at least twice
revert_dict = {v: k for k, v in index_d.items()}
if not len(revert_dict) == len(index_d):
# We compute a dict index -> list of associated items
doubled = {}
for name, index in index_d.items():
try:
doubled[index].append(name)
except KeyError:
doubled[index] = [name]
fmt = "Same index has been used twice in api definition: {}"
val = ''.join(
'\n\tindex {} -> {}'.format(index, names)
for index, names in doubled.items() if len(names) != 1
)
raise ValueError(fmt.format(val))
# No 'hole' in the indexes may be allowed, and it must starts at 0
indexes = set(index_d.values())
expected = set(range(len(indexes)))
if indexes != expected:
diff = expected.symmetric_difference(indexes)
msg = "There are some holes in the API indexing: " \
"(symmetric diff is %s)" % diff
raise ValueError(msg)
def get_api_functions(tagname, api_dict):
"""Parse source files to get functions tagged by the given tag."""
functions = []
for f in API_FILES:
functions.extend(find_functions(f, tagname))
dfunctions = []
for func in functions:
o = api_dict[func.name][0]
dfunctions.append( (o, func) )
dfunctions.sort()
return [a[1] for a in dfunctions]
def fullapi_hash(api_dicts):
"""Given a list of api dicts defining the numpy C API, compute a checksum
of the list of items in the API (as a string)."""
a = []
for d in api_dicts:
for name, data in order_dict(d):
a.extend(name)
a.extend(','.join(map(str, data)))
return hashlib.md5(''.join(a).encode('ascii')).hexdigest()
# To parse strings like 'hex = checksum' where hex is e.g. 0x1234567F and
# checksum a 128 bits md5 checksum (hex format as well)
VERRE = re.compile(r'(^0x[\da-f]{8})\s*=\s*([\da-f]{32})')
def get_versions_hash():
d = []
file = os.path.join(os.path.dirname(__file__), 'cversions.txt')
fid = open(file, 'r')
try:
for line in fid:
m = VERRE.match(line)
if m:
d.append((int(m.group(1), 16), m.group(2)))
finally:
fid.close()
return dict(d)
def main():
tagname = sys.argv[1]
order_file = sys.argv[2]
functions = get_api_functions(tagname, order_file)
m = hashlib.md5(tagname)
for func in functions:
print(func)
ah = func.api_hash()
m.update(ah)
print(hex(int(ah, 16)))
print(hex(int(m.hexdigest()[:8], 16)))
if __name__ == '__main__':
main()