Browse Source

fix threading

Daniel Sheffield 1 year ago
parent
commit
c3406eecb7
4 changed files with 98 additions and 41 deletions
  1. 54 0
      app/rest/CachedLoadingPage.py
  2. 31 29
      app/rest/pyapi.py
  3. 1 1
      app/rest/requirements.txt
  4. 12 11
      app/rest/trend.py

+ 54 - 0
app/rest/CachedLoadingPage.py

@@ -0,0 +1,54 @@
+#
+# Copyright (c) Daniel Sheffield 2023
+# All rights reserved
+#
+# THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
+from queue import Queue, Empty
+from time import time
+from threading import Lock
+
+class CachedLoadingPage():
+    
+    value: str
+
+    def __init__(self, value: str):
+        self._created = time()
+        self._queue = Queue()
+        self._loaded = False
+        self.value = value
+        self._lock = Lock()
+
+    def _age(self) -> float:
+        return time() - self._created
+
+    @property
+    def queue(self) -> Queue:
+        return self._queue
+
+    @property
+    def loaded(self) -> bool:
+        return self._loaded
+
+    def _set_loaded(self, value: bool) -> bool:
+        self._loaded = value
+        return self._loaded
+    
+    @property
+    def stale(self) -> bool:
+        return self._age() > 10*60
+
+    def update(self) -> str:
+        if not self._lock.acquire(blocking=True, timeout=0.5):
+            return self.value
+        try:
+            item = self._queue.get(block=True, timeout=0.5)
+            if item is None:
+                self._queue.task_done()
+                self._set_loaded(True)
+            else:
+                self.value = item
+                self.queue.task_done()
+        except Empty:
+            pass
+        self._lock.release()
+        return self.value

+ 31 - 29
app/rest/pyapi.py

@@ -3,8 +3,9 @@
 # All rights reserved
 # All rights reserved
 #
 #
 # THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
 # THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
-from typing import Iterable
+from typing import Iterable, Dict
 import os
 import os
+from time import time
 from urllib.parse import urlencode
 from urllib.parse import urlencode
 from bottle import (
 from bottle import (
     route,
     route,
@@ -17,11 +18,13 @@ from bottle import (
     static_file,
     static_file,
     TEMPLATE_PATH,
     TEMPLATE_PATH,
 )
 )
-from matplotlib.axes import Axes
+import matplotlib
 from psycopg import connect
 from psycopg import connect
 from psycopg.sql import SQL, Literal
 from psycopg.sql import SQL, Literal
 import seaborn as sns
 import seaborn as sns
-from multiprocessing import Lock
+from threading import Lock, Thread
+
+from queue import Queue, Empty
 
 
 from ..data.filter import(
 from ..data.filter import(
     get_filter,
     get_filter,
@@ -30,12 +33,13 @@ from ..data.filter import(
 from ..data.util import(
 from ..data.util import(
     get_where_include_exclude
     get_where_include_exclude
 )
 )
-from .trend import trend_internal
+from . import trend as worker
 from . import PARAMS
 from . import PARAMS
-import matplotlib
+from .CachedLoadingPage import CachedLoadingPage
+
 matplotlib.use('agg')
 matplotlib.use('agg')
 
 
-CACHE = dict()
+CACHE: Dict[str, CachedLoadingPage] = dict()
 
 
 host = f"host={os.getenv('HOST')}"
 host = f"host={os.getenv('HOST')}"
 db = f"dbname={os.getenv('DB', 'grocery')}"
 db = f"dbname={os.getenv('DB', 'grocery')}"
@@ -89,7 +93,6 @@ def normalize_query(query: FormsDict, allow: Iterable[str] = None) -> str:
     ])
     ])
 
 
 
 
-
 @route('/grocery/static/<filename:path>')
 @route('/grocery/static/<filename:path>')
 def send_static(filename):
 def send_static(filename):
     return static_file(filename, root='app/rest/static')
     return static_file(filename, root='app/rest/static')
@@ -106,33 +109,32 @@ def trend():
     if request.query_string != normalized:
     if request.query_string != normalized:
         return redirect(f'{path}?{normalized}')
         return redirect(f'{path}?{normalized}')
     
     
-    loading = template("loading", progress=[])
     if request.query_string in CACHE:
     if request.query_string in CACHE:
-        if LOCK.acquire(block=False):
-            try:
-                return CACHE[request.query_string]["state"]
-            finally:
-                try:
-                    CACHE[request.query_string]["state"] = next(CACHE[request.query_string]["iter"])
-                except StopIteration:
-                    del CACHE[request.query_string]
-                finally:
-                    LOCK.release()
+        page = CACHE[request.query_string]
+        if not page.stale:
+            if not page.loaded:
+                return page.update()
+            return page.value
         else:
         else:
-            return CACHE[request.query_string]["state"]
-
-    if LOCK.acquire(block=False):
+            del CACHE[request.query_string]
+    
+    if LOCK.acquire(blocking=True):
         if request.query_string in CACHE:
         if request.query_string in CACHE:
-            LOCK.release()
-            return CACHE[request.query_string]["state"]
+            page = CACHE[request.query_string]
+            if not page.stale:
+                LOCK.release()
+                if not page.loaded:
+                    return page.update()
+                return page.value
+        
         try:
         try:
