Browse Source

first commit

DESKTOP-MK04A0R\chuck 3 years ago
commit
6ee18e6c3c

+ 5 - 0
.gitignore

@@ -0,0 +1,5 @@
+.vscode
+.idea
+*.pyc
+logs/
+test.py

+ 0 - 0
READEME.md


+ 20 - 0
app/__init__.py

@@ -0,0 +1,20 @@
+from app.app import Flask
+
+#注册蓝图
+def register_blueprints(app):
+    from app.api.v1 import create_blueprint_v1
+    app.register_blueprint(create_blueprint_v1(), url_prefix='/api/v1')
+    
+# 初始化数据库
+def register_plugin(app):
+    from app.models.base import db
+    db.init_app(app)
+    with app.app_context():
+        db.create_all()
+
+def create_app():
+    app = Flask(__name__)
+    app.config.from_object('app.config.setting_dev')
+    register_blueprints(app)
+    register_plugin(app)
+    return app

+ 8 - 0
app/api/v1/__init__.py

@@ -0,0 +1,8 @@
+from flask import Blueprint
+from app.api.v1 import token, user
+
+def create_blueprint_v1():
+    bp_v1 = Blueprint('v1', __name__)
+    token.api.register(bp_v1)
+    user.api.register(bp_v1)
+    return bp_v1

+ 50 - 0
app/api/v1/token.py

@@ -0,0 +1,50 @@
+from flask import current_app
+from app.libs.error_code import AuthFailed, Success
+from app.libs.redprint import Redprint
+from app.models.user import User
+from app.validators.forms import ClientForm, TokenForm
+from authlib.jose import jwt, JoseError, errors
+from datetime import datetime
+
+api = Redprint('token')
+
+@api.route('', methods=['POST'])
+def get_token():
+    form = ClientForm().validate_for_api()
+    userinfo = User.verify(form.account.data, form.secret.data)
+    # Token
+    token = generate_token({'uid':userinfo['uid']})
+    t = {
+        'token': token.decode('utf8')
+    }
+    return Success(result=t)
+
+
+@api.route('/secret', methods=['POST'])
+def get_token_info():
+    """获取令牌信息"""
+    form = TokenForm().validate_for_api() 
+    key = current_app.config['SECRET_KEY']
+    try:
+        data = jwt.decode(s=form.token.data, key=key)
+        data.validate_exp(now=datetime.now().timestamp(), leeway=current_app.config['TOKEN_EXPIRATION'])
+    except errors.ExpiredTokenError:
+        return AuthFailed(message='token过期')
+    except JoseError:
+        return AuthFailed()
+
+    return Success(result=data)
+
+
+def generate_token(data):
+    """生成用于邮箱验证的JWT(json web token)"""
+    # 签名算法
+    header = {'alg': 'HS256'}
+    # 用于签名的**
+    key = current_app.config['SECRET_KEY']
+    # palyload
+    expire = datetime.now().timestamp() + current_app.config['TOKEN_EXPIRATION']
+    payload = {'exp': expire}
+    payload.update(data)
+    # 待签名的数据负载
+    return jwt.encode(header, payload, key)

+ 14 - 0
app/api/v1/user.py

@@ -0,0 +1,14 @@
+from app.libs.error_code import Success
+from app.libs.redprint import Redprint
+from app.models.user import User
+from app.validators.forms import UserEmailForm
+
+api = Redprint('user')
+
+@api.route('/register', methods=['POST'])
+def register():
+    form = UserEmailForm().validate_for_api()
+    User.register_by_email(form.nickname.data, form.account.data, form.secret.data)
+    return Success()
+
+    

+ 17 - 0
app/app.py

@@ -0,0 +1,17 @@
+from datetime import date
+from flask import Flask as _Flask
+from flask.json import JSONEncoder as _JSONEncoder
+
+# 重写default方法
+from app.libs.error_code import ServerError
+class JSONEncoder(_JSONEncoder):
+    def default(self, o):
+        if hasattr(o, 'key') and hasattr(o, '__getitem__'):
+            return dict(o)
+        if isinstance(o, date):
+            return o.strftime('%Y-%m-%d')
+        raise ServerError()
+    
+# 生效
+class Flask(_Flask):
+    json_encoder = JSONEncoder

+ 5 - 0
app/config/setting_dev.py

@@ -0,0 +1,5 @@
+SECRET_KEY = 'chuck2022'
+TOKEN_EXPIRATION = 30 * 24 * 3600
+SQLALCHEMY_DATABASE_URI = 'mysql+cymysql://root:mysql57-2020!d@39.100.75.63:33060/ginger'
+DEBUG = True
+SQLALCHEMY_TRACK_MODIFICATIONS = True

+ 3 - 0
app/libs/enums.py

@@ -0,0 +1,3 @@
+from enum import Enum
+class ClientTypeEnum(Enum):
+    TEST = 0

+ 25 - 0
app/libs/error.py

