瀏覽代碼

fix losing form data when navigating

Daniel Sheffield 1 年之前
父節點
當前提交
186b62ff9d
共有 5 個文件被更改,包括 63 次插入51 次删除
  1. 23 34
      app/rest/pyapi.py
  2. 24 7
      app/rest/query_to_xml.py
  3. 9 6
      app/rest/route_decorators.py
  4. 1 1
      app/rest/templates/loading.tpl
  5. 6 3
      app/rest/trend.py

+ 23 - 34
app/rest/pyapi.py

@@ -6,15 +6,15 @@
 from io import BufferedReader
 import os
 from threading import Thread
+from typing import Union
 from bottle import (
     route, request, response,
     redirect, abort, 
     template, static_file,
-    FormsDict, HTTPError, HTTPResponse,
+    FormsDict, HTTPError,
 )
-from psycopg import connect
-
-from app.data.filter import get_filter, get_query_param
+from psycopg import Cursor, connect
+from psycopg.rows import TupleRow
 
 from .hash_util import blake, bytes_to_base32, hash_to_base32, hex_to_hash, normalize_base32
 from .route_decorators import normalize, normalize_query, poison, cursor
@@ -31,74 +31,63 @@ if not password.split('=',1)[1]:
     password = ''
 conn = connect(f"{host} {db} {user} {password}")
 
-CACHE = Cache(10)
-
 @route('/grocery/static/<filename:path>')
 def send_static(filename):
     return static_file(filename, root='app/rest/static')
 
 
 @route('/grocery/trend', method=['GET', 'POST'])
-@poison(cache=CACHE)
+@poison(cache=Cache(10))
 @normalize
-def trend():
-    key = normalize_query(request.params)
-    parts = key.split('=')
-    if len(parts) == 2 and parts[0] == 'hash':
-        _, _hash = parts
-        key = hex_to_hash(_hash)
-
-    page = CACHE[key]
-    
-    _, _, path, *_ = request.urlparts
+def trend(key: str, forms: FormsDict, cache: Cache):
+    page = cache[key]
     if page:
         return page
     
