package user
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"log/slog"
"net/http"
"regexp"
"strconv"
"time"
_ "github.com/lib/pq"
"github.com/pkg/errors"
"golang.org/x/crypto/argon2"
"git.allthings.red/~neallred/allredlib/graph/model"
"git.allthings.red/~neallred/allredlib/logging"
)
type SignupRequest struct {
Username string `json:"username"`
Password string `json:"password"`
Email string `json:"email"`
}
type LoginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
}
type SignupType int
const (
InvalidSignup SignupType = 0
UserPass SignupType = 1
EmailOnly SignupType = 2
EmailPass SignupType = 3
EmailUserPass SignupType = 4
EmailUser SignupType = 5
)
var reEmail = regexp.MustCompile(`\w@\w`)
func (u SignupRequest) Validate() (SignupType, error) {
if u.Username == "" && u.Password == "" && u.Email == "" {
return InvalidSignup, errors.New("Empty user request")
}
if u.Username != "" && u.Password != "" {
if u.Email != "" {
if !reEmail.MatchString(u.Email) {
return InvalidSignup, errors.New("bad email in signup")
}
return EmailUserPass, nil
}
return UserPass, nil
}
if u.Email != "" {
if u.Username != "" {
return EmailUser, nil
}
if u.Password == "" {
return EmailOnly, nil
}
return EmailPass, nil
}
return InvalidSignup, errors.New("Bad user signup request")
}
func ParseSignupFromRequest(r *http.Request) (SignupRequest, SignupType, error) {
bytes, err := io.ReadAll(r.Body)
if err != nil {
return SignupRequest{}, InvalidSignup, errors.Wrap(err, "could not read body")
}
var signupRequest SignupRequest
err = json.Unmarshal(bytes, &signupRequest)
if err != nil {
return signupRequest, InvalidSignup, errors.Wrap(err, "could not decode body to signup request")
}
signupType, err := signupRequest.Validate()
return signupRequest, signupType, err
}
func Signup(u SignupRequest) (string, error) {
return "", nil
}
type RestApi struct {
db *sql.DB
}
func NewRestApi(db *sql.DB) RestApi {
return RestApi{
db: db,
}
}
func getToken(length int) string {
randomBytes := make([]byte, length)
_, err := rand.Read(randomBytes)
if err != nil {
panic(err)
}
return base64.StdEncoding.EncodeToString(randomBytes)[:length]
}
func (api RestApi) Signup(w http.ResponseWriter, r *http.Request) {
signupRequest, signupType, err := ParseSignupFromRequest(r)
if err != nil || signupType == InvalidSignup {
w.WriteHeader(http.StatusBadRequest)
return
}
// UserPass SignupType = 1
// EmailOnly SignupType = 2
// EmailPass SignupType = 3
// EmailUserPass SignupType = 4
// EmailUser SignupType = 5
// hashPassword
switch signupType {
case EmailOnly:
email := signupRequest.Email
_, err := api.db.Exec("INSERT INTO users (username, email) VALUES (?, ?)", email, email)
if err500(err, "Error inserting email", w, r) {
return
}
case UserPass:
salt := getToken(256)
hashedPasswordBytes := argon2.IDKey([]byte(signupRequest.Password), []byte(salt), 1, 64*1024, 4, 256)
hashedPassword := base64.StdEncoding.EncodeToString(hashedPasswordBytes)
query := "INSERT INTO users (username,password,salt) VALUES (?,?,?) RETURNING id"
// TODO: PICK BACK UP HERE
var userId int64
tx, err := api.db.BeginTx(context.Background(), nil)
if err != nil {
slog.Error("error signing up with user pass", "error", err)
w.WriteHeader(http.StatusInternalServerError)
}
newIdRow := tx.QueryRow(query, signupRequest.Username, hashedPassword, salt)
if err := newIdRow.Err(); err != nil {
slog.Error("error signing up with user pass", "error", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
err = newIdRow.Scan(&userId)
if err != nil {
slog.Error("error signing up with user pass", "error", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
sessionToken, expires, err := createSession(tx, userId)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
slog.Error("Error creating session", "err", err)
return
}
tx.Commit()
whoAmI := WhoamiResponse{
UserId: userId,
LogoutTime: expires,
}
roles, rolesErr := getUserRoles(api.db, strconv.FormatInt(userId, 10))
if rolesErr != nil {
slog.Error("could not fetch roles for newly logged in user", "user", userId)
} else {
whoAmI.Roles = roles
}
whoAmIBytes, whoAmIErr := json.Marshal(whoAmI)
if whoAmIErr != nil {
slog.Error("could not marshal whoami response", "err", whoAmIErr)
}
writeSession(w, sessionToken, expires)
if whoAmIErr == nil {
w.Write(whoAmIBytes)
}
}
}
type UserLoginCandidate struct {
Id int64
Username string
Password string
Salt string
}
// interface usage allows passing *sql.DB or *sql.Tx
func createSession(db Execer, userId int64) (string, time.Time, error) {
// create a session
sessionToken := getToken(256)
expiry := time.Now().AddDate(0, 0, 14)
qSession := "INSERT INTO session (users_id,token,expires) VALUES (?,?,?);"
if _, err := db.Exec(qSession, userId, sessionToken, expiry); err != nil {
return "", time.Now(), err
}
return sessionToken, expiry, nil
}
func writeSession(w http.ResponseWriter, token string, expires time.Time) {
sessionCookie := http.Cookie{
HttpOnly: true,
Path: "/",
SameSite: http.SameSiteStrictMode,
Name: "auth",
Expires: expires,
Value: token,
}
http.SetCookie(w, &sessionCookie)
}
type Execer interface {
Exec(query string, args ...any) (sql.Result, error)
}
const qCheckUsers = "SELECT id,username,password,salt FROM users WHERE username=?;"
func (api RestApi) Login(w http.ResponseWriter, r *http.Request) {
bytes, err := io.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
var loginRequest LoginRequest
err = json.Unmarshal(bytes, &loginRequest)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
log.Println(loginRequest)
var candidateUsers []UserLoginCandidate
userRows, err := api.db.Query(qCheckUsers, loginRequest.Username)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
for userRows.Next() {
candidate := UserLoginCandidate{}
err := userRows.Scan(&candidate.Id, &candidate.Username, &candidate.Password, &candidate.Salt)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
candidateUsers = append(candidateUsers, candidate)
}
if numUsers := len(candidateUsers); numUsers != 1 {
w.WriteHeader(http.StatusBadRequest)
log.Printf("Expected %d users for login, got %d", 1, numUsers)
return
}
candidateSalt := candidateUsers[0].Salt
hashedPasswordBytes := argon2.IDKey([]byte(loginRequest.Password), []byte(candidateSalt), 1, 64*1024, 4, 256)
loginPasswordHash := base64.StdEncoding.EncodeToString(hashedPasswordBytes)
if candidateUsers[0].Password != loginPasswordHash {
w.WriteHeader(http.StatusForbidden)
log.Printf("Expected password hash did not match provided hash")
return
}
userIdInt := candidateUsers[0].Id
userId := strconv.FormatInt(userIdInt, 10)
// create a session
sessionToken, expires, err := createSession(api.db, userIdInt)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
slog.Error("Error creating session", "err", err)
return
}
whoAmI := WhoamiResponse{
UserId: userIdInt,
LogoutTime: expires,
}
roles, rolesErr := getUserRoles(api.db, userId)
if rolesErr != nil {
slog.Error("could not fetch roles for newly logged in user", "user", userId)
} else {
whoAmI.Roles = roles
}
whoAmIBytes, whoAmIErr := json.Marshal(whoAmI)
if whoAmIErr != nil {
slog.Error("could not marshal whoami response", "err", whoAmIErr)
}
writeSession(w, sessionToken, expires)
if whoAmIErr == nil {
w.Write(whoAmIBytes)
}
}
type WhoamiToken struct {
UsersId int64 `db:"users_id"`
Token string
Expires time.Time
}
type WhoamiResponse struct {
UserId int64 `json:"userId" db:"users_id"`
Roles []model.Role `json:"roles"`
LogoutTime time.Time `json:"logoutTime"`
}
func (api RestApi) Whoami(w http.ResponseWriter, r *http.Request) {
userAuth, err := api.Auth(w, r, false)
if errors.Is(err, errFailGetAuthCookie) || errors.Is(err, errNoAuthCookie) || errors.Is(err, errExpiredToken) {
w.WriteHeader(http.StatusUnauthorized)
return
}
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
slog.Error("Non-standard error path when attempting to access logged in information", "err", err)
return
}
bytes, err := json.Marshal(WhoamiResponse{
UserId: userAuth.UserId,
Roles: userAuth.Roles,
LogoutTime: userAuth.Expires,
})
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
slog.Error("could not write userAuth response", "err", err)
return
}
w.Write(bytes)
}
type priveleges struct {
anonymous bool
viewer bool
borrower bool
lender bool
serverAdmin bool
addBooks bool
editBooks bool
deleteBooks bool
roleGranter bool
}
type UserAuth struct {
UserId int64
Roles []model.Role
Expires time.Time
priveleges map[model.Role]bool
}
func (ua UserAuth) Can(role model.Role) bool {
if ua.priveleges == nil {
return false
}
can, _ := ua.priveleges[role]
return can
}
func findInList(list []model.Role, role model.Role) bool {
for _, v := range list {
if role == v {
return true
}
}
return false
}
func (ua UserAuth) populatePrivelges() {
isAdmin := findInList(ua.Roles, model.RoleServerAdmin)
ua.priveleges = map[model.Role]bool{
model.RoleAnonymous: true,
model.RoleViewer: isAdmin || findInList(ua.Roles, model.RoleViewer),
model.RoleBorrower: isAdmin || findInList(ua.Roles, model.RoleBorrower),
model.RoleLender: isAdmin || findInList(ua.Roles, model.RoleLender),
model.RoleServerAdmin: isAdmin,
model.RoleAddBooks: isAdmin || findInList(ua.Roles, model.RoleAddBooks),
model.RoleEditBooks: isAdmin || findInList(ua.Roles, model.RoleEditBooks),
model.RoleDeleteBooks: isAdmin || findInList(ua.Roles, model.RoleDeleteBooks),
model.RoleRoleGranter: isAdmin || findInList(ua.Roles, model.RoleRoleGranter),
}
}
var (
errFailGetAuthCookie = fmt.Errorf("fail fetching auth cookie")
errNoAuthCookie = fmt.Errorf("no auth cookie")
errBadQuery = func(err error) error { return errors.Wrap(err, "error forming sql query") }
errExpiredToken = fmt.Errorf("token expired, need to log in again")
)
const qToken = "SELECT users_id,token,expires FROM session WHERE token=? LIMIT 1;"
func (api RestApi) Auth(w http.ResponseWriter, r *http.Request, needsPriveleges bool) (*UserAuth, error) {
authCookie, err := r.Cookie("auth")
if err != nil {
return nil, errFailGetAuthCookie
}
if authCookie == nil {
return nil, errNoAuthCookie
}
var whoamiToken WhoamiToken
row := api.db.QueryRow(qToken, authCookie.Value)
if err := row.Err(); err != nil {
return nil, fmt.Errorf("fail selecting whoami token %+v", err)
}
err = row.Scan(&whoamiToken.UsersId, &whoamiToken.Token, &whoamiToken.Expires)
if err != nil {
return nil, fmt.Errorf("fail selecting whoami token %+v", err)
}
if whoamiToken.Expires.Before(time.Now()) {
return nil, errExpiredToken
}
userId := whoamiToken.UsersId
roles, err := getUserRoles(api.db, strconv.FormatInt(userId, 10))
if err != nil {
return nil, fmt.Errorf("could not get roles for userId %d", userId)
}
userAuth := UserAuth{UserId: userId, Roles: roles, Expires: whoamiToken.Expires}
if needsPriveleges {
userAuth.populatePrivelges()
}
return &userAuth, nil
}
const getUserRolesQuery = "SELECT r.role FROM role AS r INNER JOIN users_role AS ur ON ur.role_id = r.id AND ur.users_id = ?;"
func getUserRoles(db *sql.DB, userId string) ([]model.Role, error) {
var roles []model.Role
rolesRows, err := db.Query(getUserRolesQuery, userId)
if err != nil {
return nil, err
}
for rolesRows.Next() {
var role model.Role
err := rolesRows.Scan(&role)
if err != nil {
return nil, err
}
roles = append(roles, role)
}
return roles, nil
}
var AuthCtxKey = &contextKey{"userAuth"}
type contextKey struct {
name string
}
const qDelToken = "DELETE FROM session WHERE token = ?;"
func err500(err error, msg string, w http.ResponseWriter, r *http.Request) bool {
if err == nil {
return false
}
logging.FromReq(r).Error(msg, "err", err)
w.WriteHeader(http.StatusInternalServerError)
return true
}
func (api RestApi) Logout(w http.ResponseWriter, r *http.Request) {
authCookie, err := r.Cookie("auth")
if err500(err, "fail getting session from cookie to delete", w, r) {
return
}
var zeroTime time.Time
if authCookie == nil {
writeSession(w, "", zeroTime)
w.WriteHeader(http.StatusOK)
return
}
_, err = api.db.Exec(qDelToken, authCookie.Value)
if err500(err, "fail deleting session", w, r) {
return
}
writeSession(w, "", zeroTime)
w.WriteHeader(http.StatusOK)
}