#!/usr/bin/python3
# pylint: disable=invalid-name

"""
Download HTTP access logs from AWS s3 storage, parse them, increment on
frontend, and clean up.

AWS CLI cheatsheet:

    aws s3 ls
    aws s3 ls s3://fedora-copr
    aws s3 cp s3://fedora-copr/cloudwatch/E2PUZIRCXCOXTG.2021-12-07-15.00d7a244.gz ./

Token permissions are required:
https://pagure.io/fedora-infrastructure/issue/10395
"""


import os
import argparse
import logging
import tempfile
import gzip
from socket import gethostname
import boto3
from copr_common.log import setup_script_logger
from copr_backend.hitcounter import update_frontend


# We will allow only this hostname to delete files from the S3 storage
PRODUCTION_HOSTNAME = "copr-be.aws.fedoraproject.org"
DEVEL_HOSTNAME = "copr-be-dev.aws.fedoraproject.org"

PRODUCTION_CDN_HOSTNAMES = [
    "download.copr.fedorainfracloud.org",
    "d1nld9ovj32u75.cloudfront.net",
]

DEVEL_CDN_HOSTNAMES = [
    "download.copr-dev.fedorainfracloud.org",
    "d1p7mxc66bhrst.cloudfront.net",
]


log = logging.getLogger(__name__)
setup_script_logger(log, "/var/log/copr-backend/hitcounter-s3.log")


class S3Bucket:
    """
    A high-level interface for interacting with files in the AWS s3 buckets
    """

    def __init__(self, bucket=None, directory=None, dry_run=False):
        self.s3 = boto3.client("s3")
        self.bucket = bucket or "fedora-copr"
        self.directory = directory or "cloudwatch/"
        self.dry_run = dry_run

    def list_files(self):
        """
        List all files within our AWS s3 bucket
        """
        paginator = self.s3.get_paginator("list_objects")
        page_iterator = paginator.paginate(
            Bucket=self.bucket,
            Prefix=self.directory)

        result = []
        for page in page_iterator:
            for obj in page.get("Contents", []):
                result.append(obj["Key"])
        return result

    def download_file(self, s3file, dstdir):
        """
        Download a file from AWS s3 bucket
        """
        dst = os.path.join(dstdir, os.path.basename(s3file))
        self.s3.download_file(self.bucket, s3file, dst)
        return dst

    def delete_file(self, s3file):
        """
        Delete a file from AWS s3 bucket
        """
        # Refusing to delete anything from elsewhere than production or devel
        # instance (e.g. from docker container on our laptops)
        if gethostname() not in [PRODUCTION_HOSTNAME, DEVEL_HOSTNAME]:
            log.debug("Not deleting %s on this instance", s3file)
            return
        if self.dry_run:
            return
        self.s3.delete_object(Bucket=self.bucket, Key=s3file)


def gunzip(path):
    """
    Take a .gz file and uncompress it in the same directory
    """
    with gzip.open(path, "rb") as src:
        with open(path.rstrip(".gz"), "w") as dst:
            dst.write(src.read().decode("utf-8"))
    return dst.name


def parse_access_file(path):
    """
    Take a raw access file and return its contents as a list of dicts.
    """
    with open(path, "r") as fd:
        content = fd.readlines()

    # The file starts with meta information and thanks to #Fields, we know what
    # each column means.
    assert content[0].startswith("#Version:")
    assert content[1].startswith("#Fields:")
    keys = content[1].lstrip("#Fields:").split()

    accesses = []
    for line in content[2:]:
        # Make sure we are not parsing any more meta information
        assert not line.startswith("#")

        # Combine field names and this row values to create a dict
        values = line.split()
        access = dict(zip(keys, values))
        accesses.append(access)

    return accesses


def get_cdn_hostnames(args):
    """
    The devel and production accesses are mixed together. Which ones do we want
    to count?
    """
    # If a CDN hostname was explicitly specified when calling the script
    if args.cdn_hostnames:
        return args.cdn_hostnames

    # Count hits from devel CDN hostname on devel instance
    hostname = gethostname()
    if hostname == DEVEL_HOSTNAME:
        return DEVEL_CDN_HOSTNAMES

    # Default to production hits. Don't worry, we don't accidentally
    # remove them from any other instance
    return PRODUCTION_CDN_HOSTNAMES


def check_different_cdn_hostname(accesses, cdn_hostnames):
    """
    If a list of HTTP accesses contain any access for a different CDN hostname
    (e.g. for devel instance when the script is running on production), return
    its value. Otherwise `None`.
    """
    for access in accesses:
        if access["x-host-header"] not in cdn_hostnames:
            return access["x-host-header"]
    return None


def get_arg_parser():
    """
    Generate argument parser for this script
    """
    name = os.path.basename(__file__)
    description = (
        "Download HTTP access logs from AWS s3 storage, parse them, increment "
        "on frontend, and clean up."
    )
    parser = argparse.ArgumentParser(name, description=description)
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help=("Do not perform any destructive changes, only print what "
              "would happen"))
    parser.add_argument(
        "--verbose",
        action="store_true",
        help=("Print verbose information about what is going on"))
    parser.add_argument(
        "--try-indefinitely",
        action="store_true",
        help=("If true, try infinite number of attempts when contacting the "
              "frontend. Do not use this option for cron tasks because the "
              "number of simultaneously running instances might go up"))
    parser.add_argument(
        "--cdn-hostname",
        action="append",
        dest="cdn_hostnames",
        help=("By default the devel instance counts only hits from devel, and "
              "the production instance from production. You can override this "
              "by explicitly specifying the CDN hostname of interest, e.g. {0}"
              .format(PRODUCTION_CDN_HOSTNAMES[0])))
    return parser


def main():
    """
    Main function
    """
    parser = get_arg_parser()
    args = parser.parse_args()
    tmp = tempfile.mkdtemp(prefix="copr-aws-s3-hitcounter-")
    cdn_hostnames = get_cdn_hostnames(args)

    if args.verbose:
        log.setLevel(logging.DEBUG)

    s3 = S3Bucket(dry_run=args.dry_run)
    files = s3.list_files()

    for i, s3file in enumerate(files, start=1):
        gz = s3.download_file(s3file, dstdir=tmp)
        raw = gunzip(gz)
        accesses = parse_access_file(raw)

        # Clean temporary files, we don't need them for the rest of the cycle
        for path in [gz, raw]:
            os.remove(path)

        different_cdn = check_different_cdn_hostname(accesses, cdn_hostnames)
        if different_cdn:
            log.debug("Skipping: %s (different hostname: %s)",
                      s3file, different_cdn)
            continue

        log.info("[%s/%s] %s (%s accesses)",
                 i, len(files), s3file, len(accesses))

        # Maybe we want to use some locking or transaction mechanism to avoid
        # a scenario when we increment the accesses on the frontend but then
        # leave the s3 file untouched, which would result in parsing and
        # incrementing from the same file again in the next run
        update_frontend(accesses, log=log, dry_run=args.dry_run,
                        try_indefinitely=args.try_indefinitely)
        s3.delete_file(s3file)

    os.removedirs(tmp)


if __name__ == "__main__":
    main()
