package user import ( "context" "crypto/rand" "database/sql" "encoding/base64" "encoding/json" "fmt" "io" "log" "log/slog" "net/http" "regexp" "strconv" "time" _ "github.com/lib/pq" "github.com/pkg/errors" "golang.org/x/crypto/argon2" "git.allthings.red/~neallred/allredlib/graph/model" "git.allthings.red/~neallred/allredlib/logging" ) type SignupRequest struct { Username string `json:"username"` Password string `json:"password"` Email string `json:"email"` } type LoginRequest struct { Username string `json:"username"` Password string `json:"password"` } type SignupType int const ( InvalidSignup SignupType = 0 UserPass SignupType = 1 EmailOnly SignupType = 2 EmailPass SignupType = 3 EmailUserPass SignupType = 4 EmailUser SignupType = 5 ) var reEmail = regexp.MustCompile(`\w@\w`) func (u SignupRequest) Validate() (SignupType, error) { if u.Username == "" && u.Password == "" && u.Email == "" { return InvalidSignup, errors.New("Empty user request") } if u.Username != "" && u.Password != "" { if u.Email != "" { if !reEmail.MatchString(u.Email) { return InvalidSignup, errors.New("bad email in signup") } return EmailUserPass, nil } return UserPass, nil } if u.Email != "" { if u.Username != "" { return EmailUser, nil } if u.Password == "" { return EmailOnly, nil } return EmailPass, nil } return InvalidSignup, errors.New("Bad user signup request") } func ParseSignupFromRequest(r *http.Request) (SignupRequest, SignupType, error) { bytes, err := io.ReadAll(r.Body) if err != nil { return SignupRequest{}, InvalidSignup, errors.Wrap(err, "could not read body") } var signupRequest SignupRequest err = json.Unmarshal(bytes, &signupRequest) if err != nil { return signupRequest, InvalidSignup, errors.Wrap(err, "could not decode body to signup request") } signupType, err := signupRequest.Validate() return signupRequest, signupType, err } func Signup(u SignupRequest) (string, error) { return "", nil } type RestApi struct { db *sql.DB } func NewRestApi(db *sql.DB) RestApi { return RestApi{ db: db, } } func getToken(length int) string { randomBytes := make([]byte, length) _, err := rand.Read(randomBytes) if err != nil { panic(err) } return base64.StdEncoding.EncodeToString(randomBytes)[:length] } func (api RestApi) Signup(w http.ResponseWriter, r *http.Request) { signupRequest, signupType, err := ParseSignupFromRequest(r) if err != nil || signupType == InvalidSignup { w.WriteHeader(http.StatusBadRequest) return } // UserPass SignupType = 1 // EmailOnly SignupType = 2 // EmailPass SignupType = 3 // EmailUserPass SignupType = 4 // EmailUser SignupType = 5 // hashPassword switch signupType { case EmailOnly: email := signupRequest.Email _, err := api.db.Exec("INSERT INTO users (username, email) VALUES (?, ?)", email, email) if err500(err, "Error inserting email", w, r) { return } case UserPass: salt := getToken(256) hashedPasswordBytes := argon2.IDKey([]byte(signupRequest.Password), []byte(salt), 1, 64*1024, 4, 256) hashedPassword := base64.StdEncoding.EncodeToString(hashedPasswordBytes) query := "INSERT INTO users (username,password,salt) VALUES (?,?,?) RETURNING id" // TODO: PICK BACK UP HERE var userId int64 tx, err := api.db.BeginTx(context.Background(), nil) if err != nil { slog.Error("error signing up with user pass", "error", err) w.WriteHeader(http.StatusInternalServerError) } newIdRow := tx.QueryRow(query, signupRequest.Username, hashedPassword, salt) if err := newIdRow.Err(); err != nil { slog.Error("error signing up with user pass", "error", err) w.WriteHeader(http.StatusInternalServerError) return } err = newIdRow.Scan(&userId) if err != nil { slog.Error("error signing up with user pass", "error", err) w.WriteHeader(http.StatusInternalServerError) return } sessionToken, expires, err := createSession(tx, userId) if err != nil { w.WriteHeader(http.StatusInternalServerError) slog.Error("Error creating session", "err", err) return } tx.Commit() whoAmI := WhoamiResponse{ UserId: userId, LogoutTime: expires, } roles, rolesErr := getUserRoles(api.db, strconv.FormatInt(userId, 10)) if rolesErr != nil { slog.Error("could not fetch roles for newly logged in user", "user", userId) } else { whoAmI.Roles = roles } whoAmIBytes, whoAmIErr := json.Marshal(whoAmI) if whoAmIErr != nil { slog.Error("could not marshal whoami response", "err", whoAmIErr) } writeSession(w, sessionToken, expires) if whoAmIErr == nil { w.Write(whoAmIBytes) } } } type UserLoginCandidate struct { Id int64 Username string Password string Salt string } // interface usage allows passing *sql.DB or *sql.Tx func createSession(db Execer, userId int64) (string, time.Time, error) { // create a session sessionToken := getToken(256) expiry := time.Now().AddDate(0, 0, 14) qSession := "INSERT INTO session (users_id,token,expires) VALUES (?,?,?);" if _, err := db.Exec(qSession, userId, sessionToken, expiry); err != nil { return "", time.Now(), err } return sessionToken, expiry, nil } func writeSession(w http.ResponseWriter, token string, expires time.Time) { sessionCookie := http.Cookie{ HttpOnly: true, Path: "/", SameSite: http.SameSiteStrictMode, Name: "auth", Expires: expires, Value: token, } http.SetCookie(w, &sessionCookie) } type Execer interface { Exec(query string, args ...any) (sql.Result, error) } const qCheckUsers = "SELECT id,username,password,salt FROM users WHERE username=?;" func (api RestApi) Login(w http.ResponseWriter, r *http.Request) { bytes, err := io.ReadAll(r.Body) if err != nil { w.WriteHeader(http.StatusBadRequest) return } var loginRequest LoginRequest err = json.Unmarshal(bytes, &loginRequest) if err != nil { w.WriteHeader(http.StatusBadRequest) return } log.Println(loginRequest) var candidateUsers []UserLoginCandidate userRows, err := api.db.Query(qCheckUsers, loginRequest.Username) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } for userRows.Next() { candidate := UserLoginCandidate{} err := userRows.Scan(&candidate.Id, &candidate.Username, &candidate.Password, &candidate.Salt) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } candidateUsers = append(candidateUsers, candidate) } if numUsers := len(candidateUsers); numUsers != 1 { w.WriteHeader(http.StatusBadRequest) log.Printf("Expected %d users for login, got %d", 1, numUsers) return } candidateSalt := candidateUsers[0].Salt hashedPasswordBytes := argon2.IDKey([]byte(loginRequest.Password), []byte(candidateSalt), 1, 64*1024, 4, 256) loginPasswordHash := base64.StdEncoding.EncodeToString(hashedPasswordBytes) if candidateUsers[0].Password != loginPasswordHash { w.WriteHeader(http.StatusForbidden) log.Printf("Expected password hash did not match provided hash") return } userIdInt := candidateUsers[0].Id userId := strconv.FormatInt(userIdInt, 10) // create a session sessionToken, expires, err := createSession(api.db, userIdInt) if err != nil { w.WriteHeader(http.StatusInternalServerError) slog.Error("Error creating session", "err", err) return } whoAmI := WhoamiResponse{ UserId: userIdInt, LogoutTime: expires, } roles, rolesErr := getUserRoles(api.db, userId) if rolesErr != nil { slog.Error("could not fetch roles for newly logged in user", "user", userId) } else { whoAmI.Roles = roles } whoAmIBytes, whoAmIErr := json.Marshal(whoAmI) if whoAmIErr != nil { slog.Error("could not marshal whoami response", "err", whoAmIErr) } writeSession(w, sessionToken, expires) if whoAmIErr == nil { w.Write(whoAmIBytes) } } type WhoamiToken struct { UsersId int64 `db:"users_id"` Token string Expires time.Time } type WhoamiResponse struct { UserId int64 `json:"userId" db:"users_id"` Roles []model.Role `json:"roles"` LogoutTime time.Time `json:"logoutTime"` } func (api RestApi) Whoami(w http.ResponseWriter, r *http.Request) { userAuth, err := api.Auth(w, r, false) if errors.Is(err, errFailGetAuthCookie) || errors.Is(err, errNoAuthCookie) || errors.Is(err, errExpiredToken) { w.WriteHeader(http.StatusUnauthorized) return } if err != nil { w.WriteHeader(http.StatusInternalServerError) slog.Error("Non-standard error path when attempting to access logged in information", "err", err) return } bytes, err := json.Marshal(WhoamiResponse{ UserId: userAuth.UserId, Roles: userAuth.Roles, LogoutTime: userAuth.Expires, }) if err != nil { w.WriteHeader(http.StatusInternalServerError) slog.Error("could not write userAuth response", "err", err) return } w.Write(bytes) } type priveleges struct { anonymous bool viewer bool borrower bool lender bool serverAdmin bool addBooks bool editBooks bool deleteBooks bool roleGranter bool } type UserAuth struct { UserId int64 Roles []model.Role Expires time.Time priveleges map[model.Role]bool } func (ua UserAuth) Can(role model.Role) bool { if ua.priveleges == nil { return false } can, _ := ua.priveleges[role] return can } func findInList(list []model.Role, role model.Role) bool { for _, v := range list { if role == v { return true } } return false } func (ua UserAuth) populatePrivelges() { isAdmin := findInList(ua.Roles, model.RoleServerAdmin) ua.priveleges = map[model.Role]bool{ model.RoleAnonymous: true, model.RoleViewer: isAdmin || findInList(ua.Roles, model.RoleViewer), model.RoleBorrower: isAdmin || findInList(ua.Roles, model.RoleBorrower), model.RoleLender: isAdmin || findInList(ua.Roles, model.RoleLender), model.RoleServerAdmin: isAdmin, model.RoleAddBooks: isAdmin || findInList(ua.Roles, model.RoleAddBooks), model.RoleEditBooks: isAdmin || findInList(ua.Roles, model.RoleEditBooks), model.RoleDeleteBooks: isAdmin || findInList(ua.Roles, model.RoleDeleteBooks), model.RoleRoleGranter: isAdmin || findInList(ua.Roles, model.RoleRoleGranter), } } var ( errFailGetAuthCookie = fmt.Errorf("fail fetching auth cookie") errNoAuthCookie = fmt.Errorf("no auth cookie") errBadQuery = func(err error) error { return errors.Wrap(err, "error forming sql query") } errExpiredToken = fmt.Errorf("token expired, need to log in again") ) const qToken = "SELECT users_id,token,expires FROM session WHERE token=? LIMIT 1;" func (api RestApi) Auth(w http.ResponseWriter, r *http.Request, needsPriveleges bool) (*UserAuth, error) { authCookie, err := r.Cookie("auth") if err != nil { return nil, errFailGetAuthCookie } if authCookie == nil { return nil, errNoAuthCookie } var whoamiToken WhoamiToken row := api.db.QueryRow(qToken, authCookie.Value) if err := row.Err(); err != nil { return nil, fmt.Errorf("fail selecting whoami token %+v", err) } err = row.Scan(&whoamiToken.UsersId, &whoamiToken.Token, &whoamiToken.Expires) if err != nil { return nil, fmt.Errorf("fail selecting whoami token %+v", err) } if whoamiToken.Expires.Before(time.Now()) { return nil, errExpiredToken } userId := whoamiToken.UsersId roles, err := getUserRoles(api.db, strconv.FormatInt(userId, 10)) if err != nil { return nil, fmt.Errorf("could not get roles for userId %d", userId) } userAuth := UserAuth{UserId: userId, Roles: roles, Expires: whoamiToken.Expires} if needsPriveleges { userAuth.populatePrivelges() } return &userAuth, nil } const getUserRolesQuery = "SELECT r.role FROM role AS r INNER JOIN users_role AS ur ON ur.role_id = r.id AND ur.users_id = ?;" func getUserRoles(db *sql.DB, userId string) ([]model.Role, error) { var roles []model.Role rolesRows, err := db.Query(getUserRolesQuery, userId) if err != nil { return nil, err } for rolesRows.Next() { var role model.Role err := rolesRows.Scan(&role) if err != nil { return nil, err } roles = append(roles, role) } return roles, nil } var AuthCtxKey = &contextKey{"userAuth"} type contextKey struct { name string } const qDelToken = "DELETE FROM session WHERE token = ?;" func err500(err error, msg string, w http.ResponseWriter, r *http.Request) bool { if err == nil { return false } logging.FromReq(r).Error(msg, "err", err) w.WriteHeader(http.StatusInternalServerError) return true } func (api RestApi) Logout(w http.ResponseWriter, r *http.Request) { authCookie, err := r.Cookie("auth") if err500(err, "fail getting session from cookie to delete", w, r) { return } var zeroTime time.Time if authCookie == nil { writeSession(w, "", zeroTime) w.WriteHeader(http.StatusOK) return } _, err = api.db.Exec(qDelToken, authCookie.Value) if err500(err, "fail deleting session", w, r) { return } writeSession(w, "", zeroTime) w.WriteHeader(http.StatusOK) }