Parcourir la source

refactor, de-lint, and add more tests

Daniel Sheffield il y a 1 an
Parent
commit
7d35c40205

+ 16 - 0
app/activities/Banner.py

@@ -0,0 +1,16 @@
+#
+# Copyright (c) Daniel Sheffield 2023
+#
+# All rights reserved
+#
+# THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
+from urwid import AttrMap, Padding, Pile, Text
+from .. import COPYRIGHT
+
+def banner(title):
+    header = Text(title, 'center')
+    _copyright = Text(COPYRIGHT, 'center')
+    return AttrMap(Pile([
+            Padding(header, 'center', width=('relative', 100)),
+            Padding(_copyright, 'center', width=('relative', 100)),
+        ]), 'banner')

+ 3 - 12
app/activities/PriceCheck.py

@@ -14,13 +14,11 @@ from urwid import (
     Divider,
     Filler,
     LineBox,
-    Padding,
     Pile,
     RadioButton,
     Text,
 )
 
-from .. import COPYRIGHT
 from ..widgets import (
     AutoCompleteEdit,
     AutoCompleteFloatEdit,
@@ -32,6 +30,8 @@ from ..widgets import (
 from ..data.QueryManager import QueryManager
 from .ActivityManager import ActivityManager, show_or_exit
 from .Rating import Rating
+from .Banner import banner
+
 
 def get_historic_prices(df):
     return df.drop(labels=[
@@ -248,15 +248,6 @@ class PriceCheck(FocusWidget):
 
         self.clear()
 
-        header = Text(u'Price Check', 'center')
-        _copyright = Text(COPYRIGHT, 'center')
-
-        banner = Pile([
-            Padding(header, 'center', width=('relative', 100)),
-            Padding(_copyright, 'center', width=('relative', 100)),
-        ])
-        banner = AttrMap(banner, 'banner')
-
         _widgets = dict(chain(*list(map(lambda x: x.items(), [
                 self.edit_fields, self.text_fields, self.checkboxes
             ])
@@ -329,7 +320,7 @@ class PriceCheck(FocusWidget):
         ])})
 
         widget = Pile([
-            banner,
+            banner(u'Price Check'),
             Divider(),
             components['top_pane'],
             Columns((components['left_pane'], components['right_pane']),

+ 10 - 29
app/activities/RecipeEditor.py

@@ -29,7 +29,13 @@ from urwid.numedit import FloatEdit
 import yaml
 from yaml.representer import SafeRepresenter
 
-from .. import COPYRIGHT
+from .grouped_widget_util import (
+    in_same_row,
+    to_numbered_field,
+    to_unnumbered_field,
+    to_named_value,
+)
+
 from ..widgets import (
     AutoCompleteEdit,
     FocusWidget,
@@ -38,6 +44,7 @@ from ..widgets import (
 )
 from ..data.QueryManager import QueryManager
 from .ActivityManager import ActivityManager, show_or_exit
+from .Banner import banner
 
 def change_style(style, representer):
     def new_representer(dumper, data):
@@ -76,28 +83,13 @@ f"""<root>
     for e in filter(lambda x: x.tag == 'strong', depth_first_elements(xhtml)):
         yield e.text
 
-def to_numbered_field(x):
-    if len(x[0].split('#', 1)) > 1:
-        name, idx = x[0].split('#', 1)
-        idx = int(idx)
-    else:
-        name, idx = x[0], 0
-
-    return (name, int(idx)), x[1]
-
-def to_unnumbered_field(x):
-    return x[0][0], x[1]
-
-def in_same_row(name):
-    if len(name.split('#', 1)) > 1:
-        _, row = name.split('#', 1)
-    return lambda x: x[0][1] == int(row)
 
 def unzip(_iter: List[Tuple[AutoCompleteEdit, FloatEdit, AutoCompleteEdit]]) -> Tuple[
     List[AutoCompleteEdit], List[FloatEdit], List[AutoCompleteEdit]
 ]:
     return zip(*_iter)
 
+
 def extract_values(x: Union[List[AutoCompleteEdit], List[FloatEdit]]) -> Iterable[str]:
     if isinstance(x, (list, tuple)):
         if len(x) == 0:
@@ -105,8 +97,6 @@ def extract_values(x: Union[List[AutoCompleteEdit], List[FloatEdit]]) -> Iterabl
         return ( v.get_edit_text() for v in x )
     raise Exception(f"Unsupported type: {type(x)}")
 
-def to_named_value(name: str) -> Callable[[str], Tuple[str,str]]:
-    return lambda e: (f'{name}#{e[0]}', e[1])
 
 def blank_ingredients_row(idx: int) -> Tuple[AutoCompleteEdit, FloatEdit, AutoCompleteEdit]:
     return (
@@ -470,15 +460,6 @@ class RecipeEditor(FocusWidget):
         connect_signal(self.buttons['exit'], 'click', lambda _: show_or_exit('esc'))
         connect_signal(self.instructions, 'postchange', lambda w,_: self.update(w))
 
-        header = Text(u'Recipe Editor', 'center')
-        _copyright = Text(COPYRIGHT, 'center')
-
-        banner = Pile([
-            Padding(header, 'center', width=('relative', 100)),
-            Padding(_copyright, 'center', width=('relative', 100)),
-        ])
-        banner = AttrMap(banner, 'banner')
-        
         left_pane, middle_pane, right_pane, gutter = self.init_ingredients()
 
         self.components = {
@@ -507,7 +488,7 @@ class RecipeEditor(FocusWidget):
         }
 
         widget = Pile([
-            banner,
+            banner(u'Recipe Editor'),
             Divider(),
             self.components['top_pane'],
             Columns([

+ 15 - 48
app/activities/TransactionEditor.py

@@ -25,13 +25,16 @@ from urwid import (
     Edit,
     Filler,
     LineBox,
-    Padding,
     Pile,
     Text,
 )
 
-from .. import COPYRIGHT
 from ..data.QueryManager import QueryManager
+from .grouped_widget_util import (
+    to_numbered_field,
+    to_unnumbered_field,
+    to_named_value,
+)
 from ..widgets import (
     AutoCompleteEdit,
     AutoCompleteFloatEdit,
@@ -43,31 +46,14 @@ from ..widgets import (
 from . import ActivityManager
 from .Rating import Rating
 from .NewProduct import NewProduct
-
-def to_numbered_field(x):
-    if len(x[0].split('#', 1)) > 1:
-        name, idx = x[0].split('#', 1)
-        idx = int(idx)
-    else:
-        name, idx = x[0], 0
-
-    return (name, int(idx)), x[1]
-
-def to_unnumbered_field(x):
-    return x[0][0], x[1]
-
-def in_same_row(name):
-    if len(name.split('#', 1)) > 1:
-        _, row = name.split('#', 1)
-    else:
-        row = 0
-    return lambda x: x[0][1] == int(row)
+from .Banner import banner
 
 def unzip(_iter: List[Tuple[AutoCompleteEdit, Edit]]) -> Tuple[
     List[AutoCompleteEdit], List[Edit]
 ]:
     return zip(*_iter)
 
+
 def extract_values(x: Union[List[AutoCompleteEdit], List[Edit]]) -> Iterable[str]:
     if isinstance(x, (list, tuple)):
         if len(x) == 0:
@@ -75,8 +61,6 @@ def extract_values(x: Union[List[AutoCompleteEdit], List[Edit]]) -> Iterable[str
         return ( v.get_edit_text() for v in x )
     raise Exception(f"Unsupported type: {type(x)}")
 
-def to_named_value(name: str) -> Callable[[str], Tuple[str,str]]:
-    return lambda e: (f'{name}#{e[0]}', e[1])
 
 def blank_tags_row(idx: int) -> Tuple[AutoCompleteEdit, Edit]:
     return (
@@ -110,10 +94,11 @@ class TransactionEditor(FocusWidget):
 
     def apply_choice(self, name, value):
         self.apply_changes(name, value)
-        data = dict(#filter(
-        #    in_same_row(name),
-            map(to_numbered_field, self.data.items())
-        )#)
+        data = {
+            (field, idx): v for field, idx, v in map(
+                to_numbered_field, self.data.items()
+            )
+        }
         for k,v in data.items():
             if f'{k[0]}#{k[1]}' == name or v:
                 continue
@@ -172,16 +157,12 @@ class TransactionEditor(FocusWidget):
 
 
     def init_tags(self):
-        #_tags = LineBox(Pile([AttrMap(
         _tags = Pile([AttrMap(
             AutoCompletePopUp(
                 tag[0],
                 self.apply_choice,
                 lambda: self.activity_manager.show(self.update())
             ), 'streak') for tag in self._tags])
-        #    title=f'Tags',
-        #    title_align='left'
-        #)
         gutter = Pile([
             *[ Divider() for _ in self._tags[:-1] ],
             Divider(),
@@ -196,7 +177,6 @@ class TransactionEditor(FocusWidget):
             blank_tags_row(len(self._tags))
         )
         _tags, gutter = self.init_tags()
-        #self.components['tags'][1].original_widget.contents = list(_tags.original_widget.contents)
         self.components['tags'].contents = list(_tags.contents)
         self.components['gutter'][1].contents = list(gutter.contents)
         for widget in self._tags:
@@ -204,16 +184,13 @@ class TransactionEditor(FocusWidget):
             connect_signal(widget[0], 'apply', lambda w, name: self.autocomplete_callback(
                 w, name, self.autocomplete_options(name, dict(map(
                     to_unnumbered_field,
-                    #filter(
-                    #    in_same_row(name),
-                        map(to_numbered_field, self.data.items()
-                    #)
+                    map(to_numbered_field, self.data.items()
                 ))))
             ))
 
 
     def clear(self):
-        self._tags = []
+        self._tags: List[Tuple[AutoCompleteEdit, Edit]] = []
         self.add_tag()
         for (k, ef) in self.edit_fields.items():
             if k in ('ts', 'store',):
@@ -421,10 +398,7 @@ class TransactionEditor(FocusWidget):
             connect_signal(ef, 'apply', lambda w, name: self.autocomplete_callback(
                 w, name, self.autocomplete_options(name, dict(map(
                     to_unnumbered_field,
-                    #filter(
-                    #    in_same_row(name),
                         map(to_numbered_field, self.data.items()
-                    #)
                 ))))
             ))
 
@@ -438,8 +412,6 @@ class TransactionEditor(FocusWidget):
                 title=k.title(), title_align='left'
             ) for k in self.edit_fields if k != 'product'
         })
-        header = Text(u'Fill Transaction', 'center')
-        _copyright = Text(COPYRIGHT, 'center')
 
         self.components.update({
             'bottom_button_bar': Columns(
@@ -472,11 +444,6 @@ class TransactionEditor(FocusWidget):
         connect_signal(self.buttons['clear'], 'click', lambda _: self.clear())
         connect_signal(self.buttons['add'], 'click', lambda _: self.add_tag())
 
-        banner = Pile([
-            Padding(header, 'center', width=('relative', 100)),
-            Padding(_copyright, 'center', width=('relative', 100)),
-        ])
-        banner = AttrMap(banner, 'banner')
         _widgets.update({
             'product': LineBox(Columns([
                 AttrMap(AutoCompletePopUp(
@@ -507,7 +474,7 @@ class TransactionEditor(FocusWidget):
         self.add_tag()
 
         widget = Pile([
-            banner,
+            banner(u'Fill Transaction'),
             Divider(),
             Columns([
                 self.components['main_pane'],

+ 30 - 0
app/activities/grouped_widget_util.py

@@ -0,0 +1,30 @@
+#
+# Copyright (c) Daniel Sheffield 2023
+#
+# All rights reserved
+#
+# THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
+from typing import Callable, Tuple
+
+
+def to_numbered_field(x: Tuple[str, str]) -> Tuple[str, int, str]:
+    if len(x[0].split('#', 1)) > 1:
+        name, idx = x[0].split('#', 1)
+        idx = int(idx)
+    else:
+        name, idx = x[0], 0
+
+    return name, int(idx), x[1]
+
+def to_unnumbered_field(x: Tuple[str, int, str]) -> Tuple[str, str]:
+    return x[0], x[2]
+
+def in_same_row(name: str) -> Callable[[Tuple[str, int, str]], bool]:
+    if len(name.split('#', 1)) > 1:
+        _, row = name.split('#', 1)
+    else:
+        row = 0
+    return lambda x: x[1] == int(row)
+
+def to_named_value(name: str) -> Callable[[Tuple[int, str]], Tuple[str, str]]:
+    return lambda e: (f'{name}#{e[0]}', e[1])

+ 1 - 2
app/rest/pyapi.py

@@ -4,9 +4,9 @@
 #
 # THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
 import os
+from threading import Thread
 from bottle import route, request, response, template, static_file
 from psycopg import connect
-from threading import Thread
 
 from .route_decorators import normalize, poison, cursor
 from .query_to_xml import get_categories, get_groups, get_products, get_tags
@@ -73,4 +73,3 @@ def products(cur):
 def tags(cur):
     response.content_type = 'application/xhtml+xml; charset=utf-8'
     return get_tags(cur, request.query)
-

+ 4 - 3
app/rest/query_to_xml.py

@@ -44,7 +44,9 @@ def get_inner_query(query: FormsDict) -> SQL:
 def render_form(cur: Cursor, inner: str, query: FormsDict):
     _filter = get_filter(query, allow=PARAMS)
     data = DataFrame(get_data(cur, inner)).dropna()
-    return get_form(request.path.split('/')[-1], 'get', _filter, BOOLEAN.get(query.organic, None), data)
+    action = request.path.split('/')[-1]
+    organic = BOOLEAN.get(query.organic, None)
+    return get_form(action, 'get', _filter, organic, data)
 
 
 def get_xml(cur: Cursor, sql: str):
@@ -100,7 +102,7 @@ WHERE q.category IS NULL
 
 def get_tags(cur: Cursor, query: FormsDict):
     form = template('form-nav', action='tags', method='get', params=[
-        {'name': k, 'value': request.params[k]} for k in request.params if k in PARAMS
+        {'name': k, 'value': query[k]} for k in query if k in PARAMS
     ])
     sql = SQL("""
 SELECT * FROM (SELECT count(DISTINCT txn.id) AS "Uses", tg.name AS "Name"
@@ -117,4 +119,3 @@ JOIN transactions txn ON txn.id = tm.transaction_id
 """).as_string(cur)
     xml = get_xml(cur, sql)
     return template("query-to-xml", title="Tags", xml=xml, form=form)
-

+ 4 - 5
app/rest/route_decorators.py

@@ -3,9 +3,9 @@
 # All rights reserved
 #
 # THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
-from bottle import request, FormsDict, redirect
 from typing import Callable, Iterable
 from urllib.parse import urlencode
+from bottle import request, FormsDict, redirect
 from psycopg import Connection
 from psycopg.connection import TupleRow
 
@@ -25,10 +25,10 @@ def normalize_query(query: FormsDict, allow: Iterable[str] = None) -> str:
     ])
 
 
-def _normalize_decorator(func: Callable, poison_on_reload: bool = False):
+def _normalize_decorator(func: Callable):
     def wrap(*args, **kwargs):
         _, _, path, *_ = request.urlparts
-        normalized = normalize_query(request.query, allow=PARAMS)
+        normalized = normalize_query(request.params, allow=PARAMS)
         if request.query_string != normalized:
             return redirect(f'{path}?{normalized}')
         return func(*args, **kwargs)
@@ -44,7 +44,7 @@ def normalize(*args, **kwargs):
 
 def _poison_decorator(func: Callable, cache: Cache = None):
     def wrap(*args, **kwargs):
-        normalized = normalize_query(request.query, allow=PARAMS)
+        normalized = normalize_query(request.params, allow=PARAMS)
         if request.params.get('reload') == 'true':
             cache.remove(normalized)
         return func(*args, **kwargs)
@@ -71,4 +71,3 @@ def cursor(*args, **kwargs):
     if not len(args):
         return lambda f: _cursor_decorator(f, **kwargs)
     raise Exception("decorator argument required")
-

+ 7 - 5
app/rest/trend.py

@@ -4,19 +4,20 @@
 # All rights reserved
 #
 # THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
+from io import StringIO
+from queue import Queue
 from bottle import (
     DictProperty,
     HTTPError,
     template,
 )
-from io import StringIO
 import matplotlib.pyplot as plt
 import matplotlib
 import seaborn as sns
 from psycopg import Connection
 from psycopg.connection import TupleRow
-from queue import Queue
-from . import ALL_UNITS, BOOLEAN
+
+from . import ALL_UNITS, BOOLEAN, PARAMS
 from ..data.QueryManager import (
     display_mapper,
     QueryManager,
@@ -30,7 +31,6 @@ from ..activities.Plot import (
 from .form import(
     get_form,
 )
-from . import PARAMS
 
 matplotlib.use('agg')
 
@@ -118,7 +118,9 @@ def trend_internal(conn: Connection[TupleRow], path: str, query: DictProperty):
             f = StringIO()
             plt.savefig(f, format='svg')
             _filter = get_filter(query, allow=PARAMS)
-            form = get_form(path.split('/')[-1], 'get', _filter, BOOLEAN.get(query.organic, None), data)
+            organic = BOOLEAN.get(query.organic, None)
+            action = path.split('/')[-1]
+            form = get_form(action, 'get', _filter, organic, data)
             
             progress[-1]["status"] = "done"
             yield template("loading", progress=progress)

+ 1 - 2
test/activities/test_Rating.py

@@ -1,10 +1,9 @@
 #
-# Copyright (c) Daniel Sheffield 2021 - 2023
+# Copyright (c) Daniel Sheffield 2023
 #
 # All rights reserved
 #
 # THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
-import numpy as np
 from app.activities.Rating import Rating
 from pytest import mark, fixture
 from urwid import Text

+ 77 - 0
test/activities/test_grouped_widget_util.py

@@ -0,0 +1,77 @@
+#
+# Copyright (c) Daniel Sheffield 2023
+#
+# All rights reserved
+#
+# THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
+from pytest import mark, fixture
+from itertools import chain
+from app.activities.grouped_widget_util import (
+    to_named_value,
+    to_numbered_field,
+    to_unnumbered_field,
+    in_same_row,
+)
+
+@mark.parametrize("idx", chain(range(2), [None, 'any']))
+@mark.parametrize("key, value", [
+    ('label', 'value'),
+    ('label2', 3),
+    ('checkbox', True),
+    ('checkbox', 'mixed'),
+])
+def test_to_unnumbered_field(key, idx, value):
+    assert (key, value) == to_unnumbered_field((key, idx, value))
+
+@mark.parametrize("key, value, expected", [
+    ('label#1', 'value', ('label', 1, 'value')),
+    ('label2#0', 3, ('label2', 0, 3)),
+    ('checkbox', True, ('checkbox', 0, True)),
+    ('checkbox', 'mixed', ('checkbox', 0, 'mixed')),
+])
+def test_to_numbered_field(key, value, expected):
+    assert expected == to_numbered_field((key, value))
+
+@mark.parametrize("key, values, expected", [
+    ('label#1', [
+        ('label', 1, 'value'),
+        ('label2', 0, 3),
+        ('checkbox', 0, True),
+        ('checkbox', 0, 'mixed'),
+        ('label2', 1, 0),
+    ], [
+        ('label', 1, 'value'),
+        ('label2', 1, 0),
+    ]),
+    ('label', [
+        ('label', 1, 'value'),
+        ('label2', 0, 3),
+        ('checkbox', 0, True),
+        ('checkbox', 0, 'mixed'),
+        ('label2', 1, 0),
+    ], [
+        ('label2', 0, 3),
+        ('checkbox', 0, True),
+        ('checkbox', 0, 'mixed'),
+    ]),
+])
+def test_in_same_row(key, values, expected):
+    assert expected == list(filter(in_same_row(key), values))
+
+@mark.parametrize("key, values, expected", [
+    ('label2', [
+        (1, 'value'),
+        (0, 3),
+        (0, True),
+        (0, 'mixed'),
+        (1, 0),
+    ], [
+        ('label2#1', 'value'),
+        ('label2#0', 3),
+        ('label2#0', True),
+        ('label2#0', 'mixed'),
+        ('label2#1', 0)
+    ]),
+])
+def test_to_named_value(key, values, expected):
+    assert expected == list(map(to_named_value(key), values))