Compare commits

...

3 Commits

Author SHA1 Message Date
383d852545 Merge remote-tracking branch 'origin/main' 2026-02-27 15:58:31 +01:00
3ba8903de9 Finished login system with refresh-tokens 2026-02-27 15:58:05 +01:00
1eb179dac1 Optimized login system 2026-02-27 14:33:08 +01:00
6 changed files with 262 additions and 27 deletions

View File

@@ -2,50 +2,186 @@ package handlers
import ( import (
"encoding/json" "encoding/json"
"log"
"net/http" "net/http"
"os"
"shap-planner-backend/auth" "shap-planner-backend/auth"
"shap-planner-backend/models" "shap-planner-backend/models"
"shap-planner-backend/storage" "shap-planner-backend/storage"
"shap-planner-backend/utils" "shap-planner-backend/utils"
"time"
) )
func Register(w http.ResponseWriter, r *http.Request) { func Register(w http.ResponseWriter, r *http.Request) {
var user models.User var user models.User
_ = json.NewDecoder(r.Body).Decode(&user) if err := json.NewDecoder(r.Body).Decode(&user); err != nil {
hashed, _ := auth.HashPassword(user.Password) http.Error(w, "Invalid request body", http.StatusBadRequest)
user.Password = hashed
user.ID = utils.GenerateUUID()
err := storage.AddUser(user)
if err != nil {
http.Error(w, "User exists", http.StatusBadRequest)
return return
} }
w.WriteHeader(http.StatusCreated)
if user.Username == "" || user.Password == "" {
http.Error(w, "username and password required", http.StatusBadRequest)
return
} }
hashed, err := auth.HashPassword(user.Password)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
user.Password = hashed
user.ID = utils.GenerateUUID()
user.Role = "user"
if err := storage.AddUser(user); err != nil {
http.Error(w, "user already exists", http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusCreated)
}
func Login(w http.ResponseWriter, r *http.Request) { func Login(w http.ResponseWriter, r *http.Request) {
var creds struct { var creds struct {
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
} }
_ = json.NewDecoder(r.Body).Decode(&creds) if err := json.NewDecoder(r.Body).Decode(&creds); err != nil {
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
user, err := storage.GetUserByUsername(creds.Username) user, err := storage.GetUserByUsername(creds.Username)
if err != nil { if err != nil {
http.Error(w, "User not found", http.StatusUnauthorized) println("user " + creds.Username + " not found")
http.Error(w, "Invalid credentials", http.StatusUnauthorized)
return return
} }
if !auth.CheckPasswordHash(creds.Password, user.Password) { if !auth.CheckPasswordHash(creds.Password, user.Password) {
http.Error(w, "Wrong password", http.StatusUnauthorized) println("invalid password")
http.Error(w, "Invalid credentials", http.StatusUnauthorized)
return return
} }
// TODO: JWT oder Session-Token erzeugen secret := []byte(os.Getenv("SHAP_JWT_SECRET"))
w.WriteHeader(http.StatusOK) if len(secret) == 0 {
err = json.NewEncoder(w).Encode(user) http.Error(w, "Server misconfiguration", http.StatusInternalServerError)
return
}
accessToken, err := auth.GenerateJWT(user.ID, user.Role, secret)
if err != nil { if err != nil {
http.Error(w, "Could not generate token", http.StatusInternalServerError)
return
}
refreshTokenPlain, err := utils.GenerateRefreshToken()
if err != nil {
http.Error(w, "could not generate refresh token", http.StatusInternalServerError)
return
}
refreshHash := utils.HashToken(refreshTokenPlain)
refreshID := utils.GenerateUUID()
refreshExpires := time.Now().Add(7 * 24 * time.Hour).Unix() // expiry: 7 days
deviceInfo := r.Header.Get("User-Agent")
if err := storage.AddRefreshToken(models.RefreshToken{
ID: refreshID,
UserID: user.ID,
Token: refreshHash,
ExpiresAt: refreshExpires,
DeviceInfo: deviceInfo,
CreatedAt: time.Now().Unix(),
Revoked: false,
}); err != nil {
http.Error(w, "could not save refresh token", http.StatusInternalServerError)
return
}
// Return access + refresh token (refresh in plain for client to store securely)
resp := map[string]interface{}{
"access_token": accessToken,
"refresh_token": refreshTokenPlain,
"user": map[string]interface{}{
"id": user.ID,
"username": user.Username,
"role": user.Role,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
func Logout(w http.ResponseWriter, r *http.Request) {
claims := r.Context().Value(auth.UserContextKey).(*auth.Claims)
storage.RevokeAllRefreshTokensForUser(claims.UserID)
w.WriteHeader(204)
}
func TestHandler(w http.ResponseWriter, r *http.Request) {
claimsRaw := r.Context().Value(auth.UserContextKey)
if claimsRaw == nil {
http.Error(w, "No claims in context", http.StatusUnauthorized)
return
}
claims, ok := claimsRaw.(*auth.Claims)
if !ok {
http.Error(w, "Invalid claims", http.StatusUnauthorized)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"user_id": claims.UserID,
"role": claims.Role,
"msg": "access granted to protected endpoint",
})
}
func RefreshToken(w http.ResponseWriter, r *http.Request) {
var req struct {
RefreshToken string `json:"refresh_token"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
hashed := utils.HashToken(req.RefreshToken)
tokenRow, err := storage.GetRefreshToken(hashed)
if err != nil || tokenRow.Revoked || tokenRow.ExpiresAt < time.Now().Unix() {
http.Error(w, "Invalid refresh token", http.StatusUnauthorized)
return
}
if err := storage.RevokeRefreshToken(tokenRow.ID); err != nil {
log.Println(err)
}
newToken, _ := utils.GenerateRefreshToken()
newHash := utils.HashToken(newToken)
newExpires := time.Now().Add(7 * 24 * time.Hour).Unix() //7 days
newID := utils.GenerateUUID()
deviceInfo := r.Header.Get("User-Agent")
if err = storage.AddRefreshToken(models.RefreshToken{
ID: newID,
UserID: tokenRow.UserID,
Token: newHash,
ExpiresAt: newExpires,
CreatedAt: time.Now().Unix(),
Revoked: false,
DeviceInfo: deviceInfo,
}); err != nil {
return
}
accessToken, _ := auth.GenerateJWT(tokenRow.UserID, "", []byte(os.Getenv("SHAP_JWT_SECRET")))
if err = json.NewEncoder(w).Encode(map[string]string{
"access_token": accessToken,
"refresh_token": newToken,
}); err != nil {
return return
} }
} }

View File

@@ -4,6 +4,7 @@ type User struct {
ID string `json:"id"` ID string `json:"id"`
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
Role string `json:"role"`
} }
type Expense struct { type Expense struct {

11
models/loginmodels.go Normal file
View File

@@ -0,0 +1,11 @@
package models
type RefreshToken struct {
ID string `json:id`
UserID string `json:userid`
Token string `json:token`
ExpiresAt int64 `json:expiresat`
CreatedAt int64 `json:createdat`
Revoked bool `json:revoked`
DeviceInfo string `json:deviceinfo`
}

View File

@@ -45,13 +45,18 @@ func InitServer() *Server {
func (server *Server) Run() { func (server *Server) Run() {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/login", handlers.Login) // Public
mux.HandleFunc("/api/login", handlers.Login)
mux.HandleFunc("/api/register", handlers.Register)
mux.HandleFunc("/api/refresh", handlers.RefreshToken)
mux.HandleFunc("/api/logout", handlers.Logout)
protected := auth.AuthMiddleware(server.JWTSecret)(http.HandlerFunc(handlers.GetExpenses)) // Login required
mux.Handle("/expenses", protected) mux.Handle("/api/expenses", auth.AuthMiddleware(server.JWTSecret)(http.HandlerFunc(handlers.GetExpenses)))
mux.Handle("/api/ping", auth.AuthMiddleware(server.JWTSecret)(http.HandlerFunc(handlers.TestHandler)))
adminOnly := auth.AuthMiddleware(server.JWTSecret)(auth.RequireRole("admin")(http.HandlerFunc(handlers.AdminPanel))) // Admin-only
mux.Handle("/admin", adminOnly) mux.Handle("/api/admin", auth.AuthMiddleware(server.JWTSecret)(auth.RequireRole("admin")(http.HandlerFunc(handlers.AdminPanel))))
log.Printf("Listening on port %s", server.Port) log.Printf("Listening on port %s", server.Port)
log.Fatal(http.ListenAndServe(":"+server.Port, mux)) log.Fatal(http.ListenAndServe(":"+server.Port, mux))

View File

@@ -2,10 +2,13 @@ package storage
import ( import (
"database/sql" "database/sql"
_ "github.com/glebarez/go-sqlite" "errors"
"shap-planner-backend/models" "shap-planner-backend/models"
_ "github.com/glebarez/go-sqlite"
) )
var ErrNotFound = sql.ErrNoRows
var DB *sql.DB var DB *sql.DB
func InitDB(filepath string) error { func InitDB(filepath string) error {
@@ -18,13 +21,33 @@ func InitDB(filepath string) error {
//Create Users-Table //Create Users-Table
_, err = DB.Exec(`CREATE TABLE IF NOT EXISTS users( _, err = DB.Exec(`CREATE TABLE IF NOT EXISTS users(
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
username TEXT UNIQUE, username TEXT UNIQUE NOT NULL,
password TEXT password TEXT NOT NULL,
role TEXT NOT NULL
);`) );`)
if err != nil { if err != nil {
return err return err
} }
//Create refresh token-table
_, err = DB.Exec(`CREATE TABLE IF NOT EXISTS refresh_tokens(
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
token_hash TEXT NOT NULL,
expires_at INTEGER NOT NULL,
created_at INTEGER NOT NULL,
revoked INTEGER NOT NULL DEFAULT 0,
device_info TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
)`)
if err != nil {
return err
}
_, err = DB.Exec(`CREATE INDEX IF NOT EXISTS idx_refresh_token_hash ON refresh_tokens(token_hash)`)
if err != nil {
return err
}
//Create Expenses-Table //Create Expenses-Table
_, err = DB.Exec(`CREATE TABLE IF NOT EXISTS expenses( _, err = DB.Exec(`CREATE TABLE IF NOT EXISTS expenses(
id TEXT PRIMARY KEY id TEXT PRIMARY KEY
@@ -33,21 +56,67 @@ func InitDB(filepath string) error {
return err return err
} }
// Users
func AddUser(user models.User) error { func AddUser(user models.User) error {
_, err := DB.Exec("INSERT INTO users(id, username, password) VALUES (?, ?, ?)", user.ID, user.Username, user.Password) _, err := DB.Exec("INSERT INTO users(id, username, password, role) VALUES (?, ?, ?, ?)", user.ID, user.Username, user.Password, user.Role)
return err return err
} }
func GetUserByUsername(username string) (models.User, error) { func GetUserByUsername(username string) (models.User, error) {
row := DB.QueryRow("SELECT * FROM users WHERE username = ?", username) row := DB.QueryRow("SELECT * FROM users WHERE username = ?", username)
var user models.User var user models.User
err := row.Scan(&user.ID, &user.Username, &user.Password) err := row.Scan(&user.ID, &user.Username, &user.Password, &user.Role)
return user, err return user, err
} }
func GetUserById(id string) (models.User, error) { func GetUserById(id string) (models.User, error) {
row := DB.QueryRow("SELECT * FROM users WHERE id = ?", id) row := DB.QueryRow("SELECT * FROM users WHERE id = ?", id)
var user models.User var user models.User
err := row.Scan(&user.ID, &user.Username, &user.Password) err := row.Scan(&user.ID, &user.Username, &user.Password)
return user, err return user, err
} }
// Refresh Tokens
func AddRefreshToken(token models.RefreshToken) error {
_, err := DB.Exec("INSERT INTO refresh_tokens(id, user_id, token_hash, expires_at, created_at, revoked, device_info) VALUES (?, ?, ?, ?, ?, ?, ?)",
token.ID, token.UserID, token.Token, token.ExpiresAt, token.CreatedAt, token.Revoked, token.DeviceInfo)
return err
}
func GetRefreshToken(token string) (models.RefreshToken, error) {
row := DB.QueryRow("SELECT * FROM refresh_tokens WHERE token_hash = ?", token)
var refresh_token models.RefreshToken
err := row.Scan(&refresh_token.ID, &refresh_token.UserID, &refresh_token.Token, &refresh_token.ExpiresAt, &refresh_token.CreatedAt, &refresh_token.Revoked, &refresh_token.DeviceInfo)
return refresh_token, err
}
func RevokeRefreshToken(tokenID string) error {
if DB == nil {
return errors.New("db not initialized")
}
res, err := DB.Exec(`
UPDATE refresh_tokens
SET revoked = 1
WHERE id = ?
`, tokenID)
if err != nil {
return err
}
n, err := res.RowsAffected()
if err != nil {
return err
}
if n == 0 {
return ErrNotFound
}
return nil
}
func RevokeAllRefreshTokensForUser(userID string) error {
if DB == nil {
return errors.New("db not initialized")
}
_, err := DB.Exec(`
UPDATE refresh_tokens
SET revoked = 1
WHERE user_id = ?
`, userID)
return err
}

View File

@@ -2,7 +2,9 @@ package utils
import ( import (
"crypto/rand" "crypto/rand"
"crypto/sha256"
"encoding/base64" "encoding/base64"
"encoding/hex"
"github.com/google/uuid" "github.com/google/uuid"
) )
@@ -18,3 +20,14 @@ func GenerateSecret() string {
} }
return base64.StdEncoding.EncodeToString(b) return base64.StdEncoding.EncodeToString(b)
} }
func GenerateRefreshToken() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func HashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}