#!/usr/bin/env python2.7
# vim:ts=4 sw=4 expandtab
"""
Get new SAS token from AKV and replace the old one. Meant for use in Azure IAAS and HDI environments
"""
import sys
import os
import re
import tempfile
import base64
import httplib
import json
import syslog
import fcntl
import signal
import logging

logger = logging.getLogger('update-sastoken')

if 'DEBUG' in os.environ:
    logging.basicConfig(level=logging.DEBUG)

source_list_file = "/etc/apt/sources.list.d/substrate.list"

def scrub_sas(token, b64=False):
    """
    Scrub a Signature part of a SAS token
    """
    if b64:
        token = base64.urlsafe_b64decode(token)
    start = token.find('sig=')  # find 'sig=' and replace everything after with '*'
    tlen = len(token)
    token = token[:(start + 4)].ljust(tlen, '*')
    if b64:
        token = base64.urlsafe_b64encode(token)
    return token


def parse_sas_url(filename):
    """
    Read first SAS token from filename
    """
    try:
        with open(filename, 'r') as f:
            for l in f.readlines():
                m = re.match(r'^deb .*https-sas:\/\/(.*)@(.+)\/.*$', l.rstrip())
                if m:
                    account = m.group(2).split('.')[0]
                    return m.group(1), account
    # pylint: disable=broad-except
    except Exception as e:
        syslog.syslog(syslog.LOG_ALERT, "Error opening {}: {}".format(filename, e))
        sys.exit(2)
    return None, None


def get_value_by_key(key, filename='/etc/default/gcp'):
    """
    Read and return first key from file or environment with format: key = value\n
    """
    value = False
    if os.path.isfile(filename):
        try:
            with open(filename, 'r') as f:
                for l in f.readlines():
                    if l.lstrip().startswith('#'):
                        continue
                    if l.find(key) != -1:
                        value = l.rsplit('=', 1)[1].strip()
                        break
        # pylint: disable=broad-except
        except Exception as e:
            syslog.syslog(syslog.LOG_ALERT, "Error finding {} from {}: {}, ".format(key, filename, e))
            return False
    if value == False:
        try:
            if key in os.environ:
                value = os.getenv(key)
        # pylint: disable=broad-except
        except Exception as e:
            syslog.syslog(syslog.LOG_ALERT, "Error finding {} in environment: {}, ".format(key, e))
            return False
    return value

def get_vault_suffix():
    suffix = get_value_by_key('VAULT_SUFFIX')
    if not suffix:
        suffix = 'vault.azure.net'
    return suffix

def get_bearer(hdinsight_quirk=False):
    """
    Get auth bearer for keyvault
    """
    vault_suffix = get_vault_suffix()
    msresid = get_value_by_key('MSIresourceID')
    if msresid:
        msresid = "&mi_res_id=" + msresid
    else:
        msresid = ''
    con = httplib.HTTPConnection("169.254.169.254")
    headers = {'Metadata': 'true'}
    con.request(
        'GET',
        "/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2F{}{}".format(vault_suffix, msresid),
        headers=headers)
    response = con.getresponse()
    if response.status != 200:
        if hdinsight_quirk:
            con.close()
            try:
                with open('/var/lib/waagent/{0}.prv'.format(enc_msi_setting['thumbprint'])) as f:
                    key = f.read()
                    f.close()
                aad_endpoint = get_value_by_key('AAD_ENDPOINT')
                if not aad_endpoint:
                    aad_endpoint = 'login.microsoftonline.com'
                auth_context = adal.AuthenticationContext(
                    "https://{}/{}".format(aad_endpoint, enc_msi_setting['tenantId']),
                    validate_authority=False)
                logger.debug('ADAL auth context: %s', auth_context.__dict__)
                logger.debug('ADAL auth context authority: %s', auth_context.authority.__dict__)
                auth_result = auth_context.acquire_token_with_client_certificate(
                    "https://{}".format(vault_suffix),
                    enc_msi_setting['clientId'],
                    key,
                    enc_msi_setting['thumbprint'])
                logger.debug('ADAL auth result: %s', auth_result)
            # pylint: disable=broad-except
            except Exception as e:
                syslog.syslog(syslog.LOG_ALERT,
                              "Unexpected error ({}) happened while retrieving authorization token.".format(e))
                sys.exit(3)
            return auth_result['accessToken']
        else:
            syslog.syslog(syslog.LOG_ALERT, "Unexpected answer ({} : {}) happened while getting authorization token."
                          " Response: {}".format(response.status, response.reason, response.read()))
            sys.exit(4)
    data = response.read()
    con.close()
    return json.loads(data)["access_token"].encode("utf-8")


