Prechádzať zdrojové kódy

refactor price_plot for reuse

Daniel Sheffield 1 rok pred
rodič
commit
8b8d425bc7
2 zmenil súbory, kde vykonal 134 pridanie a 93 odobranie
  1. 71 92
      app/activities/Plot.py
  2. 63 1
      price_plot.py

+ 71 - 92
app/activities/Plot.py

@@ -4,114 +4,93 @@
 # All rights reserved
 #
 # THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
-from app.data.QueryManager import QueryManager, display_mapper
-from datetime import date, datetime
-import seaborn as sns
+from datetime import date
 import pandas as pd
+import seaborn as sns
 import matplotlib.pyplot as plt
-import os
-import sys
-from sqlite3 import Cursor
-import psycopg
-import numpy as np
-from db_credentials import HOST, PASSWORD
-ALL_UNITS = {'g','kg','mL','L','Pieces','Bunches','Bags'}
-host = f'host={HOST}'
-password = f'password={PASSWORD}'
-user = os.getenv('USER')
-conn = psycopg.connect(f"{host} dbname=grocery user={user} {password}")
-cur: Cursor = conn.cursor()
-cur.execute("BEGIN")
+from app.data.QueryManager import QueryManager
+
+def get_data(query_manager: QueryManager, unit=None, **kwargs) -> pd.DataFrame:
+    d = pd.DataFrame(query_manager.get_historic_prices_data(unit, **kwargs))
+    if d.empty:
+        return d
+    print(d.info())
+    print(d)
+    d['ts_month'] = d['ts_raw'].apply(lambda x: date(x.date().year, x.date().month,1))
+    d[['price','quantity']] = d[['price','quantity']].apply(
+        lambda y: y.apply(lambda x: x and float(x)),
+    )
+    return d
+
+def pivot_data(data: QueryManager):
+    pivot = data.groupby(['ts_month','group',])['price', 'quantity'].sum()
+    pivot = pivot.reset_index().set_index('group')
+    print(pivot.info())
+    print(pivot)
+    return pivot
+
+def pie(p: pd.DataFrame, col=None, title=None):
+    ax = p.plot.pie(y=col, figsize=(5, 5))
+    ax.get_legend().remove()
+    ax.set_xlabel('')
+    ax.set_ylabel('')
+    ax.set_title(title)
+    plt.show()
+
+def line(pivot, ylabel=None, xlabel=None):
+    ax = sns.lineplot(data=pivot, markers=True)
+    ax.set_xlabel(xlabel)
+    ax.set_ylabel(ylabel)
+    plt.show()
 
-query_manager = QueryManager(cur, display_mapper)
-fields = dict()
-for k in ('group', 'category', 'product', 'unit'):
-    options = query_manager.unique_suggestions(k, **fields)
-    names = [ o for o in options ]
-    matches = '\t'.join(names)
-    print(f'{k.title()} names: {matches}')
+def get_selection(
+    query_manager: QueryManager,
+    fields: dict[str, str],
+    units: set[str],
+    name: str
+):
+    options = query_manager.unique_suggestions(name, **fields)
+    matches = '\t'.join(list(options))
+    print(f'{name.title()} names: {matches}')
     while (len(options) >=1):
-        v = fields[k] if k in fields else ''
-        fields[k] = v + input(f"{k.title()}: {v}")
-        options = query_manager.unique_suggestions(k, **fields)
-        if k == 'unit':
-            options = sorted(set(options) | ALL_UNITS)
+        v = fields[name] if name in fields else ''
+        fields[name] = v + input(f"{name.title()}: {v}")
+        options = query_manager.unique_suggestions(name, **fields)
+        if name == 'unit':
+            options = sorted(set(options) | units)
 
-        if fields[k] == '':
+        if fields[name] == '':
             choice = ''
             break
         elif len(options) == 1:
             choice = options[0]
             break
-        elif fields[k].lower() in map(lambda x: x.lower(), options):
+        elif fields[name].lower() in map(lambda x: x.lower(), options):
             choice = next(
-                filter(lambda x: x.lower() == fields[k].lower(), options)
+                filter(lambda x: x.lower() == fields[name].lower(), options)
             )
             break
-        elif len(options) == 0 and k == 'unit':
-            choice = fields[k]
+        elif len(options) == 0 and name == 'unit':
+            choice = fields[name]
             break
         matches = '\t'.join(options)
-        print(f'Matches ({k}): {matches}')
-
-    if k != 'unit':
-        fields[k] = (fields[k] and choice) or ''
-    else:
-        fields[k] = choice
-    if fields[k] == '':
-        print(f'Ignoring {k} {choice} as it does not exist')
-    else:
-        print(f'Selected {k}: {fields[k]}')
-    print()
-
-fields = dict((k,v or None) for k,v in fields.items())
-unit = fields['unit'] or 'kg'
-if unit not in ALL_UNITS:
-    print(f'Invalid unit: {unit}')
-    exit(2)
+        print(f'Matches ({name}): {matches}')
 
