Browse Source

add iclude/exclude filter support to get_historic_price_data

Daniel Sheffield 1 year ago
parent
commit
0f7880fd04
3 changed files with 70 additions and 65 deletions
  1. 17 20
      app/data/PriceView.py
  2. 24 1
      app/data/util.py
  3. 29 44
      app/rest/pyapi.py

+ 17 - 20
app/data/PriceView.py

@@ -14,7 +14,12 @@ from psycopg.sql import (
     Literal,
     Composable,
 )
-from .util import get_select, get_from
+from .util import(
+    get_include_exclude,
+    get_select,
+    get_from,
+    get_where_include_exclude
+)
 
 def get_window(unit, organic):
     window = SQL(f"""(
@@ -75,26 +80,17 @@ JOINS = OrderedDict([
     ('groups', ('id', 'group_id')),
 ])
 
+
 def get_where(product=None, category=None, group=None, organic=None, limit='90 days'):
-    where = [ ]
-    if product is not None:
-        where.append(SQL(' ').join([
-            Identifier('products', 'name'),
-            SQL('='),
-            Literal(product)
-        ]))
-    if category is not None:
-        where.append(SQL(' ').join([
-            Identifier('categories', 'name'),
-            SQL('='),
-            Literal(category)
-        ]))
-    if group is not None:
-        where.append(SQL(' ').join([
-            Identifier('groups', 'name'),
-            SQL('='),
-            Literal(group)
-        ]))
+    where = [
+        get_where_include_exclude(
+            k, 'name', *get_include_exclude(v)
+        ) for k, v in {
+            'products': product,
+            'categories': category,
+            'groups': group
+        }.items()
+    ]
     if organic is not None:
         where.append(SQL(' ').join([
             Identifier('organic'),
@@ -114,6 +110,7 @@ def get_where(product=None, category=None, group=None, organic=None, limit='90 d
         SQL("\n  AND ").join(where),
     ]) if where else SQL('')
 
+
 def get_historic_prices_statement(unit, sort=None, product=None, category=None, group=None, organic=None, limit='90 days'):
     
     return SQL('\n').join([

+ 24 - 1
app/data/util.py

@@ -4,18 +4,39 @@
 # All rights reserved
 #
 # THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
-from typing import Tuple
+from typing import Tuple, Iterable
 from psycopg.sql import (
     Identifier,
+    Literal,
     SQL,
     Composable,
 )
+
+
 def get_include_exclude(value):
+    value = value or ''
     include, exclude, *_ = [
             *map(lambda x: x.split('|') if x else [], value.split('!')), []
     ]
     return list(set(include)), list(set(exclude))
 
+
+def get_where_include_exclude(
+        table: str,
+        col: str,
+        include: Iterable[str],
+        exclude: Iterable[str]
+) -> SQL:
+    return SQL("""
+  ({identifier} = ANY({include}) OR ARRAY[]::text[] @> {include}::text[])
+AND
+  NOT {identifier} = ANY({exclude})
+""").format(
+        identifier=Identifier(table, col),
+        include=Literal(include),
+        exclude=Literal(exclude)
+    )
+
 def get_select(alias_to_sql: dict[str,Composable]) -> Composable:
     select = SQL(""",
     """).join([
@@ -26,6 +47,7 @@ def get_select(alias_to_sql: dict[str,Composable]) -> Composable:
     return SQL("""
     """).join([SQL("SELECT"), *select])
 
+
 def get_from(
     base: str,
     table_to_join_on: dict[Tuple[str, Tuple[str,str]]]
@@ -42,6 +64,7 @@ LEFT JOIN """).format(base=Identifier(base)),
         SQL("""
 LEFT JOIN """).join(joins)])
 
+
 def get_groupby(alias_to_sql: dict[str, Composable]) -> Composable:
     groupby = SQL(""",
     """).join([ v for k, v in alias_to_sql.items() if k != 'tags'])

+ 29 - 44
app/rest/pyapi.py

@@ -13,7 +13,10 @@ from ..activities.Plot import (
     get_data,
 )
 from ..data.QueryManager import QueryManager, display_mapper
