#!/usr/bin/env python
# -----------------------------------------------------------------------------
#
# Project: preprocessor.py
# Authors: Stephan Meissl <stephan.meissl@eox.at>
#
# -----------------------------------------------------------------------------
# Copyright (c) 2019 EOX IT Services GmbH
#
# Python script to preprocess product data.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies of this Software or works derived from this Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
# -----------------------------------------------------------------------------


import sys
import os
import argparse
import textwrap
import logging
import traceback
import redis
import tempfile
import tarfile
import re
import subprocess

from swiftclient.multithreading import OutputManager
from swiftclient.service import SwiftError, SwiftService, SwiftUploadObject

import transform_chain

SPLIT_PARTS_CHECK = os.environ.get('SPLIT_PARTS_CHECK')
ENFORCE_FOUR_BANDS = os.environ.get('ENFORCE_FOUR_BANDS')

FILESIZE_LIMIT = 4 * (1024 ** 3)  # swift 5GB limit for filesize (non-compressed), here less to have margin
swift_upload_options = {
    'use_slo': True
}

logger = logging.getLogger("preprocessor")


def setup_logging(verbosity):
    # start logging setup
    # get command line level
    verbosity = verbosity
    if verbosity == 0:
        level = logging.CRITICAL
    elif verbosity == 1:
        level = logging.ERROR
    elif verbosity == 2:
        level = logging.WARNING
    elif verbosity == 3:
        level = logging.INFO
    else:
        level = logging.DEBUG
    logger.setLevel(level)
    sh = logging.StreamHandler()
    sh.setLevel(level)
    formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s")
    sh.setFormatter(formatter)
    logger.addHandler(sh)
    # finished logging setup


