# Copyright (c) 2015, Google Inc.
#
# Permission to use, copy, modify, and/or distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
# SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
# OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""Extracts archives."""

import hashlib
import optparse
import os
import os.path
import tarfile
import shutil
import sys
import zipfile


def CheckedJoin(output, path):
    """
    CheckedJoin returns os.path.join(output, path). It checks that the resulting
    path is under output, but shouldn't be used on untrusted input.
    """
    path = os.path.normpath(path)
    if os.path.isabs(path) or path.startswith("."):
        raise ValueError(path)
    return os.path.join(output, path)


class FileEntry(object):
    def __init__(self, path, mode, fileobj):
        self.path = path
        self.mode = mode
        self.fileobj = fileobj


class SymlinkEntry(object):
    def __init__(self, path, mode, target):
        self.path = path
        self.mode = mode
        self.target = target


def IterateZip(path):
    """
    IterateZip opens the zip file at path and returns a generator of entry objects
    for each file in it.
    """
    with zipfile.ZipFile(path, "r") as zip_file:
        for info in zip_file.infolist():
            if info.filename.endswith("/"):
                continue
            yield FileEntry(info.filename, None, zip_file.open(info))


def IterateTar(path, compression):
    """
    IterateTar opens the tar.gz or tar.bz2 file at path and returns a generator of
    entry objects for each file in it.
    """
    with tarfile.open(path, "r:" + compression) as tar_file:
        for info in tar_file:
            if info.isdir():
                pass
            elif info.issym():
                yield SymlinkEntry(info.name, None, info.linkname)
            elif info.isfile():
                yield FileEntry(info.name, info.mode,
                                tar_file.extractfile(info))
            else:
                raise ValueError('Unknown entry type "%s"' % (info.name, ))


def main(args):
    parser = optparse.OptionParser(usage="Usage: %prog ARCHIVE OUTPUT")
    parser.add_option(
        "--no-prefix",
        dest="no_prefix",
        action="store_true",
        help="Do not remove a prefix from paths in the archive.",
    )
    options, args = parser.parse_args(args)

    if len(args) != 2:
        parser.print_help()
        return 1

    archive, output = args

    if not os.path.exists(archive):
        # Skip archives that weren't downloaded.
        return 0

    with open(archive, "rb") as f:
        sha256 = hashlib.sha256()
        while True:
            chunk = f.read(1024 * 1024)
            if not chunk:
                break
            sha256.update(chunk)
        digest = sha256.hexdigest()

    stamp_path = os.path.join(output, ".dawn_archive_digest")
    if os.path.exists(stamp_path):
        with open(stamp_path) as f:
            if f.read().strip() == digest:
                print("Already up-to-date.")
                return 0

    if archive.endswith(".zip"):
        entries = IterateZip(archive)
    elif archive.endswith(".tar.gz"):
        entries = IterateTar(archive, "gz")
    elif archive.endswith(".tar.bz2"):
        entries = IterateTar(archive, "bz2")
    else:
        raise ValueError(archive)

    try:
        if os.path.exists(output):
            print("Removing %s" % (output, ))
            shutil.rmtree(output)

        print("Extracting %s to %s" % (archive, output))
        prefix = None
        num_extracted = 0
        for entry in entries:
            # Even on Windows, zip files must always use forward slashes.
            if "\\" in entry.path or entry.path.startswith("/"):
                raise ValueError(entry.path)

            if not options.no_prefix:
                new_prefix, rest = entry.path.split("/", 1)

                # Ensure the archive is consistent.
                if prefix is None:
                    prefix = new_prefix
                if prefix != new_prefix:
                    raise ValueError((prefix, new_prefix))
            else:
                rest = entry.path

            # Extract the file into the output directory.
            fixed_path = CheckedJoin(output, rest)
            if not os.path.isdir(os.path.dirname(fixed_path)):
                os.makedirs(os.path.dirname(fixed_path))
            if isinstance(entry, FileEntry):
                with open(fixed_path, "wb") as out:
                    shutil.copyfileobj(entry.fileobj, out)
            elif isinstance(entry, SymlinkEntry):
                os.symlink(entry.target, fixed_path)
            else:
                raise TypeError("unknown entry type")

            # Fix up permissions if needbe.
            # TODO(davidben): To be extra tidy, this should only track the execute bit
            # as in git.
            if entry.mode is not None:
                os.chmod(fixed_path, entry.mode)

            # Print every 100 files, so bots do not time out on large archives.
            num_extracted += 1
            if num_extracted % 100 == 0:
                print("Extracted %d files..." % (num_extracted, ))
    finally:
        entries.close()

    with open(stamp_path, "w") as f:
        f.write(digest)

    print("Done. Extracted %d files." % (num_extracted, ))
    return 0


if __name__ == "__main__":
    sys.exit(main(sys.argv[1:]))