util.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. #
  2. # Copyright (c) Daniel Sheffield 2023
  3. #
  4. # All rights reserved
  5. #
  6. # THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
  7. from typing import List, Set, Tuple, Iterable
  8. from psycopg.sql import (
  9. Identifier,
  10. Literal,
  11. SQL,
  12. Composable,
  13. )
  14. def get_where_include_exclude(
  15. table: str,
  16. col: str,
  17. include: Iterable[str],
  18. exclude: Iterable[str]
  19. ) -> SQL:
  20. return SQL("""
  21. ({identifier} = ANY({include}) OR ARRAY[]::text[] @> {include}::text[])
  22. AND
  23. (NOT {identifier} = ANY({exclude}) OR {identifier} IS NULL)
  24. """).format(
  25. identifier=Identifier(table, col),
  26. include=Literal(include),
  27. exclude=Literal(exclude)
  28. )
  29. def get_select(alias_to_sql: dict[str,Composable]) -> Composable:
  30. select = SQL(""",
  31. """).join([
  32. SQL(' ').join([
  33. v, SQL('AS'), Identifier(k)
  34. ]) for k, v in alias_to_sql.items()
  35. ])
  36. return SQL("""
  37. """).join([SQL("SELECT"), *select])
  38. def get_from(
  39. base: str,
  40. table_to_join_on: dict[Tuple[str, Tuple[str,str]]]
  41. ) -> Composable:
  42. joins = [
  43. SQL("{table} ON {table_column} = {other_column}").format(
  44. table=Identifier(table),
  45. table_column=Identifier(table, table_column),
  46. other_column=Identifier(*other_column) if isinstance(other_column,tuple) else Identifier(other_column)
  47. ) for table, (table_column, other_column) in table_to_join_on.items()
  48. ]
  49. return SQL('').join([SQL("""FROM {base}
  50. LEFT JOIN """).format(base=Identifier(base)),
  51. SQL("""
  52. LEFT JOIN """).join(joins)])
  53. def get_groupby(alias_to_sql: dict[str, Composable], formatter=None) -> Composable:
  54. groupby = SQL(""",
  55. """).join([
  56. formatter(v) if formatter is not None and isinstance(
  57. v, SQL
  58. ) else v for k, v in alias_to_sql.items()
  59. ])
  60. return SQL("""
  61. """).join([SQL("GROUP BY"), *groupby])