Browse Source

remove price_plot in favour of REST api

Daniel Sheffield 1 year ago
parent
commit
292d3e5d1a
3 changed files with 54 additions and 211 deletions
  1. 0 92
      app/activities/Plot.py
  2. 54 56
      app/rest/trend.py
  3. 0 63
      price_plot.py

+ 0 - 92
app/activities/Plot.py

@@ -1,92 +0,0 @@
-#
-# Copyright (c) Daniel Sheffield 2023
-#
-# All rights reserved
-#
-# THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
-from datetime import date
-import pandas as pd
-import seaborn as sns
-import matplotlib.pyplot as plt
-from app.data.QueryManager import QueryManager
-
-def get_data(query_manager: QueryManager, unit=None, **kwargs) -> pd.DataFrame:
-    d = pd.DataFrame(query_manager.get_historic_prices_data(unit, **kwargs))
-    if d.empty:
-        return d
-    d['ts_month'] = d['ts_raw'].apply(lambda x: date(x.date().year, x.date().month,1))
-    d[['price','quantity']] = d[['price','quantity']].apply(
-        lambda y: y.apply(lambda x: x and float(x)),
-    )
-    return d
-
-def pivot_data(data: QueryManager):
-    pivot = data.groupby(['ts_month','group',])['price', 'quantity'].sum()
-    pivot = pivot.reset_index().set_index('group')
-    return pivot
-
-def pie(p: pd.DataFrame, col=None, title=None):
-    ax = p.plot.pie(y=col, figsize=(5, 5))
-    ax.get_legend().remove()
-    ax.set_xlabel('')
-    ax.set_ylabel('')
-    ax.set_title(title)
-    plt.show()
-
-def line(pivot, ylabel=None, xlabel=None):
-    ax = sns.lineplot(data=pivot, markers=True)
-    ax.set_xlabel(xlabel)
-    ax.set_ylabel(ylabel)
-    plt.show()
-
-def get_selection(
-    query_manager: QueryManager,
-    fields: dict[str, str],
-    units: set[str],
-    name: str
-):
-    options = query_manager.unique_suggestions(name, **fields)
-    matches = '\t'.join(list(options))
-    print(f'{name.title()} names: {matches}')
-    while (len(options) >=1):
-        v = fields[name] if name in fields else ''
-        fields[name] = v + input(f"{name.title()}: {v}")
-        options = query_manager.unique_suggestions(name, **fields)
-        if name == 'unit':
-            options = sorted(set(options) | units)
-
-        if fields[name] == '':
-            choice = ''
-            break
-        elif len(options) == 1:
-            choice = options[0]
-            break
-        elif fields[name].lower() in map(lambda x: x.lower(), options):
-            choice = next(
-                filter(lambda x: x.lower() == fields[name].lower(), options)
-            )
-            break
-        elif len(options) == 0 and name == 'unit':
-            choice = fields[name]
-            break
-        matches = '\t'.join(options)
-        print(f'Matches ({name}): {matches}')
-
-    return choice
-
-def get_input(query_manager: QueryManager, units: set[str]) -> dict[str, str]:
-    fields = dict()
-    for k in ('group', 'category', 'product', 'unit'):
-        choice = get_selection(query_manager, fields, units, k)
-
-        if k != 'unit':
-            fields[k] = (fields[k] and choice) or ''
-        else:
-            fields[k] = choice
-        if fields[k] == '':
-            print(f'Ignoring {k} {choice} as it does not exist')
-        else:
-            print(f'Selected {k}: {fields[k]}')
-        print()
-    return fields
-

+ 54 - 56
app/rest/trend.py

@@ -28,15 +28,38 @@ from ..data.filter import (
     get_filter,
     get_query_param,
 )
-from ..activities.Plot import (
-    get_data,
-)
 from .form import(
     get_form,
 )
 
 matplotlib.use('agg')
 
