diff --git a/src/app.py b/src/app.py index 1d7d1af..0a3a6b8 100644 --- a/src/app.py +++ b/src/app.py @@ -1,5 +1,4 @@ -from flask import Flask, flash, make_response, render_template, redirect, \ - abort, url_for, request +from flask import Flask, flash, render_template, redirect, abort, url_for from flask_ckeditor import CKEditor from flask_login import current_user, login_user, LoginManager, logout_user, \ login_required @@ -8,7 +7,7 @@ import os import config import content as con_gen -from database import Database, User +from database import Database from forms import LoginForm, RegisterForm, WriteForm @@ -109,11 +108,9 @@ def register(): if form.validate_on_submit(): user = db.get_user_by_name(form.username.data) if user is None: - user = User(form.username.data) - user.set_password(form.password.data) - ident = db.insert_user(user) + ident = db.insert_user(form.username.data, form.password.data) if ident is not None: - user.set_id(ident) + user = db.get_user_by_id(ident) login_user(user) return redirect(url_for("index")) flash("An error occured during registration.") @@ -139,7 +136,7 @@ def write_entry(): def delete_entry(ident): if not current_user.is_authenticated: return redirect(url_for("index")) - if current_user.id == db.get_entry_by_id(ident)[5]: + if current_user.id == db.get_entry_by_id(ident).user_id: db.delete_entry(ident) return redirect(url_for("index")) diff --git a/src/content.py b/src/content.py index 42cc593..467e34e 100644 --- a/src/content.py +++ b/src/content.py @@ -33,16 +33,16 @@ def gen_arch_string(): entries = db.get_entries() if entries is None: return "" - entries.sort(key=lambda y: y[1]) + entries.sort(key=lambda y: db.get_item_by_id(y.item_id).name) entries.reverse() - entries.sort(key=lambda y: y[2]) + entries.sort(key=lambda y: y.date) entries.reverse() for entry in entries: - ident = entry[0] - title = db.get_item_by_id(entry[1]).name - year = entry[2] - rating = entry[4] - username = db.get_user_by_id(entry[5]).name + ident = entry.id + title = db.get_item_by_id(entry.item_id).name + year = entry.date + rating = entry.rating + username = db.get_user_by_id(entry.user_id).name if year != last_year: if last_year != "": content_string += "\n" @@ -70,16 +70,16 @@ def gen_user_string(name): entries = db.get_entries_by_user(name) if entries is None: return "" - entries.sort(key=lambda y: y[1]) + entries.sort(key=lambda y: db.get_item_by_id(y.item_id).name) entries.reverse() - entries.sort(key=lambda y: y[2]) + entries.sort(key=lambda y: y.date) entries.reverse() for entry in entries: - ident = entry[0] - title = db.get_item_by_id(entry[1]).name - year = entry[2] - rating = entry[4] - username = db.get_user_by_id(entry[5]).name + ident = entry.id + title = db.get_item_by_id(entry.item_id).name + year = entry.date + rating = entry.rating + username = db.get_user_by_id(entry.user_id).name if year != last_year: if last_year != "": content_string += "\n" @@ -109,13 +109,13 @@ def gen_index_string(): return "" entries.reverse() for entry in entries: - ident = entry[0] - title = db.get_item_by_id(entry[1]).name - year = entry[2] - text = entry[3] - rating = entry[4] - username = db.get_user_by_id(entry[5]).name - reviewed = entry[6] + ident = entry.id + title = db.get_item_by_id(entry.item_id).name + year = entry.date + text = entry.text + rating = entry.rating + username = db.get_user_by_id(entry.user_id).name + reviewed = entry.reviewed content_string += "
\n" content_string += "

" + title + \ @@ -178,13 +178,13 @@ def get_rss_string(): return "" entries.reverse() for entry in entries: - ident = entry[0] - title = db.get_item_by_id(entry[1]).name - year = entry[2] - text = entry[3] - rating = entry[4] - username = db.get_user_by_id(entry[5]).name - reviewed = entry[6] + ident = entry.id + title = db.get_item_by_id(entry.item_id).name + year = entry.date + text = entry.text + rating = entry.rating + username = db.get_user_by_id(entry.user_id).name + reviewed = entry.reviewed content_string += "\n" content_string += "" + title + "(" + year + ") " + \ rating_to_star(rating) + " by " + username + "\n" diff --git a/src/database.py b/src/database.py index 7e79bba..0326398 100644 --- a/src/database.py +++ b/src/database.py @@ -15,17 +15,18 @@ class User(): self.pass_hash = pass_hash def set_password(self, password): - self.pass_hash = password + self.pass_hash = generate_password_hash(password) def set_id(self, ident): self.id = ident def check_password(self, password): - return self.pass_hash == password + return check_password_hash(self.pass_hash, password) def get_id(self): return self.id + class Item(): def __init__(self, name): @@ -35,6 +36,7 @@ class Item(): def set_id(self, ident): self.id = ident + class Entry(): def __init__(self, item_id, date, text, rating, user_id, reviewed): @@ -90,14 +92,15 @@ class Database: crs.execute(query) db.commit() - def insert_user(self, user): - if self.get_user_by_name(user.name) is None and user.pass_hash is not None: + def insert_user(self, username, password): + pass_hash = generate_password_hash(password) + if self.get_user_by_name(username) is None and pass_hash is not None: db = self.connect() crs = db.cursor() query = "INSERT INTO " + self.USER_TABLE_FILE + \ "(`name`,`password`)" + \ "VALUES (?, ?) ON CONFLICT DO NOTHING" - crs.execute(query, (user.name, user.pass_hash)) + crs.execute(query, (username, pass_hash)) db.commit() return crs.lastrowid return None @@ -133,7 +136,10 @@ class Database: crs = db.cursor() query = "SELECT * FROM " + self.ENTRY_TABLE_FILE crs.execute(query) - return crs.fetchall() + res = [] + for item in crs.fetchall(): + res.append(self.db_to_entry(*item)) + return res def get_entry_by_id(self, ident): db = self.connect() @@ -153,7 +159,10 @@ class Database: " WHERE user_id = (SELECT id FROM " + self.USER_TABLE_FILE + \ " WHERE name = ?)" crs.execute(query, (name, )) - return crs.fetchall() + res = [] + for item in crs.fetchall(): + res.append(self.db_to_entry(*item)) + return res def get_item_by_id(self, ident): db = self.connect()