Finished login system with refresh-tokens

This commit is contained in:
2026-02-27 15:58:05 +01:00
parent 1eb179dac1
commit 3ba8903de9
5 changed files with 186 additions and 26 deletions

View File

@@ -2,12 +2,14 @@ package handlers
import ( import (
"encoding/json" "encoding/json"
"log"
"net/http" "net/http"
"os" "os"
"shap-planner-backend/auth" "shap-planner-backend/auth"
"shap-planner-backend/models" "shap-planner-backend/models"
"shap-planner-backend/storage" "shap-planner-backend/storage"
"shap-planner-backend/utils" "shap-planner-backend/utils"
"time"
) )
func Register(w http.ResponseWriter, r *http.Request) { func Register(w http.ResponseWriter, r *http.Request) {
@@ -38,7 +40,6 @@ func Register(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusCreated) w.WriteHeader(http.StatusCreated)
} }
func Login(w http.ResponseWriter, r *http.Request) { func Login(w http.ResponseWriter, r *http.Request) {
var creds struct { var creds struct {
Username string `json:"username"` Username string `json:"username"`
@@ -51,11 +52,13 @@ func Login(w http.ResponseWriter, r *http.Request) {
user, err := storage.GetUserByUsername(creds.Username) user, err := storage.GetUserByUsername(creds.Username)
if err != nil { if err != nil {
println("user " + creds.Username + " not found")
http.Error(w, "Invalid credentials", http.StatusUnauthorized) http.Error(w, "Invalid credentials", http.StatusUnauthorized)
return return
} }
if !auth.CheckPasswordHash(creds.Password, user.Password) { if !auth.CheckPasswordHash(creds.Password, user.Password) {
println("invalid password")
http.Error(w, "Invalid credentials", http.StatusUnauthorized) http.Error(w, "Invalid credentials", http.StatusUnauthorized)
return return
} }
@@ -66,29 +69,55 @@ func Login(w http.ResponseWriter, r *http.Request) {
return return
} }
token, err := auth.GenerateJWT(user.ID, user.Role, secret) accessToken, err := auth.GenerateJWT(user.ID, user.Role, secret)
if err != nil { if err != nil {
http.Error(w, "Could not generate token", http.StatusInternalServerError) http.Error(w, "Could not generate token", http.StatusInternalServerError)
return return
} }
type userResp struct { refreshTokenPlain, err := utils.GenerateRefreshToken()
ID string `json:"id"` if err != nil {
Username string `json:"username"` http.Error(w, "could not generate refresh token", http.StatusInternalServerError)
Role string `json:"role"` return
}
refreshHash := utils.HashToken(refreshTokenPlain)
refreshID := utils.GenerateUUID()
refreshExpires := time.Now().Add(7 * 24 * time.Hour).Unix() // expiry: 7 days
deviceInfo := r.Header.Get("User-Agent")
if err := storage.AddRefreshToken(models.RefreshToken{
ID: refreshID,
UserID: user.ID,
Token: refreshHash,
ExpiresAt: refreshExpires,
DeviceInfo: deviceInfo,
CreatedAt: time.Now().Unix(),
Revoked: false,
}); err != nil {
http.Error(w, "could not save refresh token", http.StatusInternalServerError)
return
}
// Return access + refresh token (refresh in plain for client to store securely)
resp := map[string]interface{}{
"access_token": accessToken,
"refresh_token": refreshTokenPlain,
"user": map[string]interface{}{
"id": user.ID,
"username": user.Username,
"role": user.Role,
},
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{ json.NewEncoder(w).Encode(resp)
"token": token, }
"user": userResp{ func Logout(w http.ResponseWriter, r *http.Request) {
ID: user.ID, claims := r.Context().Value(auth.UserContextKey).(*auth.Claims)
Username: user.Username, storage.RevokeAllRefreshTokensForUser(claims.UserID)
Role: user.Role, w.WriteHeader(204)
},
})
} }
func TestHandler(w http.ResponseWriter, r *http.Request) { func TestHandler(w http.ResponseWriter, r *http.Request) {
claimsRaw := r.Context().Value(auth.UserContextKey) claimsRaw := r.Context().Value(auth.UserContextKey)
if claimsRaw == nil { if claimsRaw == nil {
@@ -109,3 +138,50 @@ func TestHandler(w http.ResponseWriter, r *http.Request) {
"msg": "access granted to protected endpoint", "msg": "access granted to protected endpoint",
}) })
} }
func RefreshToken(w http.ResponseWriter, r *http.Request) {
var req struct {
RefreshToken string `json:"refresh_token"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
hashed := utils.HashToken(req.RefreshToken)
tokenRow, err := storage.GetRefreshToken(hashed)
if err != nil || tokenRow.Revoked || tokenRow.ExpiresAt < time.Now().Unix() {
http.Error(w, "Invalid refresh token", http.StatusUnauthorized)
return
}
if err := storage.RevokeRefreshToken(tokenRow.ID); err != nil {
log.Println(err)
}
newToken, _ := utils.GenerateRefreshToken()
newHash := utils.HashToken(newToken)
newExpires := time.Now().Add(7 * 24 * time.Hour).Unix() //7 days
newID := utils.GenerateUUID()
deviceInfo := r.Header.Get("User-Agent")
if err = storage.AddRefreshToken(models.RefreshToken{
ID: newID,
UserID: tokenRow.UserID,
Token: newHash,
ExpiresAt: newExpires,
CreatedAt: time.Now().Unix(),
Revoked: false,
DeviceInfo: deviceInfo,
}); err != nil {
return
}
accessToken, _ := auth.GenerateJWT(tokenRow.UserID, "", []byte(os.Getenv("SHAP_JWT_SECRET")))
if err = json.NewEncoder(w).Encode(map[string]string{
"access_token": accessToken,
"refresh_token": newToken,
}); err != nil {
return
}
}

View File

@@ -1,10 +1,11 @@
package models package models
import "time"
type RefreshToken struct { type RefreshToken struct {
ID string `json:id` ID string `json:id`
UserID string `json:userid` UserID string `json:userid`
Token string `json:token` Token string `json:token`
ExpiresAt time.Time `json:expiresat` ExpiresAt int64 `json:expiresat`
CreatedAt int64 `json:createdat`
Revoked bool `json:revoked`
DeviceInfo string `json:deviceinfo`
} }

View File

@@ -48,6 +48,8 @@ func (server *Server) Run() {
// Public // Public
mux.HandleFunc("/api/login", handlers.Login) mux.HandleFunc("/api/login", handlers.Login)
mux.HandleFunc("/api/register", handlers.Register) mux.HandleFunc("/api/register", handlers.Register)
mux.HandleFunc("/api/refresh", handlers.RefreshToken)
mux.HandleFunc("/api/logout", handlers.Logout)
// Login required // Login required
mux.Handle("/api/expenses", auth.AuthMiddleware(server.JWTSecret)(http.HandlerFunc(handlers.GetExpenses))) mux.Handle("/api/expenses", auth.AuthMiddleware(server.JWTSecret)(http.HandlerFunc(handlers.GetExpenses)))

View File

@@ -2,11 +2,13 @@ package storage
import ( import (
"database/sql" "database/sql"
"errors"
"shap-planner-backend/models" "shap-planner-backend/models"
_ "github.com/glebarez/go-sqlite" _ "github.com/glebarez/go-sqlite"
) )
var ErrNotFound = sql.ErrNoRows
var DB *sql.DB var DB *sql.DB
func InitDB(filepath string) error { func InitDB(filepath string) error {
@@ -19,13 +21,33 @@ func InitDB(filepath string) error {
//Create Users-Table //Create Users-Table
_, err = DB.Exec(`CREATE TABLE IF NOT EXISTS users( _, err = DB.Exec(`CREATE TABLE IF NOT EXISTS users(
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
username TEXT UNIQUE, username TEXT UNIQUE NOT NULL,
password TEXT password TEXT NOT NULL,
role TEXT NOT NULL
);`) );`)
if err != nil { if err != nil {
return err return err
} }
//Create refresh token-table
_, err = DB.Exec(`CREATE TABLE IF NOT EXISTS refresh_tokens(
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
token_hash TEXT NOT NULL,
expires_at INTEGER NOT NULL,
created_at INTEGER NOT NULL,
revoked INTEGER NOT NULL DEFAULT 0,
device_info TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
)`)
if err != nil {
return err
}
_, err = DB.Exec(`CREATE INDEX IF NOT EXISTS idx_refresh_token_hash ON refresh_tokens(token_hash)`)
if err != nil {
return err
}
//Create Expenses-Table //Create Expenses-Table
_, err = DB.Exec(`CREATE TABLE IF NOT EXISTS expenses( _, err = DB.Exec(`CREATE TABLE IF NOT EXISTS expenses(
id TEXT PRIMARY KEY id TEXT PRIMARY KEY
@@ -34,21 +56,67 @@ func InitDB(filepath string) error {
return err return err
} }
// Users
func AddUser(user models.User) error { func AddUser(user models.User) error {
_, err := DB.Exec("INSERT INTO users(id, username, password, role) VALUES (?, ?, ?, ?)", user.ID, user.Username, user.Password, user.Role) _, err := DB.Exec("INSERT INTO users(id, username, password, role) VALUES (?, ?, ?, ?)", user.ID, user.Username, user.Password, user.Role)
return err return err
} }
func GetUserByUsername(username string) (models.User, error) { func GetUserByUsername(username string) (models.User, error) {
row := DB.QueryRow("SELECT * FROM users WHERE username = ?", username) row := DB.QueryRow("SELECT * FROM users WHERE username = ?", username)
var user models.User var user models.User
err := row.Scan(&user.ID, &user.Username, &user.Password) err := row.Scan(&user.ID, &user.Username, &user.Password, &user.Role)
return user, err return user, err
} }
func GetUserById(id string) (models.User, error) { func GetUserById(id string) (models.User, error) {
row := DB.QueryRow("SELECT * FROM users WHERE id = ?", id) row := DB.QueryRow("SELECT * FROM users WHERE id = ?", id)
var user models.User var user models.User
err := row.Scan(&user.ID, &user.Username, &user.Password) err := row.Scan(&user.ID, &user.Username, &user.Password)
return user, err return user, err
} }
// Refresh Tokens
func AddRefreshToken(token models.RefreshToken) error {
_, err := DB.Exec("INSERT INTO refresh_tokens(id, user_id, token_hash, expires_at, created_at, revoked, device_info) VALUES (?, ?, ?, ?, ?, ?, ?)",
token.ID, token.UserID, token.Token, token.ExpiresAt, token.CreatedAt, token.Revoked, token.DeviceInfo)
return err
}
func GetRefreshToken(token string) (models.RefreshToken, error) {
row := DB.QueryRow("SELECT * FROM refresh_tokens WHERE token_hash = ?", token)
var refresh_token models.RefreshToken
err := row.Scan(&refresh_token.ID, &refresh_token.UserID, &refresh_token.Token, &refresh_token.ExpiresAt, &refresh_token.CreatedAt, &refresh_token.Revoked, &refresh_token.DeviceInfo)
return refresh_token, err
}
func RevokeRefreshToken(tokenID string) error {
if DB == nil {
return errors.New("db not initialized")
}
res, err := DB.Exec(`
UPDATE refresh_tokens
SET revoked = 1
WHERE id = ?
`, tokenID)
if err != nil {
return err
}
n, err := res.RowsAffected()
if err != nil {
return err
}
if n == 0 {
return ErrNotFound
}
return nil
}
func RevokeAllRefreshTokensForUser(userID string) error {
if DB == nil {
return errors.New("db not initialized")
}
_, err := DB.Exec(`
UPDATE refresh_tokens
SET revoked = 1
WHERE user_id = ?
`, userID)
return err
}

View File

@@ -2,7 +2,9 @@ package utils
import ( import (
"crypto/rand" "crypto/rand"
"crypto/sha256"
"encoding/base64" "encoding/base64"
"encoding/hex"
"github.com/google/uuid" "github.com/google/uuid"
) )
@@ -18,3 +20,14 @@ func GenerateSecret() string {
} }
return base64.StdEncoding.EncodeToString(b) return base64.StdEncoding.EncodeToString(b)
} }
func GenerateRefreshToken() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func HashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}