def get_tenant_id(hdinsight_quirk=True):
    """
    Get Tenant ID
    """
    try:
        token = get_bearer(hdinsight_quirk)
        t = token.split('.')[1].encode("utf-8")
        tid = json.loads(base64.urlsafe_b64decode(t + '=' * (len(t) % 4)))['tid']
    # pylint: disable=broad-except
    except Exception as e:
        syslog.syslog(syslog.LOG_ALERT, "Failed to get Tenant ID: {}".format(e))
        sys.exit(5)
    return tid


def get_new_sas(kv_name='kaupstestkv', secret_name='gcplinuxdev', hdinsight_quirk=False):
    """
    Aquire new SAS from AKV
    """
    vault_suffix = get_vault_suffix()
    bearer = get_bearer(hdinsight_quirk)
    con = httplib.HTTPSConnection("{}.{}".format(kv_name, vault_suffix))
    headers = {'Authorization': "Bearer {}".format(bearer)}
    con.request('GET', "/secrets/{}?api-version=2016-10-01".format(secret_name), headers=headers)
    response = con.getresponse()
    if response.status != 200:
        syslog.syslog(
            syslog.LOG_ALERT,
            "Unexpected answer ({} : {}) happened while connecting to {}.{} Response: {}".format(
                response.status,
                response.reason,
                kv_name,
                vault_suffix,
                response.read()))
        sys.exit(6)
    data = response.read()
    con.close()
    return json.loads(data)["value"].encode("utf-8")

def main():
    if not os.path.isfile('/etc/default/gcp') and not os.path.isfile(source_list_file):
        sys.exit(0)

    RepoComponent = get_value_by_key('RepoComponent')
    if not RepoComponent:
        syslog.syslog(syslog.LOG_INFO, "'RepoComponent' is missing in /etc/default/gcp, exiting...")
        sys.exit(7)

    # Mapping for Tenant ID <=> AKV
    akvmap = {
        'cdc5aeea-15c5-4db6-b079-fcadd2505dc2': 'kevlarlinuxprod',
        '124edf19-b350-4797-aefc-3206115ffdb3': 'kevlarlinuxgme',
        '72f988bf-86f1-41af-91ab-2d7cd011db47': 'kevlarlinuxmsint'
    }

    storagemap = {
        'stable': 'gcplinux',
        'testing': 'kevlarlinuxgme',
        'unstable': 'kevlarlinuxgme'
    }

    try:
        # We are testing if we can fallback to ADAL for talking directly AAD if MSI is not
        # supported. For an example in HDI clusters this is the only supported way
        # pylint: disable=import-error
        from hdinsight_common import ClusterManifestParser
        global adal # make it available to other functions
        import adal
        cluster_manifest = ClusterManifestParser.parse_local_manifest()
        msi_setting = json.loads(cluster_manifest.settings['managedServiceIdentity'])
        enc_msi_arm_resource_id = get_value_by_key('MSIresourceID')
        global enc_msi_setting # make it available to other functions
        logger.debug('MSI response from ADAL: %s', msi_setting)
        if not enc_msi_arm_resource_id:
            syslog.syslog(syslog.LOG_INFO, "'MSIresourceID' is missing in /etc/default/gcp, exiting...")
            sys.exit(8)
        try:
            enc_msi_setting = msi_setting[enc_msi_arm_resource_id]
        except KeyError:
            # pylint: disable=no-member
            enc_msi_setting = msi_setting[enc_msi_arm_resource_id.lower()]
        hdinsight_quirk_available = True
    # pylint: disable=bare-except
    except Exception as ex:
        syslog.syslog(syslog.LOG_INFO, "Fallback to ADAL does not work. Exception: {}".format(ex))
        hdinsight_quirk_available = False

    old_sas, storagename = parse_sas_url(source_list_file)

    storagename_from_config = get_value_by_key('SASTOKEN_VAULT_SECRET')

    if storagename_from_config:
        storagename = storagename_from_config

    akv_from_config = get_value_by_key('SASTOKEN_VAULT_NAME')

    if akv_from_config:
        akv = akv_from_config
    else:
        akv = akvmap[get_tenant_id(hdinsight_quirk_available)]

    if not akv:
        if RepoComponent in storagemap:
            akv = storagemap[RepoComponent]
        else:
            syslog.syslog(syslog.LOG_INFO, "No defaults for RepoComponent {}, nothing to update".format(RepoComponent))
            sys.exit(0)

    new_sas = get_new_sas(akv, storagename, hdinsight_quirk_available)

    if old_sas == new_sas:
        syslog.syslog(syslog.LOG_INFO, "Newest SAS token is {}, no update is needed".format(scrub_sas(old_sas)))
        sys.exit(0)
    try:
        newf = tempfile.NamedTemporaryFile(mode='w+b', prefix=os.path.basename(source_list_file),
                                           dir=os.path.dirname(source_list_file), delete=False)
        # pylint: disable=broad-except
    except Exception as e:
        syslog.syslog(syslog.LOG_ALERT, "Making a new file failed with an exception: {}".format(e))
        sys.exit(9)

    try:
        oldf = open(source_list_file, 'r')
        # pylint: disable=broad-except
    except Exception as e:
        syslog.syslog(syslog.LOG_ALERT, "Error opening {}: {}".format(source_list_file, e))
        os.remove(newf.name)
        sys.exit(10)

    for line in oldf.readlines():
        if line.find(old_sas) != -1:
            syslog.syslog(syslog.LOG_INFO, "Updating old SAS ({}) with {}".format(scrub_sas(old_sas),
                                                                                  scrub_sas(new_sas)))
            line = line.replace(old_sas, new_sas)
        newf.write(line)

    newf_name = newf.name
    newf.close()

    os.rename(newf_name, source_list_file)