def preprocessor(
    collection, tar_object_path, replace=False,
    client=None, register_queue_key=None
):
    logger.info("Starting preprocessing of '%s'." % (tar_object_path))

    try:
        container = tar_object_path.split("/")[1]
        package = "/".join(tar_object_path.split("/")[2:])

        with SwiftService() as swift, OutputManager(), \
                tempfile.TemporaryDirectory() as tmpdirname:
            if not replace:
                try:
                    list_parts_gen = swift.list(
                        container=container, options={"prefix": package},
                    )
                    for page in list_parts_gen:
                        if page["success"]:
                            logger.critical(
                                "Aborting, package '%s' already exists at "
                                "target container '%s'." % (package, container)
                            )
                            return(1)
                except SwiftError as e:
                    logger.debug(traceback.format_exc())
                    logger.error("%s: %s\n" % (type(e).__name__, str(e)))
                    return(1)

            tmpfilename = os.path.join(tmpdirname, "tmp.tar")

            options = {
                "os_username": os.environ.get('OS_USERNAME_DOWNLOAD'),
                "os_password": os.environ.get('OS_PASSWORD_DOWNLOAD'),
                "os_tenant_name": os.environ.get('OS_TENANT_NAME_DOWNLOAD'),
                "os_tenant_id": os.environ.get('OS_TENANT_ID_DOWNLOAD'),
                "os_region_name": os.environ.get('OS_REGION_NAME_DOWNLOAD'),
                "os_auth_url": os.environ.get('OS_AUTH_URL_DOWNLOAD'),
                "auth_version": os.environ.get('ST_AUTH_VERSION_DOWNLOAD'),
            }
            with SwiftService(options=options) as swift_down:
                for down_res in swift_down.download(
                    container=container,
                    objects=[package, ],
                    options={"out_file": tmpfilename},
                ):
                    if down_res["success"]:
                        logger.debug(
                            "'%s' downloaded" % down_res["object"]
                        )
                    else:
                        logger.error(
                            "'%s' download failed" % down_res["object"]
                        )
                        return(1)

            tf = tarfile.open(tmpfilename, mode="r")

            data_files_ti = [
                m for m in tf.getmembers() if
                m is not None and re.search(r"IMG.+\.(TIF|JP2)", m.name, re.IGNORECASE)
            ]
            metadata_file_ti = next(
                m for m in tf.getmembers() if m is not None and re.search(r"GSC.+\.xml", m.name, re.IGNORECASE)
            )
            world_files_ti = [
                m for m in tf.getmembers() if m is not None and 
                re.search(r"RPC.+\.xml", m.name, re.IGNORECASE)
            ]
            # add J2W files only if more than one files are present
            # that signalizes that file was split into multiple or has panchromatic
            if len(data_files_ti) > 1:
                world_files_ti += [
                    m for m in tf.getmembers() if m is not None and 
                    re.search(r".+\.J2W", m.name, re.IGNORECASE)
                ]
            data_files = [
                member.name
                for member in data_files_ti
            ]
            metadata_file = metadata_file_ti.name
            members = data_files_ti + [metadata_file_ti] + world_files_ti

            if not data_files or not metadata_file:
                logger.error(
                    "Aborting, not all needed files found in package."
                )
                return(1)

            tf.extractall(path=tmpdirname, members=members)

            # cleanup after use to save space
            tf.close()
            os.remove(tmpfilename)

            source_name_first = os.path.join(tmpdirname, data_files[0])

            # if there is more than one file, make a VRT to mosaic them
            if len(data_files) > 1:
                logger.debug("More files found, creating a VRT")
                source_name_vrt = os.path.join(tmpdirname, 'tmp.vrt')
                # open all datasets one by one and create an array of open datasets
                dataset_array = [transform_chain.open_gdal_dataset(os.path.join(tmpdirname, data_file)) for data_file in data_files]
                if ENFORCE_FOUR_BANDS:
                    # remove and close datasets with different number of bands than expected
                    dataset_array = list(filter(None, [transform_chain.validate_band_count(dataset, 4) for dataset in dataset_array]))
                    if len(dataset_array) == 0:
                        logger.error(
                            "Aborting, wrong number of bands for all datasets %s" % ",".join(data_files)
                        )
                        return(1)
                # try to fix geotransform for ortho images one by one before making a vrt, which fails otherwise
                dataset_array = [transform_chain.correct_geo_transform(dataset_entity) for dataset_entity in dataset_array]
                # create a vrt out of them
                dataset = transform_chain.create_vrt_dataset(dataset_array, source_name_vrt)
                # during creating of a vrt, reference to RPC is lost
                # if there was rpc, set it to the vrt
                dataset = transform_chain.set_rpc_metadata(dataset_array[0], dataset)
                dataset_array = None
            else:
                # open file using gdal
                dataset = transform_chain.open_gdal_dataset(source_name_first)
            # close datasets with different number of bands than expected
            if ENFORCE_FOUR_BANDS:
                dataset = transform_chain.validate_band_count(dataset, 4)
                if dataset is None:
                    logger.error(
                        "Aborting, wrong number of bands for dataset %s" % data_files[0]
                    )
                    return(1)
            # change RPC to geotransform if present
            dataset = transform_chain.apply_rpc(dataset)

            # perform transformation correction if necessary
            dataset = transform_chain.correct_geo_transform(dataset)

            # save file with given options - should use ENV
            creation_options = ["BLOCKSIZE=512", "COMPRESS=DEFLATE", "LEVEL=6", "NUM_THREADS=8",
                                "BIGTIFF=IF_SAFER", "OVERVIEWS=AUTO", "RESAMPLING=CUBIC"]

            split_parts = transform_chain.split_check(dataset, FILESIZE_LIMIT) if SPLIT_PARTS_CHECK == True else 1

            output_file_list = transform_chain.write_gdal_dataset_split(dataset, "COG", "%s.tif" % os.path.splitext(
                source_name_first)[0], creation_options, split_parts)
            dataset = None
            objects = []
            # create vrt if file was split
            if len(output_file_list) > 1:
                logger.debug("Creating .vrt of previously split files.")
                vrt_name = "%s.vrt" % os.path.splitext(source_name_first)[0]
                subprocess.run(
                    ['gdalbuildvrt', '-quiet', os.path.basename(vrt_name)] + [
                        os.path.basename(data_file) for data_file in output_file_list],
                    timeout=600, check=True, cwd=os.path.dirname(vrt_name)
                )  # use cwd to create relative paths in vrt
                # add vrt to files to be uploaded
                objects.append(
                    SwiftUploadObject(
                        vrt_name,
                        object_name=os.path.join(
                            package, os.path.basename(vrt_name))
                    )
                )

            # add image files to files to be uploaded
            for data_file in output_file_list:
                # check if 5GB swift upload limit is exceeded by any of files, if yes, use segmentation
                size = os.stat(data_file).st_size
                if (size > 1024 * 1024 * 1024 * 5):
                    swift_upload_options["segment_size"] = 2 * 1024 * 1024 * 1024  # 2gb segments

                dest_object_name = os.path.join(
                    package, os.path.basename(data_file)
                )
                objects.append(
                    SwiftUploadObject(data_file, object_name=dest_object_name)
                )

            # add metadata to files to be uploaded after data files
            objects.append(
                SwiftUploadObject(
                    os.path.join(tmpdirname, metadata_file),
                    object_name=os.path.join(package, metadata_file)
                )
            )

            # upload files
            for upload in swift.upload(
                container=container,
                objects=objects,
                options=swift_upload_options
            ):
                if upload["success"]:
                    if "object" in upload:
                        logger.info(
                            "'%s' successfully uploaded." % upload["object"]
                        )
                    elif "for_object" in upload:
                        logger.debug(
                            "Successfully uploaded '%s' segment '%s'."
                            % (upload["for_object"], upload["segment_index"])
                        )
                else:
                    logger.error(
                        "'%s' upload failed" % upload["error"]
                    )
                    return(1)

            if client is not None:
                logger.debug(
                    "Storing paths in redis queue '%s" % register_queue_key
                )
                client.lpush(
                    register_queue_key, "%s" % tar_object_path
                )

    except Exception as e:
        logger.debug(traceback.format_exc())
        logger.error("%s: %s\n" % (type(e).__name__, str(e)))
        return(1)

    logger.info(
        "Successfully finished preprocessing of '%s'." % (tar_object_path)
    )