+plot_style = {
+    "lines.color": "#ffffff",
+    "patch.edgecolor": "#ffffff",
+    "text.color": "#ffffff",
+    "axes.facecolor": "#7f7f7f",
+    "axes.edgecolor": "#ffffff",
+    "axes.labelcolor": "#ffffff",
+    "xtick.color": "#ffffff",
+    "ytick.color": "#ffffff",
+    "grid.color": "#ffffff",
+    "figure.facecolor": "#7f7f7f",
+    "figure.edgecolor": "#7f7f7f",
+    "savefig.facecolor": "#7f7f7f",
+    "savefig.edgecolor": "#7f7f7f",
+}
+
+def get_data(query_manager: QueryManager, unit=None, **kwargs) -> pd.DataFrame:
+    d = pd.DataFrame(query_manager.get_historic_prices_data(unit, **kwargs))
+    if d.empty:
+        return d
+    d['ts_month'] = d['ts_raw'].apply(lambda x: date(x.date().year, x.date().month,1))
+    d[['price','quantity']] = d[['price','quantity']].apply(
+        lambda y: y.apply(lambda x: x and float(x)),
+    )
+    return d
+
 def abort(code, text):
     raise HTTPError(code, text)
 
@@ -83,20 +106,7 @@ def trend_internal(conn: Connection[TupleRow], path: str, query: FormsDict):
             pivot.columns = pivot.columns.droplevel()
             sns.set_theme(style='darkgrid', palette='pastel', context="talk")
             plt.style.use("dark_background")
-            plt.rcParams.update({
-    "lines.color": "#ffffff",
-    "patch.edgecolor": "#ffffff",
-    "text.color": "#ffffff",
-    "axes.facecolor": "#7f7f7f",
-    "axes.edgecolor": "#ffffff",
-    "axes.labelcolor": "#ffffff",
-    "xtick.color": "#ffffff",
-    "ytick.color": "#ffffff",
-    "grid.color": "#ffffff",
-    "figure.facecolor": "#7f7f7f",
-    "figure.edgecolor": "#7f7f7f",
-    "savefig.facecolor": "#7f7f7f",
-    "savefig.edgecolor": "#7f7f7f"})
+            plt.rcParams.update(plot_style)
             plt.rcParams.update({"grid.linewidth":0.2, "grid.alpha":0.5})
             plt.figure(figsize=[16, 9], layout="tight")
             xlabel='Time'
@@ -179,53 +189,41 @@ def volume_internal(conn: Connection[TupleRow], path: str, query: FormsDict):
             
             now = datetime.now().date()
             data = data[data['ts_month'] == date(now.year,now.month-1,1)]
-            pivot = data.groupby(['ts_month','group',])['price', 'quantity'].sum()
-            pivot = pivot.reset_index().set_index('group')
+            group = 'group'
+            for g, _g in zip(
+                ('category', 'group'),
+                ('product', 'category')
+            ):
+                if g and len(data[g].unique()) != 1:
+                    continue
+                group = _g
+                break
+
+            pivot = data.groupby([group,])[['price', 'quantity']].sum()
             
             if pivot.empty:
                 abort(404, f"No data.")
 
             sns.set_theme(style='darkgrid', palette='pastel', context="talk")
             plt.style.use("dark_background")
-            plt.rcParams.update({
-    "lines.color": "#ffffff",
-    "patch.edgecolor": "#ffffff",
-    "text.color": "#ffffff",
-    "axes.facecolor": "#7f7f7f",
-    "axes.edgecolor": "#ffffff",
-    "axes.labelcolor": "#ffffff",
-    "xtick.color": "#ffffff",
-    "ytick.color": "#ffffff",
-    "grid.color": "#ffffff",
-    "figure.facecolor": "#7f7f7f",
-    "figure.edgecolor": "#7f7f7f",
-    "savefig.facecolor": "#7f7f7f",
-    "savefig.edgecolor": "#7f7f7f"})
+            plt.rcParams.update(plot_style)
             plt.rcParams.update({"grid.linewidth":0.2, "grid.alpha":0.5})
             plt.figure(figsize=[16, 9], layout="tight")
