#!/usr/bin/env python3

"""
APT method to enable shared access signatures (SAS) authenticated HTTPS
requests to Azure Blob Storage using "https-sas://" scheme in APT data sources.

This method wraps around HTTPS method and acts like MITM between APT and HTTPS
method and rewrites lines starting with "URI:".

SAS token can be provided multiple ways (ordered by priority):

1. Inline, e.g. "deb [arch=amd64] https-sas://sv=2018-03-28...&sig=RVhBTVBMRQ%3D%3D@example.blob.core.windows.net/ ...".
2. File named as hostname in "/etc/apt/sastoken.d/", e.g. "/etc/apt/sastoken.d/example.blob.core.windows.net".
3. File "/etc/apt/sastoken".
"""

import logging
import logging.handlers
import os.path
import re
import subprocess
import sys

from fcntl import F_GETFL, F_SETFL, fcntl
from os import O_NONBLOCK, environ, kill
from select import select
from signal import SIGCHLD, SIGINT, SIGTERM, signal
from urllib.parse import unquote, urlparse


logger = logging.getLogger("apt-transport-https-sas")
syslog_handler = logging.handlers.SysLogHandler(address="/dev/log")
syslog_handler.setFormatter(logging.Formatter("%(name)s.%(funcName)s -- %(levelname)s -- %(message)s"))
syslog_handler.setLevel(logging.DEBUG)
logger.addHandler(syslog_handler)
if "HTTPS_SAS_DEBUG" in environ:
    logger.setLevel(logging.DEBUG)
else:
    # On regular operation write errors to stderr.
    stderr_handler = logging.StreamHandler()
    stderr_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))
    stderr_handler.setLevel(logging.WARNING)
    logger.addHandler(stderr_handler)
    logger.setLevel(logging.WARNING)


_INLINE_SASTOKEN_COMPAT = "COMPATIBILITY_PLACEHOLDER"


