123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- from app.libs.error_code import NotFound
- from datetime import datetime
- from flask_sqlalchemy import SQLAlchemy as _SQLAlchemy, BaseQuery
- from sqlalchemy import inspect, Column, Integer, SmallInteger, orm
- from contextlib import contextmanager
- class SQLAlchemy(_SQLAlchemy):
- @contextmanager
- def auto_commit(self):
- try:
- yield
- self.session.commit()
- except Exception as e:
- db.session.rollback()
- raise e
- class Query(BaseQuery):
- def filter_by(self, **kwargs):
- if 'status' not in kwargs.keys():
- kwargs['status'] = 1
- return super(Query, self).filter_by(**kwargs)
- def get_or_404(self, ident):
- rv = self.get(ident)
- if not rv:
- raise NotFound()
- return rv
- def first_or_404(self):
- rv = self.first()
- if not rv:
- raise NotFound()
- return rv
- db = SQLAlchemy(query_class=Query)
- class Base(db.Model):
- __abstract__ = True
- create_time = Column(Integer)
- status = Column(SmallInteger, default=1)
- def __init__(self):
- self.create_time = int(datetime.now().timestamp())
- def __getitem__(self, item):
- return getattr(self, item)
- @property
- def create_datetime(self):
- if self.create_time:
- return datetime.fromtimestamp(self.create_time)
- else:
- return None
- def set_attrs(self, attrs_dict):
- for key, value in attrs_dict.items():
- if hasattr(self, key) and key != 'id':
- setattr(self, key, value)
- def delete(self):
- self.status = 0
- def keys(self):
- return self.fields
- def hide(self, *keys):
- for key in keys:
- self.fields.remove(key)
- return self
- def append(self, *keys):
- for key in keys:
- self.fields.append(key)
- return self
- class MixinJSONSerializer:
- @orm.reconstructor
- def init_on_load(self):
- self._fields = []
- # self._include = []
- self._exclude = []
- self._set_fields()
- self.__prune_fields()
- def _set_fields(self):
- pass
- def __prune_fields(self):
- columns = inspect(self.__class__).columns
- if not self._fields:
- all_columns = set(columns.keys())
- self._fields = list(all_columns - set(self._exclude))
- def hide(self, *args):
- for key in args:
- self._fields.remove(key)
- return self
- def keys(self):
- return self._fields
- def __getitem__(self, key):
- return getattr(self, key)
|