-            ax = pivot.plot.pie(y='quantity', figsize=(5, 5))
-            #ax.get_legend().remove()
-            ax.set_title(f'Quantity ({unit})')
-            #pie(pivot, col='price', title='Price ($)')
-            xlabel=''
-            ylabel=''
-            # if pivot.columns.size > 50:
-            #     ax = sns.scatterplot(data=pivot, markers=True)
-            # else:
-            #     ax = sns.lineplot(data=pivot, markers=True)
-            #     legend = plt.figlegend(
-            #         loc='upper center', ncol=6,
-            #         title_fontsize="14", fontsize="12", labelcolor='#ffffff',
-            #         framealpha=0.5
-            #     )
-            #     legend.set_title(title="Products")
-            ax.legend().set_visible(False)
-
-            ax.set_xlabel(xlabel, fontsize="14")
-            ax.set_ylabel(ylabel, fontsize="14")
-            ax.axes.tick_params(labelsize="12", which='both')
-            for _, spine in ax.spines.items():
-                spine.set_color('#ffffff')
+            axes = pivot.plot.pie(subplots=True, figsize=(11, 5))
+
+            for ax, title in zip(axes, (
+                f'Expenditure ($)',
+                f'Quantity ({unit})',
+            )):
+                ax.set_title(title)
+                xlabel=''
+                ylabel=''
+                ax.legend().set_visible(False)
+                ax.legend().set_visible(False)
+
+                ax.set_xlabel(xlabel, fontsize="14")
+                ax.set_ylabel(ylabel, fontsize="14")
+                ax.axes.tick_params(labelsize="12", which='both')
             
             progress.update({ "stage": "Rendering chart", "percent": "50"})
             yield template("done") + template("progress", **progress)

+ 0 - 63
price_plot.py

@@ -1,63 +0,0 @@
-#
-# Copyright (c) Daniel Sheffield 2023
-#
-# All rights reserved
-#
-# THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
-from datetime import date, datetime
-import os
-import sys
-from psycopg import connect, Cursor
-import seaborn as sns
-from db_credentials import HOST, PASSWORD
-from app.activities.Plot import (
-    get_data,
-    get_input,
-    pivot_data,
-    pie,
-    line,
-)
-from app.data.QueryManager import QueryManager, display_mapper
-
-ALL_UNITS = {'g','kg','mL','L','Pieces','Bunches','Bags'}
-host = f'host={HOST}'
-password = f'password={PASSWORD}'
-user = os.getenv('USER')
-conn = connect(f"{host} dbname=grocery user={user} {password}")
-cur: Cursor = conn.cursor()
-cur.execute("BEGIN")
-
-query_manager = QueryManager(cur, display_mapper)
-
-fields = get_input(query_manager, ALL_UNITS)
-
-unit = fields['unit'] = fields['unit'] or 'kg'
-fields = { k: v or None for k,v in fields.items() }
-if unit not in ALL_UNITS:
-    print(f'Invalid unit: {unit}')
-    exit(2)
-
-print('Getting data for selection:\n  ')
-print('\n  '.join([
-    f'{k.title()}: {v}' for k,v in fields.items()
-]))
-
-data = get_data(query_manager, **fields)
-if data.empty:
-    sys.exit(1)
-
-now = datetime.now().date()
-pivot = pivot_data(data[data['ts_month'] == date(now.year,now.month,1)])
-
-sns.set_theme()
-
-if not pivot.empty:
-    pie(pivot, col='quantity', title=f'Quantity ({unit})')
-    pie(pivot, col='price', title='Price ($)')
-
-pivot = data.pivot_table(index=['ts_raw',], columns=['product',], values=['$/unit'], aggfunc='mean')
-pivot.columns = pivot.columns.droplevel()
-print(pivot.info())
-print(pivot)
-
-line(pivot, xlabel='Time', ylabel=f'$ / {unit}')