def preprocessor_redis_wrapper(
    collection, replace=False, host="localhost", port=6379,
    preprocess_queue_key="preprocess_queue",
    register_queue_key="register_queue"
):
    client = redis.Redis(
        host=host, port=port, charset="utf-8", decode_responses=True
    )
    while True:
        logger.debug("waiting for redis queue '%s'..." % preprocess_queue_key)
        value = client.brpop(preprocess_queue_key)
        preprocessor(
            collection,
            value[1],
            replace=replace,
            client=client,
            register_queue_key=register_queue_key
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.description = textwrap.dedent("""\
    Preprocess product data.
    """)

    parser.add_argument(
        "--mode", default="standard", choices=["standard", "redis"],
        help=(
            "The mode to run the preprocessor. Either one-off (standard) or "
            "reading from a redis queue."
        )
    )
    parser.add_argument(
        "--tar-object-path", default=None,
        help=(
            "Path to object holding tar archive file of product."
        )
    )
    parser.add_argument(
        "--replace", action="store_true",
        help=(
            "Replace existing products instead of skipping the preprocessing."
        )
    )
    parser.add_argument(
        "--redis-preprocess-queue-key", default="preprocess_queue"
    )
    parser.add_argument(
        "--redis-register-queue-key", default="register_queue"
    )
    parser.add_argument(
        "--redis-host", default="localhost"
    )
    parser.add_argument(
        "--redis-port", type=int, default=6379
    )

    parser.add_argument(
        "-v", "--verbosity", type=int, default=3, choices=[0, 1, 2, 3, 4],
        help=(
            "Set verbosity of log output "
            "(4=DEBUG, 3=INFO, 2=WARNING, 1=ERROR, 0=CRITICAL). (default: 3)"
        )
    )

    arg_values = parser.parse_args()

    setup_logging(arg_values.verbosity)

    collection = os.environ.get('Collection')
    if collection is None:
        logger.critical("Collection environment variable not set.")
        sys.exit(1)

    if arg_values.mode == "standard":
        preprocessor(
            collection,
            arg_values.tar_object_path,
            replace=arg_values.replace,
        )
    else:
        preprocessor_redis_wrapper(
            collection,
            replace=arg_values.replace,
            host=arg_values.redis_host,
            port=arg_values.redis_port,
            preprocess_queue_key=arg_values.redis_preprocess_queue_key,
            register_queue_key=arg_values.redis_register_queue_key,
        )