-            CACHE[request.query_string] = {
-                "iter": trend_internal(conn, path, request.query),
-                "state": loading,
-            }
+            page = CachedLoadingPage(template("loading", progress=[]))
+            CACHE[request.query_string] = page
+            thread = Thread(target=worker.trend, args=(page.queue, conn, path, request.query))
+            thread.start()
+            return page.value
         finally:
         finally:
             LOCK.release()
             LOCK.release()
-    return loading
 
 
 
 
 @route('/grocery/groups')
 @route('/grocery/groups')
@@ -242,4 +244,4 @@ SELECT query_to_xml_and_xmlschema({inner}, false, false, ''::text)
     response.content_type = 'application/xhtml+xml; charset=utf-8'
     response.content_type = 'application/xhtml+xml; charset=utf-8'
     return template("query-to-xml", title="Tags", xml=xml, form=form)
     return template("query-to-xml", title="Tags", xml=xml, form=form)
 
 
-run(host='0.0.0.0', port=6772, server='gunicorn')
+run(host='0.0.0.0', port=6772, server='paste')

+ 1 - 1
app/rest/requirements.txt

@@ -1,4 +1,4 @@
 seaborn
 seaborn
 psycopg[binary]
 psycopg[binary]
 bottle
 bottle
-gunicorn
+paste

+ 12 - 11
app/rest/trend.py

@@ -8,14 +8,13 @@ from bottle import (
     DictProperty,
     DictProperty,
     HTTPError,
     HTTPError,
     template,
     template,
-    request,
 )
 )
 from io import StringIO
 from io import StringIO
 import matplotlib.pyplot as plt
 import matplotlib.pyplot as plt
 import seaborn as sns
 import seaborn as sns
 from psycopg import Connection
 from psycopg import Connection
 from psycopg.connection import TupleRow
 from psycopg.connection import TupleRow
-from time import time
+from queue import Queue
 from . import ALL_UNITS
 from . import ALL_UNITS
 from ..data.QueryManager import (
 from ..data.QueryManager import (
     display_mapper,
     display_mapper,
@@ -35,6 +34,11 @@ from . import PARAMS
 def abort(code, text):
 def abort(code, text):
     return HTTPError(code, text)
     return HTTPError(code, text)
 
 
+def trend(queue: Queue, conn: Connection[TupleRow], path: str, query: DictProperty):
+    for item in trend_internal(conn, path, query):
+        queue.put(item, block=True)
+    queue.put(None)
+
 def trend_internal(conn: Connection[TupleRow], path: str, query: DictProperty):
 def trend_internal(conn: Connection[TupleRow], path: str, query: DictProperty):
     progress = []
     progress = []
     try:
     try:
@@ -43,7 +47,8 @@ def trend_internal(conn: Connection[TupleRow], path: str, query: DictProperty):
             fields = { k: query[k] or None for k in query.keys() if k in PARAMS }
             fields = { k: query[k] or None for k in query.keys() if k in PARAMS }
             unit = fields['unit'] = fields['unit'] or 'kg' if 'unit' in fields else 'kg'
             unit = fields['unit'] = fields['unit'] or 'kg' if 'unit' in fields else 'kg'
             if unit and unit not in ALL_UNITS:
             if unit and unit not in ALL_UNITS:
-                raise abort(400, f"Unsupported unit {unit}")
+                yield abort(400, f"Unsupported unit {unit}")
+                return
 
 
             progress.append({ "name": "Loading data", "status": ""})
             progress.append({ "name": "Loading data", "status": ""})
             yield template("loading", progress=progress)
             yield template("loading", progress=progress)
@@ -53,7 +58,8 @@ def trend_internal(conn: Connection[TupleRow], path: str, query: DictProperty):
             yield template("loading", progress=progress)
             yield template("loading", progress=progress)
 
 
             if data.empty:
             if data.empty:
-                raise abort(404, f"No data for {fields}")
+                yield abort(404, f"No data for {fields}")
+                return
             
             
             progress.append({ "name": "Loading chart", "status": ""})
             progress.append({ "name": "Loading chart", "status": ""})
             yield template("loading", progress=progress)
             yield template("loading", progress=progress)
@@ -93,18 +99,13 @@ def trend_internal(conn: Connection[TupleRow], path: str, query: DictProperty):
 
 
             f = StringIO()
             f = StringIO()
             plt.savefig(f, format='svg')
             plt.savefig(f, format='svg')
-            _filter = get_filter(request.query, allow=PARAMS)
+            _filter = get_filter(query, allow=PARAMS)
             form = get_form(path.split('/')[-1], 'get', _filter, data)
             form = get_form(path.split('/')[-1], 'get', _filter, data)
             
             
             progress[-1]["status"] = "done"
             progress[-1]["status"] = "done"
             yield template("loading", progress=progress)
             yield template("loading", progress=progress)
             
             
-            resp = lambda: template("trend", form=form, svg=f.getvalue())
-
-    except HTTPError as e:
-        resp = lambda exception=e: exception
+            yield template("trend", form=form, svg=f.getvalue())
 
 
     finally:
     finally:
         conn.commit()
         conn.commit()
-     
-    yield from iter(resp, lambda started=time(): time() - started > 600)