From f4266d08d27fb41504daf5d8b356498dd1fda6f6 Mon Sep 17 00:00:00 2001
From: Joseph Schorr <josephschorr@users.noreply.github.com>
Date: Fri, 20 Nov 2015 11:12:34 -0500
Subject: [PATCH] Fix handling of aggregate size in V2

Fixes #931
---
 data/model/image.py   | 53 +++++++++++++++++++++++++++----------------
 data/model/storage.py |  2 +-
 2 files changed, 34 insertions(+), 21 deletions(-)

diff --git a/data/model/image.py b/data/model/image.py
index 95ee563c9..a067b5e76 100644
--- a/data/model/image.py
+++ b/data/model/image.py
@@ -354,31 +354,40 @@ def set_image_size(docker_image_id, namespace_name, repository_name, image_size,
 
   image.storage.image_size = image_size
   image.storage.uncompressed_size = uncompressed_size
-
-  ancestors = image.ancestors.split('/')[1:-1]
-  if ancestors:
-    try:
-      # TODO(jschorr): Switch to this faster route once we have full ancestor aggregate_size
-      # parent_image = Image.get(Image.id == ancestors[-1])
-      ancestor_size = (ImageStorage
-                       .select(fn.Sum(ImageStorage.image_size))
-                       .join(Image)
-                       .where(Image.id << ancestors)
-                       .scalar())
-
-      if ancestor_size is not None:
-        image.aggregate_size = ancestor_size + image_size
-    except Image.DoesNotExist:
-      pass
-  else:
-    image.aggregate_size = image_size
-
   image.storage.save()
+
+  image.aggregate_size = calculate_image_aggregate_size(image.ancestors, image.storage,
+                                                        image.parent)
   image.save()
 
   return image
 
 
+def calculate_image_aggregate_size(ancestors_str, image_storage, parent_image):
+  ancestors = ancestors_str.split('/')[1:-1]
+  if not ancestors:
+    return image_storage.image_size
+
+  if parent_image is None:
+    raise DataModelException('Could not load parent image')
+
+  ancestor_size = parent_image.aggregate_size
+  if ancestor_size is not None:
+    return ancestor_size + image_storage.image_size
+
+  # Fallback to a slower path if the parent doesn't have an aggregate size saved.
+  # TODO: remove this code if/when we do a full backfill.
+  ancestor_size = (ImageStorage
+                   .select(fn.Sum(ImageStorage.image_size))
+                   .join(Image)
+                   .where(Image.id << ancestors)
+                   .scalar())
+  if ancestor_size is None:
+    return None
+
+  return ancestor_size + image_storage.image_size
+
+
 def get_image(repo, docker_image_id):
   try:
     return Image.get(Image.docker_image_id == docker_image_id, Image.repository == repo)
@@ -442,9 +451,13 @@ def synthesize_v1_image(repo, image_storage, docker_image_id, created_date_str,
       # parse raises different exceptions, so we cannot use a specific kind of handler here.
       pass
 
+  # Get the aggregate size for the image.
+  aggregate_size = calculate_image_aggregate_size(ancestors, image_storage, parent_image)
+
   return Image.create(docker_image_id=docker_image_id, ancestors=ancestors, comment=comment,
                       command=command, v1_json_metadata=v1_json_metadata, created=created,
-                      storage=image_storage, repository=repo, parent=parent_image)
+                      storage=image_storage, repository=repo, parent=parent_image,
+                      aggregate_size=aggregate_size)
 
 
 def ensure_image_locations(*names):
diff --git a/data/model/storage.py b/data/model/storage.py
index 10f4c9f6b..1b702f2e0 100644
--- a/data/model/storage.py
+++ b/data/model/storage.py
@@ -229,7 +229,7 @@ def lookup_repo_storages_by_content_checksum(repo, checksums):
   for counter, checksum in enumerate(set(checksums)):
     query_alias = 'q{0}'.format(counter)
     candidate_subq = (ImageStorage
-                      .select(ImageStorage.id, ImageStorage.content_checksum)
+                      .select(ImageStorage.id, ImageStorage.content_checksum, ImageStorage.image_size)
                       .join(Image)
                       .where(Image.repository == repo, ImageStorage.content_checksum == checksum)
                       .limit(1)