util.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. #
  2. # Copyright (c) Daniel Sheffield 2023
  3. #
  4. # All rights reserved
  5. #
  6. # THIS SOFTWARE IS PROVIDED AS IS WITHOUT WARRANTY
  7. from typing import List, Set, Tuple, Iterable
  8. from psycopg.sql import (
  9. Identifier,
  10. Literal,
  11. SQL,
  12. Composable,
  13. )
  14. def get_include_exclude(value) -> Tuple[List[str], List[str]]:
  15. value = value or ''
  16. include, exclude = [], []
  17. if isinstance(value, (list, tuple)):
  18. for v in value:
  19. inc, exc = get_include_exclude(v)
  20. include.extend(inc)
  21. exclude.extend(exc)
  22. else:
  23. inc, exc, *_ = [
  24. *map(lambda x: x.split('|') if x else [], value.split('!')), []
  25. ]
  26. include.extend(inc)
  27. exclude.extend(exc)
  28. return list(set(include)), list(set(exclude))
  29. def get_where_include_exclude(
  30. table: str,
  31. col: str,
  32. include: Iterable[str],
  33. exclude: Iterable[str]
  34. ) -> SQL:
  35. return SQL("""
  36. ({identifier} = ANY({include}) OR ARRAY[]::text[] @> {include}::text[])
  37. AND
  38. (NOT {identifier} = ANY({exclude}) OR {identifier} IS NULL)
  39. """).format(
  40. identifier=Identifier(table, col),
  41. include=Literal(include),
  42. exclude=Literal(exclude)
  43. )
  44. def get_select(alias_to_sql: dict[str,Composable]) -> Composable:
  45. select = SQL(""",
  46. """).join([
  47. SQL(' ').join([
  48. v, SQL('AS'), Identifier(k)
  49. ]) for k, v in alias_to_sql.items()
  50. ])
  51. return SQL("""
  52. """).join([SQL("SELECT"), *select])
  53. def get_from(
  54. base: str,
  55. table_to_join_on: dict[Tuple[str, Tuple[str,str]]]
  56. ) -> Composable:
  57. joins = [
  58. SQL("{table} ON {table_column} = {other_column}").format(
  59. table=Identifier(table),
  60. table_column=Identifier(table, table_column),
  61. other_column=Identifier(*other_column) if isinstance(other_column,tuple) else Identifier(other_column)
  62. ) for table, (table_column, other_column) in table_to_join_on.items()
  63. ]
  64. return SQL('').join([SQL("""FROM {base}
  65. LEFT JOIN """).format(base=Identifier(base)),
  66. SQL("""
  67. LEFT JOIN """).join(joins)])
  68. def get_groupby(alias_to_sql: dict[str, Composable], formatter=None) -> Composable:
  69. groupby = SQL(""",
  70. """).join([
  71. formatter(v) if formatter is not None and isinstance(
  72. v, SQL
  73. ) else v for k, v in alias_to_sql.items()
  74. ])
  75. return SQL("""
  76. """).join([SQL("GROUP BY"), *groupby])