"""
schema1 implements pure data transformations according to the Docker Manifest v2.1 Specification.

https://github.com/docker/distribution/blob/master/docs/spec/manifest-v2-1.md
"""

import hashlib
import json
import logging

from collections import namedtuple, OrderedDict
from datetime import datetime

from jwkest.jws import SIGNER_ALGS, keyrep
from jwt.utils import base64url_encode, base64url_decode

from digest import digest_tools
from image.docker import ManifestException
from image.docker.v1 import DockerV1Metadata


logger = logging.getLogger(__name__)


# Content Types
DOCKER_SCHEMA1_MANIFEST_CONTENT_TYPE = 'application/vnd.docker.distribution.manifest.v1+json'
DOCKER_SCHEMA1_SIGNED_MANIFEST_CONTENT_TYPE = 'application/vnd.docker.distribution.manifest.v1+prettyjws'
DOCKER_SCHEMA1_CONTENT_TYPES = {DOCKER_SCHEMA1_MANIFEST_CONTENT_TYPE,
                                DOCKER_SCHEMA1_SIGNED_MANIFEST_CONTENT_TYPE}

# Keys for signature-related data
DOCKER_SCHEMA1_SIGNATURES_KEY = 'signatures'
DOCKER_SCHEMA1_HEADER_KEY = 'header'
DOCKER_SCHEMA1_SIGNATURE_KEY = 'signature'
DOCKER_SCHEMA1_PROTECTED_KEY = 'protected'
DOCKER_SCHEMA1_FORMAT_LENGTH_KEY = 'formatLength'
DOCKER_SCHEMA1_FORMAT_TAIL_KEY = 'formatTail'

# Keys for manifest-related data
DOCKER_SCHEMA1_REPO_NAME_KEY = 'name'
DOCKER_SCHEMA1_REPO_TAG_KEY = 'tag'
DOCKER_SCHEMA1_ARCH_KEY = 'architecture'
DOCKER_SCHEMA1_FS_LAYERS_KEY = 'fsLayers'
DOCKER_SCHEMA1_BLOB_SUM_KEY = 'blobSum'
DOCKER_SCHEMA1_HISTORY_KEY = 'history'
DOCKER_SCHEMA1_V1_COMPAT_KEY = 'v1Compatibility'
DOCKER_SCHEMA1_SCHEMA_VER_KEY = 'schemaVersion'

# Format for time used in the protected payload.
_ISO_DATETIME_FORMAT_ZULU = '%Y-%m-%dT%H:%M:%SZ'

# The algorithm we use to sign the JWS.
_JWS_SIGNING_ALGORITHM = 'RS256'


class MalformedSchema1Manifest(ManifestException):
  """
  Raised when a manifest fails an assertion that should be true according to the Docker Manifest
  v2.1 Specification.
  """
  pass


class InvalidSchema1Signature(ManifestException):
  """
  Raised when there is a failure verifying the signature of a signed Docker 2.1 Manifest.
  """
  pass


class Schema1Layer(namedtuple('Schema1Layer', ['digest', 'v1_metadata', 'raw_v1_metadata'])):
  """
  Represents all of the data about an individual layer in a given Manifest.
  This is the union of the fsLayers (digest) and the history entries (v1_compatibility).
  """


class Schema1V1Metadata(namedtuple('Schema1V1Metadata', ['image_id', 'parent_image_id', 'created',
                                                         'comment', 'command', 'labels'])):
  """
  Represents the necessary data extracted from the v1 compatibility string in a given layer of a
  Manifest.
  """