class Locker(object):
    """Acquire exclusive locks on APT functions.

    Locker is a context manager that grabs two locks:
    /var/lib/apt/lists/lock - locks apt-get update
    /var/lib/dpkg/lock-frontend - locks the frontend
    /var/lib/dpkg/lock - locks any package installations

    There is a logic to this sequence, first we prevent apt-get update,
    then, we make sure that no apt commands can be used
    and finally, we lock any package modification.
    """
    def __init__(self, timeout_seconds=300):
        self._update_lock = "/var/lib/apt/lists/lock"
        self._dpkg_lock = "/var/lib/dpkg/lock"
        self._fe_lock = "/var/lib/dpkg/lock-frontend"
        self.lock_file_update = open(self._update_lock, "w")
        self.lock_file_dpkg = open(self._dpkg_lock, "w")
        self.lock_file_fe = open(self._fe_lock, "w")
        # default timeout is 5 minutes
        self.timeout_seconds = timeout_seconds

    # pylint: disable=W0622
    def __enter__(self):
        """Acquire locks.

        Exclusive locks on both files.
        Block until locks are acquired.
        """
        # handle the SIGALRM handler with our own dummy handler
        signal.signal(signal.SIGALRM, self.sighandler)
        # set the alarm
        signal.alarm(self.timeout_seconds)
        # do the locking, will block for max self.timeout_seconds
        fcntl.lockf(self.lock_file_update, fcntl.LOCK_EX)
        fcntl.lockf(self.lock_file_fe, fcntl.LOCK_EX)
        fcntl.lockf(self.lock_file_dpkg, fcntl.LOCK_EX)
        # reset the timer upon successful lock
        signal.alarm(0)

    def __exit__(self, type, value, traceback):
        """Release the locks."""
        self.lock_file_dpkg.close()
        self.lock_file_fe.close()
        self.lock_file_update.close()

    def sighandler(self, signum, frame):
        """Dummy signal handler.

        In this case we won't intend to do anything with the signal.
        Therefore jus allowing it to raise an IOError.
        """
        pass

if __name__ == "__main__":
    # basic support for running without locking
    if len(sys.argv) > 1 and sys.argv[1] == "--do-not-lock":
        main()
    # unless "--do-not-lock" was specified, run with locking
    else:
        # lock all apt operations before entering the main()
        with Locker():
            main()