-    param = get_filter(request.params, allow=PARAMS)
-    params = FormsDict({
-        k: get_query_param(*param[k]) if k != 'organic' else BOOLEAN[
-            BOOLEAN.get(request.params.organic, None)
-        ] for k in sorted(param) if param[k]
-    })
-
-    return CACHE.add(key, CachedLoadingPage(
+    _, _, path, *_ = request.urlparts
+    
+    return cache.add(key, CachedLoadingPage(
         template("loading", progress=[]),
         lambda queue: Thread(target=worker.trend, args=(
-            queue, conn, path, params
+            queue, conn, path, forms
         )).start()
     ))
 
 
 @route('/grocery/groups', method=['GET', 'POST'])
+@poison(cache=Cache(10))
 @normalize
 @cursor(connection=conn)
-def groups(cur):
+def groups(cur: Cursor[TupleRow], key: Union[int, str], forms: FormsDict, cache: Cache):
     response.content_type = 'application/xhtml+xml; charset=utf-8'
-    return get_groups(cur, request.query)
+    return get_groups(cur, forms)
 
 
 @route('/grocery/categories', method=['GET', 'POST'])
+@poison(cache=Cache(10))
 @normalize
 @cursor(connection=conn)
-def categories(cur):
+def categories(cur: Cursor[TupleRow], key: Union[int, str], forms: FormsDict, cache: Cache):
     response.content_type = 'application/xhtml+xml; charset=utf-8'
-    return get_categories(cur, request.query)
+    return get_categories(cur, forms)
 
 
 @route('/grocery/products', method=['GET', 'POST'])
+@poison(cache=Cache(10))
 @normalize
 @cursor(connection=conn)
-def products(cur):
+def products(cur: Cursor[TupleRow], key: Union[int, str], forms: FormsDict, cache: Cache):
     response.content_type = 'application/xhtml+xml; charset=utf-8'
-    return get_products(cur, request.query)
+    return get_products(cur, forms)
 
 
 @route('/grocery/tags', method=['GET', 'POST'])
+@poison(cache=Cache(10))
 @normalize
 @cursor(connection=conn)
-def tags(cur):
+def tags(cur: Cursor[TupleRow], key: Union[int, str], forms: FormsDict, cache: Cache):
     response.content_type = 'application/xhtml+xml; charset=utf-8'
-    return get_tags(cur, request.query)
+    return get_tags(cur, forms)
 
 CLIP_SIZE_LIMIT = 65535
 SCHEME = "http://" #"https://"

+ 24 - 7
app/rest/query_to_xml.py

@@ -16,8 +16,9 @@ from . import BOOLEAN, PARAMS
 
 
 def get_product_rollup_statement(filters) -> SQL:
+    _map = { k: k[0] for k in ('product', 'category', 'group') }
     where = [ get_where_include_exclude(
-        k[0], "name", list(include), list(exclude)
+        _map[k], "name", list(include), list(exclude)
     ) for k, (include, exclude) in filters.items() ]
     return SQL("""
 SELECT
@@ -100,15 +101,20 @@ WHERE q.category IS NULL
     xml = get_xml(cur, sql)
     return template("query-to-xml", title="Groups", xml=xml, form=form)
 
-def get_tags(cur: Cursor, query: FormsDict):
-    form = template('form-nav', action='tags', method='get', params=[
-        {'name': k, 'value': query[k]} for k in query if k in PARAMS
-    ])
-    sql = SQL("""
+def get_tags_statement(filters) -> SQL:
+    _map = {
+        k: k[0] for k in ('product', 'category', 'group')
+    }
+    _map.update({ 'tag': 'tg' })
+    where = [ get_where_include_exclude(
+        _map[k], "name", list(include), list(exclude)
+    ) for k, (include, exclude) in filters.items() ]
+    return SQL("""
 SELECT * FROM (SELECT count(DISTINCT txn.id) AS "Uses", tg.name AS "Name"
 FROM tags tg
 JOIN tags_map tm ON tg.id = tm.tag_id
 JOIN transactions txn ON txn.id = tm.transaction_id
+WHERE {where}
 GROUP BY tg.name
 ORDER BY 1 DESC, 2) q
 UNION ALL
@@ -116,6 +122,17 @@ SELECT count(DISTINCT txn.id) AS "Uses", count(DISTINCT tg.name)||'' AS "Name"
 FROM tags tg
 JOIN tags_map tm ON tg.id = tm.tag_id
 JOIN transactions txn ON txn.id = tm.transaction_id
-""").as_string(cur)
+WHERE {where}
+""").format(where=SQL("\nAND").join(where))
+
+def get_inner_tags_query(query: FormsDict) -> SQL:
+    filters = get_filter(query, allow=('tag',))
+    inner = get_tags_statement(filters)
+    return inner
+
+def get_tags(cur: Cursor, query: FormsDict):
+    inner = get_inner_tags_query(query)
+    form = render_form(cur, inner, query)
+    sql = inner.as_string(cur)
     xml = get_xml(cur, sql)
     return template("query-to-xml", title="Tags", xml=xml, form=form)

+ 9 - 6
app/rest/route_decorators.py

@@ -15,10 +15,10 @@ from .Cache import Cache
 from .hash_util import hash_to_hex, normalize_hex
 
 def normalize_query(query: FormsDict, allow: Iterable[str] = None) -> str:
-    allow = allow or (PARAMS | { 'hash'})
     if 'hash' in query and query.hash:
         _hex = normalize_hex(query.hash)
         return f'hash={_hex}'
+    allow = allow or PARAMS
     param = get_filter(query, allow=allow)
     norm = urlencode([
         (
@@ -35,7 +35,10 @@ def _normalize_decorator(func: Callable, allow=None):
         normalized = normalize_query(request.params, allow=allow)
         if request.query_string != normalized:
             return redirect(f'{path}?{normalized}', 307)
-        return func(*args, **kwargs)
+        
+        _hash = request.params.hash
+        key = _hash if _hash else request.query_string
+        return func(key, request.forms if _hash else request.query, *args, **kwargs)
     return wrap
 
 
@@ -48,16 +51,16 @@ def normalize(*args, **kwargs):
 
 def _poison_decorator(func: Callable, cache: Cache = None):
     def wrap(*args, **kwargs):
-        normalized = normalize_query(request.params, allow=PARAMS)
         if request.params.get('reload') == 'true':
+            normalized = normalize_query(request.params, allow=PARAMS)
             cache.remove(normalized)
-        return func(*args, **kwargs)
+        return func(cache, *args, **kwargs)
     return wrap
 
 
-def poison(*args, **kwargs):
+def poison(*args, cache=None, **kwargs):
     if not len(args):
-        return lambda f: _poison_decorator(f, **kwargs)
+        return lambda f: _poison_decorator(f, cache=cache, **kwargs)
     raise Exception("decorator argument required")
 
 

+ 1 - 1
app/rest/templates/loading.tpl

@@ -22,6 +22,6 @@ body {
       end
       %>
     </div>
-    <meta http-equiv="Refresh" content="0;" />
+    <meta http-equiv="Refresh" content="0.2;" />
   </body>
 </html>

+ 6 - 3
app/rest/trend.py

@@ -24,6 +24,7 @@ from ..data.QueryManager import (
 )
 from ..data.filter import (
     get_filter,
+    get_query_param,
 )
 from ..activities.Plot import (
     get_data,
@@ -43,12 +44,15 @@ def trend(queue: Queue, conn: Connection[TupleRow], path: str, query: FormsDict)
     queue.put(None)
 
 def trend_internal(conn: Connection[TupleRow], path: str, query: FormsDict):
-    print({ k: query[k] for k in query })
     progress = []
     try:
         with conn.cursor() as cur:
             query_manager = QueryManager(cur, display_mapper)
-            fields = { k: query[k] or None for k in query.keys() if k in PARAMS }
+            _filter = get_filter(query, allow=PARAMS)
+            fields = {
+                k: get_query_param(*_filter[k])
+                for k in sorted(_filter) if k not in ('organic', 'unit') and _filter[k]
+            }
             unit = fields['unit'] = fields['unit'] or 'kg' if 'unit' in fields else 'kg'
             fields['organic'] = BOOLEAN.get(query.organic, None)
             if unit and unit not in ALL_UNITS:
@@ -119,7 +123,6 @@ def trend_internal(conn: Connection[TupleRow], path: str, query: FormsDict):
 
             f = StringIO()
             plt.savefig(f, format='svg')
-            _filter = get_filter(query, allow=PARAMS)
             organic = BOOLEAN.get(query.organic, None)
             action = path.split('/')[-1]
             form = get_form(action, 'post', _filter, organic, data)