-print(f'Getting data for selection:\n  ')
-print(f'\n  '.join([f'{k.title()}: {v}' for k,v in fields.items()]))
-d = pd.DataFrame(query_manager.get_historic_prices_data(unit, **dict((k,v) for k,v in fields.items() if k != 'unit')))
-if d.empty:
-    sys.exit(1)
-print(d.info())
-print(d)
-sns.set_theme()
-now = datetime.now().date()
-d['ts_month'] = d['ts_raw'].apply(lambda x: date(x.date().year, x.date().month,1))
-d[['price','quantity']] = d[['price','quantity']].apply(
-    lambda y,axis=None: y.apply(lambda x: x and float(x)), #if y.name in ('price','quantity')
-    axis=0
-)
-p = d[d['ts_month'] == date(now.year,now.month,1) ].groupby(['ts_month','group',])['price', 'quantity'].sum()
+    return choice
 
-if not p.empty:
-  p = p.reset_index().set_index('group')
-  p['price'] = p['price'].apply(float)
-  p['quantity'] = p['quantity'].apply(float)
-  print(p.info())
-  print(p)
-  ax = p.plot.pie(y='quantity', figsize=(5, 5))
-  ax.get_legend().remove()
-  ax.set_xlabel('')
-  ax.set_ylabel('')
-  ax.set_title(f'Quantity ({unit})')
-  plt.show()
-  ax = p.plot.pie(y='price', figsize=(5, 5))
-  ax.get_legend().remove()
-  ax.set_xlabel('')
-  ax.set_ylabel('')
-  ax.set_title(f'Price ($)')
-  plt.show()
-
-p = d.pivot_table(index=['ts_raw',], columns=['product',], values=['$/unit'], aggfunc='mean')
-p.columns = p.columns.droplevel()
-print(p.info())
-print(p)
-ax = sns.lineplot(data=p, markers=True)
-ax.set_xlabel('Time')
-ax.set_ylabel(f'$ / {unit}')
-plt.show()
+def get_input(query_manager: QueryManager, units: set[str]) -> dict[str, str]:
+    fields = dict()
+    for k in ('group', 'category', 'product', 'unit'):
+        choice = get_selection(query_manager, fields, units, k)
 
+        if k != 'unit':
+            fields[k] = (fields[k] and choice) or ''
+        else:
+            fields[k] = choice
+        if fields[k] == '':
+            print(f'Ignoring {k} {choice} as it does not exist')
+        else:
+            print(f'Selected {k}: {fields[k]}')
+        print()
+    return fields
 

+ 63 - 1
price_plot.py

@@ -1 +1,63 @@
-from app.activities.Plot import *
+#
+# Copyright (c) Daniel Sheffield 2023
+#
+# All rights reserved
+#
+# THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
+from datetime import date, datetime
+import os
+import sys
+from psycopg import connect, Cursor
+import seaborn as sns
+from db_credentials import HOST, PASSWORD
+from app.activities.Plot import (
+    get_data,
+    get_input,
+    pivot_data,
+    pie,
+    line,
+)
+from app.data.QueryManager import QueryManager, display_mapper
+
+ALL_UNITS = {'g','kg','mL','L','Pieces','Bunches','Bags'}
+host = f'host={HOST}'
+password = f'password={PASSWORD}'
+user = os.getenv('USER')
+conn = connect(f"{host} dbname=grocery user={user} {password}")
+cur: Cursor = conn.cursor()
+cur.execute("BEGIN")
+
+query_manager = QueryManager(cur, display_mapper)
+
+fields = get_input(query_manager, ALL_UNITS)
+
+unit = fields['unit'] = fields['unit'] or 'kg'
+fields = { k: v or None for k,v in fields.items() }
+if unit not in ALL_UNITS:
+    print(f'Invalid unit: {unit}')
+    exit(2)
+
+print('Getting data for selection:\n  ')
+print('\n  '.join([
+    f'{k.title()}: {v}' for k,v in fields.items()
+]))
+
+data = get_data(query_manager, **fields)
+if data.empty:
+    sys.exit(1)
+
+now = datetime.now().date()
+pivot = pivot_data(data[data['ts_month'] == date(now.year,now.month,1)])
+
+sns.set_theme()
+
+if not pivot.empty:
+    pie(pivot, col='quantity', title=f'Quantity ({unit})')
+    pie(pivot, col='price', title='Price ($)')
+
+pivot = data.pivot_table(index=['ts_raw',], columns=['product',], values=['$/unit'], aggfunc='mean')
+pivot.columns = pivot.columns.droplevel()
+print(pivot.info())
+print(pivot)
+
+line(pivot, xlabel='Time', ylabel=f'$ / {unit}')