class AptMethodHttpsSas:
    """
    Class for rewriting URI lines for HTTPS method and APT. All the decisions
    are made using full URI line and for clarity it is required argument for
    all private functions.
    """
    def https(self, arg_line):
        """
        Process lines sent to HTTPS method (in subprocess). Rewrite lines
        starting with "URI:" by replacing https-sas:// with https:// and adding
        SAS token to the end of URI.
        """
        if arg_line.startswith("URI: "):
            logger.debug("Rewrite: %s", arg_line)
            urn = self._parse_urn(arg_line)
            sastoken = self._get_sastoken(arg_line)
            return_line = "URI: https://{}?{}\n".format(urn, sastoken)
        else:
            return_line = arg_line
        logger.debug("Return: %s", return_line)
        return return_line.encode()

    # This will be set to False by _parse_sastoken_from_uri_line if SAS token
    # is not inline and URI line for APT will be reconstructed without inline
    # SAS token.
    _inline_sastoken = True

    # Set to True if inline SAS token is compatibility placeholder. This is
    # required to support safe package downgrade.
    _inline_sastoken_compat = False

    def apt(self, arg_line):
        """
        Process lines sent back to APT. This is reconstruction of line APT sent
        to this method and we need to add back https-sas:// and inline SAS
        token, if needed.
        """
        if arg_line.startswith("URI: "):
            logger.debug("Rewrite: %s", arg_line)
            return_line = "URI: https-sas://"
            urn = self._parse_urn(arg_line)
            if self._inline_sastoken or self._inline_sastoken_compat:
                if self._inline_sastoken_compat:
                    sastoken = _INLINE_SASTOKEN_COMPAT
                else:
                    sastoken = self._get_sastoken(arg_line)
                return_line += "{}@{}\n".format(sastoken, urn)
            else:
                return_line += "{}\n".format(urn)
        else:
            return_line = arg_line
        logger.debug("Return: %s", return_line)
        return return_line

    def _parse_urn(self, arg_uri_line):
        """
        URN is URI without resource access method (e.g. https://), query
        parameters and fragment identifiers.

        There are 3 possible URI line formats:
        1. https-sas://sv=2018-03-28...&sig=RVhBTVBMRQ%3D%3D@example.blob.core.windows.net/path/to/file
        2. https://example.blob.core.windows.net/path/to/file?sv=2018-03-28...&sig=RVhBTVBMRQ%3D%3D
        3. https-sas://example.blob.core.windows.net/path/to/file

        And URN for all 3 is "example.blob.core.windows.net/path/to/file".
        """
        logger.debug("Argument: %s", arg_uri_line)
        return_urn = None
        match = re.search(r"@(.+)", arg_uri_line) \
            or re.search(r":\/\/(.+)\?", arg_uri_line) \
            or re.search(r":\/\/(.+)", arg_uri_line)
        if match:
            return_urn = match.group(1)
        logger.debug("Return: %s", return_urn)
        return return_urn

    _sastoken = None

    def _get_sastoken(self, arg_uri_line):
        """
        SAS token can be inline or in file. If SAS token is not found, then
        exit with error, because there's nothing left to do without it.
        """
        logger.debug("Argument: %s", arg_uri_line)
        return_sastoken = self._sastoken \
            or self._parse_sastoken(arg_uri_line) \
            or self._read_sastoken_from_file(arg_uri_line)
        if return_sastoken:
            # If SAS token is inline, then decoding percent-encoded sequences
            # is usually handled by APT, but there might be weird cases when it
            # fails and we also need to do it if reading SAS token from file.
            sastoken_unquoted = unquote(return_sastoken)
            if len(sastoken_unquoted) < len(return_sastoken):
                logger.debug("Decoded percent-encoded sequences in SAS token")
                return_sastoken = sastoken_unquoted
            self._sastoken = return_sastoken
        else:
            logger.error("SAS token not found")
            sys.exit(1)
        logger.debug("Return: %s", return_sastoken)
        return return_sastoken

    def _parse_sastoken(self, arg_uri_line):
        """
        Parse inline SAS token from URI line.

        SAS token example:
        sv=2018-03-28&ss=b&srt=co&sp=r&se=1970-01-01T00:00:00Z&spr=https&sig=RVhBTVBMRQ%3D%3D

        Parse inline SAS token from URI line, which can be between "://" and
        "@" (APT -> HTTPS-SAS) or after "?" (HTTPS -> HTTPS-SAS). See
        _parse_urn for possible URI line formats. If SAS token is not in URI
        line, set self._inline_sastoken to False.
        """
        logger.debug("Argument: %s", arg_uri_line)
        return_sastoken = None
        match = re.search(r":\/\/(.+)@", arg_uri_line) \
            or re.search(r"\?(.+)", arg_uri_line)
        if match:
            return_sastoken = match.group(1)
            if return_sastoken == _INLINE_SASTOKEN_COMPAT:
                self._inline_sastoken_compat = True
                return_sastoken = None
                logger.debug("Inline SAS token is compatibility placeholder")
        else:
            self._inline_sastoken = False
            logger.debug("SAS token is not inline")
        logger.debug("Return: %s", return_sastoken)
        return return_sastoken

    def _read_sastoken_from_file(self, arg_uri_line):
        """
        Read SAS token from file.
        """
        logger.debug("Argument: %s", arg_uri_line)
        return_sastoken = None
        sastoken_file_path = self._get_sastoken_file_path(arg_uri_line)
        if sastoken_file_path:
            with open(sastoken_file_path, encoding="utf-8") as sastoken_file_handle:
                return_sastoken = sastoken_file_handle.read().strip()
        logger.debug("Return: %s", return_sastoken)
        return return_sastoken

    def _get_sastoken_file_path(self, arg_uri_line):
        """
        Get SAS token file path. See comment in the beginning of file about SAS
        token file paths.
        """
        logger.debug("Argument: %s", arg_uri_line)
        return_sastoken_file_path = None
        sastoken_file_paths = ["/etc/apt/sastoken"]
        file_path_from_netloc = None
        try:
            urn = self._parse_urn(arg_uri_line)
            netloc_from_uri = urlparse("https://{}".format(urn)).netloc
            file_path_from_netloc = os.path.join("/etc/apt/sastoken.d", netloc_from_uri)
        except Exception as error:
            logger.error(error)
        if file_path_from_netloc:
            sastoken_file_paths.insert(0, file_path_from_netloc)
        for sastoken_file_path_in_list in sastoken_file_paths:
            if os.path.isfile(sastoken_file_path_in_list):
                return_sastoken_file_path = sastoken_file_path_in_list
                break
        logger.debug("Return: %s", return_sastoken_file_path)
        return return_sastoken_file_path