class DockerSchema1Manifest(object):
  def __init__(self, manifest_bytes, validate=True):
    self._layers = None
    self._bytes = manifest_bytes

    try:
      self._parsed = json.loads(manifest_bytes)
    except ValueError as ve:
      raise MalformedSchema1Manifest('malformed manifest data: %s' % ve)

    self._signatures = self._parsed[DOCKER_SCHEMA1_SIGNATURES_KEY]
    self._tag = self._parsed[DOCKER_SCHEMA1_REPO_TAG_KEY]

    repo_name = self._parsed[DOCKER_SCHEMA1_REPO_NAME_KEY]
    repo_name_tuple = repo_name.split('/')
    if len(repo_name_tuple) > 1:
      self._namespace, self._repo_name = repo_name_tuple
    elif len(repo_name_tuple) == 1:
      self._namespace = ''
      self._repo_name = repo_name_tuple[0]
    else:
      raise MalformedSchema1Manifest('malformed repository name: %s' % repo_name)

    if validate:
      self._validate()

  def _validate(self):
    for signature in self._signatures:
      bytes_to_verify = '{0}.{1}'.format(signature['protected'],
                                         base64url_encode(self.payload))
      signer = SIGNER_ALGS[signature['header']['alg']]
      key = keyrep(signature['header']['jwk'])
      gk = key.get_key()
      sig = base64url_decode(signature['signature'].encode('utf-8'))
      verified = signer.verify(bytes_to_verify, sig, gk)
      if not verified:
        raise InvalidSchema1Signature()

  @property
  def content_type(self):
    return DOCKER_SCHEMA1_SIGNED_MANIFEST_CONTENT_TYPE

  @property
  def media_type(self):
    return DOCKER_SCHEMA1_SIGNED_MANIFEST_CONTENT_TYPE

  @property
  def signatures(self):
    return self._signatures

  @property
  def namespace(self):
    return self._namespace

  @property
  def repo_name(self):
    return self._repo_name

  @property
  def tag(self):
    return self._tag

  @property
  def json(self):
    return self._bytes

  @property
  def bytes(self):
    return self._bytes

  @property
  def manifest_json(self):
    return self._parsed

  @property
  def digest(self):
    return digest_tools.sha256_digest(self.payload)

  @property
  def image_ids(self):
    return {mdata.v1_metadata.image_id for mdata in self.layers}

  @property
  def parent_image_ids(self):
    return {mdata.v1_metadata.parent_image_id for mdata in self.layers
            if mdata.v1_metadata.parent_image_id}

  @property
  def checksums(self):
    return list({str(mdata.digest) for mdata in self.layers})

  @property
  def leaf_layer(self):
    return self.layers[-1]

  @property
  def layers(self):
    if self._layers is None:
      self._layers = list(self._generate_layers())
    return self._layers

  def _generate_layers(self):
    """
    Returns a generator of objects that have the blobSum and v1Compatibility keys in them,
    starting from the base image and working toward the leaf node.
    """
    for blob_sum_obj, history_obj in reversed(zip(self._parsed[DOCKER_SCHEMA1_FS_LAYERS_KEY],
                                                  self._parsed[DOCKER_SCHEMA1_HISTORY_KEY])):

      try:
        image_digest = digest_tools.Digest.parse_digest(blob_sum_obj[DOCKER_SCHEMA1_BLOB_SUM_KEY])
      except digest_tools.InvalidDigestException:
        raise MalformedSchema1Manifest('could not parse manifest digest: %s' %
                                       blob_sum_obj[DOCKER_SCHEMA1_BLOB_SUM_KEY])

      metadata_string = history_obj[DOCKER_SCHEMA1_V1_COMPAT_KEY]

      v1_metadata = json.loads(metadata_string)
      command_list = v1_metadata.get('container_config', {}).get('Cmd', None)
      command = json.dumps(command_list) if command_list else None

      if not 'id' in v1_metadata:
        raise MalformedSchema1Manifest('id field missing from v1Compatibility JSON')

      labels = v1_metadata.get('config', {}).get('Labels', {}) or {}
      extracted = Schema1V1Metadata(v1_metadata['id'], v1_metadata.get('parent'),
                                    v1_metadata.get('created'), v1_metadata.get('comment'),
                                    command, labels)
      yield Schema1Layer(image_digest, extracted, metadata_string)

  @property
  def payload(self):
    protected = str(self._signatures[0][DOCKER_SCHEMA1_PROTECTED_KEY])
    parsed_protected = json.loads(base64url_decode(protected))
    signed_content_head = self._bytes[:parsed_protected[DOCKER_SCHEMA1_FORMAT_LENGTH_KEY]]
    signed_content_tail = base64url_decode(str(parsed_protected[DOCKER_SCHEMA1_FORMAT_TAIL_KEY]))
    return signed_content_head + signed_content_tail

  def rewrite_invalid_image_ids(self, images_map):
    """
    Rewrites Docker v1 image IDs and returns a generator of DockerV1Metadata.

    If Docker gives us a layer with a v1 image ID that already points to existing
    content, but the checksums don't match, then we need to rewrite the image ID
    to something new in order to ensure consistency.
    """

    # Used to synthesize a new "content addressable" image id
    digest_history = hashlib.sha256()
    has_rewritten_ids = False
    updated_id_map = {}

    for layer in self.layers:
      digest_str = str(layer.digest)
      extracted_v1_metadata = layer.v1_metadata
      working_image_id = extracted_v1_metadata.image_id

      # Update our digest_history hash for the new layer data.
      digest_history.update(digest_str)
      digest_history.update("@")
      digest_history.update(layer.raw_v1_metadata.encode('utf-8'))
      digest_history.update("|")

      # Ensure that the v1 image's storage matches the V2 blob. If not, we've
      # found a data inconsistency and need to create a new layer ID for the V1
      # image, and all images that follow it in the ancestry chain.
      digest_mismatch = (extracted_v1_metadata.image_id in images_map and
                         images_map[extracted_v1_metadata.image_id].content_checksum != digest_str)
      if digest_mismatch or has_rewritten_ids:
        working_image_id = digest_history.hexdigest()
        has_rewritten_ids = True

      # Store the new docker id in the map
      updated_id_map[extracted_v1_metadata.image_id] = working_image_id

      # Lookup the parent image for the layer, if any.
      parent_image_id = extracted_v1_metadata.parent_image_id
      if parent_image_id is not None:
        parent_image_id = updated_id_map.get(parent_image_id, parent_image_id)

      # Synthesize and store the v1 metadata in the db.
      v1_metadata_json = layer.raw_v1_metadata
      if has_rewritten_ids:
        v1_metadata_json = _updated_v1_metadata(v1_metadata_json, updated_id_map)

      updated_image = DockerV1Metadata(
        namespace_name=self.namespace,
        repo_name=self.repo_name,
        image_id=working_image_id,
        created=extracted_v1_metadata.created,
        comment=extracted_v1_metadata.comment,
        command=extracted_v1_metadata.command,
        compat_json=v1_metadata_json,
        parent_image_id=parent_image_id,
        checksum=None, # TODO: Check if we need this.
        content_checksum=digest_str,
      )

      yield updated_image


