From 2b07b6d8a908b3d87b845e038ba4b50780ac2a0b Mon Sep 17 00:00:00 2001
From: Jimmy Zelinskie <jimmy.zelinskie@coreos.com>
Date: Thu, 11 Feb 2016 17:00:38 -0500
Subject: [PATCH] allow HEAD on ACI images

Fixes #911.
---
 endpoints/verbs.py     | 6 ++++--
 storage/basestorage.py | 2 +-
 storage/cloud.py       | 4 +++-
 storage/fakestorage.py | 2 +-
 storage/swift.py       | 4 ++--
 5 files changed, 11 insertions(+), 7 deletions(-)

diff --git a/endpoints/verbs.py b/endpoints/verbs.py
index f9ab20985..907006feb 100644
--- a/endpoints/verbs.py
+++ b/endpoints/verbs.py
@@ -258,7 +258,9 @@ def _repo_verb(namespace, repository, tag, verb, formatter, sign=False, checker=
   if not derived.uploading:
     logger.debug('Derived %s image %s exists in storage', verb, derived.uuid)
     derived_layer_path = model.storage.get_layer_path(derived)
-    download_url = storage.get_direct_download_url(derived.locations, derived_layer_path)
+    is_head_request = request.method == 'HEAD'
+    download_url = storage.get_direct_download_url(derived.locations, derived_layer_path,
+                                                   head=is_head_request)
     if download_url:
       logger.debug('Redirecting to download URL for derived %s image %s', verb, derived.uuid)
       return redirect(download_url)
@@ -359,7 +361,7 @@ def get_aci_signature(server, namespace, repository, tag, os, arch):
 
 
 @anon_protect
-@verbs.route('/aci/<server>/<namespace>/<repository>/<tag>/aci/<os>/<arch>/', methods=['GET'])
+@verbs.route('/aci/<server>/<namespace>/<repository>/<tag>/aci/<os>/<arch>/', methods=['GET', 'HEAD'])
 @process_auth
 def get_aci_image(server, namespace, repository, tag, os, arch):
   return _repo_verb(namespace, repository, tag, 'aci', ACIImage(),
diff --git a/storage/basestorage.py b/storage/basestorage.py
index 75f08d6e7..008eaba47 100644
--- a/storage/basestorage.py
+++ b/storage/basestorage.py
@@ -46,7 +46,7 @@ class BaseStorage(StoragePaths):
         client to use for any external calls. """
     pass
 
-  def get_direct_download_url(self, path, expires_in=60, requires_cors=False):
+  def get_direct_download_url(self, path, expires_in=60, requires_cors=False, head=False):
     return None
 
   def get_direct_upload_url(self, path, mime_type, requires_cors=True):
diff --git a/storage/cloud.py b/storage/cloud.py
index e6831d271..6b6296042 100644
--- a/storage/cloud.py
+++ b/storage/cloud.py
@@ -116,10 +116,12 @@ class _CloudStorage(BaseStorageV2):
   def get_supports_resumable_downloads(self):
     return True
 
-  def get_direct_download_url(self, path, expires_in=60, requires_cors=False):
+  def get_direct_download_url(self, path, expires_in=60, requires_cors=False, head=False):
     self._initialize_cloud_conn()
     path = self._init_path(path)
     k = self._key_class(self._cloud_bucket, path)
+    if head:
+      return k.generate_url(expires_in, 'HEAD')
     return k.generate_url(expires_in)
 
   def get_direct_upload_url(self, path, mime_type, requires_cors=True):
diff --git a/storage/fakestorage.py b/storage/fakestorage.py
index 7187b7a6b..50bb5720e 100644
--- a/storage/fakestorage.py
+++ b/storage/fakestorage.py
@@ -15,7 +15,7 @@ class FakeStorage(BaseStorageV2):
   def _init_path(self, path=None, create=False):
     return path
 
-  def get_direct_download_url(self, path, expires_in=60, requires_cors=False):
+  def get_direct_download_url(self, path, expires_in=60, requires_cors=False, head=False):
     try:
       if self.get_content('supports_direct_download') == 'true':
         return 'http://somefakeurl'
diff --git a/storage/swift.py b/storage/swift.py
index e1dbdfe96..75bd058ce 100644
--- a/storage/swift.py
+++ b/storage/swift.py
@@ -114,7 +114,7 @@ class SwiftStorage(BaseStorage):
       logger.exception('Could not head object: %s', path)
       return None
 
-  def get_direct_download_url(self, object_path, expires_in=60, requires_cors=False):
+  def get_direct_download_url(self, object_path, expires_in=60, requires_cors=False, head=False):
     if requires_cors:
       return None
 
@@ -137,7 +137,7 @@ class SwiftStorage(BaseStorage):
     object_path = self._normalize_path(object_path)
 
     # Generate the signed HMAC body.
-    method = 'GET'
+    method = 'HEAD' if head else 'GET'
     expires = int(time() + expires_in)
     full_path = '%s/%s/%s' % (path, self._swift_container, object_path)