@@ -0,0 +1,25 @@
+from flask import json, Response
+from werkzeug.exceptions import HTTPException
+
+class APIException(HTTPException):
+    message = '出错了'
+    code = 999
+    result = None
+
+    def __init__(self, code=None, message=None, result=None):
+        if code:
+            self.code = code
+        if message:
+            self.message = message
+        if result:
+            self.result = result
+        super(APIException, self).__init__(response=self.__make_response())
+    
+    def __make_response(self):
+        r = {
+            'result': self.result,
+            'message': self.message,
+            'code': self.code
+        }
+        response = Response(json.dumps(r), mimetype='application/json')
+        return response

+ 26 - 0
app/libs/error_code.py

@@ -0,0 +1,26 @@
+from app.libs.error import APIException
+
+class ServerError(APIException):
+    code = 500
+    message = '服务端错误'
+
+class AuthFailed(APIException):
+    code = 401
+    message = '认证失败'
+
+class Forbidden(APIException):
+    code = 403
+    message = '禁止操作'
+
+class NotFound(APIException):
+    code = 404
+    msg = '未找到'
+
+class ParameterException(APIException):
+    code = 400
+    msg = '参数错误'
+    error_code = 1000
+
+class Success(APIException):
+    message = 'ok'
+    code = 0

+ 18 - 0
app/libs/redprint.py

@@ -0,0 +1,18 @@
+# 红图
+class Redprint:
+    def __init__(self, name):
+        self.name = name
+        self.mound = []
+    
+    def route(self, rule, **options):
+        def decorator(f):
+            self.mound.append((f, rule, options))
+            return f
+        return decorator
+    
+    def register(self, bp, url_prefix=None):
+        if url_prefix is None:
+            url_prefix = '/' + self.name
+        for f, rule, options in self.mound:
+            endpoint = self.name + '+' + options.pop("endpoint", f.__name__)
+            bp.add_url_rule(url_prefix + rule, endpoint, f, **options)

+ 7 - 0
app/libs/scope.py

@@ -0,0 +1,7 @@
+def is_in_scope(endpoint):
+    # 通过数据库查询出该用户的所有权限
+    permissions = ['']
+    if endpoint in permissions:
+        return True
+    else:
+        return False

+ 34 - 0
app/libs/token_auth.py

@@ -0,0 +1,34 @@
+from collections import namedtuple
+
+from flask import current_app, g, request
+from flask_httpauth import HTTPBasicAuth
+from itsdangerous import Serializer, BadSignature, SignatureExpired
+
+from app.libs.error_code import AuthFailed, Forbidden
+from app.libs.scope import is_in_scope
+
+auth = HTTPBasicAuth(scheme='JWT')
+User = namedtuple('User', ['uid'])
+
+@auth.verify_password
+def verify_password(token, password):
+    user_info = verify_auth_token(token)
+    if not user_info:
+        return False
+    else:
+        g.user = user_info
+        return True
+
+def verify_auth_token(token):
+    s = Serializer(current_app.config['SECRET_KEY'])
+    try:
+        data = s.loads(token)
+    except SignatureExpired:
+        raise AuthFailed(message='token无效', code=1003)
+    except BadSignature:
+        raise AuthFailed(message='token过期', code=1002)  
+    uid = data['uid']
+    allow = is_in_scope(request.endpoint)
+    if not allow:
+        raise Forbidden()
+    return User(uid)

+ 110 - 0
app/models/base.py

@@ -0,0 +1,110 @@
+from app.libs.error_code import NotFound
+from datetime import datetime
+
+from flask_sqlalchemy import SQLAlchemy as _SQLAlchemy, BaseQuery
+from sqlalchemy import inspect, Column, Integer, SmallInteger, orm
+from contextlib import contextmanager
+
+
+class SQLAlchemy(_SQLAlchemy):
+    @contextmanager
+    def auto_commit(self):
+        try:
+            yield
+            self.session.commit()
+        except Exception as e:
+            db.session.rollback()
+            raise e
+
+
+class Query(BaseQuery):
+    def filter_by(self, **kwargs):
+        if 'status' not in kwargs.keys():
+            kwargs['status'] = 1
+        return super(Query, self).filter_by(**kwargs)
+
+    def get_or_404(self, ident):
+        rv = self.get(ident)
+        if not rv:
+            raise NotFound()
+        return rv
+
+    def first_or_404(self):
+        rv = self.first()
+        if not rv:
+            raise NotFound()
+        return rv
+
+
+db = SQLAlchemy(query_class=Query)
+
+
+class Base(db.Model):
+    __abstract__ = True
+    create_time = Column(Integer)
+    status = Column(SmallInteger, default=1)
+
+    def __init__(self):
+        self.create_time = int(datetime.now().timestamp())
+
+    def __getitem__(self, item):
+        return getattr(self, item)
+
+    @property
+    def create_datetime(self):
+        if self.create_time:
+            return datetime.fromtimestamp(self.create_time)
+        else:
+            return None
+
+    def set_attrs(self, attrs_dict):
+        for key, value in attrs_dict.items():
+            if hasattr(self, key) and key != 'id':
+                setattr(self, key, value)
+
+    def delete(self):
+        self.status = 0
+
+    def keys(self):
+        return self.fields
+
+    def hide(self, *keys):
+        for key in keys:
+            self.fields.remove(key)
+        return self
+
+    def append(self, *keys):
+        for key in keys:
+            self.fields.append(key)
+        return self
+
+
+class MixinJSONSerializer:
+    @orm.reconstructor
+    def init_on_load(self):
+        self._fields = []
+        # self._include = []
+        self._exclude = []
+
+        self._set_fields()
+        self.__prune_fields()
+
+    def _set_fields(self):
+        pass
+
+    def __prune_fields(self):
+        columns = inspect(self.__class__).columns
+        if not self._fields:
+            all_columns = set(columns.keys())
+            self._fields = list(all_columns - set(self._exclude))
+
+    def hide(self, *args):
+        for key in args:
+            self._fields.remove(key)
+        return self
+
+    def keys(self):
+        return self._fields
+
+    def __getitem__(self, key):
+        return getattr(self, key)

