이전에 PyJWT 의 사용법을 배웠으니 이제부터 인증 엔드포인트와 인증 decorator를 구현해 보도록 하겠다.
import jwt
import bcrypt
from flask import Flask, request, jsonify, current_app, Response, g
from flask.json import JSONEncoder
from sqlalchemy import create_engine, text
from datetime import datetime, timedelta
from functools import wraps
class CustomJSONEncoder(JSONEncoder):
def default(self, obj):
if isinstance(obj, set):
return list(obj)
return JSONEncoder.default(self, obj)
def get_user(user_id):
user = current_app.database.execute(text("""
SELECT
id,
name,
email,
profile
FROM users
WHERE id = :user_id
"""), {
'user_id' : user_id
}).fetchone()
return {
'id' : user['id'],
'name' : user['name'],
'email' : user['email'],
'profile' : user['profile']
} if user else None
def insert_user(user):
return current_app.database.excute(text("""
INSERT INTO users (
name,
email,
profile,
hashed_password
) VALUES (
:name,
:email,
:profile,
:password
)
"""), user).lastrowid # lastrowid를 통해 새로 생성된 사용자의 id를 읽어들인다.
def insert_tweet(user_tweet):
return current_app.database.excute(text("""
INSERT INTO tweets (
user_id,
tweet
) VALUES (
:id,
:tweet
)
"""), user_tweet).rowcount
def insert_follow(user_follow):
return current_app.database.execute(text("""
INSERT INTO users_follow_list (
user_id,
follow_user_id
) VALUES (
:id,
:follow
)
"""), user_follow).rowcount
def insert_unfollow(user_unfollow):
return current_app.database.execute(text("""
DELETE FROM users_follow_list
WHERE user_id = :id
AND follow_user_id = :unfollow
"""), user_unfollow).rowcount
def get_timeline(user_id):
timeline = current_app.database.execute(text("""
SELECT
t.user_id,
t.tweet
FROM tweets t
LEFT JOIN users_follow_list ufl ON ufl.user_id = :user_id
WHERE t.user_id = :user_id
"""), {
'user_id' : user_id
}).fetchall()
return [{
'user_id' : tweet['user_id'],
'tweet' : tweet['tweet']
} for tweet in timeline]
def get_user_id_and_password(email):
row = current_app.database.execute(text("""
SELECT
id,
hashed_password
FROM users
WHERE email = :email
"""),{'email' : email}).fetchone()
return {
'id' : row['id'],
'hashed_password' : row['hashed_password']
} if row else None
##########################################
# Decorators
##########################################
def login_required(f):
@wraps(f)
def decorated_function(*args, **kwargs):
access_token = request.headers.get('Authorization')
if access_token is not None:
try:
payload = jwt.decode(access_token, current_app.config['JWT_SECRET_KEY'], 'HS256')
except jwt.InvalidTokenError:
payload = None
if payload is None: return Response(status=401)
user_id = payload['user_id']
g.user_id = user_id
g.user = get_user(user_id) if user_id else None
else:
return Response(status = 401)
return f(*args, **kwargs)
return decorated_function
def create_app(test_config = None):
app = Flask(__name__)
app.json_encoder = CustomJSONEncoder
if test_config is None:
app.config.from_pyfile("config.py")
else:
app.config.update(test_config)
database = create_engine(app.config['DB_URL'], encoding = 'utf-8', max_overflow = 0)
app.database = database
@app.route("/ping", methods=['GET'])
def ping():
return "pong"
@app.route("/sign-up", methods=['POST'])
def sign_up():
new_user = request.json
new_user['password'] = bcrypt.hashpw(
new_user['password'].encode('UTF-8'),
bcrypt.gensalt()
)
new_user_id = insert_user(new_user)
new_user = get_user(new_user_id)
return jsonify(new_user)
@app.route("/login", methods=['POST'])
def login():
credential = request.json
email = credential['email']
password = credential['password']
user_credential = get_user_id_and_password(email)
if user_credential and bcrypt.checkpw(password.encode(UTF-8), user_credential['hashed_password'].encode('UTF-8')):
user_id = user_credential['id']
payload = {
'user_id' : user_id,
'exp' : datetime.utcnow() + timedelta(seconds = 60 * 60 * 24)
}
token = jwt.encode(payload, app.config['JWT_SECRET_KEY'], 'HS256')
return jsonify({
'access_token' : token.decode('UTF-8')
})
else:
return '', 401
@app.route("/tweet", methods=['POST'])
@login_required
def tweet():
user_tweet = request.json
user_tweet['id'] = g.user_id
tweet = user_tweet['tweet']
if len(tweet) > 300:
return '300자를 초과했습니다.', 400
insert_tweet(user_tweet)
return '', 200
@app.route("/follow", methods=['POST'])
@login_required
def follow():
payload = request.json
payload['id'] = g.user_id
insert_follow(payload)
return '', 200
@app.route("/unfollow", methods=['POST'])
@login_required
def unfollow():
payload = request.json
payload['id'] = g.user_id
insert_unfollow(payload)
return '', 200
@app.route("/timeline/<int:user_id>", methods=['GET'])
def timeline(user_id):
return jsonify({
'user_id' : user_id,
'timeline' : get_timeline(user_id)
})
return app