-from ..data.util import get_include_exclude
+from ..data.util import(
+    get_include_exclude,
+    get_where_include_exclude
+)
 
 def line(pivot, ylabel=None, xlabel=None):
     ax = sns.lineplot(data=pivot, markers=True)
@@ -41,7 +44,6 @@ def trend():
             unit = fields['unit'] = fields['unit'] if 'unit' in fields else None or 'kg'
             if unit not in ALL_UNITS:
                 raise abort(400, f"Unsupported unit {unit}")
-            
             data = get_data(query_manager, **fields)
             if data.empty:
                 raise abort(404, f"No data for {fields}")
@@ -144,29 +146,22 @@ ORDER BY 1', false, false, ''::text)
 
 @route('/grocery/categories')
 def categories():
-    fields = { 'group': ([],[]) }
+    fields = { 'group': '' }
+    fields.update({
+        k: request.query[k] for k in request.query.keys() if k == 'group'
+    })
     try:
         with conn.cursor() as cur:
-            fields.update({
-                k: get_include_exclude(
-                    request.query[k]
-                ) for k in request.query.keys() if k == 'group'
-            })
-            inner = SQL("""
-SELECT
-  c.name AS "Category",
-  g.name AS "Group"
+            inner = SQL('\n').join([SQL("""
+SELECT c.name AS "Category",  g.name AS "Group"
 FROM categories c
 JOIN groups g ON c.group_id = g.id
 WHERE
-  (g.name = ANY({group}) OR array_length({group}::text[], 1) IS NULL)
-AND
-  NOT g.name = ANY({ex_group})
+"""), get_where_include_exclude(
+    "g", "name", *get_include_exclude(fields['group'])
+), SQL("""
 ORDER BY 1, 2
-""").format(
-  **{ k: Literal(i) for k, (i, _) in fields.items() },
-  **{ f'ex_{k}': Literal(e) for k, (_, e) in fields.items() }
-).as_string(cur)
+""")]).as_string(cur)
             xml = cur.execute(SQL("""
 SELECT query_to_xml_and_xmlschema({inner}, false, false, ''::text)
 """).format(inner=Literal(inner))).fetchone()[0].splitlines()
@@ -178,39 +173,29 @@ SELECT query_to_xml_and_xmlschema({inner}, false, false, ''::text)
 @route('/grocery/products')
 def products():
     fields = {
-        'group': ([],[]),
-        'category': ([],[]),
+        'group': '',
+        'category': ''
     }
+    fields.update({
+        k: request.query[k] for k in request.query.keys() if k in (
+            'group', 'category'
+        )
+    })
     try:
         with conn.cursor() as cur:
-            fields.update({
-                k: get_include_exclude(
-                    request.query[k]
-                ) for k in request.query.keys() if k in (
-                    'group', 'category'
-                )
-            })
-            inner = SQL("""
-SELECT
-  p.name AS "Product",
-  c.name AS "Category",
-  g.name AS "Group"
+            inner = SQL('\n').join([SQL("""
+SELECT p.name AS "Product", c.name AS "Category", g.name AS "Group"
 FROM products p
 JOIN categories c ON p.category_id = c.id
 JOIN groups g ON c.group_id = g.id
 WHERE
-  (g.name = ANY({group}) OR array_length({group}::text[], 1) IS NULL)
-AND
-  NOT g.name = ANY({ex_group})
-AND
-  (c.name = ANY({category}) OR array_length({category}::text[], 1) IS NULL)
-AND
-  NOT c.name = ANY({ex_category})
+"""), SQL('\nAND\n').join([get_where_include_exclude(
+    "g", "name", *get_include_exclude(fields['group'])
+), get_where_include_exclude(
+    "c", "name", *get_include_exclude(fields['category'])
+)]), SQL("""
 ORDER BY 1, 2, 3
-""").format(
-  **{ k: Literal(i) for k, (i, _) in fields.items() },
-  **{ f'ex_{k}': Literal(e) for k, (_, e) in fields.items() }
-).as_string(cur)
+""")]).as_string(cur)
             xml = cur.execute(SQL("""
 SELECT query_to_xml_and_xmlschema({inner}, false, false, ''::text)
 """).format(inner=Literal(inner))).fetchone()[0].splitlines()