#!/usr/bin/python
# vim:ts=4 sw=4 expandtab
"""
Get new SAS token from AKV and replace the old one. Meant for use in Azure IAAS and HDI environments
"""
import sys
import os
import re
import tempfile
import base64
import httplib
import json
import syslog

source_list_file = "/etc/apt/sources.list.d/substrate.list"

def scrub_sas(token, b64=False):
    """
    Scrub a Signature part of a SAS token
    """
    if b64:
        token = base64.urlsafe_b64decode(token)
    start = token.find('sig=')  # find 'sig=' and replace everything after with '*'
    tlen = len(token)
    token = token[:(start + 4)].ljust(tlen, '*')
    if b64:
        token = base64.urlsafe_b64encode(token)
    return token


def parse_sas_url(filename):
    """
    Read first SAS token from filename
    """
    try:
        with open(filename, 'r') as f:
            for l in f.readlines():
                m = re.match(r'^deb .*https-sas:\/\/(.*)@(.+)\/.*$', l.rstrip())
                if m:
                    account = m.group(2).split('.')[0]
                    return m.group(1), account
    # pylint: disable=broad-except
    except Exception as e:
        syslog.syslog(syslog.LOG_ALERT, "Error opening {}: {}".format(filename, e))
        sys.exit(1)
    return None, None

def get_value_by_key(key, filename='/etc/default/gcp'):
    """
    Read and return first key from file with format: key = value\n
    """
    value = False
    try:
        with open(filename, 'r') as f:
            for l in f.readlines():
                if l.find(key) != -1:
                    value = l.rsplit('=', 1)[1].strip()
                    break
    # pylint: disable=broad-except
    except Exception as e:
        syslog.syslog(syslog.LOG_ALERT, "Error finding {} from {}: {}, ".format(key, filename, e))
        return False
    return value


def get_bearer(hdinsight_quirk=False):
    """
    Get auth bearer for keyvault
    """
    msresid = get_value_by_key('MSIresourceID')
    if msresid:
        msresid = "&mi_res_id=" + msresid
    else:
        msresid = ''
    con = httplib.HTTPConnection("169.254.169.254")
    headers = {'Metadata': 'true'}
    con.request('GET', "/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fvault.azure.net" +
                msresid, headers=headers)
    response = con.getresponse()
    if response.status != 200:
        if hdinsight_quirk:
            con.close()
            try:
                with open('/var/lib/waagent/{0}.prv'.format(enc_msi_setting['thumbprint'])) as f:
                    key = f.read()
                    f.close()
                auth_context = adal.AuthenticationContext(
                    "https://login.microsoftonline.com/{}".format(enc_msi_setting['tenantId']))
                auth_result = auth_context.acquire_token_with_client_certificate("https://vault.azure.net",
                                                                                 enc_msi_setting['clientId'],
                                                                                 key, enc_msi_setting['thumbprint'])
            # pylint: disable=broad-except
            except Exception as e:
                syslog.syslog(syslog.LOG_ALERT,
                              "Unexpected error ({}) happened while retrieving authorization token.".format(e))
                sys.exit(1)
            return auth_result['accessToken']
        else:
            syslog.syslog(syslog.LOG_ALERT, "Unexpected answer ({} : {}) happened while getting authorization token."
                          " Response: {}".format(response.status, response.reason, response.read()))
            sys.exit(1)
    data = response.read()
    con.close()
    return json.loads(data)["access_token"].encode("utf-8")


def get_tenant_id(hdinsight_quirk=True):
    """
    Get Tenant ID
    """
    try:
        token = get_bearer(hdinsight_quirk)
        t = token.split('.')[1].encode("utf-8")
        tid = json.loads(base64.urlsafe_b64decode(t + '=' * (len(t) % 4)))['tid']
    # pylint: disable=broad-except
    except Exception as e:
        syslog.syslog(syslog.LOG_ALERT, "Failed to get Tenant ID: {}".format(e))
        sys.exit(1)
    return tid


