184 lines
5.8 KiB
Python
184 lines
5.8 KiB
Python
# Copyright (c) 2015, Google Inc.
|
|
#
|
|
# Permission to use, copy, modify, and/or distribute this software for any
|
|
# purpose with or without fee is hereby granted, provided that the above
|
|
# copyright notice and this permission notice appear in all copies.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
|
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
|
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
|
|
# SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
|
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
|
|
# OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
|
|
# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
|
"""Extracts archives."""
|
|
|
|
import hashlib
|
|
import optparse
|
|
import os
|
|
import os.path
|
|
import tarfile
|
|
import shutil
|
|
import sys
|
|
import zipfile
|
|
|
|
|
|
def CheckedJoin(output, path):
|
|
"""
|
|
CheckedJoin returns os.path.join(output, path). It checks that the resulting
|
|
path is under output, but shouldn't be used on untrusted input.
|
|
"""
|
|
path = os.path.normpath(path)
|
|
if os.path.isabs(path) or path.startswith("."):
|
|
raise ValueError(path)
|
|
return os.path.join(output, path)
|
|
|
|
|
|
class FileEntry(object):
|
|
def __init__(self, path, mode, fileobj):
|
|
self.path = path
|
|
self.mode = mode
|
|
self.fileobj = fileobj
|
|
|
|
|
|
class SymlinkEntry(object):
|
|
def __init__(self, path, mode, target):
|
|
self.path = path
|
|
self.mode = mode
|
|
self.target = target
|
|
|
|
|
|
def IterateZip(path):
|
|
"""
|
|
IterateZip opens the zip file at path and returns a generator of entry objects
|
|
for each file in it.
|
|
"""
|
|
with zipfile.ZipFile(path, "r") as zip_file:
|
|
for info in zip_file.infolist():
|
|
if info.filename.endswith("/"):
|
|
continue
|
|
yield FileEntry(info.filename, None, zip_file.open(info))
|
|
|
|
|
|
def IterateTar(path, compression):
|
|
"""
|
|
IterateTar opens the tar.gz or tar.bz2 file at path and returns a generator of
|
|
entry objects for each file in it.
|
|
"""
|
|
with tarfile.open(path, "r:" + compression) as tar_file:
|
|
for info in tar_file:
|
|
if info.isdir():
|
|
pass
|
|
elif info.issym():
|
|
yield SymlinkEntry(info.name, None, info.linkname)
|
|
elif info.isfile():
|
|
yield FileEntry(info.name, info.mode,
|
|
tar_file.extractfile(info))
|
|
else:
|
|
raise ValueError('Unknown entry type "%s"' % (info.name, ))
|
|
|
|
|
|
def main(args):
|
|
parser = optparse.OptionParser(usage="Usage: %prog ARCHIVE OUTPUT")
|
|
parser.add_option(
|
|
"--no-prefix",
|
|
dest="no_prefix",
|
|
action="store_true",
|
|
help="Do not remove a prefix from paths in the archive.",
|
|
)
|
|
options, args = parser.parse_args(args)
|
|
|
|
if len(args) != 2:
|
|
parser.print_help()
|
|
return 1
|
|
|
|
archive, output = args
|
|
|
|
if not os.path.exists(archive):
|
|
# Skip archives that weren't downloaded.
|
|
return 0
|
|
|
|
with open(archive, "rb") as f:
|
|
sha256 = hashlib.sha256()
|
|
while True:
|
|
chunk = f.read(1024 * 1024)
|
|
if not chunk:
|
|
break
|
|
sha256.update(chunk)
|
|
digest = sha256.hexdigest()
|
|
|
|
stamp_path = os.path.join(output, ".dawn_archive_digest")
|
|
if os.path.exists(stamp_path):
|
|
with open(stamp_path) as f:
|
|
if f.read().strip() == digest:
|
|
print("Already up-to-date.")
|
|
return 0
|
|
|
|
if archive.endswith(".zip"):
|
|
entries = IterateZip(archive)
|
|
elif archive.endswith(".tar.gz"):
|
|
entries = IterateTar(archive, "gz")
|
|
elif archive.endswith(".tar.bz2"):
|
|
entries = IterateTar(archive, "bz2")
|
|
else:
|
|
raise ValueError(archive)
|
|
|
|
try:
|
|
if os.path.exists(output):
|
|
print("Removing %s" % (output, ))
|
|
shutil.rmtree(output)
|
|
|
|
print("Extracting %s to %s" % (archive, output))
|
|
prefix = None
|
|
num_extracted = 0
|
|
for entry in entries:
|
|
# Even on Windows, zip files must always use forward slashes.
|
|
if "\\" in entry.path or entry.path.startswith("/"):
|
|
raise ValueError(entry.path)
|
|
|
|
if not options.no_prefix:
|
|
new_prefix, rest = entry.path.split("/", 1)
|
|
|
|
# Ensure the archive is consistent.
|
|
if prefix is None:
|
|
prefix = new_prefix
|
|
if prefix != new_prefix:
|
|
raise ValueError((prefix, new_prefix))
|
|
else:
|
|
rest = entry.path
|
|
|
|
# Extract the file into the output directory.
|
|
fixed_path = CheckedJoin(output, rest)
|
|
if not os.path.isdir(os.path.dirname(fixed_path)):
|
|
os.makedirs(os.path.dirname(fixed_path))
|
|
if isinstance(entry, FileEntry):
|
|
with open(fixed_path, "wb") as out:
|
|
shutil.copyfileobj(entry.fileobj, out)
|
|
elif isinstance(entry, SymlinkEntry):
|
|
os.symlink(entry.target, fixed_path)
|
|
else:
|
|
raise TypeError("unknown entry type")
|
|
|
|
# Fix up permissions if needbe.
|
|
# TODO(davidben): To be extra tidy, this should only track the execute bit
|
|
# as in git.
|
|
if entry.mode is not None:
|
|
os.chmod(fixed_path, entry.mode)
|
|
|
|
# Print every 100 files, so bots do not time out on large archives.
|
|
num_extracted += 1
|
|
if num_extracted % 100 == 0:
|
|
print("Extracted %d files..." % (num_extracted, ))
|
|
finally:
|
|
entries.close()
|
|
|
|
with open(stamp_path, "w") as f:
|
|
f.write(digest)
|
|
|
|
print("Done. Extracted %d files." % (num_extracted, ))
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main(sys.argv[1:]))
|