Bläddra i källkod

add simple query caching

Daniel Sheffield 1 år sedan
förälder
incheckning
a79d037763
1 ändrade filer med 25 tillägg och 10 borttagningar
  1. 25 10
      app/rest/pyapi.py

+ 25 - 10
app/rest/pyapi.py

@@ -3,7 +3,11 @@
 # All rights reserved
 #
 # THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
+from time import time
 from typing import Iterable
+from io import StringIO
+import os
+from urllib.parse import urlencode
 from bottle import (
     route,
     request,
@@ -13,13 +17,12 @@ from bottle import (
     DictProperty,
     redirect,
     template,
+    HTTPError,
 )
 from psycopg import connect
 from psycopg.sql import SQL, Literal
-import os
 import matplotlib.pyplot as plt
 import seaborn as sns
-from urllib.parse import urlencode
 from ..activities.Plot import (
     get_data,
 )
@@ -35,7 +38,9 @@ def line(pivot, ylabel=None, xlabel=None):
     ax.set_ylabel(ylabel)
 
 ALL_UNITS = {'g','kg','mL','L','Pieces','Bunches','Bags'}
-PARAMS ={ 'group', 'category', 'product', 'unit' }
+PARAMS = { 'group', 'category', 'product', 'unit' }
+CACHE = dict()
+
 host = f"host={os.getenv('HOST')}"
 db = f"dbname={os.getenv('DB', 'grocery')}"
 user = f"user={os.getenv('USER', 'das')}"
@@ -45,8 +50,6 @@ if not password.split('=',1)[1]:
 conn = connect(f"{host} {db} {user} {password}")
 sns.set_theme()
 
-from io import StringIO
-
 def get_filter(query: DictProperty, allow: Iterable[str] = None):
     return {
         k: get_include_exclude(
@@ -112,10 +115,18 @@ def trend():
 
     if request.query_string != normalized:
         return redirect(f'{path}?{normalized}')
+    
+    if request.query_string in CACHE:
+        return next(CACHE[request.query_string], None)
+    
+    CACHE[request.query_string] = trend_internal(path, request.query)
+    return next(CACHE[request.query_string])
+
+def trend_internal(path, query):
     try:
         with conn.cursor() as cur:
             query_manager = QueryManager(cur, display_mapper)
-            fields = { k: request.query[k] or None for k in request.query.keys() if k in PARAMS }
+            fields = { k: query[k] or None for k in query.keys() if k in PARAMS }
             unit = fields['unit'] = fields['unit'] or 'kg' if 'unit' in fields else 'kg'
             if unit and unit not in ALL_UNITS:
                 raise abort(400, f"Unsupported unit {unit}")
@@ -123,16 +134,14 @@ def trend():
             data = get_data(query_manager, **fields)
             if data.empty:
                 raise abort(404, f"No data for {fields}")
+            
             pivot = data.pivot_table(index=['ts_raw',], columns=['product',], values=['$/unit'], aggfunc='mean')
             pivot.columns = pivot.columns.droplevel()
             plt.figure(figsize=[16, 9])
             line(pivot, xlabel='Time', ylabel=f'$ / {unit}')
             f = StringIO()
             plt.savefig(f, format='svg')
-    finally:
-        conn.commit()
-    
-    return f"""
+            resp = lambda: f"""
 <!DOCTYPE html>
 <html>
     <head>
@@ -148,6 +157,12 @@ def trend():
     </body>
 </html>
 """
+    except HTTPError as e:
+        resp = lambda exception=e: exception
+    finally:
+        conn.commit()
+     
+    yield from iter(resp, lambda started=time(): time() - started > 600)
 
 heading = """<?xml version="1.0" encoding="UTF-8"?>
 <?xml-stylesheet type="text/xsl" href="/grocery/style/table"?>