diff --git a/tools/license_header.py b/tools/license_header.py index 53d450c4e0da..199d56c7ee35 100755 --- a/tools/license_header.py +++ b/tools/license_header.py @@ -37,6 +37,7 @@ from itertools import chain import logging import sys +import subprocess # the default apache license _LICENSE = """Licensed to the Apache Software Foundation (ASF) under one @@ -101,9 +102,23 @@ # Previous license header, which will be removed _OLD_LICENSE = re.compile('.*Copyright.*by Contributors') -def _has_license(lines): + +def get_mxnet_root(): + curpath = os.path.abspath(os.path.dirname(__file__)) + def is_mxnet_root(path: str) -> bool: + return os.path.exists(os.path.join(path, ".mxnet_root")) + while not is_mxnet_root(curpath): + parent = os.path.abspath(os.path.join(curpath, os.pardir)) + if parent == curpath: + raise RuntimeError("Got to the root and couldn't find a parent folder with .mxnet_root") + curpath = parent + return curpath + + +def _lines_have_license(lines): return any([any([p in l for p in _LICENSE_PATTERNS]) for l in lines]) + def _get_license(comment_mark): if comment_mark == '*': body = '/*\n' @@ -122,65 +137,88 @@ def _get_license(comment_mark): body += '\n' return body -def _valid_file(fname, verbose=False): + +def should_have_license(fname): if any([l in fname for l in _WHITE_LIST]): - if verbose: - logging.info('skip ' + fname + ', it matches the white list') + logging.debug('skip ' + fname + ', it matches the white list') return False _, ext = os.path.splitext(fname) if ext not in _LANGS: - if verbose: - logging.info('skip ' + fname + ', unknown file extension') + logging.debug('skip ' + fname + ', unknown file extension') return False return True -def process_file(fname, action, verbose=True): - if not _valid_file(fname, verbose): + +def file_has_license(fname): + if not should_have_license(fname): return True try: with open(fname, 'r', encoding="utf-8") as f: lines = f.readlines() - if not lines: + if not lines or _lines_have_license(lines): return True - if _has_license(lines): - return True - elif action == 'check': + else: + logging.error("File %s doesn't have a license", fname) return False - _, ext = os.path.splitext(fname) - with open(fname, 'w', encoding="utf-8") as f: - # shebang line - if lines[0].startswith('#!'): - f.write(lines[0].rstrip()+'\n\n') - del lines[0] - f.write(_get_license(_LANGS[ext])) - for l in lines: - f.write(l.rstrip()+'\n') - logging.info('added license header to ' + fname) except UnicodeError: return True return True -def process_folder(root, action): - excepts = [] - for root, _, files in os.walk(root): - for f in files: - fname = os.path.normpath(os.path.join(root, f)) - if not process_file(fname, action): - excepts.append(fname) - if action == 'check' and excepts: - logging.warning('The following files do not contain a valid license, '+ - 'you can use `tools/license_header.py add [file]` to add'+ - 'them automatically: ') - for x in excepts: - logging.warning(x) - return False - return True -if __name__ == '__main__': - logging.getLogger().setLevel(logging.INFO) - logging.basicConfig(format='%(asctime)-15s %(message)s') +def file_add_license(fname): + if not should_have_license(fname): + return + with open(fname, 'r', encoding="utf-8") as f: + lines = f.readlines() + if _lines_have_license(lines): + return + _, ext = os.path.splitext(fname) + with open(fname, 'w', encoding="utf-8") as f: + # shebang line + if lines[0].startswith('#!'): + f.write(lines[0].rstrip()+'\n\n') + del lines[0] + f.write(_get_license(_LANGS[ext])) + for l in lines: + f.write(l.rstrip()+'\n') + logging.info('added license header to ' + fname) + return + + +def under_git(): + return subprocess.run(['git', 'rev-parse', 'HEAD'], + stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0 + + +def git_files(): + return list(map(os.fsdecode, + subprocess.check_output('git ls-tree -r HEAD --name-only -z'.split()).split(b'\0'))) + + +def file_generator(path: str): + for (dirpath, dirnames, files) in os.walk(path): + for file in files: + yield os.path.abspath(os.path.join(dirpath, file)) + + +def foreach(fn, iterable): + for x in iterable: + fn(x) + + +def script_name(): + """:returns: script name with leading paths removed""" + return os.path.split(sys.argv[0])[1] + + +def main(): + logging.basicConfig( + format='{}: %(levelname)s %(message)s'.format(script_name()), + level=os.environ.get("LOGLEVEL", "INFO")) + parser = argparse.ArgumentParser( description='Add or check source license header') + parser.add_argument( 'action', nargs=1, type=str, choices=['add', 'check'], default='add', @@ -191,19 +229,26 @@ def process_folder(root, action): help='Files to add license header to') args = parser.parse_args() - files = list(chain(*args.file)) action = args.action[0] - has_license = True - if len(files) > 0: - for file in files: - has_license = process_file(file, action) - if action == 'check' and not has_license: - logging.warn("{} doesn't have a license".format(file)) - has_license = False - else: - has_license = process_folder(os.path.join(os.path.dirname(__file__), '..'), action) - if not has_license: - sys.exit(1) + files = list(chain(*args.file)) + if not files and action =='check': + if under_git(): + logging.info("Git detected: Using files under version control") + files = git_files() + else: + logging.info("Using files under mxnet sources root") + files = file_generator(get_mxnet_root()) + + if action == 'check': + if not all(map(file_has_license, files)): + return 1 + else: + logging.info("All known and whitelisted files have license") + return 0 else: - sys.exit(0) + assert action == 'add' + foreach(file_add_license, files) + return 0 +if __name__ == '__main__': + sys.exit(main())