class DockerSchema1ManifestBuilder(object):
  """
  A convenient abstraction around creating new DockerSchema1Manifests.
  """
  def __init__(self, namespace_name, repo_name, tag, architecture='amd64'):
    repo_name_key = '{0}/{1}'.format(namespace_name, repo_name)
    if namespace_name == '':
      repo_name_key = repo_name

    self._base_payload = {
      DOCKER_SCHEMA1_REPO_TAG_KEY: tag,
      DOCKER_SCHEMA1_REPO_NAME_KEY: repo_name_key,
      DOCKER_SCHEMA1_ARCH_KEY: architecture,
      DOCKER_SCHEMA1_SCHEMA_VER_KEY: 1,
    }

    self._fs_layer_digests = []
    self._history = []

  def add_layer(self, layer_digest, v1_json_metadata):
    self._fs_layer_digests.append({
      DOCKER_SCHEMA1_BLOB_SUM_KEY: layer_digest,
    })
    self._history.append({
      DOCKER_SCHEMA1_V1_COMPAT_KEY: v1_json_metadata,
    })
    return self


  def build(self, json_web_key):
    """
    Builds a DockerSchema1Manifest object complete with signature.
    """
    payload = OrderedDict(self._base_payload)
    payload.update({
      DOCKER_SCHEMA1_HISTORY_KEY: self._history,
      DOCKER_SCHEMA1_FS_LAYERS_KEY: self._fs_layer_digests,
    })

    payload_str = json.dumps(payload, indent=3)

    split_point = payload_str.rfind('\n}')

    protected_payload = {
      'formatTail': base64url_encode(payload_str[split_point:]),
      'formatLength': split_point,
      'time': datetime.utcnow().strftime(_ISO_DATETIME_FORMAT_ZULU),
    }
    protected = base64url_encode(json.dumps(protected_payload))
    logger.debug('Generated protected block: %s', protected)

    bytes_to_sign = '{0}.{1}'.format(protected, base64url_encode(payload_str))

    signer = SIGNER_ALGS[_JWS_SIGNING_ALGORITHM]
    signature = base64url_encode(signer.sign(bytes_to_sign, json_web_key.get_key()))
    logger.debug('Generated signature: %s', signature)

    public_members = set(json_web_key.public_members)
    public_key = {comp: value for comp, value in json_web_key.to_dict().items()
                  if comp in public_members}

    signature_block = {
      DOCKER_SCHEMA1_HEADER_KEY: {'jwk': public_key, 'alg': _JWS_SIGNING_ALGORITHM},
      DOCKER_SCHEMA1_SIGNATURE_KEY: signature,
      DOCKER_SCHEMA1_PROTECTED_KEY: protected,
    }

    logger.debug('Encoded signature block: %s', json.dumps(signature_block))

    payload.update({DOCKER_SCHEMA1_SIGNATURES_KEY: [signature_block]})

    return DockerSchema1Manifest(json.dumps(payload, indent=3))


def _updated_v1_metadata(v1_metadata_json, updated_id_map):
  """
  Updates v1_metadata with new image IDs.
  """
  parsed = json.loads(v1_metadata_json)
  parsed['id'] = updated_id_map[parsed['id']]

  if parsed.get('parent') and parsed['parent'] in updated_id_map:
    parsed['parent'] = updated_id_map[parsed['parent']]

  if parsed.get('container_config', {}).get('Image'):
    existing_image = parsed['container_config']['Image']
    if existing_image in updated_id_map:
      parsed['container_config']['image'] = updated_id_map[existing_image]

  return json.dumps(parsed)