diff --git a/handlers/account.go b/handlers/account.go index cfebe11..7c8829a 100644 --- a/handlers/account.go +++ b/handlers/account.go @@ -9,6 +9,8 @@ import ( "encoding/json" "log" "net/http" + "os" + "time" ) var cfg, _ = config.LoadConfig() @@ -37,99 +39,98 @@ func Register(w http.ResponseWriter, r *http.Request) { user.ID = utils.GenerateUUID() user.Role = models.RoleUser - //if err := storage.AddUser(&user); err != nil { - // log.Println("POST [api/register] " + r.RemoteAddr + ": " + err.Error()) - // http.Error(w, "user already exists", http.StatusBadRequest) - // return - //} + if err := storage.AddUser(&user); err != nil { + log.Println("POST [api/register] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "user already exists", http.StatusBadRequest) + return + } w.WriteHeader(http.StatusCreated) log.Println("POST [api/register] " + r.RemoteAddr + ": Successfully created user") } func Login(w http.ResponseWriter, r *http.Request) { - //var creds struct { - // Username string `json:"username"` - // Password string `json:"password"` - //} - //if err := json.NewDecoder(r.Body).Decode(&creds); err != nil { - // log.Println("POST [api/login] " + r.RemoteAddr + ": " + err.Error()) - // http.Error(w, "Invalid request", http.StatusBadRequest) - // return - //} - // - //user, err := storage.GetUserByUsername(creds.Username) - //if err != nil { - // log.Println("POST [api/login] " + r.RemoteAddr + ": " + err.Error()) - // http.Error(w, "Invalid credentials", http.StatusUnauthorized) - // return - //} - // - //if !auth.CheckPasswordHash(creds.Password, user.Password) { - // log.Println("POST [api/login] " + r.RemoteAddr + ": Invalid credentials") - // http.Error(w, "Invalid credentials", http.StatusUnauthorized) - // return - //} - // - //secret := []byte(os.Getenv("SHAP_JWT_SECRET")) - //if len(secret) == 0 { - // log.Println("POST [api/login] " + r.RemoteAddr + ": Server misconfiguration") - // http.Error(w, "Server misconfiguration", http.StatusInternalServerError) - // return - //} - // - //accessToken, err := auth.GenerateJWT(user.ID, user.Role, secret) - //if err != nil { - // log.Println("POST [api/login] " + r.RemoteAddr + ": " + err.Error()) - // http.Error(w, "Could not generate token", http.StatusInternalServerError) - // return - //} - // - //refreshTokenPlain, err := utils.GenerateRefreshToken() - //if err != nil { - // log.Println("POST [api/login] " + r.RemoteAddr + ": " + err.Error()) - // http.Error(w, "could not generate refresh token", http.StatusInternalServerError) - // return - //} - //refreshHash := utils.HashToken(refreshTokenPlain) - //refreshID := utils.GenerateUUID() - //refreshExpires := time.Now().Add(7 * 24 * time.Hour).Unix() // expiry: 7 days - // - //deviceInfo := r.Header.Get("User-Agent") - // - //if err := storage.AddRefreshToken(&models.RefreshToken{ - // ID: refreshID, - // UserID: user.ID, - // Token: refreshHash, - // ExpiresAt: refreshExpires, - // DeviceInfo: deviceInfo, - // CreatedAt: time.Now().Unix(), - // Revoked: false, - //}); err != nil { - // log.Println("POST [api/login] " + r.RemoteAddr + ": " + err.Error()) - // http.Error(w, "could not save refresh token", http.StatusInternalServerError) - // return - //} - // - //// Return access + refresh token (refresh in plain for client to store securely) - //resp := map[string]interface{}{ - // "access_token": accessToken, - // "refresh_token": refreshTokenPlain, - // "user": map[string]interface{}{ - // "id": user.ID, - // "username": user.Username, - // "role": user.Role, - // }, - // "wgName": cfg.HouseholdName, - //} - // - //w.Header().Set("Content-Type", "application/json") - //err = json.NewEncoder(w).Encode(resp) - //if err != nil { - // log.Println("POST [api/login] " + r.RemoteAddr + ": " + err.Error()) - // http.Error(w, "Something went wrong", http.StatusInternalServerError) - // return - //} - //log.Println("POST [api/login] " + r.RemoteAddr + ": Successfully logged in") + var creds struct { + Username string `json:"username"` + Password string `json:"password"` + } + if err := json.NewDecoder(r.Body).Decode(&creds); err != nil { + log.Println("POST [api/login] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } + + user, err := storage.GetUserByUsername(creds.Username) + if err != nil { + log.Println("POST [api/login] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Invalid credentials", http.StatusUnauthorized) + return + } + + if !auth.CheckPasswordHash(creds.Password, user.Password) { + log.Println("POST [api/login] " + r.RemoteAddr + ": Invalid credentials") + http.Error(w, "Invalid credentials", http.StatusUnauthorized) + return + } + + secret := []byte(os.Getenv("JWT_SECRET")) + if len(secret) == 0 { + log.Println("POST [api/login] " + r.RemoteAddr + ": Server misconfiguration") + http.Error(w, "Server misconfiguration", http.StatusInternalServerError) + return + } + + accessToken, err := auth.GenerateJWT(user.ID, user.Role, secret) + if err != nil { + log.Println("POST [api/login] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Could not generate token", http.StatusInternalServerError) + return + } + + refreshTokenPlain, err := utils.GenerateRefreshToken() + if err != nil { + log.Println("POST [api/login] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "could not generate refresh token", http.StatusInternalServerError) + return + } + refreshHash := utils.HashToken(refreshTokenPlain) + refreshID := utils.GenerateUUID() + refreshExpires := time.Now().Add(7 * 24 * time.Hour).Unix() // expiry: 7 days + + deviceInfo := r.Header.Get("User-Agent") + + if err := storage.AddRefreshToken(&models.RefreshToken{ + ID: refreshID, + UserID: user.ID, + Token: refreshHash, + ExpiresAt: refreshExpires, + DeviceInfo: deviceInfo, + CreatedAt: time.Now().Unix(), + Revoked: false, + }); err != nil { + log.Println("POST [api/login] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "could not save refresh token", http.StatusInternalServerError) + return + } + + // Return access + refresh token (refresh in plain for client to store securely) + resp := map[string]interface{}{ + "access_token": accessToken, + "refresh_token": refreshTokenPlain, + "user": map[string]interface{}{ + "id": user.ID, + "username": user.Username, + "role": user.Role, + }, + } + + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(resp) + if err != nil { + log.Println("POST [api/login] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Something went wrong", http.StatusInternalServerError) + return + } + log.Println("POST [api/login] " + r.RemoteAddr + ": Successfully logged in") } func Logout(w http.ResponseWriter, r *http.Request) { claims := r.Context().Value(auth.UserContextKey).(*auth.Claims) @@ -157,88 +158,88 @@ func TestHandler(w http.ResponseWriter, r *http.Request) { log.Println("GET [api/login] " + r.RemoteAddr + ": Successfully tested connection") } func RefreshToken(w http.ResponseWriter, r *http.Request) { - //var req struct { - // RefreshToken string `json:"refresh_token"` - //} - //if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - // log.Println("POST [api/refresh] " + r.RemoteAddr + ": " + err.Error()) - // http.Error(w, "Invalid request", http.StatusBadRequest) - // return - //} - // - //hashed := utils.HashToken(req.RefreshToken) - // - //tokenRow, err := storage.GetRefreshToken(hashed) - //if err != nil || tokenRow.Revoked || tokenRow.ExpiresAt < time.Now().Unix() { - // log.Println("POST [api/refresh] " + r.RemoteAddr + ": Invalid refresh token") - // http.Error(w, "Invalid refresh token", http.StatusUnauthorized) - // return - //} - // - //if err := storage.RevokeRefreshToken(tokenRow.ID); err != nil { - // log.Println(err) - //} - // - //newToken, _ := utils.GenerateRefreshToken() - //newHash := utils.HashToken(newToken) - //newExpires := time.Now().Add(7 * 24 * time.Hour).Unix() //7 days - //newID := utils.GenerateUUID() - //deviceInfo := r.Header.Get("User-Agent") - //if err = storage.AddRefreshToken(&models.RefreshToken{ - // ID: newID, - // UserID: tokenRow.UserID, - // Token: newHash, - // ExpiresAt: newExpires, - // CreatedAt: time.Now().Unix(), - // Revoked: false, - // DeviceInfo: deviceInfo, - //}); err != nil { - // log.Println("POST [api/refresh] " + r.RemoteAddr + ": " + err.Error()) - // http.Error(w, "Could not generate new refresh token", http.StatusInternalServerError) - // return - //} - // - //user, err := storage.GetUserById(tokenRow.UserID) - //if err != nil { - // log.Println("POST [api/refresh] " + r.RemoteAddr + ": " + err.Error()) - // http.Error(w, "Internal server error", http.StatusInternalServerError) - // return - //} - //accessToken, _ := auth.GenerateJWT(tokenRow.UserID, user.Role, []byte(os.Getenv("SHAP_JWT_SECRET"))) - // - //if err = json.NewEncoder(w).Encode(map[string]string{ - // "access_token": accessToken, - // "refresh_token": newToken, - //}); err != nil { - // log.Println("POST [api/refresh] " + r.RemoteAddr + ": " + err.Error()) - // http.Error(w, "Internal server error", http.StatusInternalServerError) - // return - //} - //log.Println("POST [api/refresh] " + r.RemoteAddr + ": Successfully refreshed token") + var req struct { + RefreshToken string `json:"refresh_token"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + log.Println("POST [api/refresh] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } + + hashed := utils.HashToken(req.RefreshToken) + + tokenRow, err := storage.GetRefreshToken(hashed) + if err != nil || tokenRow.Revoked || tokenRow.ExpiresAt < time.Now().Unix() { + log.Println("POST [api/refresh] " + r.RemoteAddr + ": Invalid refresh token") + http.Error(w, "Invalid refresh token", http.StatusUnauthorized) + return + } + + if err := storage.RevokeRefreshToken(tokenRow.ID); err != nil { + log.Println(err) + } + + newToken, _ := utils.GenerateRefreshToken() + newHash := utils.HashToken(newToken) + newExpires := time.Now().Add(7 * 24 * time.Hour).Unix() //7 days + newID := utils.GenerateUUID() + deviceInfo := r.Header.Get("User-Agent") + if err = storage.AddRefreshToken(&models.RefreshToken{ + ID: newID, + UserID: tokenRow.UserID, + Token: newHash, + ExpiresAt: newExpires, + CreatedAt: time.Now().Unix(), + Revoked: false, + DeviceInfo: deviceInfo, + }); err != nil { + log.Println("POST [api/refresh] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Could not generate new refresh token", http.StatusInternalServerError) + return + } + + user, err := storage.GetUserById(tokenRow.UserID) + if err != nil { + log.Println("POST [api/refresh] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + accessToken, _ := auth.GenerateJWT(tokenRow.UserID, user.Role, []byte(os.Getenv("JWT_SECRET"))) + + if err = json.NewEncoder(w).Encode(map[string]string{ + "access_token": accessToken, + "refresh_token": newToken, + }); err != nil { + log.Println("POST [api/refresh] " + r.RemoteAddr + ": " + err.Error()) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + log.Println("POST [api/refresh] " + r.RemoteAddr + ": Successfully refreshed token") } func UserInfo(w http.ResponseWriter, r *http.Request) { - //if r.Method != http.MethodGet { - // log.Println("GET [api/userinfo] " + r.RemoteAddr + ": Method " + r.Method + " not allowed") - // http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - // return - //} - //query := r.URL.Query() - //idParam := query.Get("id") - //user, err := storage.GetUserById(idParam) - //if err != nil { - // log.Println("GET [api/userinfo] " + r.RemoteAddr + ": User " + idParam + " not found") - // http.Error(w, "User not found", http.StatusNotFound) - // return - //} - //w.Header().Set("Content-Type", "application/json") - //err = json.NewEncoder(w).Encode(map[string]interface{}{ - // "id": user.ID, - // "name": user.Username, - // "avatar_url": "", - //}) - //if err != nil { - // log.Println("GET [api/userinfo] " + r.RemoteAddr + ": " + err.Error()) - // return - //} - //log.Println("GET [api/userinfo] " + r.RemoteAddr + ": Successfully retrieved user info") + if r.Method != http.MethodGet { + log.Println("GET [api/userinfo] " + r.RemoteAddr + ": Method " + r.Method + " not allowed") + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + query := r.URL.Query() + idParam := query.Get("id") + user, err := storage.GetUserById(idParam) + if err != nil { + log.Println("GET [api/userinfo] " + r.RemoteAddr + ": User " + idParam + " not found") + http.Error(w, "User not found", http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(map[string]interface{}{ + "id": user.ID, + "name": user.Username, + "avatar_url": "", + }) + if err != nil { + log.Println("GET [api/userinfo] " + r.RemoteAddr + ": " + err.Error()) + return + } + log.Println("GET [api/userinfo] " + r.RemoteAddr + ": Successfully retrieved user info") } diff --git a/storage/storage.go b/storage/storage.go index 659e14d..684a21a 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "log" + "strings" _ "github.com/glebarez/go-sqlite" ) @@ -22,7 +23,7 @@ func InitDB(filepath string) error { schema := ` PRAGMA foreign_keys = ON; - CREATE TABLE IF NOT EXISTS users ( + CREATE TABLE IF NOT EXISTS users ( id TEXT PRIMARY KEY, username TEXT NOT NULL UNIQUE, password TEXT NOT NULL, @@ -74,6 +75,24 @@ func InitDB(filepath string) error { return err } +// Users +func AddUser(user *models.User) error { + _, err := DB.Exec("INSERT INTO users(id, username, password, role) VALUES (?, ?, ?, ?)", user.ID, strings.ToLower(user.Username), user.Password, user.Role) + return err +} +func GetUserByUsername(username string) (models.User, error) { + row := DB.QueryRow("SELECT * FROM users WHERE username = ?", strings.ToLower(username)) + var user models.User + err := row.Scan(&user.ID, &user.Username, &user.Password, &user.Role) + return user, err +} +func GetUserById(id string) (models.User, error) { + row := DB.QueryRow("SELECT * FROM users WHERE id = ?", id) + var user models.User + err := row.Scan(&user.ID, &user.Username, &user.Password, &user.Role) + return user, err +} + // Refresh Tokens func AddRefreshToken(token *models.RefreshToken) error { _, err := DB.Exec("INSERT INTO refresh_tokens(id, user_id, token_hash, expires_at, created_at, revoked, device_info) VALUES (?, ?, ?, ?, ?, ?, ?)",