base.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from app.libs.error_code import NotFound
  2. from datetime import datetime
  3. from flask_sqlalchemy import SQLAlchemy as _SQLAlchemy, BaseQuery
  4. from sqlalchemy import inspect, Column, Integer, SmallInteger, orm
  5. from contextlib import contextmanager
  6. class SQLAlchemy(_SQLAlchemy):
  7. @contextmanager
  8. def auto_commit(self):
  9. try:
  10. yield
  11. self.session.commit()
  12. except Exception as e:
  13. db.session.rollback()
  14. raise e
  15. class Query(BaseQuery):
  16. def filter_by(self, **kwargs):
  17. if 'status' not in kwargs.keys():
  18. kwargs['status'] = 1
  19. return super(Query, self).filter_by(**kwargs)
  20. def get_or_404(self, ident):
  21. rv = self.get(ident)
  22. if not rv:
  23. raise NotFound()
  24. return rv
  25. def first_or_404(self):
  26. rv = self.first()
  27. if not rv:
  28. raise NotFound()
  29. return rv
  30. db = SQLAlchemy(query_class=Query)
  31. class Base(db.Model):
  32. __abstract__ = True
  33. create_time = Column(Integer)
  34. status = Column(SmallInteger, default=1)
  35. def __init__(self):
  36. self.create_time = int(datetime.now().timestamp())
  37. def __getitem__(self, item):
  38. return getattr(self, item)
  39. @property
  40. def create_datetime(self):
  41. if self.create_time:
  42. return datetime.fromtimestamp(self.create_time)
  43. else:
  44. return None
  45. def set_attrs(self, attrs_dict):
  46. for key, value in attrs_dict.items():
  47. if hasattr(self, key) and key != 'id':
  48. setattr(self, key, value)
  49. def delete(self):
  50. self.status = 0
  51. def keys(self):
  52. return self.fields
  53. def hide(self, *keys):
  54. for key in keys:
  55. self.fields.remove(key)
  56. return self
  57. def append(self, *keys):
  58. for key in keys:
  59. self.fields.append(key)
  60. return self
  61. class MixinJSONSerializer:
  62. @orm.reconstructor
  63. def init_on_load(self):
  64. self._fields = []
  65. # self._include = []
  66. self._exclude = []
  67. self._set_fields()
  68. self.__prune_fields()
  69. def _set_fields(self):
  70. pass
  71. def __prune_fields(self):
  72. columns = inspect(self.__class__).columns
  73. if not self._fields:
  74. all_columns = set(columns.keys())
  75. self._fields = list(all_columns - set(self._exclude))
  76. def hide(self, *args):
  77. for key in args:
  78. self._fields.remove(key)
  79. return self
  80. def keys(self):
  81. return self._fields
  82. def __getitem__(self, key):
  83. return getattr(self, key)