Daniel Sheffield 1 rok pred
rodič
commit
c3406eecb7

+ 54 - 0
app/rest/CachedLoadingPage.py

@@ -0,0 +1,54 @@
+#
+# Copyright (c) Daniel Sheffield 2023
+# All rights reserved
+#
+# THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
+from queue import Queue, Empty
+from time import time
+from threading import Lock
+
+class CachedLoadingPage():
+    
+    value: str
+
+    def __init__(self, value: str):
+        self._created = time()
+        self._queue = Queue()
+        self._loaded = False
+        self.value = value
+        self._lock = Lock()
+
+    def _age(self) -> float:
+        return time() - self._created
+
+    @property
+    def queue(self) -> Queue:
+        return self._queue
+
+    @property
+    def loaded(self) -> bool:
+        return self._loaded
+
+    def _set_loaded(self, value: bool) -> bool:
+        self._loaded = value
+        return self._loaded
+    
+    @property
+    def stale(self) -> bool:
+        return self._age() > 10*60
+
+    def update(self) -> str:
+        if not self._lock.acquire(blocking=True, timeout=0.5):
+            return self.value
+        try:
+            item = self._queue.get(block=True, timeout=0.5)
+            if item is None:
+                self._queue.task_done()
+                self._set_loaded(True)
+            else:
+                self.value = item
+                self.queue.task_done()
+        except Empty:
+            pass
+        self._lock.release()
+        return self.value

+ 31 - 29
app/rest/pyapi.py

@@ -3,8 +3,9 @@
 # All rights reserved
 #
 # THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
-from typing import Iterable
+from typing import Iterable, Dict
 import os
