import hashlib
import json

from enum import Enum, unique

from image.docker.schema1 import DockerSchema1ManifestBuilder, DockerSchema1Manifest
from test.registry.protocols import (RegistryProtocol, Failures, ProtocolOptions, PushResult,
                                     PullResult)


@unique
class V2ProtocolSteps(Enum):
  """ Defines the various steps of the protocol, for matching failures. """
  AUTH = 'auth'
  BLOB_HEAD_CHECK = 'blob-head-check'
  GET_MANIFEST = 'get-manifest'
  PUT_MANIFEST = 'put-manifest'
  MOUNT_BLOB = 'mount-blob'
  CATALOG = 'catalog'
  LIST_TAGS = 'list-tags'


class V2Protocol(RegistryProtocol):
  FAILURE_CODES = {
    V2ProtocolSteps.AUTH: {
      Failures.UNAUTHENTICATED: 401,
      Failures.INVALID_REGISTRY: 400,
      Failures.APP_REPOSITORY: 405,
      Failures.ANONYMOUS_NOT_ALLOWED: 401,
      Failures.INVALID_REPOSITORY: 400,
      Failures.NAMESPACE_DISABLED: 400,
    },
    V2ProtocolSteps.MOUNT_BLOB: {
      Failures.UNAUTHORIZED_FOR_MOUNT: 202,
    },
    V2ProtocolSteps.GET_MANIFEST: {
      Failures.UNKNOWN_TAG: 404,
      Failures.UNAUTHORIZED: 403,
      Failures.DISALLOWED_LIBRARY_NAMESPACE: 400,
    },
    V2ProtocolSteps.PUT_MANIFEST: {
      Failures.DISALLOWED_LIBRARY_NAMESPACE: 400,
      Failures.MISSING_TAG: 404,
      Failures.INVALID_TAG: 400,
      Failures.INVALID_IMAGES: 400,
      Failures.INVALID_BLOB: 400,
      Failures.UNSUPPORTED_CONTENT_TYPE: 415,
    },
  }

  def __init__(self, jwk):
    self.jwk = jwk

  def ping(self, session):
    result = session.get('/v2/')
    assert result.status_code == 401
    assert result.headers['Docker-Distribution-API-Version'] == 'registry/2.0'

  def login(self, session, username, password, scopes, expect_success):
    scopes = scopes if isinstance(scopes, list) else [scopes]
    params = {
      'account': username,
      'service': 'localhost:5000',
      'scope': scopes,
    }

    auth = (username, password)
    if not username or not password:
      auth = None

    response = session.get('/v2/auth', params=params, auth=auth)
    if expect_success:
      assert response.status_code / 100 == 2
    else:
      assert response.status_code / 100 == 4

    return response

  def auth(self, session, credentials, namespace, repo_name, scopes=None,
           expected_failure=None):
    """
    Performs the V2 Auth flow, returning the token (if any) and the response.

    Spec: https://docs.docker.com/registry/spec/auth/token/
    """

    scopes = scopes or []
    auth = None
    username = None

    if credentials is not None:
      username, _ = credentials
      auth = credentials

    params = {
      'account': username,
      'service': 'localhost:5000',
    }

    if scopes:
      params['scope'] = scopes

    response = self.conduct(session, 'GET', '/v2/auth', params=params, auth=auth,
                            expected_status=(200, expected_failure, V2ProtocolSteps.AUTH))

    if expected_failure is None:
      assert response.json().get('token') is not None
      return response.json().get('token'), response

    return None, response

  def push(self, session, namespace, repo_name, tag_names, images, credentials=None,
           expected_failure=None, options=None):
    options = options or ProtocolOptions()
    scopes = options.scopes or ['repository:%s:push,pull' % self.repo_name(namespace, repo_name)]
    tag_names = [tag_names] if isinstance(tag_names, str) else tag_names

    # Ping!
    self.ping(session)

    # Perform auth and retrieve a token.
    token, _ = self.auth(session, credentials, namespace, repo_name, scopes=scopes,
                         expected_failure=expected_failure)
    if token is None:
      return

    headers = {
      'Authorization': 'Bearer ' + token,
    }

    # Build fake manifests.
    manifests = {}
    for tag_name in tag_names:
      builder = DockerSchema1ManifestBuilder(namespace, repo_name, tag_name)

      for image in reversed(images):
        checksum = 'sha256:' + hashlib.sha256(image.bytes).hexdigest()

        # If invalid blob references were requested, just make it up.
        if options.manifest_invalid_blob_references:
          checksum = 'sha256:' + hashlib.sha256('notarealthing').hexdigest()

        layer_dict = {'id': image.id, 'parent': image.parent_id}
        if image.config is not None:
          layer_dict['config'] = image.config

        if image.size is not None:
          layer_dict['Size'] = image.size

        builder.add_layer(checksum, json.dumps(layer_dict))

      # Build the manifest.
      manifests[tag_name] = builder.build(self.jwk)

    # Push the layer data.
    checksums = {}
    for image in reversed(images):
      checksum = 'sha256:' + hashlib.sha256(image.bytes).hexdigest()
      checksums[image.id] = checksum

      if not options.skip_head_checks:
        # Layer data should not yet exist.
        self.conduct(session, 'HEAD',
                     '/v2/%s/blobs/%s' % (self.repo_name(namespace, repo_name), checksum),
                     expected_status=(404, expected_failure, V2ProtocolSteps.BLOB_HEAD_CHECK),
                     headers=headers)

      # Check for mounting of blobs.
      if options.mount_blobs and image.id in options.mount_blobs:
        self.conduct(session, 'POST',
                     '/v2/%s/blobs/uploads/' % self.repo_name(namespace, repo_name),
                     params={
                       'mount': checksum,
                       'from': options.mount_blobs[image.id],
                     },
                     expected_status=(201, expected_failure, V2ProtocolSteps.MOUNT_BLOB),
                     headers=headers)
        if expected_failure is not None:
          return
      else:
        # Start a new upload of the layer data.
        response = self.conduct(session, 'POST',
                                '/v2/%s/blobs/uploads/' % self.repo_name(namespace, repo_name),
                                expected_status=202,
                                headers=headers)

        upload_uuid = response.headers['Docker-Upload-UUID']
        new_upload_location = response.headers['Location']
        assert new_upload_location.startswith('http://localhost:5000')

        # We need to make this relative just for the tests because the live server test
        # case modifies the port.
        location = response.headers['Location'][len('http://localhost:5000'):]

        # PATCH the image data into the layer.
        if options.chunks_for_upload is None:
          self.conduct(session, 'PATCH', location, data=image.bytes, expected_status=204,
                       headers=headers)
        else:
          # If chunked upload is requested, upload the data as a series of chunks, checking
          # status at every point.
          for chunk_data in options.chunks_for_upload:
            if len(chunk_data) == 3:
              (start_byte, end_byte, expected_code) = chunk_data
            else:
              (start_byte, end_byte) = chunk_data
              expected_code = 204

            patch_headers = {'Range': 'bytes=%s-%s' % (start_byte, end_byte)}
            patch_headers.update(headers)

            contents_chunk = image.bytes[start_byte:end_byte]
            self.conduct(session, 'PATCH', location, data=contents_chunk,
                         expected_status=expected_code,
                         headers=patch_headers)
            if expected_code != 204:
              return

            # Retrieve the upload status at each point, and ensure it is valid.
            status_url = '/v2/%s/blobs/uploads/%s' % (self.repo_name(namespace, repo_name),
                                                      upload_uuid)
            response = self.conduct(session, 'GET', status_url, expected_status=204,
                                    headers=headers)
            assert response.headers['Docker-Upload-UUID'] == upload_uuid
            assert response.headers['Range'] == "bytes=0-%s" % end_byte

        if options.cancel_blob_upload:
          self.conduct(session, 'DELETE', location, params=dict(digest=checksum),
                       expected_status=204, headers=headers)

          # Ensure the upload was canceled.
          status_url = '/v2/%s/blobs/uploads/%s' % (self.repo_name(namespace, repo_name),
                                                    upload_uuid)
          self.conduct(session, 'GET', status_url, expected_status=404, headers=headers)
          return

        # Finish the layer upload with a PUT.
        response = self.conduct(session, 'PUT', location, params=dict(digest=checksum),
                                expected_status=201, headers=headers)
        assert response.headers['Docker-Content-Digest'] == checksum

      # Ensure the layer exists now.
      response = self.conduct(session, 'HEAD',
                              '/v2/%s/blobs/%s' % (self.repo_name(namespace, repo_name), checksum),
                              expected_status=200, headers=headers)

      assert response.headers['Docker-Content-Digest'] == checksum
      assert response.headers['Content-Length'] == str(len(image.bytes))

      # And retrieve the layer data.
      result = self.conduct(session, 'GET',
                            '/v2/%s/blobs/%s' % (self.repo_name(namespace, repo_name), checksum),
                            headers=headers, expected_status=200)
      assert result.content == image.bytes

    # Write a manifest for each tag.
    for tag_name in tag_names:
      manifest = manifests[tag_name]

      # Write the manifest. If we expect it to be invalid, we expect a 404 code. Otherwise, we
      # expect a 202 response for success.
      put_code = 404 if options.manifest_invalid_blob_references else 202
      manifest_headers = {'Content-Type': 'application/json'}
      manifest_headers.update(headers)

      if options.manifest_content_type is not None:
        manifest_headers['Content-Type'] = options.manifest_content_type

      self.conduct(session, 'PUT',
                   '/v2/%s/manifests/%s' % (self.repo_name(namespace, repo_name), tag_name),
                   data=manifest.bytes,
                   expected_status=(put_code, expected_failure, V2ProtocolSteps.PUT_MANIFEST),
                   headers=manifest_headers)

    return PushResult(checksums=checksums, manifests=manifests, headers=headers)


  def delete(self, session, namespace, repo_name, tag_names, credentials=None,
             expected_failure=None, options=None):
    options = options or ProtocolOptions()
    scopes = options.scopes or ['repository:%s:*' % self.repo_name(namespace, repo_name)]
    tag_names = [tag_names] if isinstance(tag_names, str) else tag_names

    # Ping!
    self.ping(session)

    # Perform auth and retrieve a token.
    token, _ = self.auth(session, credentials, namespace, repo_name, scopes=scopes,
                         expected_failure=expected_failure)
    if token is None:
      return None

    headers = {
      'Authorization': 'Bearer ' + token,
    }

    for tag_name in tag_names:
      self.conduct(session, 'DELETE',
                   '/v2/%s/manifests/%s' % (self.repo_name(namespace, repo_name), tag_name),
                   headers=headers,
                   expected_status=202)


  def pull(self, session, namespace, repo_name, tag_names, images, credentials=None,
           expected_failure=None, options=None):
    options = options or ProtocolOptions()
    scopes = options.scopes or ['repository:%s:pull' % self.repo_name(namespace, repo_name)]
    tag_names = [tag_names] if isinstance(tag_names, str) else tag_names

    # Ping!
    self.ping(session)

    # Perform auth and retrieve a token.
    token, _ = self.auth(session, credentials, namespace, repo_name, scopes=scopes,
                         expected_failure=expected_failure)
    if token is None:
      return None

    headers = {
      'Authorization': 'Bearer ' + token,
    }

    manifests = {}
    image_ids = {}
    for tag_name in tag_names:
      # Retrieve the manifest for the tag or digest.
      response = self.conduct(session, 'GET',
                              '/v2/%s/manifests/%s' % (self.repo_name(namespace, repo_name),
                                                       tag_name),
                              expected_status=(200, expected_failure, V2ProtocolSteps.GET_MANIFEST),
                              headers=headers)
      if expected_failure is not None:
        return None

      # Ensure the manifest returned by us is valid.
      manifest = DockerSchema1Manifest(response.text)
      manifests[tag_name] = manifest
      image_ids[tag_name] = manifest.leaf_layer.v1_metadata.image_id

      # Verify the layers.
      for index, layer in enumerate(manifest.layers):
        result = self.conduct(session, 'GET',
                              '/v2/%s/blobs/%s' % (self.repo_name(namespace, repo_name),
                                                   layer.digest),
                              expected_status=200,
                              headers=headers)
        assert result.content == images[index].bytes

    return PullResult(manifests=manifests, image_ids=image_ids)


  def tags(self, session, namespace, repo_name, page_size=2, credentials=None, options=None,
           expected_failure=None):
    options = options or ProtocolOptions()
    scopes = options.scopes or ['repository:%s:pull' % self.repo_name(namespace, repo_name)]

    # Ping!
    self.ping(session)

    # Perform auth and retrieve a token.
    headers = {}
    if credentials is not None:
      token, _ = self.auth(session, credentials, namespace, repo_name, scopes=scopes,
                           expected_failure=expected_failure)
      if token is None:
        return None

      headers = {
        'Authorization': 'Bearer ' + token,
      }

    results = []
    url = '/v2/%s/tags/list' % (self.repo_name(namespace, repo_name))
    params = {}
    if page_size is not None:
      params['n'] = page_size

    while True:
      response = self.conduct(session, 'GET', url, headers=headers, params=params,
                              expected_status=(200, expected_failure, V2ProtocolSteps.LIST_TAGS))
      data = response.json()

      assert len(data['tags']) <= page_size
      results.extend(data['tags'])

      if not response.headers.get('Link'):
        return results

      link_url = response.headers['Link']
      v2_index = link_url.find('/v2/')
      url = link_url[v2_index:]

    return results

  def catalog(self, session, page_size=2, credentials=None, options=None, expected_failure=None,
              namespace=None, repo_name=None, bearer_token=None):
    options = options or ProtocolOptions()
    scopes = options.scopes or []

    # Ping!
    self.ping(session)

    # Perform auth and retrieve a token.
    headers = {}
    if credentials is not None:
      token, _ = self.auth(session, credentials, namespace, repo_name, scopes=scopes,
                           expected_failure=expected_failure)
      if token is None:
        return None

      headers = {
        'Authorization': 'Bearer ' + token,
      }

    if bearer_token is not None:
      headers = {
        'Authorization': 'Bearer ' + bearer_token,
      }

    results = []
    url = '/v2/_catalog'
    params = {}
    if page_size is not None:
      params['n'] = page_size

    while True:
      response = self.conduct(session, 'GET', url, headers=headers, params=params,
                              expected_status=(200, expected_failure, V2ProtocolSteps.CATALOG))
      data = response.json()

      assert len(data['repositories']) <= page_size
      results.extend(data['repositories'])

      if not response.headers.get('Link'):
        return results

      link_url = response.headers['Link']
      v2_index = link_url.find('/v2/')
      url = link_url[v2_index:]

    return results