if __name__ == "__main__":
    # Ignore "consider-using-with" if running with newer pylint and ignore
    # "bad-option-value" because older pylint don't know the former.
    # pylint: disable=bad-option-value,consider-using-with
    https = subprocess.Popen(
        ["/usr/lib/apt/methods/https"],
        stdin=subprocess.PIPE,
        stdout=subprocess.PIPE,
        stderr=sys.stderr,
        shell=False,
        close_fds=True,
        bufsize=65536)
    # The SIGCHLD signal is sent to the parent of a child process when it
    # exits, is interrupted, or resumes after being interrupted. If HTTPS
    # subprocess exits supposedly prematurely, let's exit HTTPS-SAS too.
    signal(SIGCHLD, lambda *_: sys.exit(1))
    # If HTTPS-SAS gets interrupted, send SIGINT to HTTPS subprocess too.
    signal(SIGINT, lambda *_: kill(https.pid, SIGINT))
    # Set HTTPS subprocess and HTTPS-SAS as non-blocking, because we need to
    # process every line from stdout (as they come) and write to stdin of
    # running processes.
    fcntl(sys.stdin, F_SETFL, fcntl(sys.stdin, F_GETFL) | O_NONBLOCK)
    fcntl(https.stdout, F_SETFL, fcntl(https.stdout, F_GETFL) | O_NONBLOCK)
    timeout_counter = 0
    https_sas = AptMethodHttpsSas()
    while https.poll() is None:
        # The return value (of select) is a triple of lists of objects that are
        # ready. When the timeout is reached without a file descriptor becoming
        # ready, three empty lists are returned.
        read, _, _ = select([sys.stdin, https.stdout], [], [], 600)
        if not read:
            timeout_counter += 1
            # Let's wait another 2*10 minutes (600 seconds = 10 minutes).
            if timeout_counter < 3:
                # Let's flush too just in case there's something stuck in buffer.
                sys.stdout.flush()
                https.stdin.flush()
                logger.warning("select() timed out, waiting for another 10 minutes.")
                continue
            # HTTPS subprocess probably got stuck with something.
            # For example very slow download speed from blob storage.
            # Let's terminate HTTPS subprocess and exit with error.
            logger.error("Have been waiting for 30 minutes, exiting.")
            kill(https.pid, SIGTERM)
            sys.exit(1)
        # Reset counter if we succeed.
        timeout_counter = 0
        for i in read:
            # APT -> HTTPS-SAS -> HTTPS
            if i.fileno() == sys.stdin.fileno():
                while https.poll() is None:
                    try:
                        # Read line from APT.
                        line = sys.stdin.readline()
                        if len(line) == 0:
                            break
                    except Exception:
                        https.stdin.flush()
                        break
                    # Process line from APT and write it to HTTPS.
                    https.stdin.write(https_sas.https(line))
            # HTTPS -> HTTPS-SAS -> APT
            if i.fileno() == https.stdout.fileno():
                while https.poll() is None:
                    try:
                        # Read line from HTTPS.
                        line = https.stdout.readline().decode()
                        if len(line) == 0:
                            break
                    except Exception:
                        sys.stdout.flush()
                        break
                    # Process line from HTTPS and write it to APT.
                    sys.stdout.write(https_sas.apt(line))
        sys.stdout.flush()
        https.stdin.flush()
