Jelajahi Sumber

separate out cache and add tests for it

Daniel Sheffield 1 tahun lalu
induk
melakukan
8e96e1c824
4 mengubah file dengan 125 tambahan dan 32 penghapusan
  1. 1 1
      app/data/util.py
  2. 44 0
      app/rest/Cache.py
  3. 8 31
      app/rest/pyapi.py
  4. 72 0
      test/rest/test_Cache.py

+ 1 - 1
app/data/util.py

@@ -4,7 +4,7 @@
 # All rights reserved
 #
 # THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
-from typing import List, Set, Tuple, Iterable
+from typing import Tuple, Iterable
 from psycopg.sql import (
     Identifier,
     Literal,

+ 44 - 0
app/rest/Cache.py

@@ -0,0 +1,44 @@
+#
+# Copyright (c) Daniel Sheffield 2023
+# All rights reserved
+#
+# THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
+from typing import Dict
+
+from .CachedLoadingPage import CachedLoadingPage
+
+class Cache:
+    def __init__(self, limit) -> None:
+        self._cache: Dict[str, CachedLoadingPage] = dict()
+        self._limit = limit
+
+    def get(self, key: str) -> str:
+        if key not in self._cache:
+            return None
+        
+        page = self._cache[key]
+        if page.stale:
+            del self._cache[key]
+            return None
+        return page.value if page.loaded else page.update()
+    
+    def _enforce_limit(self, limit):
+        for idx, (_, k) in enumerate(sorted([
+                (v.age, k) for k, v in self._cache.items()
+            ])):
+            if idx >= limit: del self._cache[k]
+
+    def _clear_stale(self):
+        for k in [k for k, v in self._cache.items() if v.stale]:
+            del self._cache[k]
+    
+    def add(self, key: str, page: CachedLoadingPage) -> str:
+        self._clear_stale()
+        self._enforce_limit(self._limit)
+        self._cache[key] = page
+        return page.value
+    
+    def remove(self, key: str):
+        if key in self._cache:
+            del self._cache[key]
+

+ 8 - 31
app/rest/pyapi.py

@@ -17,7 +17,7 @@ from bottle import (
 )
 from psycopg import connect
 from psycopg.sql import SQL, Literal
-from threading import Lock, Thread
+from threading import Thread
 
 from ..data.filter import(
     get_filter,
@@ -30,6 +30,7 @@ from ..data.util import(
 from . import trend as worker
 from . import PARAMS
 from .CachedLoadingPage import CachedLoadingPage
+from .Cache import Cache
 
 host = f"host={os.getenv('HOST')}"
 db = f"dbname={os.getenv('DB', 'grocery')}"
@@ -39,19 +40,7 @@ if not password.split('=',1)[1]:
     password = ''
 conn = connect(f"{host} {db} {user} {password}")
 
-CACHE: Dict[str, CachedLoadingPage] = dict()
-
-def enforce_limit(cache, limit):
-    for idx, (_, k) in enumerate(sorted([
-            (v.age, k) for k, v in cache.items()
-        ])):
-        if idx > limit: del cache[k]
-
-
-def clear_stale(cache):
-    for k in [k for k, v in cache.items() if v.stale]:
-        del cache[k]
-
+CACHE = Cache(10)
 
 def get_product_rollup_statement(filters, having=None):
     where = [ get_where_include_exclude(
@@ -94,36 +83,24 @@ def normalize_query(query: FormsDict, allow: Iterable[str] = None) -> str:
 def send_static(filename):
     return static_file(filename, root='app/rest/static')
 
-global LOCK
-LOCK = Lock()
-
 @route('/grocery/trend')
 def trend():
     _, _, path, *_ = request.urlparts
     normalized = normalize_query(request.query, allow=PARAMS)
-    if normalized in CACHE:
-        if request.params.get('reload') == 'true':
-            del CACHE[normalized]
+    if request.params.get('reload') == 'true':
+        CACHE.remove(normalized)
 
     if request.query_string != normalized:
         return redirect(f'{path}?{normalized}')
     
-    if request.query_string in CACHE:
-        page = CACHE[request.query_string]
-        if not page.stale:
-            return page.value if page.loaded else page.update()
-        del CACHE[request.query_string]
+    page = CACHE.get(normalized)
 
-    page = CachedLoadingPage(
+    return page if page else CACHE.add(normalized, CachedLoadingPage(
         template("loading", progress=[]),
         lambda queue: Thread(target=worker.trend, args=(
             queue, conn, path, request.query
         )).start()
-    )
-    clear_stale(CACHE)
-    enforce_limit(CACHE, 10)
-    CACHE[request.query_string] = page
-    return page.value
+    ))
 
 @route('/grocery/groups')
 def groups():

+ 72 - 0
test/rest/test_Cache.py

@@ -0,0 +1,72 @@
+#
+# Copyright (c) Daniel Sheffield 2023
+#
+# All rights reserved
+#
+# THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
+from pytest import fixture
+from time import time
+from app.rest.Cache import Cache
+from app.rest.CachedLoadingPage import (
+    CachedLoadingPage,
+)
+
+@fixture
+def cache():
+    return Cache(0)
+
+def test_add(cache: Cache):
+    val = 'test-cached-value'
+    key = 'test-key'
+    assert cache.add(key, CachedLoadingPage(val, lambda _: None)) == val
+
+def test_get(cache: Cache):
+    val = 'test-cached-value'
+    key = 'test-key'
+
+    assert cache.get(key) is None
+
+    assert cache.add(key, CachedLoadingPage(val, lambda q: q.put('next-val'))) == val
+    assert cache.get(key) == 'next-val'
+
+def test_remove(cache: Cache):
+    val = 'test-cached-value'
+    key = 'test-key'
+
+    assert cache.get(key) is None
+    assert cache.add(key, CachedLoadingPage(val, lambda q: q.put('next-val'))) == val
+    cache.remove(key)
+    assert cache.get(key) is None
+
+def test_enforce_limit(cache: Cache):
+    val = 'test-cached-value'
+    key = 'test-key'
+    assert cache.add(key, CachedLoadingPage(val, lambda _: None)) == val
+    # adding more exceeds limit
+    assert cache.add('other', CachedLoadingPage(val, lambda _: None)) == val
+    assert cache.get(key) is None
+    assert cache.get('other') == val
+
+def test_clean_stale(cache: Cache):
+    val = 'test-cached-value'
+    key = 'test-key'
+    page = CachedLoadingPage(val, lambda _: None)
+    page._created = time() - 10*60
+    # add stale page
+    assert cache.add(key, page) == val
+    assert cache.get(key) is None
+    
+    page = CachedLoadingPage(val, lambda _: None)
+    assert cache.add(key, page) == val
+    # make page stale
+    page._created = time() - 10*60
+    assert cache.get(key) is None
+
+    page = CachedLoadingPage(val, lambda _: None)
+    assert cache.add(key, page) == val
+    # make page stale
+    page._created = time() - 10*60
+    # stale page is rotated out on addition
+    assert cache.add('other', CachedLoadingPage(val, lambda _: None)) == val
+    assert cache.get(key) is None
+    assert cache.get('other') == val