Mercurial > ~astiob > upreckon > hgweb
changeset 21:ec6f1a132109
A pretty usable version
Test groups and testconfs in non-ZIP archives or ZIP archives with comments are not yet supported.
| author | Oleg Oshmyan <chortos@inbox.lv> | 
|---|---|
| date | Fri, 06 Aug 2010 15:39:29 +0000 | 
| parents | 5bfa23cd638d | 
| children | f07b7a431ea6 | 
| files | 2.00/compat.py 2.00/config.py 2.00/files.py 2.00/problem.py 2.00/test-svn.py 2.00/testcases.py 2.00/zipfile.py 2.00/zipfile2.py 2.00/zipfile3.py | 
| diffstat | 9 files changed, 3947 insertions(+), 121 deletions(-) [+] | 
line wrap: on
 line diff
--- a/2.00/compat.py Mon Jun 14 21:02:06 2010 +0000 +++ b/2.00/compat.py Fri Aug 06 15:39:29 2010 +0000 @@ -1,20 +1,51 @@ -#!/usr/bin/python +#! /usr/bin/env python # Copyright (c) 2010 Chortos-2 <chortos@inbox.lv> +# A compatibility layer for Python 2.5+. This is what lets test.py +# run on all versions of Python starting with 2.5, including Python 3. + +# A few notes regarding some compatibility-driven peculiarities +# in the use of the language that can be seen in all modules: +# +# * Except statements never specify target; instead, when needed, +# the exception is taken from sys.exc_info(). Blame the incompatible +# syntaxes of the except clause in Python 2.5 and Python 3 and the lack +# of preprocessor macros in Python of any version ;P. +# +# * Keyword-only parameters are never used, even for parameters +# that should never be given in as arguments. The reason is +# the laziness of some Python developers who have failed to finish +# implementing them in Python 2 even though they had several years +# of time and multiple version releases to sneak them in. +# +# * Abstract classes are only implemented for Python 2.6 and 2.7. +# ABC's require the abc module and the specification of metaclasses, +# but in Python 2.5, the abc module does not exist, while in Python 3, +# metaclasses are specified using a syntax totally incompatible +# with Python 2 and not usable conditionally via exec() and such +# because it is a detail of the syntax of the class statement itself. + +__all__ = ('say', 'basestring', 'range', 'map', 'zip', 'filter', + 'items', 'keys', 'values', 'ABCMeta', 'abstractmethod') + try: # Python 3 exec('say = print') except SyntaxError: try: # Python 2.6/2.7 - exec('say = __builtins__["print"]') + # An alternative is exec('from __future__ import print_function; say = print'); + # if problems arise with the current line, one should try replacing it + # with this one with the future import before abandoning the idea altogether + say = __builtins__['print'] except Exception: # Python 2.5 import sys # This should fully emulate the print function of Python 2.6 in Python 2.3+ - # The error messages are taken from Python 2.6/2.7 + # The error messages are taken from Python 2.6 + # The name bindings at the bottom of this file are in effect def saytypeerror(value, name): - return TypeError(name + ' must be None, str or unicode, not ' + type(value).__name__) + return TypeError(' '.join((name, 'must be None, str or unicode, not', type(value).__name__))) def say(*values, **kwargs): sep = kwargs.pop('sep' , None) end = kwargs.pop('end' , None) @@ -25,7 +56,7 @@ if file is None: file = sys.stdout if not isinstance(sep, basestring): raise saytypeerror(sep, 'sep') if not isinstance(end, basestring): raise saytypeerror(end, 'end') - file.write(sep.join((str(i) for i in values)) + end) + file.write(sep.join(map(str, values)) + end) def import_urllib(): try: @@ -35,4 +66,47 @@ except ImportError: # Python 2 import urllib - return urllib, lambda url: urllib.urlopen(url).read() \ No newline at end of file + return urllib, lambda url: urllib.urlopen(url).read() + +try: + from abc import ABCMeta, abstractmethod +except ImportError: + ABCMeta, abstractmethod = None, lambda x: x + +# In all of the following, the try clause is for Python 2 and the except +# clause is for Python 3. More checks are performed than needed +# for standard builds of Python to ensure as much as possible works +# on custom builds. +try: + basestring = basestring +except NameError: + basestring = str + +try: + range = xrange +except NameError: + range = range + +try: + from itertools import imap as map +except ImportError: + map = map + +try: + from itertools import izip as zip +except ImportError: + zip = zip + +try: + from itertools import ifilter as filter +except ImportError: + filter = filter + +items = dict.iteritems if hasattr(dict, 'iteritems') else dict.items +keys = dict.iterkeys if hasattr(dict, 'iterkeys') else dict.keys +values = dict.itervalues if hasattr(dict, 'itervalues') else dict.values + +for name in __all__: + __builtins__[name] = globals()[name] + +__builtins__['xrange'] = range \ No newline at end of file
--- a/2.00/config.py Mon Jun 14 21:02:06 2010 +0000 +++ b/2.00/config.py Fri Aug 06 15:39:29 2010 +0000 @@ -1,8 +1,128 @@ -#!/usr/bin/python +#! /usr/bin/env python # Copyright (c) 2010 Chortos-2 <chortos@inbox.lv> -import os -tasknames = (os.path.curdir,) +from __future__ import division, with_statement + +try: + import files +except ImportError: + import __main__ + __main__.import_error(sys.exc_info()[1]) +else: + from __main__ import options + +if files.ZipArchive: + try: + import zipimport + except ImportError: + zipimport = None +else: + zipimport = None + +import imp, os, sys + +__all__ = 'load_problem', 'load_global', 'globalconf' + +defaults_problem = {'usegroups': False, + 'maxtime': None, + 'maxmemory': None, + 'dummies': {}, + 'testsexcluded': (), + 'padtests': 0, + 'paddummies': 0, + 'taskweight': 100, + 'pointmap': {}, + 'stdio': False, + 'dummyinname': '', + 'dummyoutname': '', + 'tester': None, + 'maxexitcode': 0, + 'inname': '', + 'ansname': ''} +patterns = ('inname', 'outname', 'ansname', 'testcaseinname', + 'testcaseoutname', 'dummyinname', 'dummyoutname') +defaults_global = {'tasknames': None, + 'force_zero_exitcode': True} + +class Config(object): + __slots__ = 'modules', '__dict__' + + def __init__(self, *modules): + self.modules = modules + + def __getattr__(self, name): + for module in self.modules: + try: + return getattr(module, name) + except AttributeError: + pass + # TODO: provide a message + raise AttributeError(name) -def load_problem(name): - return object() \ No newline at end of file +def load_problem(problem_name): + dwb = sys.dont_write_bytecode + sys.dont_write_bytecode = True + metafile = files.File('/'.join((problem_name, 'testconf.py')), True, 'configuration') + module = None + if zipimport and isinstance(metafile.archive, files.ZipArchive): + try: + module = zipimport.zipimporter(os.path.dirname(metafile.full_real_path)).load_module('testconf') + except zipimport.ZipImportError: + pass + else: + del sys.modules['testconf'] + if not module: + with metafile.open() as f: + module = imp.load_module('testconf', f, metafile.full_real_path, ('.py', 'r', imp.PY_SOURCE)) + del sys.modules['testconf'] + if hasattr(module, 'padwithzeroestolength'): + if not hasattr(module, 'padtests'): + try: + module.padtests = module.padwithzeroestolength[0] + except TypeError: + module.padtests = module.padwithzeroestolength + if not hasattr(module, 'paddummies'): + try: + module.paddummies = module.padwithzeroestolength[1] + except TypeError: + module.paddummies = module.padwithzeroestolength + for name in defaults_problem: + if not hasattr(globalconf, name): + setattr(module, name, getattr(module, name, defaults_problem[name])) + for name in patterns: + if hasattr(module, name): + setattr(module, name, getattr(module, name).replace('%', problem_name)) + if not hasattr(module, 'path'): + if hasattr(module, 'name'): + module.path = module.name + elif sys.platform != 'win32': + module.path = os.path.join(os.path.curdir, problem_name) + else: + module.path = problem_name + if options.no_maxtime: + module.maxtime = 0 + sys.dont_write_bytecode = dwb + return Config(module, globalconf) + +def load_global(): + dwb = sys.dont_write_bytecode + sys.dont_write_bytecode = True + metafile = files.File('testconf.py', True, 'configuration') + module = None + if zipimport and isinstance(metafile.archive, files.ZipArchive): + try: + module = zipimport.zipimporter(os.path.dirname(metafile.full_real_path)).load_module('testconf') + except zipimport.ZipImportError: + pass + else: + del sys.modules['testconf'] + if not module: + with metafile.open() as f: + module = imp.load_module('testconf', f, metafile.full_real_path, ('.py', 'r', imp.PY_SOURCE)) + del sys.modules['testconf'] + for name in defaults_global: + setattr(module, name, getattr(module, name, defaults_global[name])) + global globalconf + globalconf = module + sys.dont_write_bytecode = dwb + return module \ No newline at end of file
--- a/2.00/files.py Mon Jun 14 21:02:06 2010 +0000 +++ b/2.00/files.py Fri Aug 06 15:39:29 2010 +0000 @@ -1,25 +1,251 @@ -#!/usr/bin/python +#! /usr/bin/env python # Copyright (c) 2010 Chortos-2 <chortos@inbox.lv> -import os -tasknames = (os.path.curdir,) +"""File access routines and classes with support for archives.""" + +from __future__ import division, with_statement + +try: + from compat import * +except ImportError: + import __main__ + __main__.import_error(sys.exc_info()[1]) + +import contextlib, os, shutil, sys + +# You don't need to know about anything else. +__all__ = 'File', + +# In these two variables, use full stops no matter what os.extsep is; +# all full stops will be converted to os.extsep on the fly +archives = 'tests.tar', 'tests.zip', 'tests.tgz', 'tests.tar.gz', 'tests.tbz2', 'tests.tar.bz2' +formats = {} + +class Archive(object): + __slots__ = 'file' + + if ABCMeta: + __metaclass__ = ABCMeta + + def __new__(cls, path): + """ + Create a new instance of the archive class corresponding + to the file name in the given path. + """ + if cls is not Archive: + return object.__new__(cls) + else: + # Do this by hand rather than through os.path.splitext + # because we support multi-dotted file name extensions + ext = path.partition(os.path.extsep)[2] + while ext: + if ext in formats: + return formats[ext](path) + ext = ext.partition(os.path.extsep)[2] + raise LookupError("unsupported archive file name extension in file name '%s'" % filename) + + @abstractmethod + def __init__(self, path): raise NotImplementedError + + @abstractmethod + def extract(self, name, target): raise NotImplementedError + + def __del__(self): + del self.file -class Files(object): - __slots__ = 'name', 'paths' - stdpaths = '%/', '%/^:%/', '%/^:', 'tests/%/', 'tests/', '^:%/', '^:', '' +try: + import tarfile - def __init__(self, name, paths = stdpaths): - self.name = name - self.paths = paths + class TarArchive(Archive): + __slots__ = '__namelist' + + def __init__(self, path): + self.file = tarfile.open(path) + + def extract(self, name, target): + member = self.file.getmember(name) + member.name = target + self.file.extract(member) + + # TODO: somehow automagically emulate universal line break support + def open(self, name): + return self.file.extractfile(name) + + def exists(self, queried_name): + if not hasattr(self, '__namelist'): + names = set() + for name in self.file.getnames(): + cutname = name + while cutname: + names.add(cutname) + cutname = cutname.rpartition('/')[0] + self.__namelist = frozenset(names) + return queried_name in self.__namelist + + def __enter__(self): + if hasattr(self.file, '__enter__'): + self.file.__enter__() + return self + + def __exit__(self, exc_type, exc_value, traceback): + if hasattr(self.file, '__exit__'): + return self.file.__exit__(exc_type, exc_value, traceback) + elif exc_type is None: + self.file.close() + else: + # This code was shamelessly copied from tarfile.py of Python 2.7 + if not self.file._extfileobj: + self.file.fileobj.close() + self.file.closed = True + + formats['tar'] = formats['tgz'] = formats['tar.gz'] = formats['tbz2'] = formats['tar.bz2'] = TarArchive +except ImportError: + TarArchive = None + +try: + import zipfile - def __iter__(self): - for path in paths: - p = getpath(path, self.name) - if isfile(p): - yield p + class ZipArchive(Archive): + __slots__ = '__namelist' + + def __init__(self, path): + self.file = zipfile.ZipFile(path) + + def extract(self, name, target): + if os.path.isabs(target): + # To my knowledge, this is as portable as it gets + path = os.path.join(os.path.splitdrive(target)[0], os.path.sep) + else: + path = None + + member = self.file.getinfo(name) + # FIXME: 2.5 lacks os.path.realpath + member.filename = os.path.relpath(target, path) + # FIXME: 2.5 lacks ZipFile.extract + self.file.extract(member, path) + + def open(self, name): + return self.file.open(name, 'rU') + + def exists(self, queried_name): + if not hasattr(self, '__namelist'): + names = set() + for name in self.file.namelist(): + cutname = name + while cutname: + names.add(cutname) + cutname = cutname.rpartition('/')[0] + self.__namelist = frozenset(names) + return queried_name in self.__namelist + + def __enter__(self): + if hasattr(self.file, '__enter__'): + self.file.__enter__() + return self + + def __exit__(self, exc_type, exc_value, traceback): + if hasattr(self.file, '__exit__'): + return self.file.__exit__(exc_type, exc_value, traceback) + else: + return self.file.close() + + formats['zip'] = ZipArchive +except ImportError: + ZipArchive = None + +# Remove unsupported archive formats and replace full stops +# with the platform-dependent file name extension separator +def issupported(filename, formats=formats): + ext = filename.partition('.')[2] + while ext: + if ext in formats: return True + ext = ext.partition('.')[2] + return False +archives = [filename.replace('.', os.path.extsep) for filename in filter(issupported, archives)] +formats = dict((item[0].replace('.', os.path.extsep), item[1]) for item in items(formats)) + +open_archives = {} + +def open_archive(path): + if path in open_archives: + return open_archives[path] + else: + open_archives[path] = archive = Archive(path) + return archive -def isfile(path): - return os.path.isfile(path) - -def getpath(path, name): - return path + name \ No newline at end of file +class File(object): + __slots__ = 'virtual_path', 'real_path', 'full_real_path', 'archive' + + def __init__(self, virtpath, allow_root=False, msg='test data'): + self.virtual_path = virtpath + self.archive = None + if not self.realize_path('', tuple(comp.replace('.', os.path.extsep) for comp in virtpath.split('/')), allow_root): + raise IOError("%s file '%s' could not be found" % (msg, virtpath)) + + def realize_path(self, root, virtpath, allow_root=False, hastests=False): + if root and not os.path.exists(root): + return False + if len(virtpath) > 1: + if self.realize_path(os.path.join(root, virtpath[0]), virtpath[1:], allow_root, hastests): + return True + elif not hastests: + if self.realize_path(os.path.join(root, 'tests'), virtpath, allow_root, True): + return True + for archive in archives: + path = os.path.join(root, archive) + if os.path.exists(path): + if self.realize_path_archive(open_archive(path), '', virtpath, path): + return True + elif self.realize_path(root, virtpath[1:], allow_root, hastests): + return True + else: + if not hastests: + path = os.path.join(root, 'tests', virtpath[0]) + if os.path.exists(path): + self.full_real_path = self.real_path = path + return True + for archive in archives: + path = os.path.join(root, archive) + if os.path.exists(path): + if self.realize_path_archive(open_archive(path), '', virtpath, path): + return True + if hastests or allow_root: + path = os.path.join(root, virtpath[0]) + if os.path.exists(path): + self.full_real_path = self.real_path = path + return True + return False + + def realize_path_archive(self, archive, root, virtpath, archpath): + if root and not archive.exists(root): + return False + if root: path = ''.join((root, '/', virtpath[0])) + else: path = virtpath[0] + if len(virtpath) > 1: + if self.realize_path_archive(archive, path, virtpath[1:], archpath): + return True + elif self.realize_path_archive(archive, root, virtpath[1:], archpath): + return True + else: + if archive.exists(path): + self.archive = archive + self.real_path = path + self.full_real_path = os.path.join(archpath, *path.split('/')) + return True + return False + + def open(self): + if self.archive: + file = self.archive.open(self.real_path) + if hasattr(file, '__exit__'): + return file + else: + return contextlib.closing(file) + else: + return open(self.real_path) + + def copy(self, target): + if self.archive: + self.archive.extract(self.real_path, target) + else: + shutil.copy(self.real_path, target) \ No newline at end of file
--- a/2.00/problem.py Mon Jun 14 21:02:06 2010 +0000 +++ b/2.00/problem.py Fri Aug 06 15:39:29 2010 +0000 @@ -1,14 +1,44 @@ -#!/usr/bin/python +#! /usr/bin/env python # Copyright (c) 2010 Chortos-2 <chortos@inbox.lv> +from __future__ import division, with_statement + +try: + from compat import * + import config, testcases +except ImportError: + import __main__ + __main__.import_error(sys.exc_info()[1]) +else: + from __main__ import clock + +import sys, re + try: - import config as _config, testcases as _testcases -except ImportError as e: - import __main__ - __main__.import_error(e) + import signal +except ImportError: + signalnames = () +else: + # Construct a cache of all signal names available on the current + # platform. Prefer names from the UNIX standards over other versions. + unixnames = frozenset(('HUP', 'INT', 'QUIT', 'ILL', 'ABRT', 'FPE', 'KILL', 'SEGV', 'PIPE', 'ALRM', 'TERM', 'USR1', 'USR2', 'CHLD', 'CONT', 'STOP', 'TSTP', 'TTIN', 'TTOU', 'BUS', 'POLL', 'PROF', 'SYS', 'TRAP', 'URG', 'VTALRM', 'XCPU', 'XFSZ')) + signalnames = {} + for name in dir(signal): + if re.match('SIG[A-Z]+$', name): + value = signal.__dict__[name] + if isinstance(value, int) and (value not in signalnames or signalnames[value][3:] not in unixnames): + signalnames[value] = name + del unixnames +__all__ = 'Problem', + +# This should no more be needed; pass all work on to the TestCase inheritance tree # LIBRARY and STDIO refer to interactive aka reactive problems -BATCH, OUTONLY, LIBRARY, STDIO, BESTOUT = xrange(5) +#BATCH, OUTONLY, LIBRARY, STDIO, BESTOUT = xrange(5) + +class Cache(object): + def __init__(self, mydict): + self.__dict__ = mydict class Problem(object): __slots__ = 'name', 'config', 'cache', 'testcases' @@ -16,16 +46,111 @@ def __init__(prob, name): if not isinstance(name, basestring): # This shouldn't happen, of course - raise TypeError, "Problem() argument 1 must be string, not " + str(type(name)).split('\'')[1] + raise TypeError('Problem() argument 1 must be string, not ' + type(name).__name__) prob.name = name - prob.config = _config.load_problem(name) - prob.cache = type('Cache', (object,), {'padoutputtolength': 0})() - prob.testcases = _testcases.load_problem(prob) + prob.config = config.load_problem(name) + if not getattr(prob.config, 'kind', None): prob.config.kind = 'batch' + prob.cache = Cache({'padoutput': 0, 'usegroups': False}) + prob.testcases = testcases.load_problem(prob) + + # TODO + def build(prob): + raise NotImplementedError def test(prob): - real = max = 0 + real = max = ntotal = nvalued = ncorrect = ncorrectvalued = 0 for case in prob.testcases: - r, m = case() - real += r - max += m - return real, max \ No newline at end of file + ntotal += 1 + max += case.points + if case.points: nvalued += 1 + granted = 0 + id = str(case.id) + if case.isdummy: + id = 'sample ' + id + say('%*s: ' % (prob.cache.padoutput, id), end='') + sys.stdout.flush() + try: + granted = case() + except KeyboardInterrupt: + if not hasattr(case, 'time_stopped'): + # Too quick! The testing has not even started! + raise + verdict = 'canceled by the user' + except testcases.TimeLimitExceeded: + verdict = 'time limit exceeded' + except testcases.WrongAnswer: + e = sys.exc_info()[1] + if e.comment: + verdict = 'wrong answer (%s)' % e.comment + else: + verdict = 'wrong answer' + except testcases.NonZeroExitCode: + e = sys.exc_info()[1] + if e.exitcode < 0: + if sys.platform == 'win32': + verdict = 'terminated with error 0x%X' % (e.exitcode + 0x100000000) + elif -e.exitcode in signalnames: + verdict = 'terminated by signal %d (%s)' % (-e.exitcode, signalnames[-e.exitcode]) + else: + verdict = 'terminated by signal %d' % -e.exitcode + else: + verdict = 'non-zero return code %d' % e.exitcode + except testcases.CannotStartTestee: + e = sys.exc_info()[1] + if e.upstream.strerror: + verdict = 'cannot launch the program to test (%s)' % e.upstream.strerror.lower() + else: + verdict = 'cannot launch the program to test' + except testcases.CannotStartValidator: + e = sys.exc_info()[1] + if e.upstream.strerror: + verdict = 'cannot launch the validator (%s)' % e.upstream.strerror.lower() + else: + verdict = 'cannot launch the validator' + except testcases.CannotReadOutputFile: + e = sys.exc_info()[1] + if e.upstream.strerror: + verdict = 'cannot read the output file (%s)' % e.upstream.strerror.lower() + else: + verdict = 'cannot read the output file' + except testcases.CannotReadInputFile: + e = sys.exc_info()[1] + if e.upstream.strerror: + verdict = 'cannot read the input file (%s)' % e.upstream.strerror.lower() + else: + verdict = 'cannot read the input file' + except testcases.CannotReadAnswerFile: + e = sys.exc_info()[1] + if e.upstream.strerror: + verdict = 'cannot read the reference output file (%s)' % e.upstream.strerror.lower() + else: + verdict = 'cannot read the reference output file' + except testcases.TestCaseNotPassed: + e = sys.exc_info()[1] + verdict = 'unspecified reason [this may be a bug in test.py] (%s)' % e + #except Exception: + # e = sys.exc_info()[1] + # verdict = 'unknown error [this may be a bug in test.py] (%s)' % e + else: + if hasattr(granted, '__iter__'): + granted, comment = granted + if comment: + comment = ' (%s)' % comment + else: + comment = '' + if granted == case.points: + ncorrect += 1 + if granted: ncorrectvalued += 1 + verdict = 'OK' + comment + elif not granted: + verdict = 'wrong answer' + comment + else: + verdict = 'partly correct' + comment + say('%.3f%s s, %g/%g, %s' % (case.time_stopped - case.time_started, case.time_limit_string, granted, case.points, verdict)) + real += granted + weighted = real * prob.config.taskweight / max if max else 0 + if nvalued != ntotal: + say('Grand total: %d/%d tests (%d/%d valued); %g/%g points; weighted score: %g/%g' % (ncorrect, ntotal, ncorrectvalued, nvalued, real, max, weighted, prob.config.taskweight)) + else: + say('Grand total: %d/%d tests; %g/%g points; weighted score: %g/%g' % (ncorrect, ntotal, real, max, weighted, prob.config.taskweight)) + return weighted, prob.config.taskweight
--- a/2.00/test-svn.py Mon Jun 14 21:02:06 2010 +0000 +++ b/2.00/test-svn.py Fri Aug 06 15:39:29 2010 +0000 @@ -1,9 +1,14 @@ -#!/usr/bin/python +#! /usr/bin/env python # Copyright (c) 2009-2010 Chortos-2 <chortos@inbox.lv> from __future__ import division, with_statement import optparse, sys, compat -from compat import say + +def import_error(e): + say('Error: your installation of test.py is incomplete;', str(e).lower() + '.', file=sys.stderr) + sys.exit(3) + +from compat import * # $Rev$ version = '2.00.0 (SVN r$$REV$$)' @@ -11,6 +16,9 @@ parser.add_option('-u', '--update', dest='update', action='store_true', default=False, help='check for an updated version of test.py') parser.add_option('-m', '--copy-io', dest='copyonly', action='store_true', default=False, help='create a copy of the input/output files of the last test case for manual testing and exit') parser.add_option('-x', '--auto-exit', dest='pause', action='store_false', default=True, help='do not wait for a key to be pressed after finishing testing') +parser.add_option('-s', '--save-io', dest='erase', action='store_false', default=True, help='do not delete the copies of input/output files after the last test case; create copies of input files and store output in files even if the solution uses standard I/O; delete the stored input/output files if the solution uses standard I/O and the -c/--cleanup option is specified') +parser.add_option('-t', '--detect-time', dest='autotime', action='store_true', default=False, help='spend a second detecting the most precise time measurement function') +parser.add_option('--no-time-limits', dest='no_maxtime', action='store_true', default=False, help='disable all time limits') options, args = parser.parse_args() parser.destroy() @@ -49,73 +57,115 @@ say('Downloaded and installed. Now you are using test.py ' + latesttext + '.') sys.exit() -def import_error(e): - say('Your installation of test.py is incomplete:', str(e).lower() + '.', file=sys.stderr) - sys.exit(3) - -import os, config +import config, itertools, os, sys, time -# Do this check here so that if we have to warn them, we do it as early as possible -if options.pause and not hasattr(config, 'pause'): - try: - # If we have getch, we don't need config.pause - import msvcrt - msvcrt.getch.__call__ - except Exception: - if os.name == 'posix': - config.pause = 'read -s -n 1' - say('Warning: configuration variable pause is not defined; it was devised automatically but the choice might be incorrect, so test.py might exit immediately after the testing is completed.') - sys.stdout.flush() - elif os.name == 'nt': - config.pause = 'pause' - else: - sys.exit('Error: configuration variable pause is not defined and cannot be devised automatically.') +if options.autotime: + c = time.clock() + time.sleep(1) + c = time.clock() - c + if int(c + .5) == 1: + clock = time.clock + else: + clock = time.time +elif sys.platform == 'win32': + clock = time.clock +else: + clock = time.time try: - from problem import * -except ImportError as e: - import_error(e) + globalconf = config.load_global() -# Support single-problem configurations -try: - shouldprintnames = len(config.tasknames) > 1 -except Exception: - shouldprintnames = True + # Do this check here so that if we have to warn them, we do it as early as possible + if options.pause and not hasattr(globalconf, 'pause'): + try: + # If we have getch, we don't need config.pause + import msvcrt + msvcrt.getch.__call__ + except Exception: + if os.name == 'posix': + globalconf.pause = 'read -s -n 1' + say('Warning: configuration variable pause is not defined; it was devised automatically but the choice might be incorrect, so test.py might exit immediately after the testing is completed.') + sys.stdout.flush() + elif os.name == 'nt': + globalconf.pause = 'pause' + else: + sys.exit('Error: configuration variable pause is not defined and cannot be devised automatically.') -ntasks = 0 -nfulltasks = 0 -maxscore = 0 -realscore = 0 + try: + from problem import * + except ImportError: + import_error(sys.exc_info()[1]) -for taskname in config.tasknames: - problem = Problem(taskname) - - if ntasks: say() - if shouldprintnames: say(taskname) - - if options.copyonly: - problem.copytestdata() + # Support single-problem configurations + if globalconf.tasknames is None: + shouldprintnames = False + globalconf.multiproblem = False + globalconf.tasknames = os.path.curdir, else: - real, max = problem.test() - - ntasks += 1 - nfulltasks += (real == max) - realscore += real - maxscore += max + globalconf.multiproblem = True + try: + shouldprintnames = len(globalconf.tasknames) > 1 + except Exception: + # Try to retrieve the first two problem names and cache them on success + globalconf.tasknames = iter(globalconf.tasknames) + try: + try: + first = next(globalconf.tasknames) + except NameError: + # Python 2.5 lacks the next() built-in + first = globalconf.tasknames.next() + except StopIteration: + globalconf.tasknames = () + shouldprintnames = False + else: + try: + try: + second = next(globalconf.tasknames) + except NameError: + second = globalconf.tasknames.next() + except StopIteration: + globalconf.tasknames = first, + shouldprintnames = False + else: + globalconf.tasknames = itertools.chain((first, second), globalconf.tasknames) + shouldprintnames = True -if options.copyonly: - sys.exit() + ntasks = 0 + nfulltasks = 0 + maxscore = 0 + realscore = 0 -if ntasks != 1: - say() - say('Grand grand total: %g/%g weighted points; %d/%d problems solved fully' % (realscore, maxscore, nfulltasks, ntasks)) + for taskname in globalconf.tasknames: + problem = Problem(taskname) + + if ntasks: say() + if shouldprintnames: say(taskname) + + if options.copyonly: + problem.copytestdata() + else: + real, max = problem.test() + + ntasks += 1 + nfulltasks += (real == max) + realscore += real + maxscore += max + + if options.copyonly: + sys.exit() + + if ntasks != 1: + say() + say('Grand grand total: %g/%g weighted points; %d/%d problems solved fully' % (realscore, maxscore, nfulltasks, ntasks)) +except KeyboardInterrupt: + sys.exit('Exiting due to a keyboard interrupt.') if options.pause: - say('Press any key to exit...', end='') + say('Press any key to exit...') sys.stdout.flush() try: import msvcrt msvcrt.getch() except Exception: - os.system(config.pause + ' >' + os.devnull) \ No newline at end of file + os.system(globalconf.pause + ' >' + os.devnull) \ No newline at end of file
--- a/2.00/testcases.py Mon Jun 14 21:02:06 2010 +0000 +++ b/2.00/testcases.py Fri Aug 06 15:39:29 2010 +0000 @@ -1,35 +1,380 @@ -#!/usr/bin/python +#! /usr/bin/env python # Copyright (c) 2010 Chortos-2 <chortos@inbox.lv> +from __future__ import division, with_statement + +try: + from compat import * + import files, problem, config +except ImportError: + import __main__ + __main__.import_error(sys.exc_info()[1]) +else: + from __main__ import clock, options + +import glob, re, sys, tempfile, time +from subprocess import Popen, PIPE, STDOUT + +import os +devnull = open(os.path.devnull, 'w+') + +try: + from signal import SIGTERM, SIGKILL +except ImportError: + SIGTERM = 15 + SIGKILL = 9 + try: - import files as _files, problem as _problem -except ImportError as e: - import __main__ - __main__.import_error(e) + from _subprocess import TerminateProcess +except ImportError: + # CPython 2.5 does define _subprocess.TerminateProcess even though it is + # not used in the subprocess module, but maybe something else does not + try: + import ctypes + TerminateProcess = ctypes.windll.kernel32.TerminateProcess + except (ImportError, AttributeError): + TerminateProcess = None + +__all__ = ('TestCase', 'load_problem', 'TestCaseNotPassed', + 'TimeLimitExceeded', 'WrongAnswer', 'NonZeroExitCode', + 'CannotStartTestee', 'CannotStartValidator', + 'CannotReadOutputFile') + + + +# Exceptions + +class TestCaseNotPassed(Exception): __slots__ = () +class TimeLimitExceeded(TestCaseNotPassed): __slots__ = () + +class WrongAnswer(TestCaseNotPassed): + __slots__ = 'comment' + def __init__(self, comment=''): + self.comment = comment + +class NonZeroExitCode(TestCaseNotPassed): + __slots__ = 'exitcode' + def __init__(self, exitcode): + self.exitcode = exitcode + +class ExceptionWrapper(TestCaseNotPassed): + __slots__ = 'upstream' + def __init__(self, upstream): + self.upstream = upstream + +class CannotStartTestee(ExceptionWrapper): __slots__ = () +class CannotStartValidator(ExceptionWrapper): __slots__ = () +class CannotReadOutputFile(ExceptionWrapper): __slots__ = () +class CannotReadInputFile(ExceptionWrapper): __slots__ = () +class CannotReadAnswerFile(ExceptionWrapper): __slots__ = () + + + +# Test case types class TestCase(object): - __slots__ = 'problem', 'infile', 'outfile' + __slots__ = ('problem', 'id', 'isdummy', 'infile', 'outfile', 'points', + 'process', 'time_started', 'time_stopped', 'time_limit_string', + 'realinname', 'realoutname', 'maxtime', 'maxmemory') + + if ABCMeta: + __metaclass__ = ABCMeta - def __init__(case, prob, infile, outfile): + def __init__(case, prob, id, isdummy, points): case.problem = prob - case.infile = infile - case.outfile = outfile + case.id = id + case.isdummy = isdummy + case.points = points + case.maxtime = case.problem.config.maxtime + case.maxmemory = case.problem.config.maxmemory + if case.maxtime: + case.time_limit_string = '/%.3f' % case.maxtime + else: + case.time_limit_string = '' + if not isdummy: + case.realinname = case.problem.config.testcaseinname + case.realoutname = case.problem.config.testcaseoutname + else: + case.realinname = case.problem.config.dummyinname + case.realoutname = case.problem.config.dummyoutname + + @abstractmethod + def test(case): raise NotImplementedError def __call__(case): - os.copy() + try: + return case.test() + finally: + case.cleanup() + + def cleanup(case): + if not getattr(case, 'time_started', None): + case.time_started = case.time_stopped = clock() + elif not getattr(case, 'time_stopped', None): + case.time_stopped = clock() + #if getattr(case, 'infile', None): + # case.infile.close() + #if getattr(case, 'outfile', None): + # case.outfile.close() + if getattr(case, 'process', None): + # Try killing after three unsuccessful TERM attempts in a row + # (except on Windows, where TERMing is killing) + for i in range(3): + try: + try: + case.process.terminate() + except AttributeError: + # Python 2.5 + if TerminateProcess and hasattr(proc, '_handle'): + # Windows API + TerminateProcess(proc._handle, 1) + else: + # POSIX + os.kill(proc.pid, SIGTERM) + except Exception: + time.sleep(0) + case.process.poll() + else: + break + else: + # If killing the process is unsuccessful three times in a row, + # just silently stop trying + for i in range(3): + try: + try: + case.process.kill() + except AttributeError: + # Python 2.5 + if TerminateProcess and hasattr(proc, '_handle'): + # Windows API + TerminateProcess(proc._handle, 1) + else: + # POSIX + os.kill(proc.pid, SIGKILL) + except Exception: + time.sleep(0) + case.process.poll() + else: + break + + def open_infile(case): + try: + case.infile = files.File('/'.join((case.problem.name, case.realinname.replace('$', case.id)))) + except IOError: + e = sys.exc_info()[1] + raise CannotReadInputFile(e) + + def open_outfile(case): + try: + case.outfile = files.File('/'.join((case.problem.name, case.realoutname.replace('$', case.id)))) + except IOError: + e = sys.exc_info()[1] + raise CannotReadAnswerFile(e) + -def load_problem(prob): +class ValidatedTestCase(TestCase): + __slots__ = 'validator' + + def __init__(case, *args): + TestCase.__init__(case, *args) + if not case.problem.config.tester: + case.validator = None + else: + case.validator = case.problem.config.tester + + # TODO + def validate(case, output): + if not case.validator: + # Compare the output with the reference output + case.open_outfile() + with case.outfile.open() as refoutput: + for line, refline in zip(output, refoutput): + if not isinstance(refline, basestring): + line = bytes(line, sys.getdefaultencoding()) + if line != refline: + raise WrongAnswer() + try: + try: + next(output) + except NameError: + output.next() + except StopIteration: + pass + else: + raise WrongAnswer() + try: + try: + next(refoutput) + except NameError: + refoutput.next() + except StopIteration: + pass + else: + raise WrongAnswer() + return case.points + elif callable(case.validator): + return case.validator(output) + else: + # Call the validator program + output.close() + case.open_outfile() + if case.problem.config.ansname: + case.outfile.copy(case.problem.config.ansname) + case.process = Popen(case.validator, stdin=devnull, stdout=PIPE, stderr=STDOUT, universal_newlines=True, bufsize=-1) + comment = case.process.communicate()[0].strip() + lower = comment.lower() + match = re.match(r'(ok|correct|wrong(?:(?:\s|_)*answer)?)(?:$|\s+|[.,!:]+\s*)', lower) + if match: + comment = comment[match.end():] + if not case.problem.config.maxexitcode: + if case.process.returncode: + raise WrongAnswer(comment) + else: + return case.points, comment + else: + return case.points * case.process.returncode / case.problem.config.maxexitcode, comment + + +class BatchTestCase(ValidatedTestCase): + __slots__ = () + + def test(case): + if sys.platform == 'win32' or not case.maxmemory: + preexec_fn = None + else: + def preexec_fn(): + try: + import resource + maxmemory = int(case.maxmemory * 1048576) + resource.setrlimit(resource.RLIMIT_AS, (maxmemory, maxmemory)) + # I would also set a CPU time limit but I do not want the time + # that passes between the calls to fork and exec to be counted in + except MemoryError: + # We do not have enough memory for ourselves; + # let the parent know about this + raise + except Exception: + # Well, at least we tried + pass + case.open_infile() + case.time_started = None + if case.problem.config.stdio: + if options.erase and not case.validator: + # FIXME: 2.5 lacks the delete parameter + with tempfile.NamedTemporaryFile(delete=False) as f: + inputdatafname = f.name + else: + inputdatafname = case.problem.config.inname + case.infile.copy(inputdatafname) + # FIXME: inputdatafname should be deleted on __exit__ + with open(inputdatafname, 'rU') as infile: + with tempfile.TemporaryFile('w+') if options.erase and not case.validator else open(case.problem.config.outname, 'w+') as outfile: + try: + try: + case.process = Popen(case.problem.config.path, stdin=infile, stdout=outfile, stderr=devnull, universal_newlines=True, bufsize=-1, preexec_fn=preexec_fn) + except MemoryError: + # If there is not enough memory for the forked test.py, + # opt for silent dropping of the limit + case.process = Popen(case.problem.config.path, stdin=infile, stdout=outfile, stderr=devnull, universal_newlines=True, bufsize=-1) + except OSError: + raise CannotStartTestee(sys.exc_info()[1]) + case.time_started = clock() + # If we use a temporary file, it may not be a true file object, + # and if so, Popen will relay the standard output through pipes + if not case.maxtime: + case.process.communicate() + case.time_stopped = clock() + else: + time_end = case.time_started + case.maxtime + # FIXME: emulate communicate() + while True: + exitcode = case.process.poll() + now = clock() + if exitcode is not None: + case.time_stopped = now + break + elif now >= time_end: + raise TimeLimitExceeded() + if config.globalconf.force_zero_exitcode and case.process.returncode: + raise NonZeroExitCode(case.process.returncode) + outfile.seek(0) + return case.validate(outfile) + else: + if case.problem.config.inname: + case.infile.copy(case.problem.config.inname) + try: + try: + case.process = Popen(case.problem.config.path, stdin=devnull, stdout=devnull, stderr=STDOUT, preexec_fn=preexec_fn) + except MemoryError: + # If there is not enough memory for the forked test.py, + # opt for silent dropping of the limit + case.process = Popen(case.problem.config.path, stdin=devnull, stdout=devnull, stderr=STDOUT) + except OSError: + raise CannotStartTestee(sys.exc_info()[1]) + case.time_started = clock() + if not case.maxtime: + case.process.wait() + case.time_stopped = clock() + else: + time_end = case.time_started + case.maxtime + while True: + exitcode = case.process.poll() + now = clock() + if exitcode is not None: + case.time_stopped = now + break + elif now >= time_end: + raise TimeLimitExceeded() + if config.globalconf.force_zero_exitcode and case.process.returncode: + raise NonZeroExitCode(case.process.returncode) + with open(case.problem.config.outname, 'rU') as output: + return case.validate(output) + + +# This is the only test case type not executing any programs to be tested +class OutputOnlyTestCase(ValidatedTestCase): + __slots__ = () + def cleanup(case): pass + +class BestOutputTestCase(ValidatedTestCase): + __slots__ = () + +# This is the only test case type executing two programs simultaneously +class ReactiveTestCase(TestCase): + __slots__ = () + # The basic idea is to launch the program to be tested and the grader + # and to pipe their standard I/O from and to each other, + # and then to capture the grader's exit code and use it + # like the exit code of a test validator is used. + + +def load_problem(prob, _types={'batch' : BatchTestCase, + 'outonly' : OutputOnlyTestCase, + 'bestout' : BestOutputTestCase, + 'reactive': ReactiveTestCase}): if prob.config.usegroups: pass else: + # We will need to iterate over these configuration variables twice + try: + len(prob.config.dummies) + except Exception: + prob.config.dummies = tuple(prob.config.dummies) + try: + len(prob.config.tests) + except Exception: + prob.config.dummies = tuple(prob.config.tests) + # First get prob.cache.padoutput right + for i in prob.config.dummies: + s = 'sample ' + str(i).zfill(prob.config.paddummies) + prob.cache.padoutput = max(prob.cache.padoutput, len(s)) for i in prob.config.tests: - s = str(i).zfill(prob.config.padwithzeroestolength) - prob.cache.padoutputtolength = max(prob.cache.padoutputtolength, len(s)) - infile = _files.TestCaseFile(prob, prob.config.testcaseinname.replace('$', s)) - if infile: - if prob.config.kind != _problem.BATCH: - yield TestCase(prob, infile, None) - else: - outfile = _files.TestCaseFile(prob, prob.config.testcaseoutname.replace('$', s)) - if outfile: - yield TestCase(prob, infile, outfile) \ No newline at end of file + s = str(i).zfill(prob.config.padtests) + prob.cache.padoutput = max(prob.cache.padoutput, len(s)) + # Now yield the actual test cases + for i in prob.config.dummies: + s = str(i).zfill(prob.config.paddummies) + yield _types[prob.config.kind](prob, s, True, 0) + for i in prob.config.tests: + s = str(i).zfill(prob.config.padtests) + yield _types[prob.config.kind](prob, s, False, prob.config.pointmap.get(i, prob.config.pointmap.get(None, prob.config.maxexitcode if prob.config.maxexitcode else 1))) \ No newline at end of file
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/2.00/zipfile.py Fri Aug 06 15:39:29 2010 +0000 @@ -0,0 +1,5 @@ +import sys +if sys.version_info[0] == 3: + from zipfile3 import * +else: + from zipfile2 import * \ No newline at end of file
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/2.00/zipfile2.py Fri Aug 06 15:39:29 2010 +0000 @@ -0,0 +1,1427 @@ +""" +Read and write ZIP files. +""" +# Improved by Chortos-2 in 2010 (added bzip2 support) +import struct, os, time, sys, shutil +import binascii, cStringIO, stat +import io +import re + +try: + import zlib # We may need its compression method + crc32 = zlib.crc32 +except ImportError: + zlib = None + crc32 = binascii.crc32 + +try: + import bz2 # We may need its compression method +except ImportError: + bz2 = None + +__all__ = ["BadZipfile", "error", "ZIP_STORED", "ZIP_DEFLATED", "is_zipfile", + "ZipInfo", "ZipFile", "PyZipFile", "LargeZipFile", "ZIP_BZIP2" ] + +class BadZipfile(Exception): + pass + + +class LargeZipFile(Exception): + """ + Raised when writing a zipfile, the zipfile requires ZIP64 extensions + and those extensions are disabled. + """ + +error = BadZipfile # The exception raised by this module + +ZIP64_LIMIT = (1 << 31) - 1 +ZIP_FILECOUNT_LIMIT = 1 << 16 +ZIP_MAX_COMMENT = (1 << 16) - 1 + +# constants for Zip file compression methods +ZIP_STORED = 0 +ZIP_DEFLATED = 8 +ZIP_BZIP2 = 12 +# Other ZIP compression methods not supported + +# Below are some formats and associated data for reading/writing headers using +# the struct module. The names and structures of headers/records are those used +# in the PKWARE description of the ZIP file format: +# http://www.pkware.com/documents/casestudies/APPNOTE.TXT +# (URL valid as of January 2008) + +# The "end of central directory" structure, magic number, size, and indices +# (section V.I in the format document) +structEndArchive = "<4s4H2LH" +stringEndArchive = "PK\005\006" +sizeEndCentDir = struct.calcsize(structEndArchive) + +_ECD_SIGNATURE = 0 +_ECD_DISK_NUMBER = 1 +_ECD_DISK_START = 2 +_ECD_ENTRIES_THIS_DISK = 3 +_ECD_ENTRIES_TOTAL = 4 +_ECD_SIZE = 5 +_ECD_OFFSET = 6 +_ECD_COMMENT_SIZE = 7 +# These last two indices are not part of the structure as defined in the +# spec, but they are used internally by this module as a convenience +_ECD_COMMENT = 8 +_ECD_LOCATION = 9 + +# The "central directory" structure, magic number, size, and indices +# of entries in the structure (section V.F in the format document) +structCentralDir = "<4s4B4HL2L5H2L" +stringCentralDir = "PK\001\002" +sizeCentralDir = struct.calcsize(structCentralDir) + +# indexes of entries in the central directory structure +_CD_SIGNATURE = 0 +_CD_CREATE_VERSION = 1 +_CD_CREATE_SYSTEM = 2 +_CD_EXTRACT_VERSION = 3 +_CD_EXTRACT_SYSTEM = 4 +_CD_FLAG_BITS = 5 +_CD_COMPRESS_TYPE = 6 +_CD_TIME = 7 +_CD_DATE = 8 +_CD_CRC = 9 +_CD_COMPRESSED_SIZE = 10 +_CD_UNCOMPRESSED_SIZE = 11 +_CD_FILENAME_LENGTH = 12 +_CD_EXTRA_FIELD_LENGTH = 13 +_CD_COMMENT_LENGTH = 14 +_CD_DISK_NUMBER_START = 15 +_CD_INTERNAL_FILE_ATTRIBUTES = 16 +_CD_EXTERNAL_FILE_ATTRIBUTES = 17 +_CD_LOCAL_HEADER_OFFSET = 18 + +# The "local file header" structure, magic number, size, and indices +# (section V.A in the format document) +structFileHeader = "<4s2B4HL2L2H" +stringFileHeader = "PK\003\004" +sizeFileHeader = struct.calcsize(structFileHeader) + +_FH_SIGNATURE = 0 +_FH_EXTRACT_VERSION = 1 +_FH_EXTRACT_SYSTEM = 2 +_FH_GENERAL_PURPOSE_FLAG_BITS = 3 +_FH_COMPRESSION_METHOD = 4 +_FH_LAST_MOD_TIME = 5 +_FH_LAST_MOD_DATE = 6 +_FH_CRC = 7 +_FH_COMPRESSED_SIZE = 8 +_FH_UNCOMPRESSED_SIZE = 9 +_FH_FILENAME_LENGTH = 10 +_FH_EXTRA_FIELD_LENGTH = 11 + +# The "Zip64 end of central directory locator" structure, magic number, and size +structEndArchive64Locator = "<4sLQL" +stringEndArchive64Locator = "PK\x06\x07" +sizeEndCentDir64Locator = struct.calcsize(structEndArchive64Locator) + +# The "Zip64 end of central directory" record, magic number, size, and indices +# (section V.G in the format document) +structEndArchive64 = "<4sQ2H2L4Q" +stringEndArchive64 = "PK\x06\x06" +sizeEndCentDir64 = struct.calcsize(structEndArchive64) + +_CD64_SIGNATURE = 0 +_CD64_DIRECTORY_RECSIZE = 1 +_CD64_CREATE_VERSION = 2 +_CD64_EXTRACT_VERSION = 3 +_CD64_DISK_NUMBER = 4 +_CD64_DISK_NUMBER_START = 5 +_CD64_NUMBER_ENTRIES_THIS_DISK = 6 +_CD64_NUMBER_ENTRIES_TOTAL = 7 +_CD64_DIRECTORY_SIZE = 8 +_CD64_OFFSET_START_CENTDIR = 9 + +def _check_zipfile(fp): + try: + if _EndRecData(fp): + return True # file has correct magic number + except IOError: + pass + return False + +def is_zipfile(filename): + """Quickly see if a file is a ZIP file by checking the magic number. + + The filename argument may be a file or file-like object too. + """ + result = False + try: + if hasattr(filename, "read"): + result = _check_zipfile(fp=filename) + else: + with open(filename, "rb") as fp: + result = _check_zipfile(fp) + except IOError: + pass + return result + +def _EndRecData64(fpin, offset, endrec): + """ + Read the ZIP64 end-of-archive records and use that to update endrec + """ + fpin.seek(offset - sizeEndCentDir64Locator, 2) + data = fpin.read(sizeEndCentDir64Locator) + sig, diskno, reloff, disks = struct.unpack(structEndArchive64Locator, data) + if sig != stringEndArchive64Locator: + return endrec + + if diskno != 0 or disks != 1: + raise BadZipfile("zipfiles that span multiple disks are not supported") + + # Assume no 'zip64 extensible data' + fpin.seek(offset - sizeEndCentDir64Locator - sizeEndCentDir64, 2) + data = fpin.read(sizeEndCentDir64) + sig, sz, create_version, read_version, disk_num, disk_dir, \ + dircount, dircount2, dirsize, diroffset = \ + struct.unpack(structEndArchive64, data) + if sig != stringEndArchive64: + return endrec + + # Update the original endrec using data from the ZIP64 record + endrec[_ECD_SIGNATURE] = sig + endrec[_ECD_DISK_NUMBER] = disk_num + endrec[_ECD_DISK_START] = disk_dir + endrec[_ECD_ENTRIES_THIS_DISK] = dircount + endrec[_ECD_ENTRIES_TOTAL] = dircount2 + endrec[_ECD_SIZE] = dirsize + endrec[_ECD_OFFSET] = diroffset + return endrec + + +def _EndRecData(fpin): + """Return data from the "End of Central Directory" record, or None. + + The data is a list of the nine items in the ZIP "End of central dir" + record followed by a tenth item, the file seek offset of this record.""" + + # Determine file size + fpin.seek(0, 2) + filesize = fpin.tell() + + # Check to see if this is ZIP file with no archive comment (the + # "end of central directory" structure should be the last item in the + # file if this is the case). + try: + fpin.seek(-sizeEndCentDir, 2) + except IOError: + return None + data = fpin.read() + if data[0:4] == stringEndArchive and data[-2:] == "\000\000": + # the signature is correct and there's no comment, unpack structure + endrec = struct.unpack(structEndArchive, data) + endrec=list(endrec) + + # Append a blank comment and record start offset + endrec.append("") + endrec.append(filesize - sizeEndCentDir) + + # Try to read the "Zip64 end of central directory" structure + return _EndRecData64(fpin, -sizeEndCentDir, endrec) + + # Either this is not a ZIP file, or it is a ZIP file with an archive + # comment. Search the end of the file for the "end of central directory" + # record signature. The comment is the last item in the ZIP file and may be + # up to 64K long. It is assumed that the "end of central directory" magic + # number does not appear in the comment. + maxCommentStart = max(filesize - (1 << 16) - sizeEndCentDir, 0) + fpin.seek(maxCommentStart, 0) + data = fpin.read() + start = data.rfind(stringEndArchive) + if start >= 0: + # found the magic number; attempt to unpack and interpret + recData = data[start:start+sizeEndCentDir] + endrec = list(struct.unpack(structEndArchive, recData)) + comment = data[start+sizeEndCentDir:] + # check that comment length is correct + if endrec[_ECD_COMMENT_SIZE] == len(comment): + # Append the archive comment and start offset + endrec.append(comment) + endrec.append(maxCommentStart + start) + + # Try to read the "Zip64 end of central directory" structure + return _EndRecData64(fpin, maxCommentStart + start - filesize, + endrec) + + # Unable to find a valid end of central directory structure + return + + +class ZipInfo (object): + """Class with attributes describing each file in the ZIP archive.""" + + __slots__ = ( + 'orig_filename', + 'filename', + 'date_time', + 'compress_type', + 'comment', + 'extra', + 'create_system', + 'create_version', + 'extract_version', + 'reserved', + 'flag_bits', + 'volume', + 'internal_attr', + 'external_attr', + 'header_offset', + 'CRC', + 'compress_size', + 'file_size', + '_raw_time', + ) + + def __init__(self, filename="NoName", date_time=(1980,1,1,0,0,0)): + self.orig_filename = filename # Original file name in archive + + # Terminate the file name at the first null byte. Null bytes in file + # names are used as tricks by viruses in archives. + null_byte = filename.find(chr(0)) + if null_byte >= 0: + filename = filename[0:null_byte] + # This is used to ensure paths in generated ZIP files always use + # forward slashes as the directory separator, as required by the + # ZIP format specification. + if os.sep != "/" and os.sep in filename: + filename = filename.replace(os.sep, "/") + + self.filename = filename # Normalized file name + self.date_time = date_time # year, month, day, hour, min, sec + # Standard values: + self.compress_type = ZIP_STORED # Type of compression for the file + self.comment = "" # Comment for each file + self.extra = "" # ZIP extra data + if sys.platform == 'win32': + self.create_system = 0 # System which created ZIP archive + else: + # Assume everything else is unix-y + self.create_system = 3 # System which created ZIP archive + self.create_version = 20 # Version which created ZIP archive + self.extract_version = 20 # Version needed to extract archive + self.reserved = 0 # Must be zero + self.flag_bits = 0 # ZIP flag bits + self.volume = 0 # Volume number of file header + self.internal_attr = 0 # Internal attributes + self.external_attr = 0 # External file attributes + # Other attributes are set by class ZipFile: + # header_offset Byte offset to the file header + # CRC CRC-32 of the uncompressed file + # compress_size Size of the compressed file + # file_size Size of the uncompressed file + + def FileHeader(self): + """Return the per-file header as a string.""" + dt = self.date_time + dosdate = (dt[0] - 1980) << 9 | dt[1] << 5 | dt[2] + dostime = dt[3] << 11 | dt[4] << 5 | (dt[5] // 2) + if self.flag_bits & 0x08: + # Set these to zero because we write them after the file data + CRC = compress_size = file_size = 0 + else: + CRC = self.CRC + compress_size = self.compress_size + file_size = self.file_size + + extra = self.extra + + if file_size > ZIP64_LIMIT or compress_size > ZIP64_LIMIT: + # File is larger than what fits into a 4 byte integer, + # fall back to the ZIP64 extension + fmt = '<HHQQ' + extra = extra + struct.pack(fmt, + 1, struct.calcsize(fmt)-4, file_size, compress_size) + file_size = 0xffffffff + compress_size = 0xffffffff + self.extract_version = max(45, self.extract_version) + self.create_version = max(45, self.extract_version) + + filename, flag_bits = self._encodeFilenameFlags() + header = struct.pack(structFileHeader, stringFileHeader, + self.extract_version, self.reserved, flag_bits, + self.compress_type, dostime, dosdate, CRC, + compress_size, file_size, + len(filename), len(extra)) + return header + filename + extra + + def _encodeFilenameFlags(self): + if isinstance(self.filename, unicode): + try: + return self.filename.encode('ascii'), self.flag_bits + except UnicodeEncodeError: + return self.filename.encode('utf-8'), self.flag_bits | 0x800 + else: + return self.filename, self.flag_bits + + def _decodeFilename(self): + if self.flag_bits & 0x800: + return self.filename.decode('utf-8') + else: + return self.filename + + def _decodeExtra(self): + # Try to decode the extra field. + extra = self.extra + unpack = struct.unpack + while extra: + tp, ln = unpack('<HH', extra[:4]) + if tp == 1: + if ln >= 24: + counts = unpack('<QQQ', extra[4:28]) + elif ln == 16: + counts = unpack('<QQ', extra[4:20]) + elif ln == 8: + counts = unpack('<Q', extra[4:12]) + elif ln == 0: + counts = () + else: + raise RuntimeError, "Corrupt extra field %s"%(ln,) + + idx = 0 + + # ZIP64 extension (large files and/or large archives) + if self.file_size in (0xffffffffffffffffL, 0xffffffffL): + self.file_size = counts[idx] + idx += 1 + + if self.compress_size == 0xFFFFFFFFL: + self.compress_size = counts[idx] + idx += 1 + + if self.header_offset == 0xffffffffL: + old = self.header_offset + self.header_offset = counts[idx] + idx+=1 + + extra = extra[ln+4:] + + +class _ZipDecrypter: + """Class to handle decryption of files stored within a ZIP archive. + + ZIP supports a password-based form of encryption. Even though known + plaintext attacks have been found against it, it is still useful + to be able to get data out of such a file. + + Usage: + zd = _ZipDecrypter(mypwd) + plain_char = zd(cypher_char) + plain_text = map(zd, cypher_text) + """ + + def _GenerateCRCTable(): + """Generate a CRC-32 table. + + ZIP encryption uses the CRC32 one-byte primitive for scrambling some + internal keys. We noticed that a direct implementation is faster than + relying on binascii.crc32(). + """ + poly = 0xedb88320 + table = [0] * 256 + for i in range(256): + crc = i + for j in range(8): + if crc & 1: + crc = ((crc >> 1) & 0x7FFFFFFF) ^ poly + else: + crc = ((crc >> 1) & 0x7FFFFFFF) + table[i] = crc + return table + crctable = _GenerateCRCTable() + + def _crc32(self, ch, crc): + """Compute the CRC32 primitive on one byte.""" + return ((crc >> 8) & 0xffffff) ^ self.crctable[(crc ^ ord(ch)) & 0xff] + + def __init__(self, pwd): + self.key0 = 305419896 + self.key1 = 591751049 + self.key2 = 878082192 + for p in pwd: + self._UpdateKeys(p) + + def _UpdateKeys(self, c): + self.key0 = self._crc32(c, self.key0) + self.key1 = (self.key1 + (self.key0 & 255)) & 4294967295 + self.key1 = (self.key1 * 134775813 + 1) & 4294967295 + self.key2 = self._crc32(chr((self.key1 >> 24) & 255), self.key2) + + def __call__(self, c): + """Decrypt a single character.""" + c = ord(c) + k = self.key2 | 2 + c = c ^ (((k * (k^1)) >> 8) & 255) + c = chr(c) + self._UpdateKeys(c) + return c + +class ZipExtFile(io.BufferedIOBase): + """File-like object for reading an archive member. + Is returned by ZipFile.open(). + """ + + # Max size supported by decompressor. + MAX_N = 1 << 31 - 1 + + # Read from compressed files in 4k blocks. + MIN_READ_SIZE = 4096 + + # Search for universal newlines or line chunks. + PATTERN = re.compile(r'^(?P<chunk>[^\r\n]+)|(?P<newline>\n|\r\n?)') + + def __init__(self, fileobj, mode, zipinfo, decrypter=None): + self._fileobj = fileobj + self._decrypter = decrypter + + self._compress_type = zipinfo.compress_type + self._compress_size = zipinfo.compress_size + self._compress_left = zipinfo.compress_size + + if self._compress_type == ZIP_DEFLATED: + self._decompressor = zlib.decompressobj(-15) + elif self._compress_type == ZIP_BZIP2: + self._decompressor = bz2.BZ2Decompressor() + self.MIN_READ_SIZE = 900000 + self._unconsumed = '' + + self._readbuffer = '' + self._offset = 0 + + self._universal = 'U' in mode + self.newlines = None + + # Adjust read size for encrypted files since the first 12 bytes + # are for the encryption/password information. + if self._decrypter is not None: + self._compress_left -= 12 + + self.mode = mode + self.name = zipinfo.filename + + def readline(self, limit=-1): + """Read and return a line from the stream. + + If limit is specified, at most limit bytes will be read. + """ + + if not self._universal and limit < 0: + # Shortcut common case - newline found in buffer. + i = self._readbuffer.find('\n', self._offset) + 1 + if i > 0: + line = self._readbuffer[self._offset: i] + self._offset = i + return line + + if not self._universal: + return io.BufferedIOBase.readline(self, limit) + + line = '' + while limit < 0 or len(line) < limit: + readahead = self.peek(2) + if readahead == '': + return line + + # + # Search for universal newlines or line chunks. + # + # The pattern returns either a line chunk or a newline, but not + # both. Combined with peek(2), we are assured that the sequence + # '\r\n' is always retrieved completely and never split into + # separate newlines - '\r', '\n' due to coincidental readaheads. + # + match = self.PATTERN.search(readahead) + newline = match.group('newline') + if newline is not None: + if self.newlines is None: + self.newlines = [] + if newline not in self.newlines: + self.newlines.append(newline) + self._offset += len(newline) + return line + '\n' + + chunk = match.group('chunk') + if limit >= 0: + chunk = chunk[: limit - len(line)] + + self._offset += len(chunk) + line += chunk + + return line + + def peek(self, n=1): + """Returns buffered bytes without advancing the position.""" + if n > len(self._readbuffer) - self._offset: + chunk = self.read(n) + self._offset -= len(chunk) + + # Return up to 512 bytes to reduce allocation overhead for tight loops. + return self._readbuffer[self._offset: self._offset + 512] + + def readable(self): + return True + + def read(self, n=-1): + """Read and return up to n bytes. + If the argument is omitted, None, or negative, data is read and returned until EOF is reached.. + """ + + buf = '' + while n < 0 or n is None or n > len(buf): + data = self.read1(n) + if len(data) == 0: + return buf + + buf += data + + return buf + + def read1(self, n): + """Read up to n bytes with at most one read() system call.""" + + # Simplify algorithm (branching) by transforming negative n to large n. + if n < 0 or n is None: + n = self.MAX_N + + # Bytes available in read buffer. + len_readbuffer = len(self._readbuffer) - self._offset + + # Read from file. + if self._compress_left > 0 and n > len_readbuffer + len(self._unconsumed): + nbytes = n - len_readbuffer - len(self._unconsumed) + nbytes = max(nbytes, self.MIN_READ_SIZE) + nbytes = min(nbytes, self._compress_left) + + data = self._fileobj.read(nbytes) + self._compress_left -= len(data) + + if data and self._decrypter is not None: + data = ''.join(map(self._decrypter, data)) + + if self._compress_type == ZIP_STORED: + self._readbuffer = self._readbuffer[self._offset:] + data + self._offset = 0 + else: + # Prepare deflated bytes for decompression. + self._unconsumed += data + + # Handle unconsumed data. + if (len(self._unconsumed) > 0 and n > len_readbuffer and + self._compress_type == ZIP_DEFLATED): + data = self._decompressor.decompress( + self._unconsumed, + max(n - len_readbuffer, self.MIN_READ_SIZE) + ) + + self._unconsumed = self._decompressor.unconsumed_tail + if len(self._unconsumed) == 0 and self._compress_left == 0: + data += self._decompressor.flush() + + self._readbuffer = self._readbuffer[self._offset:] + data + self._offset = 0 + elif (len(self._unconsumed) > 0 and n > len_readbuffer and + self._compress_type == ZIP_BZIP2): + data = self._decompressor.decompress(self._unconsumed) + + self._unconsumed = '' + self._readbuffer = self._readbuffer[self._offset:] + data + self._offset = 0 + + # Read from buffer. + data = self._readbuffer[self._offset: self._offset + n] + self._offset += len(data) + return data + + + +class ZipFile: + """ Class with methods to open, read, write, close, list zip files. + + z = ZipFile(file, mode="r", compression=ZIP_STORED, allowZip64=False) + + file: Either the path to the file, or a file-like object. + If it is a path, the file will be opened and closed by ZipFile. + mode: The mode can be either read "r", write "w" or append "a". + compression: ZIP_STORED (no compression), ZIP_DEFLATED (requires zlib), + or ZIP_BZIP2 (requires bz2). + allowZip64: if True ZipFile will create files with ZIP64 extensions when + needed, otherwise it will raise an exception when this would + be necessary. + + """ + + fp = None # Set here since __del__ checks it + + def __init__(self, file, mode="r", compression=ZIP_STORED, allowZip64=False): + """Open the ZIP file with mode read "r", write "w" or append "a".""" + if mode not in ("r", "w", "a"): + raise RuntimeError('ZipFile() requires mode "r", "w", or "a"') + + if compression == ZIP_STORED: + pass + elif compression == ZIP_DEFLATED: + if not zlib: + raise RuntimeError,\ + "Compression requires the (missing) zlib module" + elif compression == ZIP_BZIP2: + if not bz2: + raise RuntimeError,\ + "Compression requires the (missing) bz2 module" + else: + raise RuntimeError, "That compression method is not supported" + + self._allowZip64 = allowZip64 + self._didModify = False + self.debug = 0 # Level of printing: 0 through 3 + self.NameToInfo = {} # Find file info given name + self.filelist = [] # List of ZipInfo instances for archive + self.compression = compression # Method of compression + self.mode = key = mode.replace('b', '')[0] + self.pwd = None + self.comment = '' + + # Check if we were passed a file-like object + if isinstance(file, basestring): + self._filePassed = 0 + self.filename = file + modeDict = {'r' : 'rb', 'w': 'wb', 'a' : 'r+b'} + try: + self.fp = open(file, modeDict[mode]) + except IOError: + if mode == 'a': + mode = key = 'w' + self.fp = open(file, modeDict[mode]) + else: + raise + else: + self._filePassed = 1 + self.fp = file + self.filename = getattr(file, 'name', None) + + if key == 'r': + self._GetContents() + elif key == 'w': + pass + elif key == 'a': + try: # See if file is a zip file + self._RealGetContents() + # seek to start of directory and overwrite + self.fp.seek(self.start_dir, 0) + except BadZipfile: # file is not a zip file, just append + self.fp.seek(0, 2) + else: + if not self._filePassed: + self.fp.close() + self.fp = None + raise RuntimeError, 'Mode must be "r", "w" or "a"' + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() + + def _GetContents(self): + """Read the directory, making sure we close the file if the format + is bad.""" + try: + self._RealGetContents() + except BadZipfile: + if not self._filePassed: + self.fp.close() + self.fp = None + raise + + def _RealGetContents(self): + """Read in the table of contents for the ZIP file.""" + fp = self.fp + endrec = _EndRecData(fp) + if not endrec: + raise BadZipfile, "File is not a zip file" + if self.debug > 1: + print endrec + size_cd = endrec[_ECD_SIZE] # bytes in central directory + offset_cd = endrec[_ECD_OFFSET] # offset of central directory + self.comment = endrec[_ECD_COMMENT] # archive comment + + # "concat" is zero, unless zip was concatenated to another file + concat = endrec[_ECD_LOCATION] - size_cd - offset_cd + if endrec[_ECD_SIGNATURE] == stringEndArchive64: + # If Zip64 extension structures are present, account for them + concat -= (sizeEndCentDir64 + sizeEndCentDir64Locator) + + if self.debug > 2: + inferred = concat + offset_cd + print "given, inferred, offset", offset_cd, inferred, concat + # self.start_dir: Position of start of central directory + self.start_dir = offset_cd + concat + fp.seek(self.start_dir, 0) + data = fp.read(size_cd) + fp = cStringIO.StringIO(data) + total = 0 + while total < size_cd: + centdir = fp.read(sizeCentralDir) + if centdir[0:4] != stringCentralDir: + raise BadZipfile, "Bad magic number for central directory" + centdir = struct.unpack(structCentralDir, centdir) + if self.debug > 2: + print centdir + filename = fp.read(centdir[_CD_FILENAME_LENGTH]) + # Create ZipInfo instance to store file information + x = ZipInfo(filename) + x.extra = fp.read(centdir[_CD_EXTRA_FIELD_LENGTH]) + x.comment = fp.read(centdir[_CD_COMMENT_LENGTH]) + x.header_offset = centdir[_CD_LOCAL_HEADER_OFFSET] + (x.create_version, x.create_system, x.extract_version, x.reserved, + x.flag_bits, x.compress_type, t, d, + x.CRC, x.compress_size, x.file_size) = centdir[1:12] + x.volume, x.internal_attr, x.external_attr = centdir[15:18] + # Convert date/time code to (year, month, day, hour, min, sec) + x._raw_time = t + x.date_time = ( (d>>9)+1980, (d>>5)&0xF, d&0x1F, + t>>11, (t>>5)&0x3F, (t&0x1F) * 2 ) + + x._decodeExtra() + x.header_offset = x.header_offset + concat + x.filename = x._decodeFilename() + self.filelist.append(x) + self.NameToInfo[x.filename] = x + + # update total bytes read from central directory + total = (total + sizeCentralDir + centdir[_CD_FILENAME_LENGTH] + + centdir[_CD_EXTRA_FIELD_LENGTH] + + centdir[_CD_COMMENT_LENGTH]) + + if self.debug > 2: + print "total", total + + + def namelist(self): + """Return a list of file names in the archive.""" + l = [] + for data in self.filelist: + l.append(data.filename) + return l + + def infolist(self): + """Return a list of class ZipInfo instances for files in the + archive.""" + return self.filelist + + def printdir(self): + """Print a table of contents for the zip file.""" + print "%-46s %19s %12s" % ("File Name", "Modified ", "Size") + for zinfo in self.filelist: + date = "%d-%02d-%02d %02d:%02d:%02d" % zinfo.date_time[:6] + print "%-46s %s %12d" % (zinfo.filename, date, zinfo.file_size) + + def testzip(self): + """Read all the files and check the CRC.""" + chunk_size = 2 ** 20 + for zinfo in self.filelist: + try: + # Read by chunks, to avoid an OverflowError or a + # MemoryError with very large embedded files. + f = self.open(zinfo.filename, "r") + while f.read(chunk_size): # Check CRC-32 + pass + except BadZipfile: + return zinfo.filename + + def getinfo(self, name): + """Return the instance of ZipInfo given 'name'.""" + info = self.NameToInfo.get(name) + if info is None: + raise KeyError( + 'There is no item named %r in the archive' % name) + + return info + + def setpassword(self, pwd): + """Set default password for encrypted files.""" + self.pwd = pwd + + def read(self, name, pwd=None): + """Return file bytes (as a string) for name.""" + return self.open(name, "r", pwd).read() + + def open(self, name, mode="r", pwd=None): + """Return file-like object for 'name'.""" + if mode not in ("r", "U", "rU"): + raise RuntimeError, 'open() requires mode "r", "U", or "rU"' + if not self.fp: + raise RuntimeError, \ + "Attempt to read ZIP archive that was already closed" + + # Only open a new file for instances where we were not + # given a file object in the constructor + if self._filePassed: + zef_file = self.fp + else: + zef_file = open(self.filename, 'rb') + + # Make sure we have an info object + if isinstance(name, ZipInfo): + # 'name' is already an info object + zinfo = name + else: + # Get info object for name + zinfo = self.getinfo(name) + + zef_file.seek(zinfo.header_offset, 0) + + # Skip the file header: + fheader = zef_file.read(sizeFileHeader) + if fheader[0:4] != stringFileHeader: + raise BadZipfile, "Bad magic number for file header" + + fheader = struct.unpack(structFileHeader, fheader) + fname = zef_file.read(fheader[_FH_FILENAME_LENGTH]) + if fheader[_FH_EXTRA_FIELD_LENGTH]: + zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH]) + + if fname != zinfo.orig_filename: + raise BadZipfile, \ + 'File name in directory "%s" and header "%s" differ.' % ( + zinfo.orig_filename, fname) + + # check for encrypted flag & handle password + is_encrypted = zinfo.flag_bits & 0x1 + zd = None + if is_encrypted: + if not pwd: + pwd = self.pwd + if not pwd: + raise RuntimeError, "File %s is encrypted, " \ + "password required for extraction" % name + + zd = _ZipDecrypter(pwd) + # The first 12 bytes in the cypher stream is an encryption header + # used to strengthen the algorithm. The first 11 bytes are + # completely random, while the 12th contains the MSB of the CRC, + # or the MSB of the file time depending on the header type + # and is used to check the correctness of the password. + bytes = zef_file.read(12) + h = map(zd, bytes[0:12]) + if zinfo.flag_bits & 0x8: + # compare against the file type from extended local headers + check_byte = (zinfo._raw_time >> 8) & 0xff + else: + # compare against the CRC otherwise + check_byte = (zinfo.CRC >> 24) & 0xff + if ord(h[11]) != check_byte: + raise RuntimeError("Bad password for file", name) + + return ZipExtFile(zef_file, mode, zinfo, zd) + + def extract(self, member, path=None, pwd=None): + """Extract a member from the archive to the current working directory, + using its full name. Its file information is extracted as accurately + as possible. `member' may be a filename or a ZipInfo object. You can + specify a different directory using `path'. + """ + if not isinstance(member, ZipInfo): + member = self.getinfo(member) + + if path is None: + path = os.getcwd() + + return self._extract_member(member, path, pwd) + + def extractall(self, path=None, members=None, pwd=None): + """Extract all members from the archive to the current working + directory. `path' specifies a different directory to extract to. + `members' is optional and must be a subset of the list returned + by namelist(). + """ + if members is None: + members = self.namelist() + + for zipinfo in members: + self.extract(zipinfo, path, pwd) + + def _extract_member(self, member, targetpath, pwd): + """Extract the ZipInfo object 'member' to a physical + file on the path targetpath. + """ + # build the destination pathname, replacing + # forward slashes to platform specific separators. + # Strip trailing path separator, unless it represents the root. + if (targetpath[-1:] in (os.path.sep, os.path.altsep) + and len(os.path.splitdrive(targetpath)[1]) > 1): + targetpath = targetpath[:-1] + + # don't include leading "/" from file name if present + if member.filename[0] == '/': + targetpath = os.path.join(targetpath, member.filename[1:]) + else: + targetpath = os.path.join(targetpath, member.filename) + + targetpath = os.path.normpath(targetpath) + + # Create all upper directories if necessary. + upperdirs = os.path.dirname(targetpath) + if upperdirs and not os.path.exists(upperdirs): + os.makedirs(upperdirs) + + if member.filename[-1] == '/': + if not os.path.isdir(targetpath): + os.mkdir(targetpath) + return targetpath + + source = self.open(member, pwd=pwd) + target = file(targetpath, "wb") + shutil.copyfileobj(source, target) + source.close() + target.close() + + return targetpath + + def _writecheck(self, zinfo): + """Check for errors before writing a file to the archive.""" + if zinfo.filename in self.NameToInfo: + if self.debug: # Warning for duplicate names + print "Duplicate name:", zinfo.filename + if self.mode not in ("w", "a"): + raise RuntimeError, 'write() requires mode "w" or "a"' + if not self.fp: + raise RuntimeError, \ + "Attempt to write ZIP archive that was already closed" + if zinfo.compress_type == ZIP_DEFLATED and not zlib: + raise RuntimeError, \ + "Compression requires the (missing) zlib module" + if zinfo.compress_type == ZIP_BZIP2 and not bz2: + raise RuntimeError, \ + "Compression requires the (missing) bz2 module" + if zinfo.compress_type not in (ZIP_STORED, ZIP_DEFLATED, ZIP_BZIP2): + raise RuntimeError, \ + "That compression method is not supported" + if zinfo.file_size > ZIP64_LIMIT: + if not self._allowZip64: + raise LargeZipFile("Filesize would require ZIP64 extensions") + if zinfo.header_offset > ZIP64_LIMIT: + if not self._allowZip64: + raise LargeZipFile("Zipfile size would require ZIP64 extensions") + + def write(self, filename, arcname=None, compress_type=None): + """Put the bytes from filename into the archive under the name + arcname.""" + if not self.fp: + raise RuntimeError( + "Attempt to write to ZIP archive that was already closed") + + st = os.stat(filename) + isdir = stat.S_ISDIR(st.st_mode) + mtime = time.localtime(st.st_mtime) + date_time = mtime[0:6] + # Create ZipInfo instance to store file information + if arcname is None: + arcname = filename + arcname = os.path.normpath(os.path.splitdrive(arcname)[1]) + while arcname[0] in (os.sep, os.altsep): + arcname = arcname[1:] + if isdir: + arcname += '/' + zinfo = ZipInfo(arcname, date_time) + zinfo.external_attr = (st[0] & 0xFFFF) << 16L # Unix attributes + if compress_type is None: + zinfo.compress_type = self.compression + else: + zinfo.compress_type = compress_type + + zinfo.file_size = st.st_size + zinfo.flag_bits = 0x00 + zinfo.header_offset = self.fp.tell() # Start of header bytes + + self._writecheck(zinfo) + self._didModify = True + + if isdir: + zinfo.file_size = 0 + zinfo.compress_size = 0 + zinfo.CRC = 0 + self.filelist.append(zinfo) + self.NameToInfo[zinfo.filename] = zinfo + self.fp.write(zinfo.FileHeader()) + return + + with open(filename, "rb") as fp: + # Must overwrite CRC and sizes with correct data later + zinfo.CRC = CRC = 0 + zinfo.compress_size = compress_size = 0 + zinfo.file_size = file_size = 0 + self.fp.write(zinfo.FileHeader()) + if zinfo.compress_type == ZIP_DEFLATED: + cmpr = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, + zlib.DEFLATED, -15) + elif zinfo.compress_type == ZIP_BZIP2: + cmpr = bz2.BZ2Compressor() + else: + cmpr = None + while 1: + buf = fp.read(1024 * 8) + if not buf: + break + file_size = file_size + len(buf) + CRC = crc32(buf, CRC) & 0xffffffff + if cmpr: + buf = cmpr.compress(buf) + compress_size = compress_size + len(buf) + self.fp.write(buf) + if cmpr: + buf = cmpr.flush() + compress_size = compress_size + len(buf) + self.fp.write(buf) + zinfo.compress_size = compress_size + else: + zinfo.compress_size = file_size + zinfo.CRC = CRC + zinfo.file_size = file_size + # Seek backwards and write CRC and file sizes + position = self.fp.tell() # Preserve current position in file + self.fp.seek(zinfo.header_offset + 14, 0) + self.fp.write(struct.pack("<LLL", zinfo.CRC, zinfo.compress_size, + zinfo.file_size)) + self.fp.seek(position, 0) + self.filelist.append(zinfo) + self.NameToInfo[zinfo.filename] = zinfo + + def writestr(self, zinfo_or_arcname, bytes, compress_type=None): + """Write a file into the archive. The contents is the string + 'bytes'. 'zinfo_or_arcname' is either a ZipInfo instance or + the name of the file in the archive.""" + if not isinstance(zinfo_or_arcname, ZipInfo): + zinfo = ZipInfo(filename=zinfo_or_arcname, + date_time=time.localtime(time.time())[:6]) + + zinfo.compress_type = self.compression + zinfo.external_attr = 0600 << 16 + else: + zinfo = zinfo_or_arcname + + if not self.fp: + raise RuntimeError( + "Attempt to write to ZIP archive that was already closed") + + if compress_type is not None: + zinfo.compress_type = compress_type + + zinfo.file_size = len(bytes) # Uncompressed size + zinfo.header_offset = self.fp.tell() # Start of header bytes + self._writecheck(zinfo) + self._didModify = True + zinfo.CRC = crc32(bytes) & 0xffffffff # CRC-32 checksum + if zinfo.compress_type == ZIP_DEFLATED: + co = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, + zlib.DEFLATED, -15) + bytes = co.compress(bytes) + co.flush() + zinfo.compress_size = len(bytes) # Compressed size + elif zinfo.compress_type == ZIP_BZIP2: + co = bz2.BZ2Compressor() + bytes = co.compress(bytes) + co.flush() + zinfo.compress_size = len(bytes) # Compressed size + else: + zinfo.compress_size = zinfo.file_size + zinfo.header_offset = self.fp.tell() # Start of header bytes + self.fp.write(zinfo.FileHeader()) + self.fp.write(bytes) + self.fp.flush() + if zinfo.flag_bits & 0x08: + # Write CRC and file sizes after the file data + self.fp.write(struct.pack("<LLL", zinfo.CRC, zinfo.compress_size, + zinfo.file_size)) + self.filelist.append(zinfo) + self.NameToInfo[zinfo.filename] = zinfo + + def __del__(self): + """Call the "close()" method in case the user forgot.""" + self.close() + + def close(self): + """Close the file, and for mode "w" and "a" write the ending + records.""" + if self.fp is None: + return + + if self.mode in ("w", "a") and self._didModify: # write ending records + count = 0 + pos1 = self.fp.tell() + for zinfo in self.filelist: # write central directory + count = count + 1 + dt = zinfo.date_time + dosdate = (dt[0] - 1980) << 9 | dt[1] << 5 | dt[2] + dostime = dt[3] << 11 | dt[4] << 5 | (dt[5] // 2) + extra = [] + if zinfo.file_size > ZIP64_LIMIT \ + or zinfo.compress_size > ZIP64_LIMIT: + extra.append(zinfo.file_size) + extra.append(zinfo.compress_size) + file_size = 0xffffffff + compress_size = 0xffffffff + else: + file_size = zinfo.file_size + compress_size = zinfo.compress_size + + if zinfo.header_offset > ZIP64_LIMIT: + extra.append(zinfo.header_offset) + header_offset = 0xffffffffL + else: + header_offset = zinfo.header_offset + + extra_data = zinfo.extra + if extra: + # Append a ZIP64 field to the extra's + extra_data = struct.pack( + '<HH' + 'Q'*len(extra), + 1, 8*len(extra), *extra) + extra_data + + extract_version = max(45, zinfo.extract_version) + create_version = max(45, zinfo.create_version) + else: + extract_version = zinfo.extract_version + create_version = zinfo.create_version + + try: + filename, flag_bits = zinfo._encodeFilenameFlags() + centdir = struct.pack(structCentralDir, + stringCentralDir, create_version, + zinfo.create_system, extract_version, zinfo.reserved, + flag_bits, zinfo.compress_type, dostime, dosdate, + zinfo.CRC, compress_size, file_size, + len(filename), len(extra_data), len(zinfo.comment), + 0, zinfo.internal_attr, zinfo.external_attr, + header_offset) + except DeprecationWarning: + print >>sys.stderr, (structCentralDir, + stringCentralDir, create_version, + zinfo.create_system, extract_version, zinfo.reserved, + zinfo.flag_bits, zinfo.compress_type, dostime, dosdate, + zinfo.CRC, compress_size, file_size, + len(zinfo.filename), len(extra_data), len(zinfo.comment), + 0, zinfo.internal_attr, zinfo.external_attr, + header_offset) + raise + self.fp.write(centdir) + self.fp.write(filename) + self.fp.write(extra_data) + self.fp.write(zinfo.comment) + + pos2 = self.fp.tell() + # Write end-of-zip-archive record + centDirCount = count + centDirSize = pos2 - pos1 + centDirOffset = pos1 + if (centDirCount >= ZIP_FILECOUNT_LIMIT or + centDirOffset > ZIP64_LIMIT or + centDirSize > ZIP64_LIMIT): + # Need to write the ZIP64 end-of-archive records + zip64endrec = struct.pack( + structEndArchive64, stringEndArchive64, + 44, 45, 45, 0, 0, centDirCount, centDirCount, + centDirSize, centDirOffset) + self.fp.write(zip64endrec) + + zip64locrec = struct.pack( + structEndArchive64Locator, + stringEndArchive64Locator, 0, pos2, 1) + self.fp.write(zip64locrec) + centDirCount = min(centDirCount, 0xFFFF) + centDirSize = min(centDirSize, 0xFFFFFFFF) + centDirOffset = min(centDirOffset, 0xFFFFFFFF) + + # check for valid comment length + if len(self.comment) >= ZIP_MAX_COMMENT: + if self.debug > 0: + msg = 'Archive comment is too long; truncating to %d bytes' \ + % ZIP_MAX_COMMENT + self.comment = self.comment[:ZIP_MAX_COMMENT] + + endrec = struct.pack(structEndArchive, stringEndArchive, + 0, 0, centDirCount, centDirCount, + centDirSize, centDirOffset, len(self.comment)) + self.fp.write(endrec) + self.fp.write(self.comment) + self.fp.flush() + + if not self._filePassed: + self.fp.close() + self.fp = None + + +class PyZipFile(ZipFile): + """Class to create ZIP archives with Python library files and packages.""" + + def writepy(self, pathname, basename = ""): + """Add all files from "pathname" to the ZIP archive. + + If pathname is a package directory, search the directory and + all package subdirectories recursively for all *.py and enter + the modules into the archive. If pathname is a plain + directory, listdir *.py and enter all modules. Else, pathname + must be a Python *.py file and the module will be put into the + archive. Added modules are always module.pyo or module.pyc. + This method will compile the module.py into module.pyc if + necessary. + """ + dir, name = os.path.split(pathname) + if os.path.isdir(pathname): + initname = os.path.join(pathname, "__init__.py") + if os.path.isfile(initname): + # This is a package directory, add it + if basename: + basename = "%s/%s" % (basename, name) + else: + basename = name + if self.debug: + print "Adding package in", pathname, "as", basename + fname, arcname = self._get_codename(initname[0:-3], basename) + if self.debug: + print "Adding", arcname + self.write(fname, arcname) + dirlist = os.listdir(pathname) + dirlist.remove("__init__.py") + # Add all *.py files and package subdirectories + for filename in dirlist: + path = os.path.join(pathname, filename) + root, ext = os.path.splitext(filename) + if os.path.isdir(path): + if os.path.isfile(os.path.join(path, "__init__.py")): + # This is a package directory, add it + self.writepy(path, basename) # Recursive call + elif ext == ".py": + fname, arcname = self._get_codename(path[0:-3], + basename) + if self.debug: + print "Adding", arcname + self.write(fname, arcname) + else: + # This is NOT a package directory, add its files at top level + if self.debug: + print "Adding files from directory", pathname + for filename in os.listdir(pathname): + path = os.path.join(pathname, filename) + root, ext = os.path.splitext(filename) + if ext == ".py": + fname, arcname = self._get_codename(path[0:-3], + basename) + if self.debug: + print "Adding", arcname + self.write(fname, arcname) + else: + if pathname[-3:] != ".py": + raise RuntimeError, \ + 'Files added with writepy() must end with ".py"' + fname, arcname = self._get_codename(pathname[0:-3], basename) + if self.debug: + print "Adding file", arcname + self.write(fname, arcname) + + def _get_codename(self, pathname, basename): + """Return (filename, archivename) for the path. + + Given a module name path, return the correct file path and + archive name, compiling if necessary. For example, given + /python/lib/string, return (/python/lib/string.pyc, string). + """ + file_py = pathname + ".py" + file_pyc = pathname + ".pyc" + file_pyo = pathname + ".pyo" + if os.path.isfile(file_pyo) and \ + os.stat(file_pyo).st_mtime >= os.stat(file_py).st_mtime: + fname = file_pyo # Use .pyo file + elif not os.path.isfile(file_pyc) or \ + os.stat(file_pyc).st_mtime < os.stat(file_py).st_mtime: + import py_compile + if self.debug: + print "Compiling", file_py + try: + py_compile.compile(file_py, file_pyc, None, True) + except py_compile.PyCompileError,err: + print err.msg + fname = file_pyc + else: + fname = file_pyc + archivename = os.path.split(fname)[1] + if basename: + archivename = "%s/%s" % (basename, archivename) + return (fname, archivename) + + +def main(args = None): + import textwrap + USAGE=textwrap.dedent("""\ + Usage: + zipfile.py -l zipfile.zip # Show listing of a zipfile + zipfile.py -t zipfile.zip # Test if a zipfile is valid + zipfile.py -e zipfile.zip target # Extract zipfile into target dir + zipfile.py -c zipfile.zip src ... # Create zipfile from sources + """) + if args is None: + args = sys.argv[1:] + + if not args or args[0] not in ('-l', '-c', '-e', '-t'): + print USAGE + sys.exit(1) + + if args[0] == '-l': + if len(args) != 2: + print USAGE + sys.exit(1) + zf = ZipFile(args[1], 'r') + zf.printdir() + zf.close() + + elif args[0] == '-t': + if len(args) != 2: + print USAGE + sys.exit(1) + zf = ZipFile(args[1], 'r') + zf.testzip() + print "Done testing" + + elif args[0] == '-e': + if len(args) != 3: + print USAGE + sys.exit(1) + + zf = ZipFile(args[1], 'r') + out = args[2] + for path in zf.namelist(): + if path.startswith('./'): + tgt = os.path.join(out, path[2:]) + else: + tgt = os.path.join(out, path) + + tgtdir = os.path.dirname(tgt) + if not os.path.exists(tgtdir): + os.makedirs(tgtdir) + with open(tgt, 'wb') as fp: + fp.write(zf.read(path)) + zf.close() + + elif args[0] == '-c': + if len(args) < 3: + print USAGE + sys.exit(1) + + def addToZip(zf, path, zippath): + if os.path.isfile(path): + zf.write(path, zippath, ZIP_DEFLATED) + elif os.path.isdir(path): + for nm in os.listdir(path): + addToZip(zf, + os.path.join(path, nm), os.path.join(zippath, nm)) + # else: ignore + + zf = ZipFile(args[1], 'w', allowZip64=True) + for src in args[2:]: + addToZip(zf, src, os.path.basename(src)) + + zf.close() + +if __name__ == "__main__": + main()
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/2.00/zipfile3.py Fri Aug 06 15:39:29 2010 +0000 @@ -0,0 +1,1454 @@ +""" +Read and write ZIP files. + +XXX references to utf-8 need further investigation. +""" +# Improved by Chortos-2 in 2010 (added bzip2 support) +import struct, os, time, sys, shutil +import binascii, io, stat + +try: + import zlib # We may need its compression method + crc32 = zlib.crc32 +except ImportError: + zlib = None + crc32 = binascii.crc32 + +try: + import bz2 # We may need its compression method +except ImportError: + bz2 = None + +__all__ = ["BadZipfile", "error", "ZIP_STORED", "ZIP_DEFLATED", "is_zipfile", + "ZipInfo", "ZipFile", "PyZipFile", "LargeZipFile", "ZIP_BZIP2" ] + +class BadZipfile(Exception): + pass + + +class LargeZipFile(Exception): + """ + Raised when writing a zipfile, the zipfile requires ZIP64 extensions + and those extensions are disabled. + """ + +error = BadZipfile # The exception raised by this module + +ZIP64_LIMIT = (1 << 31) - 1 +ZIP_FILECOUNT_LIMIT = 1 << 16 +ZIP_MAX_COMMENT = (1 << 16) - 1 + +# constants for Zip file compression methods +ZIP_STORED = 0 +ZIP_DEFLATED = 8 +ZIP_BZIP2 = 12 +# Other ZIP compression methods not supported + +# Below are some formats and associated data for reading/writing headers using +# the struct module. The names and structures of headers/records are those used +# in the PKWARE description of the ZIP file format: +# http://www.pkware.com/documents/casestudies/APPNOTE.TXT +# (URL valid as of January 2008) + +# The "end of central directory" structure, magic number, size, and indices +# (section V.I in the format document) +structEndArchive = b"<4s4H2LH" +stringEndArchive = b"PK\005\006" +sizeEndCentDir = struct.calcsize(structEndArchive) + +_ECD_SIGNATURE = 0 +_ECD_DISK_NUMBER = 1 +_ECD_DISK_START = 2 +_ECD_ENTRIES_THIS_DISK = 3 +_ECD_ENTRIES_TOTAL = 4 +_ECD_SIZE = 5 +_ECD_OFFSET = 6 +_ECD_COMMENT_SIZE = 7 +# These last two indices are not part of the structure as defined in the +# spec, but they are used internally by this module as a convenience +_ECD_COMMENT = 8 +_ECD_LOCATION = 9 + +# The "central directory" structure, magic number, size, and indices +# of entries in the structure (section V.F in the format document) +structCentralDir = "<4s4B4HL2L5H2L" +stringCentralDir = b"PK\001\002" +sizeCentralDir = struct.calcsize(structCentralDir) + +# indexes of entries in the central directory structure +_CD_SIGNATURE = 0 +_CD_CREATE_VERSION = 1 +_CD_CREATE_SYSTEM = 2 +_CD_EXTRACT_VERSION = 3 +_CD_EXTRACT_SYSTEM = 4 +_CD_FLAG_BITS = 5 +_CD_COMPRESS_TYPE = 6 +_CD_TIME = 7 +_CD_DATE = 8 +_CD_CRC = 9 +_CD_COMPRESSED_SIZE = 10 +_CD_UNCOMPRESSED_SIZE = 11 +_CD_FILENAME_LENGTH = 12 +_CD_EXTRA_FIELD_LENGTH = 13 +_CD_COMMENT_LENGTH = 14 +_CD_DISK_NUMBER_START = 15 +_CD_INTERNAL_FILE_ATTRIBUTES = 16 +_CD_EXTERNAL_FILE_ATTRIBUTES = 17 +_CD_LOCAL_HEADER_OFFSET = 18 + +# The "local file header" structure, magic number, size, and indices +# (section V.A in the format document) +structFileHeader = "<4s2B4HL2L2H" +stringFileHeader = b"PK\003\004" +sizeFileHeader = struct.calcsize(structFileHeader) + +_FH_SIGNATURE = 0 +_FH_EXTRACT_VERSION = 1 +_FH_EXTRACT_SYSTEM = 2 +_FH_GENERAL_PURPOSE_FLAG_BITS = 3 +_FH_COMPRESSION_METHOD = 4 +_FH_LAST_MOD_TIME = 5 +_FH_LAST_MOD_DATE = 6 +_FH_CRC = 7 +_FH_COMPRESSED_SIZE = 8 +_FH_UNCOMPRESSED_SIZE = 9 +_FH_FILENAME_LENGTH = 10 +_FH_EXTRA_FIELD_LENGTH = 11 + +# The "Zip64 end of central directory locator" structure, magic number, and size +structEndArchive64Locator = "<4sLQL" +stringEndArchive64Locator = b"PK\x06\x07" +sizeEndCentDir64Locator = struct.calcsize(structEndArchive64Locator) + +# The "Zip64 end of central directory" record, magic number, size, and indices +# (section V.G in the format document) +structEndArchive64 = "<4sQ2H2L4Q" +stringEndArchive64 = b"PK\x06\x06" +sizeEndCentDir64 = struct.calcsize(structEndArchive64) + +_CD64_SIGNATURE = 0 +_CD64_DIRECTORY_RECSIZE = 1 +_CD64_CREATE_VERSION = 2 +_CD64_EXTRACT_VERSION = 3 +_CD64_DISK_NUMBER = 4 +_CD64_DISK_NUMBER_START = 5 +_CD64_NUMBER_ENTRIES_THIS_DISK = 6 +_CD64_NUMBER_ENTRIES_TOTAL = 7 +_CD64_DIRECTORY_SIZE = 8 +_CD64_OFFSET_START_CENTDIR = 9 + +def _check_zipfile(fp): + try: + if _EndRecData(fp): + return True # file has correct magic number + except IOError: + pass + return False + +def is_zipfile(filename): + """Quickly see if a file is a ZIP file by checking the magic number. + + The filename argument may be a file or file-like object too. + """ + result = False + try: + if hasattr(filename, "read"): + result = _check_zipfile(fp=filename) + else: + with open(filename, "rb") as fp: + result = _check_zipfile(fp) + except IOError: + pass + return result + +def _EndRecData64(fpin, offset, endrec): + """ + Read the ZIP64 end-of-archive records and use that to update endrec + """ + fpin.seek(offset - sizeEndCentDir64Locator, 2) + data = fpin.read(sizeEndCentDir64Locator) + sig, diskno, reloff, disks = struct.unpack(structEndArchive64Locator, data) + if sig != stringEndArchive64Locator: + return endrec + + if diskno != 0 or disks != 1: + raise BadZipfile("zipfiles that span multiple disks are not supported") + + # Assume no 'zip64 extensible data' + fpin.seek(offset - sizeEndCentDir64Locator - sizeEndCentDir64, 2) + data = fpin.read(sizeEndCentDir64) + sig, sz, create_version, read_version, disk_num, disk_dir, \ + dircount, dircount2, dirsize, diroffset = \ + struct.unpack(structEndArchive64, data) + if sig != stringEndArchive64: + return endrec + + # Update the original endrec using data from the ZIP64 record + endrec[_ECD_SIGNATURE] = sig + endrec[_ECD_DISK_NUMBER] = disk_num + endrec[_ECD_DISK_START] = disk_dir + endrec[_ECD_ENTRIES_THIS_DISK] = dircount + endrec[_ECD_ENTRIES_TOTAL] = dircount2 + endrec[_ECD_SIZE] = dirsize + endrec[_ECD_OFFSET] = diroffset + return endrec + + +def _EndRecData(fpin): + """Return data from the "End of Central Directory" record, or None. + + The data is a list of the nine items in the ZIP "End of central dir" + record followed by a tenth item, the file seek offset of this record.""" + + # Determine file size + fpin.seek(0, 2) + filesize = fpin.tell() + + # Check to see if this is ZIP file with no archive comment (the + # "end of central directory" structure should be the last item in the + # file if this is the case). + try: + fpin.seek(-sizeEndCentDir, 2) + except IOError: + return None + data = fpin.read() + if data[0:4] == stringEndArchive and data[-2:] == b"\000\000": + # the signature is correct and there's no comment, unpack structure + endrec = struct.unpack(structEndArchive, data) + endrec=list(endrec) + + # Append a blank comment and record start offset + endrec.append(b"") + endrec.append(filesize - sizeEndCentDir) + + # Try to read the "Zip64 end of central directory" structure + return _EndRecData64(fpin, -sizeEndCentDir, endrec) + + # Either this is not a ZIP file, or it is a ZIP file with an archive + # comment. Search the end of the file for the "end of central directory" + # record signature. The comment is the last item in the ZIP file and may be + # up to 64K long. It is assumed that the "end of central directory" magic + # number does not appear in the comment. + maxCommentStart = max(filesize - (1 << 16) - sizeEndCentDir, 0) + fpin.seek(maxCommentStart, 0) + data = fpin.read() + start = data.rfind(stringEndArchive) + if start >= 0: + # found the magic number; attempt to unpack and interpret + recData = data[start:start+sizeEndCentDir] + endrec = list(struct.unpack(structEndArchive, recData)) + comment = data[start+sizeEndCentDir:] + # check that comment length is correct + if endrec[_ECD_COMMENT_SIZE] == len(comment): + # Append the archive comment and start offset + endrec.append(comment) + endrec.append(maxCommentStart + start) + + # Try to read the "Zip64 end of central directory" structure + return _EndRecData64(fpin, maxCommentStart + start - filesize, + endrec) + + # Unable to find a valid end of central directory structure + return + + +class ZipInfo (object): + """Class with attributes describing each file in the ZIP archive.""" + + __slots__ = ( + 'orig_filename', + 'filename', + 'date_time', + 'compress_type', + 'comment', + 'extra', + 'create_system', + 'create_version', + 'extract_version', + 'reserved', + 'flag_bits', + 'volume', + 'internal_attr', + 'external_attr', + 'header_offset', + 'CRC', + 'compress_size', + 'file_size', + '_raw_time', + ) + + def __init__(self, filename="NoName", date_time=(1980,1,1,0,0,0)): + self.orig_filename = filename # Original file name in archive + + # Terminate the file name at the first null byte. Null bytes in file + # names are used as tricks by viruses in archives. + null_byte = filename.find(chr(0)) + if null_byte >= 0: + filename = filename[0:null_byte] + # This is used to ensure paths in generated ZIP files always use + # forward slashes as the directory separator, as required by the + # ZIP format specification. + if os.sep != "/" and os.sep in filename: + filename = filename.replace(os.sep, "/") + + self.filename = filename # Normalized file name + self.date_time = date_time # year, month, day, hour, min, sec + # Standard values: + self.compress_type = ZIP_STORED # Type of compression for the file + self.comment = b"" # Comment for each file + self.extra = b"" # ZIP extra data + if sys.platform == 'win32': + self.create_system = 0 # System which created ZIP archive + else: + # Assume everything else is unix-y + self.create_system = 3 # System which created ZIP archive + self.create_version = 20 # Version which created ZIP archive + self.extract_version = 20 # Version needed to extract archive + self.reserved = 0 # Must be zero + self.flag_bits = 0 # ZIP flag bits + self.volume = 0 # Volume number of file header + self.internal_attr = 0 # Internal attributes + self.external_attr = 0 # External file attributes + # Other attributes are set by class ZipFile: + # header_offset Byte offset to the file header + # CRC CRC-32 of the uncompressed file + # compress_size Size of the compressed file + # file_size Size of the uncompressed file + + def FileHeader(self): + """Return the per-file header as a string.""" + dt = self.date_time + dosdate = (dt[0] - 1980) << 9 | dt[1] << 5 | dt[2] + dostime = dt[3] << 11 | dt[4] << 5 | (dt[5] // 2) + if self.flag_bits & 0x08: + # Set these to zero because we write them after the file data + CRC = compress_size = file_size = 0 + else: + CRC = self.CRC + compress_size = self.compress_size + file_size = self.file_size + + extra = self.extra + + if file_size > ZIP64_LIMIT or compress_size > ZIP64_LIMIT: + # File is larger than what fits into a 4 byte integer, + # fall back to the ZIP64 extension + fmt = '<HHQQ' + extra = extra + struct.pack(fmt, + 1, struct.calcsize(fmt)-4, file_size, compress_size) + file_size = 0xffffffff + compress_size = 0xffffffff + self.extract_version = max(45, self.extract_version) + self.create_version = max(45, self.extract_version) + + filename, flag_bits = self._encodeFilenameFlags() + header = struct.pack(structFileHeader, stringFileHeader, + self.extract_version, self.reserved, flag_bits, + self.compress_type, dostime, dosdate, CRC, + compress_size, file_size, + len(filename), len(extra)) + return header + filename + extra + + def _encodeFilenameFlags(self): + try: + return self.filename.encode('ascii'), self.flag_bits + except UnicodeEncodeError: + return self.filename.encode('utf-8'), self.flag_bits | 0x800 + + def _decodeExtra(self): + # Try to decode the extra field. + extra = self.extra + unpack = struct.unpack + while extra: + tp, ln = unpack('<HH', extra[:4]) + if tp == 1: + if ln >= 24: + counts = unpack('<QQQ', extra[4:28]) + elif ln == 16: + counts = unpack('<QQ', extra[4:20]) + elif ln == 8: + counts = unpack('<Q', extra[4:12]) + elif ln == 0: + counts = () + else: + raise RuntimeError("Corrupt extra field %s"%(ln,)) + + idx = 0 + + # ZIP64 extension (large files and/or large archives) + if self.file_size in (0xffffffffffffffff, 0xffffffff): + self.file_size = counts[idx] + idx += 1 + + if self.compress_size == 0xFFFFFFFF: + self.compress_size = counts[idx] + idx += 1 + + if self.header_offset == 0xffffffff: + old = self.header_offset + self.header_offset = counts[idx] + idx+=1 + + extra = extra[ln+4:] + + +class _ZipDecrypter: + """Class to handle decryption of files stored within a ZIP archive. + + ZIP supports a password-based form of encryption. Even though known + plaintext attacks have been found against it, it is still useful + to be able to get data out of such a file. + + Usage: + zd = _ZipDecrypter(mypwd) + plain_char = zd(cypher_char) + plain_text = map(zd, cypher_text) + """ + + def _GenerateCRCTable(): + """Generate a CRC-32 table. + + ZIP encryption uses the CRC32 one-byte primitive for scrambling some + internal keys. We noticed that a direct implementation is faster than + relying on binascii.crc32(). + """ + poly = 0xedb88320 + table = [0] * 256 + for i in range(256): + crc = i + for j in range(8): + if crc & 1: + crc = ((crc >> 1) & 0x7FFFFFFF) ^ poly + else: + crc = ((crc >> 1) & 0x7FFFFFFF) + table[i] = crc + return table + crctable = _GenerateCRCTable() + + def _crc32(self, ch, crc): + """Compute the CRC32 primitive on one byte.""" + return ((crc >> 8) & 0xffffff) ^ self.crctable[(crc ^ ch) & 0xff] + + def __init__(self, pwd): + self.key0 = 305419896 + self.key1 = 591751049 + self.key2 = 878082192 + for p in pwd: + self._UpdateKeys(p) + + def _UpdateKeys(self, c): + self.key0 = self._crc32(c, self.key0) + self.key1 = (self.key1 + (self.key0 & 255)) & 4294967295 + self.key1 = (self.key1 * 134775813 + 1) & 4294967295 + self.key2 = self._crc32((self.key1 >> 24) & 255, self.key2) + + def __call__(self, c): + """Decrypt a single character.""" + assert isinstance(c, int) + k = self.key2 | 2 + c = c ^ (((k * (k^1)) >> 8) & 255) + self._UpdateKeys(c) + return c + +class ZipExtFile: + """File-like object for reading an archive member. + Is returned by ZipFile.open(). + """ + + def __init__(self, fileobj, zipinfo, decrypt=None): + self.fileobj = fileobj + self.decrypter = decrypt + self.bytes_read = 0 + self.rawbuffer = b'' + self.readbuffer = b'' + self.linebuffer = b'' + self.eof = False + self.univ_newlines = False + self.nlSeps = (b"\n", ) + self.lastdiscard = b'' + + self.compress_type = zipinfo.compress_type + self.compress_size = zipinfo.compress_size + + self.closed = False + self.mode = "r" + self.name = zipinfo.filename + + # read from compressed files in 64k blocks + self.compreadsize = 64*1024 + if self.compress_type == ZIP_DEFLATED: + self.dc = zlib.decompressobj(-15) + elif self.compress_type == ZIP_BZIP2: + self.dc = bz2.BZ2Decompressor() + self.compreadsize = 900000 + + def set_univ_newlines(self, univ_newlines): + self.univ_newlines = univ_newlines + + # pick line separator char(s) based on universal newlines flag + self.nlSeps = (b"\n", ) + if self.univ_newlines: + self.nlSeps = (b"\r\n", b"\r", b"\n") + + def __iter__(self): + return self + + def __next__(self): + nextline = self.readline() + if not nextline: + raise StopIteration() + + return nextline + + def close(self): + self.closed = True + + def _checkfornewline(self): + nl, nllen = -1, -1 + if self.linebuffer: + # ugly check for cases where half of an \r\n pair was + # read on the last pass, and the \r was discarded. In this + # case we just throw away the \n at the start of the buffer. + if (self.lastdiscard, self.linebuffer[:1]) == (b'\r', b'\n'): + self.linebuffer = self.linebuffer[1:] + + for sep in self.nlSeps: + nl = self.linebuffer.find(sep) + if nl >= 0: + nllen = len(sep) + return nl, nllen + + return nl, nllen + + def readline(self, size = -1): + """Read a line with approx. size. If size is negative, + read a whole line. + """ + if size < 0: + size = sys.maxsize + elif size == 0: + return b'' + + # check for a newline already in buffer + nl, nllen = self._checkfornewline() + + if nl >= 0: + # the next line was already in the buffer + nl = min(nl, size) + else: + # no line break in buffer - try to read more + size -= len(self.linebuffer) + while nl < 0 and size > 0: + buf = self.read(min(size, 100)) + if not buf: + break + self.linebuffer += buf + size -= len(buf) + + # check for a newline in buffer + nl, nllen = self._checkfornewline() + + # we either ran out of bytes in the file, or + # met the specified size limit without finding a newline, + # so return current buffer + if nl < 0: + s = self.linebuffer + self.linebuffer = b'' + return s + + buf = self.linebuffer[:nl] + self.lastdiscard = self.linebuffer[nl:nl + nllen] + self.linebuffer = self.linebuffer[nl + nllen:] + + # line is always returned with \n as newline char (except possibly + # for a final incomplete line in the file, which is handled above). + return buf + b"\n" + + def readlines(self, sizehint = -1): + """Return a list with all (following) lines. The sizehint parameter + is ignored in this implementation. + """ + result = [] + while True: + line = self.readline() + if not line: break + result.append(line) + return result + + def read(self, size = None): + # act like file obj and return empty string if size is 0 + if size == 0: + return b'' + + # determine read size + bytesToRead = self.compress_size - self.bytes_read + + # adjust read size for encrypted files since the first 12 bytes + # are for the encryption/password information + if self.decrypter is not None: + bytesToRead -= 12 + + if size is not None and size >= 0: + if self.compress_type == ZIP_STORED: + lr = len(self.readbuffer) + bytesToRead = min(bytesToRead, size - lr) + else: + if len(self.readbuffer) > size: + # the user has requested fewer bytes than we've already + # pulled through the decompressor; don't read any more + bytesToRead = 0 + else: + # user will use up the buffer, so read some more + lr = len(self.rawbuffer) + bytesToRead = min(bytesToRead, self.compreadsize - lr) + + # avoid reading past end of file contents + if bytesToRead + self.bytes_read > self.compress_size: + bytesToRead = self.compress_size - self.bytes_read + + # try to read from file (if necessary) + if bytesToRead > 0: + data = self.fileobj.read(bytesToRead) + self.bytes_read += len(data) + try: + self.rawbuffer += data + except: + print(repr(self.fileobj), repr(self.rawbuffer), + repr(data)) + raise + + # handle contents of raw buffer + if self.rawbuffer: + newdata = self.rawbuffer + self.rawbuffer = b'' + + # decrypt new data if we were given an object to handle that + if newdata and self.decrypter is not None: + newdata = bytes(map(self.decrypter, newdata)) + + # decompress newly read data if necessary + if newdata and self.compress_type != ZIP_STORED: + newdata = self.dc.decompress(newdata) + self.rawbuffer = self.dc.unconsumed_tail if self.compress_type == ZIP_DEFLATED else '' + if self.eof and len(self.rawbuffer) == 0: + # we're out of raw bytes (both from the file and + # the local buffer); flush just to make sure the + # decompressor is done + if hasattr(self.dc, 'flush'): + newdata += self.dc.flush() + # prevent decompressor from being used again + self.dc = None + + self.readbuffer += newdata + + + # return what the user asked for + if size is None or len(self.readbuffer) <= size: + data = self.readbuffer + self.readbuffer = b'' + else: + data = self.readbuffer[:size] + self.readbuffer = self.readbuffer[size:] + + return data + + +class ZipFile: + """ Class with methods to open, read, write, close, list zip files. + + z = ZipFile(file, mode="r", compression=ZIP_STORED, allowZip64=False) + + file: Either the path to the file, or a file-like object. + If it is a path, the file will be opened and closed by ZipFile. + mode: The mode can be either read "r", write "w" or append "a". + compression: ZIP_STORED (no compression), ZIP_DEFLATED (requires zlib), + or ZIP_BZIP2 (requires bz2). + allowZip64: if True ZipFile will create files with ZIP64 extensions when + needed, otherwise it will raise an exception when this would + be necessary. + + """ + + fp = None # Set here since __del__ checks it + + def __init__(self, file, mode="r", compression=ZIP_STORED, allowZip64=False): + """Open the ZIP file with mode read "r", write "w" or append "a".""" + if mode not in ("r", "w", "a"): + raise RuntimeError('ZipFile() requires mode "r", "w", or "a"') + + if compression == ZIP_STORED: + pass + elif compression == ZIP_DEFLATED: + if not zlib: + raise RuntimeError( + "Compression requires the (missing) zlib module") + elif compression == ZIP_BZIP2: + if not bz2: + raise RuntimeError( + "Compression requires the (missing) bz2 module") + else: + raise RuntimeError("That compression method is not supported") + + self._allowZip64 = allowZip64 + self._didModify = False + self.debug = 0 # Level of printing: 0 through 3 + self.NameToInfo = {} # Find file info given name + self.filelist = [] # List of ZipInfo instances for archive + self.compression = compression # Method of compression + self.mode = key = mode.replace('b', '')[0] + self.pwd = None + self.comment = b'' + + # Check if we were passed a file-like object + if isinstance(file, str): + # No, it's a filename + self._filePassed = 0 + self.filename = file + modeDict = {'r' : 'rb', 'w': 'wb', 'a' : 'r+b'} + try: + self.fp = io.open(file, modeDict[mode]) + except IOError: + if mode == 'a': + mode = key = 'w' + self.fp = io.open(file, modeDict[mode]) + else: + raise + else: + self._filePassed = 1 + self.fp = file + self.filename = getattr(file, 'name', None) + + if key == 'r': + self._GetContents() + elif key == 'w': + pass + elif key == 'a': + try: # See if file is a zip file + self._RealGetContents() + # seek to start of directory and overwrite + self.fp.seek(self.start_dir, 0) + except BadZipfile: # file is not a zip file, just append + self.fp.seek(0, 2) + else: + if not self._filePassed: + self.fp.close() + self.fp = None + raise RuntimeError('Mode must be "r", "w" or "a"') + + def _GetContents(self): + """Read the directory, making sure we close the file if the format + is bad.""" + try: + self._RealGetContents() + except BadZipfile: + if not self._filePassed: + self.fp.close() + self.fp = None + raise + + def _RealGetContents(self): + """Read in the table of contents for the ZIP file.""" + fp = self.fp + endrec = _EndRecData(fp) + if not endrec: + raise BadZipfile("File is not a zip file") + if self.debug > 1: + print(endrec) + size_cd = endrec[_ECD_SIZE] # bytes in central directory + offset_cd = endrec[_ECD_OFFSET] # offset of central directory + self.comment = endrec[_ECD_COMMENT] # archive comment + + # "concat" is zero, unless zip was concatenated to another file + concat = endrec[_ECD_LOCATION] - size_cd - offset_cd + if endrec[_ECD_SIGNATURE] == stringEndArchive64: + # If Zip64 extension structures are present, account for them + concat -= (sizeEndCentDir64 + sizeEndCentDir64Locator) + + if self.debug > 2: + inferred = concat + offset_cd + print("given, inferred, offset", offset_cd, inferred, concat) + # self.start_dir: Position of start of central directory + self.start_dir = offset_cd + concat + fp.seek(self.start_dir, 0) + data = fp.read(size_cd) + fp = io.BytesIO(data) + total = 0 + while total < size_cd: + centdir = fp.read(sizeCentralDir) + if centdir[0:4] != stringCentralDir: + raise BadZipfile("Bad magic number for central directory") + centdir = struct.unpack(structCentralDir, centdir) + if self.debug > 2: + print(centdir) + filename = fp.read(centdir[_CD_FILENAME_LENGTH]) + flags = centdir[5] + if flags & 0x800: + # UTF-8 file names extension + filename = filename.decode('utf-8') + else: + # Historical ZIP filename encoding + filename = filename.decode('cp437') + # Create ZipInfo instance to store file information + x = ZipInfo(filename) + x.extra = fp.read(centdir[_CD_EXTRA_FIELD_LENGTH]) + x.comment = fp.read(centdir[_CD_COMMENT_LENGTH]) + x.header_offset = centdir[_CD_LOCAL_HEADER_OFFSET] + (x.create_version, x.create_system, x.extract_version, x.reserved, + x.flag_bits, x.compress_type, t, d, + x.CRC, x.compress_size, x.file_size) = centdir[1:12] + x.volume, x.internal_attr, x.external_attr = centdir[15:18] + # Convert date/time code to (year, month, day, hour, min, sec) + x._raw_time = t + x.date_time = ( (d>>9)+1980, (d>>5)&0xF, d&0x1F, + t>>11, (t>>5)&0x3F, (t&0x1F) * 2 ) + + x._decodeExtra() + x.header_offset = x.header_offset + concat + self.filelist.append(x) + self.NameToInfo[x.filename] = x + + # update total bytes read from central directory + total = (total + sizeCentralDir + centdir[_CD_FILENAME_LENGTH] + + centdir[_CD_EXTRA_FIELD_LENGTH] + + centdir[_CD_COMMENT_LENGTH]) + + if self.debug > 2: + print("total", total) + + + def namelist(self): + """Return a list of file names in the archive.""" + l = [] + for data in self.filelist: + l.append(data.filename) + return l + + def infolist(self): + """Return a list of class ZipInfo instances for files in the + archive.""" + return self.filelist + + def printdir(self, file=None): + """Print a table of contents for the zip file.""" + print("%-46s %19s %12s" % ("File Name", "Modified ", "Size"), + file=file) + for zinfo in self.filelist: + date = "%d-%02d-%02d %02d:%02d:%02d" % zinfo.date_time[:6] + print("%-46s %s %12d" % (zinfo.filename, date, zinfo.file_size), + file=file) + + def testzip(self): + """Read all the files and check the CRC.""" + chunk_size = 2 ** 20 + for zinfo in self.filelist: + try: + # Read by chunks, to avoid an OverflowError or a + # MemoryError with very large embedded files. + f = self.open(zinfo.filename, "r") + while f.read(chunk_size): # Check CRC-32 + pass + except BadZipfile: + return zinfo.filename + + def getinfo(self, name): + """Return the instance of ZipInfo given 'name'.""" + info = self.NameToInfo.get(name) + if info is None: + raise KeyError( + 'There is no item named %r in the archive' % name) + + return info + + def setpassword(self, pwd): + """Set default password for encrypted files.""" + assert isinstance(pwd, bytes) + self.pwd = pwd + + def read(self, name, pwd=None): + """Return file bytes (as a string) for name.""" + return self.open(name, "r", pwd).read() + + def open(self, name, mode="r", pwd=None): + """Return file-like object for 'name'.""" + if mode not in ("r", "U", "rU"): + raise RuntimeError('open() requires mode "r", "U", or "rU"') + if not self.fp: + raise RuntimeError( + "Attempt to read ZIP archive that was already closed") + + # Only open a new file for instances where we were not + # given a file object in the constructor + if self._filePassed: + zef_file = self.fp + else: + zef_file = io.open(self.filename, 'rb') + + # Make sure we have an info object + if isinstance(name, ZipInfo): + # 'name' is already an info object + zinfo = name + else: + # Get info object for name + zinfo = self.getinfo(name) + + zef_file.seek(zinfo.header_offset, 0) + + # Skip the file header: + fheader = zef_file.read(sizeFileHeader) + if fheader[0:4] != stringFileHeader: + raise BadZipfile("Bad magic number for file header") + + fheader = struct.unpack(structFileHeader, fheader) + fname = zef_file.read(fheader[_FH_FILENAME_LENGTH]) + if fheader[_FH_EXTRA_FIELD_LENGTH]: + zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH]) + + if fname != zinfo.orig_filename.encode("utf-8"): + raise BadZipfile( + 'File name in directory %r and header %r differ.' + % (zinfo.orig_filename, fname)) + + # check for encrypted flag & handle password + is_encrypted = zinfo.flag_bits & 0x1 + zd = None + if is_encrypted: + if not pwd: + pwd = self.pwd + if not pwd: + raise RuntimeError("File %s is encrypted, " + "password required for extraction" % name) + + zd = _ZipDecrypter(pwd) + # The first 12 bytes in the cypher stream is an encryption header + # used to strengthen the algorithm. The first 11 bytes are + # completely random, while the 12th contains the MSB of the CRC, + # or the MSB of the file time depending on the header type + # and is used to check the correctness of the password. + bytes = zef_file.read(12) + h = list(map(zd, bytes[0:12])) + if zinfo.flag_bits & 0x8: + # compare against the file type from extended local headers + check_byte = (zinfo._raw_time >> 8) & 0xff + else: + # compare against the CRC otherwise + check_byte = (zinfo.CRC >> 24) & 0xff + if h[11] != check_byte: + raise RuntimeError("Bad password for file", name) + + # build and return a ZipExtFile + if zd is None: + zef = ZipExtFile(zef_file, zinfo) + else: + zef = ZipExtFile(zef_file, zinfo, zd) + + # set universal newlines on ZipExtFile if necessary + if "U" in mode: + zef.set_univ_newlines(True) + return zef + + def extract(self, member, path=None, pwd=None): + """Extract a member from the archive to the current working directory, + using its full name. Its file information is extracted as accurately + as possible. `member' may be a filename or a ZipInfo object. You can + specify a different directory using `path'. + """ + if not isinstance(member, ZipInfo): + member = self.getinfo(member) + + if path is None: + path = os.getcwd() + + return self._extract_member(member, path, pwd) + + def extractall(self, path=None, members=None, pwd=None): + """Extract all members from the archive to the current working + directory. `path' specifies a different directory to extract to. + `members' is optional and must be a subset of the list returned + by namelist(). + """ + if members is None: + members = self.namelist() + + for zipinfo in members: + self.extract(zipinfo, path, pwd) + + def _extract_member(self, member, targetpath, pwd): + """Extract the ZipInfo object 'member' to a physical + file on the path targetpath. + """ + # build the destination pathname, replacing + # forward slashes to platform specific separators. + # Strip trailing path separator, unless it represents the root. + if (targetpath[-1:] in (os.path.sep, os.path.altsep) + and len(os.path.splitdrive(targetpath)[1]) > 1): + targetpath = targetpath[:-1] + + # don't include leading "/" from file name if present + if member.filename[0] == '/': + targetpath = os.path.join(targetpath, member.filename[1:]) + else: + targetpath = os.path.join(targetpath, member.filename) + + targetpath = os.path.normpath(targetpath) + + # Create all upper directories if necessary. + upperdirs = os.path.dirname(targetpath) + if upperdirs and not os.path.exists(upperdirs): + os.makedirs(upperdirs) + + if member.filename[-1] == '/': + if not os.path.isdir(targetpath): + os.mkdir(targetpath) + return targetpath + + source = self.open(member, pwd=pwd) + target = open(targetpath, "wb") + shutil.copyfileobj(source, target) + source.close() + target.close() + + return targetpath + + def _writecheck(self, zinfo): + """Check for errors before writing a file to the archive.""" + if zinfo.filename in self.NameToInfo: + if self.debug: # Warning for duplicate names + print("Duplicate name:", zinfo.filename) + if self.mode not in ("w", "a"): + raise RuntimeError('write() requires mode "w" or "a"') + if not self.fp: + raise RuntimeError( + "Attempt to write ZIP archive that was already closed") + if zinfo.compress_type == ZIP_DEFLATED and not zlib: + raise RuntimeError( + "Compression requires the (missing) zlib module") + if zinfo.compress_type == ZIP_BZIP2 and not bz2: + raise RuntimeError( + "Compression requires the (missing) bz2 module") + if zinfo.compress_type not in (ZIP_STORED, ZIP_DEFLATED, ZIP_BZIP2): + raise RuntimeError("That compression method is not supported") + if zinfo.file_size > ZIP64_LIMIT: + if not self._allowZip64: + raise LargeZipFile("Filesize would require ZIP64 extensions") + if zinfo.header_offset > ZIP64_LIMIT: + if not self._allowZip64: + raise LargeZipFile( + "Zipfile size would require ZIP64 extensions") + + def write(self, filename, arcname=None, compress_type=None): + """Put the bytes from filename into the archive under the name + arcname.""" + if not self.fp: + raise RuntimeError( + "Attempt to write to ZIP archive that was already closed") + + st = os.stat(filename) + isdir = stat.S_ISDIR(st.st_mode) + mtime = time.localtime(st.st_mtime) + date_time = mtime[0:6] + # Create ZipInfo instance to store file information + if arcname is None: + arcname = filename + arcname = os.path.normpath(os.path.splitdrive(arcname)[1]) + while arcname[0] in (os.sep, os.altsep): + arcname = arcname[1:] + if isdir: + arcname += '/' + zinfo = ZipInfo(arcname, date_time) + zinfo.external_attr = (st[0] & 0xFFFF) << 16 # Unix attributes + if compress_type is None: + zinfo.compress_type = self.compression + else: + zinfo.compress_type = compress_type + + zinfo.file_size = st.st_size + zinfo.flag_bits = 0x00 + zinfo.header_offset = self.fp.tell() # Start of header bytes + + self._writecheck(zinfo) + self._didModify = True + + if isdir: + zinfo.file_size = 0 + zinfo.compress_size = 0 + zinfo.CRC = 0 + self.filelist.append(zinfo) + self.NameToInfo[zinfo.filename] = zinfo + self.fp.write(zinfo.FileHeader()) + return + + with open(filename, "rb") as fp: + # Must overwrite CRC and sizes with correct data later + zinfo.CRC = CRC = 0 + zinfo.compress_size = compress_size = 0 + zinfo.file_size = file_size = 0 + self.fp.write(zinfo.FileHeader()) + if zinfo.compress_type == ZIP_DEFLATED: + cmpr = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, + zlib.DEFLATED, -15) + elif zinfo.compress_type == ZIP_BZIP2: + cmpr = bz2.BZ2Compressor() + else: + cmpr = None + while 1: + buf = fp.read(1024 * 8) + if not buf: + break + file_size = file_size + len(buf) + CRC = crc32(buf, CRC) & 0xffffffff + if cmpr: + buf = cmpr.compress(buf) + compress_size = compress_size + len(buf) + self.fp.write(buf) + if cmpr: + buf = cmpr.flush() + compress_size = compress_size + len(buf) + self.fp.write(buf) + zinfo.compress_size = compress_size + else: + zinfo.compress_size = file_size + zinfo.CRC = CRC + zinfo.file_size = file_size + # Seek backwards and write CRC and file sizes + position = self.fp.tell() # Preserve current position in file + self.fp.seek(zinfo.header_offset + 14, 0) + self.fp.write(struct.pack("<LLL", zinfo.CRC, zinfo.compress_size, + zinfo.file_size)) + self.fp.seek(position, 0) + self.filelist.append(zinfo) + self.NameToInfo[zinfo.filename] = zinfo + + def writestr(self, zinfo_or_arcname, data): + """Write a file into the archive. The contents is 'data', which + may be either a 'str' or a 'bytes' instance; if it is a 'str', + it is encoded as UTF-8 first. + 'zinfo_or_arcname' is either a ZipInfo instance or + the name of the file in the archive.""" + if isinstance(data, str): + data = data.encode("utf-8") + if not isinstance(zinfo_or_arcname, ZipInfo): + zinfo = ZipInfo(filename=zinfo_or_arcname, + date_time=time.localtime(time.time())[:6]) + zinfo.compress_type = self.compression + zinfo.external_attr = 0o600 << 16 + else: + zinfo = zinfo_or_arcname + + if not self.fp: + raise RuntimeError( + "Attempt to write to ZIP archive that was already closed") + + zinfo.file_size = len(data) # Uncompressed size + zinfo.header_offset = self.fp.tell() # Start of header data + self._writecheck(zinfo) + self._didModify = True + zinfo.CRC = crc32(data) & 0xffffffff # CRC-32 checksum + if zinfo.compress_type == ZIP_DEFLATED: + co = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, + zlib.DEFLATED, -15) + data = co.compress(data) + co.flush() + zinfo.compress_size = len(data) # Compressed size + elif zinfo.compress_type == ZIP_BZIP2: + co = bz2.BZ2Compressor() + data = co.compress(data) + co.flush() + zinfo.compress_size = len(data) # Compressed size + else: + zinfo.compress_size = zinfo.file_size + zinfo.header_offset = self.fp.tell() # Start of header data + self.fp.write(zinfo.FileHeader()) + self.fp.write(data) + self.fp.flush() + if zinfo.flag_bits & 0x08: + # Write CRC and file sizes after the file data + self.fp.write(struct.pack("<LLL", zinfo.CRC, zinfo.compress_size, + zinfo.file_size)) + self.filelist.append(zinfo) + self.NameToInfo[zinfo.filename] = zinfo + + def __del__(self): + """Call the "close()" method in case the user forgot.""" + self.close() + + def close(self): + """Close the file, and for mode "w" and "a" write the ending + records.""" + if self.fp is None: + return + + if self.mode in ("w", "a") and self._didModify: # write ending records + count = 0 + pos1 = self.fp.tell() + for zinfo in self.filelist: # write central directory + count = count + 1 + dt = zinfo.date_time + dosdate = (dt[0] - 1980) << 9 | dt[1] << 5 | dt[2] + dostime = dt[3] << 11 | dt[4] << 5 | (dt[5] // 2) + extra = [] + if zinfo.file_size > ZIP64_LIMIT \ + or zinfo.compress_size > ZIP64_LIMIT: + extra.append(zinfo.file_size) + extra.append(zinfo.compress_size) + file_size = 0xffffffff + compress_size = 0xffffffff + else: + file_size = zinfo.file_size + compress_size = zinfo.compress_size + + if zinfo.header_offset > ZIP64_LIMIT: + extra.append(zinfo.header_offset) + header_offset = 0xffffffff + else: + header_offset = zinfo.header_offset + + extra_data = zinfo.extra + if extra: + # Append a ZIP64 field to the extra's + extra_data = struct.pack( + '<HH' + 'Q'*len(extra), + 1, 8*len(extra), *extra) + extra_data + + extract_version = max(45, zinfo.extract_version) + create_version = max(45, zinfo.create_version) + else: + extract_version = zinfo.extract_version + create_version = zinfo.create_version + + try: + filename, flag_bits = zinfo._encodeFilenameFlags() + centdir = struct.pack(structCentralDir, + stringCentralDir, create_version, + zinfo.create_system, extract_version, zinfo.reserved, + flag_bits, zinfo.compress_type, dostime, dosdate, + zinfo.CRC, compress_size, file_size, + len(filename), len(extra_data), len(zinfo.comment), + 0, zinfo.internal_attr, zinfo.external_attr, + header_offset) + except DeprecationWarning: + print((structCentralDir, stringCentralDir, create_version, + zinfo.create_system, extract_version, zinfo.reserved, + zinfo.flag_bits, zinfo.compress_type, dostime, dosdate, + zinfo.CRC, compress_size, file_size, + len(zinfo.filename), len(extra_data), len(zinfo.comment), + 0, zinfo.internal_attr, zinfo.external_attr, + header_offset), file=sys.stderr) + raise + self.fp.write(centdir) + self.fp.write(filename) + self.fp.write(extra_data) + self.fp.write(zinfo.comment) + + pos2 = self.fp.tell() + # Write end-of-zip-archive record + centDirCount = count + centDirSize = pos2 - pos1 + centDirOffset = pos1 + if (centDirCount >= ZIP_FILECOUNT_LIMIT or + centDirOffset > ZIP64_LIMIT or + centDirSize > ZIP64_LIMIT): + # Need to write the ZIP64 end-of-archive records + zip64endrec = struct.pack( + structEndArchive64, stringEndArchive64, + 44, 45, 45, 0, 0, centDirCount, centDirCount, + centDirSize, centDirOffset) + self.fp.write(zip64endrec) + + zip64locrec = struct.pack( + structEndArchive64Locator, + stringEndArchive64Locator, 0, pos2, 1) + self.fp.write(zip64locrec) + centDirCount = min(centDirCount, 0xFFFF) + centDirSize = min(centDirSize, 0xFFFFFFFF) + centDirOffset = min(centDirOffset, 0xFFFFFFFF) + + # check for valid comment length + if len(self.comment) >= ZIP_MAX_COMMENT: + if self.debug > 0: + msg = 'Archive comment is too long; truncating to %d bytes' \ + % ZIP_MAX_COMMENT + self.comment = self.comment[:ZIP_MAX_COMMENT] + + endrec = struct.pack(structEndArchive, stringEndArchive, + 0, 0, centDirCount, centDirCount, + centDirSize, centDirOffset, len(self.comment)) + self.fp.write(endrec) + self.fp.write(self.comment) + self.fp.flush() + + if not self._filePassed: + self.fp.close() + self.fp = None + + +class PyZipFile(ZipFile): + """Class to create ZIP archives with Python library files and packages.""" + + def writepy(self, pathname, basename=""): + """Add all files from "pathname" to the ZIP archive. + + If pathname is a package directory, search the directory and + all package subdirectories recursively for all *.py and enter + the modules into the archive. If pathname is a plain + directory, listdir *.py and enter all modules. Else, pathname + must be a Python *.py file and the module will be put into the + archive. Added modules are always module.pyo or module.pyc. + This method will compile the module.py into module.pyc if + necessary. + """ + dir, name = os.path.split(pathname) + if os.path.isdir(pathname): + initname = os.path.join(pathname, "__init__.py") + if os.path.isfile(initname): + # This is a package directory, add it + if basename: + basename = "%s/%s" % (basename, name) + else: + basename = name + if self.debug: + print("Adding package in", pathname, "as", basename) + fname, arcname = self._get_codename(initname[0:-3], basename) + if self.debug: + print("Adding", arcname) + self.write(fname, arcname) + dirlist = os.listdir(pathname) + dirlist.remove("__init__.py") + # Add all *.py files and package subdirectories + for filename in dirlist: + path = os.path.join(pathname, filename) + root, ext = os.path.splitext(filename) + if os.path.isdir(path): + if os.path.isfile(os.path.join(path, "__init__.py")): + # This is a package directory, add it + self.writepy(path, basename) # Recursive call + elif ext == ".py": + fname, arcname = self._get_codename(path[0:-3], + basename) + if self.debug: + print("Adding", arcname) + self.write(fname, arcname) + else: + # This is NOT a package directory, add its files at top level + if self.debug: + print("Adding files from directory", pathname) + for filename in os.listdir(pathname): + path = os.path.join(pathname, filename) + root, ext = os.path.splitext(filename) + if ext == ".py": + fname, arcname = self._get_codename(path[0:-3], + basename) + if self.debug: + print("Adding", arcname) + self.write(fname, arcname) + else: + if pathname[-3:] != ".py": + raise RuntimeError( + 'Files added with writepy() must end with ".py"') + fname, arcname = self._get_codename(pathname[0:-3], basename) + if self.debug: + print("Adding file", arcname) + self.write(fname, arcname) + + def _get_codename(self, pathname, basename): + """Return (filename, archivename) for the path. + + Given a module name path, return the correct file path and + archive name, compiling if necessary. For example, given + /python/lib/string, return (/python/lib/string.pyc, string). + """ + file_py = pathname + ".py" + file_pyc = pathname + ".pyc" + file_pyo = pathname + ".pyo" + if os.path.isfile(file_pyo) and \ + os.stat(file_pyo).st_mtime >= os.stat(file_py).st_mtime: + fname = file_pyo # Use .pyo file + elif not os.path.isfile(file_pyc) or \ + os.stat(file_pyc).st_mtime < os.stat(file_py).st_mtime: + import py_compile + if self.debug: + print("Compiling", file_py) + try: + py_compile.compile(file_py, file_pyc, None, True) + except py_compile.PyCompileError as err: + print(err.msg) + fname = file_pyc + else: + fname = file_pyc + archivename = os.path.split(fname)[1] + if basename: + archivename = "%s/%s" % (basename, archivename) + return (fname, archivename) + + +def main(args = None): + import textwrap + USAGE=textwrap.dedent("""\ + Usage: + zipfile.py -l zipfile.zip # Show listing of a zipfile + zipfile.py -t zipfile.zip # Test if a zipfile is valid + zipfile.py -e zipfile.zip target # Extract zipfile into target dir + zipfile.py -c zipfile.zip src ... # Create zipfile from sources + """) + if args is None: + args = sys.argv[1:] + + if not args or args[0] not in ('-l', '-c', '-e', '-t'): + print(USAGE) + sys.exit(1) + + if args[0] == '-l': + if len(args) != 2: + print(USAGE) + sys.exit(1) + zf = ZipFile(args[1], 'r') + zf.printdir() + zf.close() + + elif args[0] == '-t': + if len(args) != 2: + print(USAGE) + sys.exit(1) + zf = ZipFile(args[1], 'r') + zf.testzip() + print("Done testing") + + elif args[0] == '-e': + if len(args) != 3: + print(USAGE) + sys.exit(1) + + zf = ZipFile(args[1], 'r') + out = args[2] + for path in zf.namelist(): + if path.startswith('./'): + tgt = os.path.join(out, path[2:]) + else: + tgt = os.path.join(out, path) + + tgtdir = os.path.dirname(tgt) + if not os.path.exists(tgtdir): + os.makedirs(tgtdir) + with open(tgt, 'wb') as fp: + fp.write(zf.read(path)) + zf.close() + + elif args[0] == '-c': + if len(args) < 3: + print(USAGE) + sys.exit(1) + + def addToZip(zf, path, zippath): + if os.path.isfile(path): + zf.write(path, zippath, ZIP_DEFLATED) + elif os.path.isdir(path): + for nm in os.listdir(path): + addToZip(zf, + os.path.join(path, nm), os.path.join(zippath, nm)) + # else: ignore + + zf = ZipFile(args[1], 'w', allowZip64=True) + for src in args[2:]: + addToZip(zf, src, os.path.basename(src)) + + zf.close() + +if __name__ == "__main__": + main()
