#!/usr/bin/python
# vim:ts=4 sw=4 expandtab
"""
Get new SAS token from AKV and replace the old one. Meant for use in Azure IAAS and HDI environments
"""
import sys
import os
import tempfile
import base64
import httplib
import json
import syslog


def scrub_sas(token, b64=False):
    """
    Scrub a Signature part of a SAS token
    """
    if b64:
        token = base64.urlsafe_b64decode(token)
    start = token.find('sig=')  # find 'sig=' and replace everything after with '*'
    tlen = len(token)
    token = token[:(start + 4)].ljust(tlen, '*')
    if b64:
        token = base64.urlsafe_b64encode(token)
    return token


def read_old_sas():
    """
    Read first SAS token from /etc/apt/sources.list.d/substrate.list
    """
    try:
        with open("/etc/apt/sources.list.d/substrate.list", 'r') as f:
            for l in f.readlines():
                if l.find('https-sas://') == -1:
                    continue
                return l.rsplit('@')[0].rsplit('/', 1)[1]
    # pylint: disable=broad-except
    except Exception as e:
        syslog.syslog(syslog.LOG_ALERT, "Error opening /etc/apt/sources.list.d/substrate.list: {}".format(e))
        sys.exit(1)


def get_value_by_key(key, filename='/etc/default/gcp'):
    """
    Read and return first key from file with format: key = value\n
    """
    value = False
    try:
        with open(filename, 'r') as f:
            for l in f.readlines():
                if l.find(key) != -1:
                    value = l.rsplit('=', 1)[1].strip()
                    break
    # pylint: disable=broad-except
    except Exception as e:
        syslog.syslog(syslog.LOG_ALERT, "Error finding {} from {}: {}, ".format(key, filename, e))
        return False
    return value


def get_bearer(hdinsight_quirk=False):
    """
    Get auth bearer for keyvault
    """
    msresid = get_value_by_key('MSIresourceID')
    if msresid:
        msresid = "&mi_res_id=" + msresid
    else:
        msresid = ''
    con = httplib.HTTPConnection("169.254.169.254")
    headers = {'Metadata': 'true'}
    con.request('GET', "/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fvault.azure.net" +
                msresid, headers=headers)
    response = con.getresponse()
    if response.status != 200:
        if hdinsight_quirk:
            con.close()
            return False
        else:
            syslog.syslog(syslog.LOG_ALERT, "Unexpected answer ({} : {}) happened while getting authorization token."
                          " Response: {}".format(response.status, response.reason, response.read()))
            sys.exit(1)
    data = response.read()
    con.close()
    return json.loads(data)["access_token"].encode("utf-8")


# pylint: disable=unused-argument
def aad_auth_callback(server, resource, scope):
    """
    Callback for KeyVaultAuthentication
    """
    client_id = enc_msi_setting['clientId']
    with open('/var/lib/waagent/{0}.prv'.format(enc_msi_setting['thumbprint'])) as f:
        key = f.read()
        f.close()
    auth_context = adal.AuthenticationContext(server)
    auth_result = auth_context.acquire_token_with_client_certificate(resource, client_id, key,
                                                                     enc_msi_setting['thumbprint'])
    return auth_result['tokenType'], auth_result['accessToken']


def get_new_sas_hdinsight_way(enc_vault_uri, secret_name):
    """
    Get new SAS from AKV usinng alternate aproach
    """
    kv_client = KeyVaultClient(KeyVaultAuthentication(aad_auth_callback))
    secret_bundle = kv_client.get_secret(enc_vault_uri, secret_name, '')
    return secret_bundle.value


def get_new_sas(kv_name='kaupstestkv', secret_name='gcplinuxdev', hdinsight_quirk=False):
    """
    Aquire new SAS from AKV
    """
    bearer = get_bearer(hdinsight_quirk)
    if not bearer:
        try:
            return get_new_sas_hdinsight_way("https://{}.vault.azure.net".format(kv_name), secret_name)
        # pylint: disable=broad-except
        except Exception as e:
            syslog.syslog(syslog.LOG_ALERT, "Getting new SAS token failed: {}".format(e))
            sys.exit(1)
    con = httplib.HTTPSConnection("{}.vault.azure.net".format(kv_name))
    headers = {'Authorization': "Bearer {}".format(get_bearer())}
    con.request('GET', "/secrets/{}?api-version=2016-10-01".format(secret_name), headers=headers)
    response = con.getresponse()
    if response.status != 200:
        syslog.syslog(syslog.LOG_ALERT, "Unexpected answer ({} : {}) happened while connecting to {}.vault.azure.net"
                      " Response: {}".format(response.status, response.reason, kv_name, response.read()))
        sys.exit(1)
    data = response.read()
    con.close()
    return json.loads(data)["value"].encode("utf-8")


if not os.path.isfile('/etc/default/gcp') or not os.path.isfile('/etc/apt/sources.list.d/substrate.list'):
    sys.exit(0)

RepoComponent = get_value_by_key('RepoComponent')
if not RepoComponent:
    sys.exit(1)

# Mapping for AKV <=> RepoComponent
akvmap = {
    'stable': 'gcplinux',
    'testing': 'gcplinuxtesting',
    'unstable': 'gcplinuxdev'
}

try:
    # We are testing if we can fallback to ADAL for talking directly AAD if MSI is not
    # supported. For an example in HDI clusters this is the only supported way
    # pylint: disable=import-error
    from azure.keyvault import KeyVaultClient, KeyVaultAuthentication
    from hdinsight_common import ClusterManifestParser
    import adal
    cluster_manifest = ClusterManifestParser.parse_local_manifest()
    msi_setting = json.loads(cluster_manifest.settings['managedServiceIdentity'])
    enc_msi_arm_resource_id = get_value_by_key('MSIresourceID')
    if not enc_msi_arm_resource_id:
        sys.exit(1)
    try:
        enc_msi_setting = msi_setting[enc_msi_arm_resource_id]
    except KeyError:
        # pylint: disable=no-member
        enc_msi_setting = msi_setting[enc_msi_arm_resource_id.lower()]
    hdinsight_quirk_available = True
except:
    hdinsight_quirk_available = False


old_sas = read_old_sas()
new_sas = get_new_sas(akvmap[RepoComponent], RepoComponent, hdinsight_quirk_available)

if old_sas == new_sas:
    syslog.syslog(syslog.LOG_INFO, "Newest SAS token is {}, no update is needed".format(scrub_sas(old_sas)))
    sys.exit(0)
try:
    newf = tempfile.NamedTemporaryFile(mode='w+b', prefix='substrate.list', dir='/etc/apt/sources.list.d', delete=False)
    # pylint: disable=broad-except
except Exception as e:
    syslog.syslog(syslog.LOG_ALERT, "Making a new file failed with an exception: {}".format(e))
    sys.exit(1)

try:
    oldf = open("/etc/apt/sources.list.d/substrate.list", 'r')
    # pylint: disable=broad-except
except Exception as e:
    syslog.syslog(syslog.LOG_ALERT, "Error opening /etc/apt/sources.list.d/substrate.list: {}".format(e))
    os.remove(newf.name)
    sys.exit(1)

for line in oldf.readlines():
    if line.find(old_sas) != -1:
        syslog.syslog(syslog.LOG_INFO, "Updating old SAS ({}) with {}".format(scrub_sas(old_sas),
                                                                              scrub_sas(new_sas)))
        line = line.replace(old_sas, new_sas)
    newf.write(line)

newf_name = newf.name
newf.close()

os.rename(newf_name, "/etc/apt/sources.list.d/substrate.list")