+ 47 - 0
app/models/user.py

@@ -0,0 +1,47 @@
+from sqlalchemy import Column, Integer, String
+from werkzeug.security import generate_password_hash, check_password_hash
+
+from app.libs.error_code import AuthFailed
+from app.models.base import Base, db
+
+class User(Base):
+    id = Column(Integer, primary_key=True)
+    email = Column(String(24), unique=True, nullable=False)
+    nickname = Column(String(24), unique=True)
+    _password = Column('password', String(225))
+
+    def keys(self):
+        return ['id', 'email', 'nickname']
+
+    @property
+    def password(self):
+        return self._password
+
+    @password.setter
+    def password(self, raw):
+        self._password = generate_password_hash(raw)
+
+    @staticmethod
+    def register_by_email(nickname, account, secret):
+        with db.auto_commit():
+            user = User()
+            user.nickname = nickname
+            user.email = account
+            user.password = secret
+            db.session.add(user)
+
+    @staticmethod
+    def verify(email, password):
+        user = User.query.filter_by(email=email).first_or_404()
+        if not user.check_password(password):
+            raise AuthFailed()
+        return {'uid': user.id}
+
+    def check_password(self, raw):
+        if not self._password:
+            return False
+        return check_password_hash(self._password, raw)
+
+    # def _set_fields(self):
+    #     # self._exclude = ['_password']
+    #     self._fields = ['_password', 'nickname']

+ 18 - 0
app/validators/base.py

@@ -0,0 +1,18 @@
+from flask import request
+from wtforms import Form
+
+from app.libs.error_code import ParameterException
+
+
+class BaseForm(Form):
+    def __init__(self, *args, **kwargs):
+        data = request.get_json(silent=True)
+        args = request.args.to_dict()
+        kwargs['csrf_enabled'] = False
+        super(BaseForm, self).__init__(data=data, **args, **kwargs)
+
+    def validate_for_api(self):
+        valid = super(BaseForm, self).validate()
+        if not valid:
+            raise ParameterException(message=self.errors)
+        return self

+ 30 - 0
app/validators/forms.py

@@ -0,0 +1,30 @@
+from wtforms import StringField
+from wtforms.validators import DataRequired, length, Email, Regexp
+from wtforms import ValidationError
+
+from app.models.user import User
+from app.validators.base import BaseForm as Form
+
+class ClientForm(Form):
+    account = StringField(validators=[DataRequired(message='不允许为空'), length(
+        min=5, max=32
+    )])
+    secret = StringField()
+
+class UserEmailForm(ClientForm):
+    account = StringField(validators=[
+        Email(message='invalidate email')
+    ])
+    secret = StringField(validators=[
+        DataRequired(),
+        Regexp(r'^[A-Za-z0-9_*&$#@]{6,22}$')
+    ])
+    nickname = StringField(validators=[DataRequired(), length(min=2, max=22)])
+
+    def validate_account(self, value):
+        if User.query.filter_by(email=value.data).first():
+            raise ValidationError()
+
+
+class TokenForm(Form):
+    token = StringField(validators=[DataRequired()])

+ 30 - 0
demo.py

@@ -0,0 +1,30 @@
+from venv import create
+from flask.json import jsonify
+from app import create_app
+from app.libs.error import APIException, HTTPException
+from app.libs.error_code import ServerError
+from wsgiref.simple_server import make_server
+
+app = create_app()
+
+@app.errorhandler(Exception)
+def framework_error(e):
+    if isinstance(e, APIException):
+        return e
+    if isinstance(e, HTTPException):
+        message = e.description
+        code = 1007
+        return APIException(code, message)
+    else:
+        if not app.debug:
+            return ServerError()
+        else:
+            raise e
+        
+if __name__ == '__main__':
+    # 启动方式一 适用于本地开发
+    app.run(debug=app.config['DEBUG'], host='0.0.0.0', port=5000)
+
+    # 启动方式二 适用于生产部署
+    # server = make_server('0.0.0.0', 5000, app)
+    # server.serve_forever()

+ 7 - 0
requirements.txt

@@ -0,0 +1,7 @@
+Flask==2.1.2
+flask-httpauth==4.6.0
+flask-sqlalchemy==2.5.1
+flask-wtf==1.0.1
+email-validator==1.2.1
+cymysql==0.9.18
+authlib==1.0.1