Parcourir la source

refactor to tidy up

Daniel Sheffield il y a 1 an
Parent
commit
35701cfc47
3 fichiers modifiés avec 44 ajouts et 38 suppressions
  1. 37 31
      app/data/PriceView.py
  2. 5 0
      app/data/util.py
  3. 2 7
      app/rest/pyapi.py

+ 37 - 31
app/data/PriceView.py

@@ -16,42 +16,51 @@ from psycopg.sql import (
 )
 from .util import get_select, get_from
 
+def get_window(unit, organic):
+    window = SQL(f"""(
+PARTITION BY {'organic,' if organic is not None else ''} product_id
+ORDER BY convert_unit(units.name, {{unit}}, products.name) NULLS FIRST, ts
+ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
+)""").format(unit=Literal(unit))
+
+    return window
+
 def get_selectors(
     unit: str,
-    product: str,
-    window: SQL
+    organic: str
 ) -> OrderedDict[Tuple[str, Composable]]:
+    window = get_window(unit, organic)
     return  OrderedDict([
         ('id', Identifier('transactions', 'id')),
         ('ts_raw', SQL("""(transactions.ts AT TIME ZONE 'UTC')::timestamp without time zone""")),
         ('%d/%m/%y %_I%P', SQL("""(transactions.ts AT TIME ZONE 'UTC')::timestamp without time zone""")),
         ('code', Identifier('stores', 'code')),
         ('$/unit', SQL("""TRUNC(
-    price / quantity / convert_unit(units.name, {unit}, {product}), 4
-    )""").format(unit=Literal(unit), product=Literal(product))),
+    price / quantity / convert_unit(units.name, {unit}, products.name), 4
+    )""").format(unit=Literal(unit))),
         ('last', SQL("""TRUNC(last_value(
-    price / quantity / convert_unit(units.name, {unit}, {product})
+    price / quantity / convert_unit(units.name, {unit}, products.name)
 ) OVER {window}, 4)
-""").format(unit=Literal(unit), product=Literal(product), window=window)),
+""").format(unit=Literal(unit), window=window)),
         ('avg', SQL("""TRUNC(sum(CASE
-    WHEN convert_unit(units.name, {unit}, {product}) IS NOT NULL THEN price
+    WHEN convert_unit(units.name, {unit}, products.name) IS NOT NULL THEN price
     ELSE NULL
 END) OVER {window} / sum(
-    quantity * convert_unit(units.name, {unit}, {product})
+    quantity * convert_unit(units.name, {unit}, products.name)
 ) OVER {window}, 4)
-""").format(unit=Literal(unit), product=Literal(product), window=window)),
+""").format(unit=Literal(unit), window=window)),
         ('min', SQL("""TRUNC(min(
-    price / quantity / convert_unit(units.name, {unit}, {product})
+    price / quantity / convert_unit(units.name, {unit}, products.name)
 ) OVER {window}, 4)
-""").format(unit=Literal(unit), product=Literal(product), window=window)),
+""").format(unit=Literal(unit), window=window)),
         ('max', SQL("""TRUNC(max(
-    price / quantity / convert_unit(units.name, {unit}, {product})
+    price / quantity / convert_unit(units.name, {unit}, products.name)
 ) OVER {window}, 4)
-""").format(unit=Literal(unit), product=Literal(product), window=window)),
+""").format(unit=Literal(unit), window=window)),
         ('price', SQL("""TRUNC(price, 4)""")),
         ('quantity', SQL("""TRUNC(
-    quantity * convert_unit(units.name, {unit}, {product}), 4
-)""").format(unit=Literal(unit), product=Literal(product))),
+    quantity * convert_unit(units.name, {unit}, products.name), 4
+)""").format(unit=Literal(unit))),
         ('product', Identifier('products', 'name')),
         ('category', Identifier('categories', 'name')),
         ('group', Identifier('groups', 'name')),
@@ -106,11 +115,15 @@ def get_where(product=None, category=None, group=None, organic=None, limit='90 d
     ]) if where else SQL('')
 
 def get_historic_prices_statement(unit, sort=None, product=None, category=None, group=None, organic=None, limit='90 days'):
-    window = SQL(f"""(
-PARTITION BY {'organic,' if organic is not None else ''} product_id
-ORDER BY convert_unit(units.name, {{unit}}, {{product}}) NULLS FIRST, ts
-ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
-)""").format(unit=Literal(unit), product=Literal(product))
+    
+    return SQL('\n').join([
+        get_select(get_selectors(unit, organic)),
+        get_from("transactions", JOINS),
+        get_where(product=product, category=category, group=group, organic=organic, limit=limit),
+        get_sort(sort, organic),
+    ])
+
+def get_sort(sort, organic):
     organic_sort = f"{'organic,' if organic is not None else ''}"
     sort_sql = SQL('').join([
         SQL('{sort} {direction},').format(
@@ -118,16 +131,9 @@ ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
             direction=SQL('DESC' if sort == 'ts' else 'ASC')
         ),
     ]) if sort is not None else SQL('')
-
-    statement = SQL('\n').join([
-        get_select(get_selectors(unit, product, window)),
-        get_from("transactions", JOINS),
-        get_where(product=product, category=category, group=group, organic=organic, limit=limit),
-        SQL("""
+    return SQL("""
 ORDER BY {organic_sort} {sort} code, product, category, "group", "$/unit" ASC, ts DESC
 """).format(
-            sort=sort_sql,
-            organic_sort=SQL(organic_sort),
-       ),
-    ])
-    return statement
+        sort=sort_sql,
+        organic_sort=SQL(organic_sort),
+    )

+ 5 - 0
app/data/util.py

@@ -10,6 +10,11 @@ from psycopg.sql import (
     SQL,
     Composable,
 )
+def get_include_exclude(value):
+    include, exclude, *_ = [
+            *map(lambda x: x.split('|') if x else [], value.split('!')), []
+    ]
+    return list(set(include)), list(set(exclude))
 
 def get_select(alias_to_sql: dict[str,Composable]) -> Composable:
     select = SQL(""",

+ 2 - 7
app/rest/pyapi.py

@@ -4,9 +4,8 @@
 #
 # THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
 from bottle import route, request, run, response, abort
-from psycopg import Cursor, connect
+from psycopg import connect
 from psycopg.sql import SQL, Literal
-from datetime import date, datetime
 import os
 import matplotlib.pyplot as plt
 import seaborn as sns
@@ -14,11 +13,7 @@ from ..activities.Plot import (
     get_data,
 )
 from ..data.QueryManager import QueryManager, display_mapper
-def get_include_exclude(value):
-    include, exclude, *_ = [
-            *map(lambda x: x.split('|') if x else [], value.split('!')), []
-    ]
-    return list(set(include)), list(set(exclude))
+from ..data.util import get_include_exclude
 
 def line(pivot, ylabel=None, xlabel=None):
     ax = sns.lineplot(data=pivot, markers=True)