+from time import time
 from urllib.parse import urlencode
 from bottle import (
     route,
@@ -17,11 +18,13 @@ from bottle import (
     static_file,
     TEMPLATE_PATH,
 )
-from matplotlib.axes import Axes
+import matplotlib
 from psycopg import connect
 from psycopg.sql import SQL, Literal
 import seaborn as sns
-from multiprocessing import Lock
+from threading import Lock, Thread
+
+from queue import Queue, Empty
 
 from ..data.filter import(
     get_filter,
@@ -30,12 +33,13 @@ from ..data.filter import(
 from ..data.util import(
     get_where_include_exclude
 )
-from .trend import trend_internal
+from . import trend as worker
 from . import PARAMS
-import matplotlib
+from .CachedLoadingPage import CachedLoadingPage
+
 matplotlib.use('agg')
 
-CACHE = dict()
+CACHE: Dict[str, CachedLoadingPage] = dict()
 
 host = f"host={os.getenv('HOST')}"
 db = f"dbname={os.getenv('DB', 'grocery')}"
@@ -89,7 +93,6 @@ def normalize_query(query: FormsDict, allow: Iterable[str] = None) -> str:
     ])
 
 
-
 @route('/grocery/static/<filename:path>')
 def send_static(filename):
     return static_file(filename, root='app/rest/static')
@@ -106,33 +109,32 @@ def trend():
     if request.query_string != normalized:
         return redirect(f'{path}?{normalized}')
     
-    loading = template("loading", progress=[])
     if request.query_string in CACHE:
-        if LOCK.acquire(block=False):
-            try:
-                return CACHE[request.query_string]["state"]
-            finally:
-                try:
-                    CACHE[request.query_string]["state"] = next(CACHE[request.query_string]["iter"])
-                except StopIteration:
-                    del CACHE[request.query_string]
-                finally:
-                    LOCK.release()
+        page = CACHE[request.query_string]
+        if not page.stale:
+            if not page.loaded:
+                return page.update()
+            return page.value
         else:
-            return CACHE[request.query_string]["state"]
-
-    if LOCK.acquire(block=False):
+            del CACHE[request.query_string]
+    
+    if LOCK.acquire(blocking=True):
         if request.query_string in CACHE:
-            LOCK.release()
-            return CACHE[request.query_string]["state"]
+            page = CACHE[request.query_string]
+            if not page.stale:
+                LOCK.release()
+                if not page.loaded:
+                    return page.update()
+                return page.value
+        
         try:
-            CACHE[request.query_string] = {
-                "iter": trend_internal(conn, path, request.query),
-                "state": loading,
-            }
+            page = CachedLoadingPage(template("loading", progress=[]))
+            CACHE[request.query_string] = page
+            thread = Thread(target=worker.trend, args=(page.queue, conn, path, request.query))
+            thread.start()
+            return page.value
         finally:
             LOCK.release()
-    return loading
 
 
 @route('/grocery/groups')
@@ -242,4 +244,4 @@ SELECT query_to_xml_and_xmlschema({inner}, false, false, ''::text)
     response.content_type = 'application/xhtml+xml; charset=utf-8'
     return template("query-to-xml", title="Tags", xml=xml, form=form)
 
-run(host='0.0.0.0', port=6772, server='gunicorn')
+run(host='0.0.0.0', port=6772, server='paste')

+ 1 - 1
app/rest/requirements.txt

@@ -1,4 +1,4 @@
 seaborn
 psycopg[binary]
 bottle
-gunicorn
+paste

+ 12 - 11
app/rest/trend.py

@@ -8,14 +8,13 @@ from bottle import (
     DictProperty,
     HTTPError,
     template,
-    request,
 )
 from io import StringIO
 import matplotlib.pyplot as plt
 import seaborn as sns
 from psycopg import Connection
 from psycopg.connection import TupleRow
-from time import time
+from queue import Queue
 from . import ALL_UNITS
 from ..data.QueryManager import (
     display_mapper,
@@ -35,6 +34,11 @@ from . import PARAMS
 def abort(code, text):
     return HTTPError(code, text)
 
+def trend(queue: Queue, conn: Connection[TupleRow], path: str, query: DictProperty):
+    for item in trend_internal(conn, path, query):
+        queue.put(item, block=True)
+    queue.put(None)
+
 def trend_internal(conn: Connection[TupleRow], path: str, query: DictProperty):
     progress = []
     try:
@@ -43,7 +47,8 @@ def trend_internal(conn: Connection[TupleRow], path: str, query: DictProperty):
             fields = { k: query[k] or None for k in query.keys() if k in PARAMS }
             unit = fields['unit'] = fields['unit'] or 'kg' if 'unit' in fields else 'kg'
             if unit and unit not in ALL_UNITS:
-                raise abort(400, f"Unsupported unit {unit}")
+                yield abort(400, f"Unsupported unit {unit}")
+                return
 
             progress.append({ "name": "Loading data", "status": ""})
             yield template("loading", progress=progress)
@@ -53,7 +58,8 @@ def trend_internal(conn: Connection[TupleRow], path: str, query: DictProperty):
             yield template("loading", progress=progress)
 
             if data.empty:
-                raise abort(404, f"No data for {fields}")
+                yield abort(404, f"No data for {fields}")
+                return
             
             progress.append({ "name": "Loading chart", "status": ""})
             yield template("loading", progress=progress)
@@ -93,18 +99,13 @@ def trend_internal(conn: Connection[TupleRow], path: str, query: DictProperty):
 
             f = StringIO()
             plt.savefig(f, format='svg')
-            _filter = get_filter(request.query, allow=PARAMS)
+            _filter = get_filter(query, allow=PARAMS)
             form = get_form(path.split('/')[-1], 'get', _filter, data)
             
             progress[-1]["status"] = "done"
             yield template("loading", progress=progress)
             
-            resp = lambda: template("trend", form=form, svg=f.getvalue())
-
-    except HTTPError as e:
-        resp = lambda exception=e: exception
+            yield template("trend", form=form, svg=f.getvalue())
 
     finally:
         conn.commit()
-     
-    yield from iter(resp, lambda started=time(): time() - started > 600)