def get_new_sas(kv_name='kaupstestkv', secret_name='gcplinuxdev', hdinsight_quirk=False):
    """
    Aquire new SAS from AKV
    """
    bearer = get_bearer(hdinsight_quirk)
    con = httplib.HTTPSConnection("{}.vault.azure.net".format(kv_name))
    headers = {'Authorization': "Bearer {}".format(bearer)}
    con.request('GET', "/secrets/{}?api-version=2016-10-01".format(secret_name), headers=headers)
    response = con.getresponse()
    if response.status != 200:
        syslog.syslog(syslog.LOG_ALERT, "Unexpected answer ({} : {}) happened while connecting to {}.vault.azure.net"
                      " Response: {}".format(response.status, response.reason, kv_name, response.read()))
        sys.exit(1)
    data = response.read()
    con.close()
    return json.loads(data)["value"].encode("utf-8")


if not os.path.isfile('/etc/default/gcp') or not os.path.isfile(source_list_file):
    sys.exit(0)

RepoComponent = get_value_by_key('RepoComponent')
if not RepoComponent:
    sys.exit(1)

# Mapping for Tenant ID <=> AKV
akvmap = {
    'cdc5aeea-15c5-4db6-b079-fcadd2505dc2': 'kevlarlinuxprod',
    '124edf19-b350-4797-aefc-3206115ffdb3': 'kevlarlinuxgme',
    '72f988bf-86f1-41af-91ab-2d7cd011db47': 'kevlarlinuxmsint'
}

storagemap = {
    'stable': 'gcplinux',
    'testing': 'kevlarlinuxgme',
    'unstable': 'kevlarlinuxgme'
}

try:
    # We are testing if we can fallback to ADAL for talking directly AAD if MSI is not
    # supported. For an example in HDI clusters this is the only supported way
    # pylint: disable=import-error
    from hdinsight_common import ClusterManifestParser
    import adal
    cluster_manifest = ClusterManifestParser.parse_local_manifest()
    msi_setting = json.loads(cluster_manifest.settings['managedServiceIdentity'])
    enc_msi_arm_resource_id = get_value_by_key('MSIresourceID')
    if not enc_msi_arm_resource_id:
        sys.exit(1)
    try:
        enc_msi_setting = msi_setting[enc_msi_arm_resource_id]
    except KeyError:
        # pylint: disable=no-member
        enc_msi_setting = msi_setting[enc_msi_arm_resource_id.lower()]
    hdinsight_quirk_available = True
# pylint: disable=bare-except
except:
    hdinsight_quirk_available = False

old_sas, storagename = parse_sas_url(source_list_file)

akv = akvmap[get_tenant_id(hdinsight_quirk_available)]

if not akv:
    if RepoComponent in storagemap:
        akv = storagemap[RepoComponent]
    else:
        syslog.syslog(syslog.LOG_INFO, "No defaults for RepoComponent {}, nothing to update".format(RepoComponent))
        sys.exit(0)

new_sas = get_new_sas(akv, storagename, hdinsight_quirk_available)

if old_sas == new_sas:
    syslog.syslog(syslog.LOG_INFO, "Newest SAS token is {}, no update is needed".format(scrub_sas(old_sas)))
    sys.exit(0)
try:
    newf = tempfile.NamedTemporaryFile(mode='w+b', prefix=os.path.basename(source_list_file),
                                       dir=os.path.dirname(source_list_file), delete=False)
    # pylint: disable=broad-except
except Exception as e:
    syslog.syslog(syslog.LOG_ALERT, "Making a new file failed with an exception: {}".format(e))
    sys.exit(1)

try:
    oldf = open(source_list_file, 'r')
    # pylint: disable=broad-except
except Exception as e:
    syslog.syslog(syslog.LOG_ALERT, "Error opening {}: {}".format(source_list_file, e))
    os.remove(newf.name)
    sys.exit(1)

for line in oldf.readlines():
    if line.find(old_sas) != -1:
        syslog.syslog(syslog.LOG_INFO, "Updating old SAS ({}) with {}".format(scrub_sas(old_sas),
                                                                              scrub_sas(new_sas)))
        line = line.replace(old_sas, new_sas)
    newf.write(line)

newf_name = newf.name
newf.close()

os.rename(newf_name, source_list_file)
