Autocommit test+ui polish
This commit is contained in:
@@ -203,6 +203,7 @@ func main() {
|
||||
protected.Use(s.authMiddleware)
|
||||
protected.HandleFunc("/auth/me", s.handleMe).Methods(http.MethodGet)
|
||||
protected.HandleFunc("/user/preferences", s.handleSetPreferences).Methods(http.MethodPost)
|
||||
protected.HandleFunc("/user/google/link/start", s.handleGoogleLinkStart).Methods(http.MethodGet)
|
||||
protected.HandleFunc("/user/protocols", s.handleUserProtocols).Methods(http.MethodGet)
|
||||
protected.HandleFunc("/files", s.handleListFiles).Methods(http.MethodGet)
|
||||
protected.HandleFunc("/files/upload", s.handleUpload).Methods(http.MethodPost)
|
||||
|
||||
@@ -8,9 +8,14 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
const googleOAuthStateCookie = "google_oauth_state"
|
||||
const (
|
||||
googleOAuthLoginStateCookie = "google_oauth_state"
|
||||
googleOAuthLinkStateCookie = "google_oauth_link_state"
|
||||
)
|
||||
|
||||
type googleTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
@@ -33,7 +38,39 @@ func (s *Server) handleGoogleAuthStart(w http.ResponseWriter, r *http.Request) {
|
||||
writeErr(w, http.StatusInternalServerError, "failed to initialize oauth")
|
||||
return
|
||||
}
|
||||
setCookie(w, googleOAuthStateCookie, state, 600, s.config.CookieSecure)
|
||||
setCookie(w, googleOAuthLoginStateCookie, state, 600, s.config.CookieSecure)
|
||||
clearCookie(w, googleOAuthLinkStateCookie, s.config.CookieSecure)
|
||||
|
||||
u, err := url.Parse(s.config.GoogleAuthURL)
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusInternalServerError, "invalid google auth config")
|
||||
return
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("client_id", strings.TrimSpace(s.config.GoogleClientID))
|
||||
q.Set("redirect_uri", s.googleRedirectURL(r))
|
||||
q.Set("response_type", "code")
|
||||
q.Set("scope", "openid email profile")
|
||||
q.Set("state", state)
|
||||
q.Set("prompt", "select_account")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
http.Redirect(w, r, u.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
func (s *Server) handleGoogleLinkStart(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.config.GoogleAuthEnabled {
|
||||
writeErr(w, http.StatusNotFound, "google auth is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
state, err := randomToken()
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusInternalServerError, "failed to initialize oauth")
|
||||
return
|
||||
}
|
||||
setCookie(w, googleOAuthLinkStateCookie, state, 600, s.config.CookieSecure)
|
||||
clearCookie(w, googleOAuthLoginStateCookie, s.config.CookieSecure)
|
||||
|
||||
u, err := url.Parse(s.config.GoogleAuthURL)
|
||||
if err != nil {
|
||||
@@ -70,12 +107,11 @@ func (s *Server) handleGoogleAuthCallback(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
state := strings.TrimSpace(r.URL.Query().Get("state"))
|
||||
stateCookie, err := r.Cookie(googleOAuthStateCookie)
|
||||
if err != nil || stateCookie == nil || stateCookie.Value == "" || subtleConstantTimeEq(stateCookie.Value, state) == 0 {
|
||||
flow, uid, ok := s.resolveGoogleOAuthFlow(w, r, state)
|
||||
if !ok {
|
||||
writeErr(w, http.StatusUnauthorized, "invalid oauth state")
|
||||
return
|
||||
}
|
||||
clearCookie(w, googleOAuthStateCookie, s.config.CookieSecure)
|
||||
|
||||
token, err := s.exchangeGoogleCode(r.Context(), code, s.googleRedirectURL(r))
|
||||
if err != nil {
|
||||
@@ -91,6 +127,17 @@ func (s *Server) handleGoogleAuthCallback(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
if flow == "link" {
|
||||
if err := s.linkGoogleSubToUser(uid, info.Sub); err != nil {
|
||||
log.Printf("auth.google.failed ip=%q reason=%q", clientIP(r), "user_link_failed")
|
||||
writeErr(w, http.StatusUnauthorized, "google account link failed")
|
||||
return
|
||||
}
|
||||
log.Printf("auth.google.linked user_id=%d ip=%q", uid, clientIP(r))
|
||||
http.Redirect(w, r, "/drive", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := s.findOrCreateGoogleUser(info.Sub, info.Email)
|
||||
if err != nil {
|
||||
log.Printf("auth.google.failed ip=%q reason=%q", clientIP(r), "user_provision_failed")
|
||||
@@ -108,6 +155,83 @@ func (s *Server) handleGoogleAuthCallback(w http.ResponseWriter, r *http.Request
|
||||
http.Redirect(w, r, "/drive", http.StatusFound)
|
||||
}
|
||||
|
||||
func (s *Server) resolveGoogleOAuthFlow(w http.ResponseWriter, r *http.Request, state string) (string, int64, bool) {
|
||||
state = strings.TrimSpace(state)
|
||||
if state == "" {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
if linkCookie, err := r.Cookie(googleOAuthLinkStateCookie); err == nil && linkCookie != nil && linkCookie.Value != "" {
|
||||
if subtleConstantTimeEq(linkCookie.Value, state) == 1 {
|
||||
clearCookie(w, googleOAuthLinkStateCookie, s.config.CookieSecure)
|
||||
uid, uidErr := s.userIDFromAccessCookie(r)
|
||||
if uidErr != nil {
|
||||
return "", 0, false
|
||||
}
|
||||
return "link", uid, true
|
||||
}
|
||||
}
|
||||
|
||||
if loginCookie, err := r.Cookie(googleOAuthLoginStateCookie); err == nil && loginCookie != nil && loginCookie.Value != "" {
|
||||
if subtleConstantTimeEq(loginCookie.Value, state) == 1 {
|
||||
clearCookie(w, googleOAuthLoginStateCookie, s.config.CookieSecure)
|
||||
return "login", 0, true
|
||||
}
|
||||
}
|
||||
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func (s *Server) userIDFromAccessCookie(r *http.Request) (int64, error) {
|
||||
cookie, err := r.Cookie("access_token")
|
||||
if err != nil || cookie == nil || cookie.Value == "" {
|
||||
return 0, fmt.Errorf("missing access token")
|
||||
}
|
||||
|
||||
claims := &AccessClaims{}
|
||||
tkn, err := jwt.ParseWithClaims(cookie.Value, claims, func(token *jwt.Token) (any, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method")
|
||||
}
|
||||
return []byte(s.config.JWTSecret), nil
|
||||
})
|
||||
if err != nil || !tkn.Valid {
|
||||
return 0, fmt.Errorf("invalid access token")
|
||||
}
|
||||
if claims.UserID <= 0 {
|
||||
return 0, fmt.Errorf("invalid access token claims")
|
||||
}
|
||||
return claims.UserID, nil
|
||||
}
|
||||
|
||||
func (s *Server) linkGoogleSubToUser(userID int64, googleSub string) error {
|
||||
googleSub = strings.TrimSpace(googleSub)
|
||||
if userID <= 0 || googleSub == "" {
|
||||
return fmt.Errorf("invalid link request")
|
||||
}
|
||||
|
||||
if _, err := s.findUser(userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if existing, err := s.findUserByGoogleSub(googleSub); err == nil {
|
||||
if existing.ID != userID {
|
||||
return fmt.Errorf("google account already linked")
|
||||
}
|
||||
return nil
|
||||
} else if !isNoRows(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.orm.updateGoogleSub(userID, googleSub); err != nil {
|
||||
if strings.Contains(strings.ToLower(err.Error()), "unique") {
|
||||
return fmt.Errorf("google account already linked")
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) googleRedirectURL(r *http.Request) string {
|
||||
if v := strings.TrimSpace(s.config.GoogleRedirectURL); v != "" {
|
||||
return v
|
||||
|
||||
@@ -60,7 +60,7 @@ func TestGoogleOAuthCallbackCreatesSessionAndUser(t *testing.T) {
|
||||
if startRec.Code != http.StatusFound {
|
||||
t.Fatalf("start status = %d, want %d", startRec.Code, http.StatusFound)
|
||||
}
|
||||
stateCookie := cookieByName(startRec.Result().Cookies(), googleOAuthStateCookie)
|
||||
stateCookie := cookieByName(startRec.Result().Cookies(), googleOAuthLoginStateCookie)
|
||||
if stateCookie == nil || stateCookie.Value == "" {
|
||||
t.Fatal("missing oauth state cookie")
|
||||
}
|
||||
@@ -106,3 +106,84 @@ func TestGoogleOAuthCallbackCreatesSessionAndUser(t *testing.T) {
|
||||
t.Fatalf("google_sub = %q, want %q", googleSub, "google-sub-1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleOAuthCallbackLinksGoogleToExistingUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
provider := httptest.NewServer(mux)
|
||||
defer provider.Close()
|
||||
|
||||
mux.HandleFunc("/token", func(w http.ResponseWriter, _ *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"access_token": "google-link-token"})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/userinfo", func(w http.ResponseWriter, _ *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"sub": "google-sub-link-1",
|
||||
"email": "other@example.com",
|
||||
"email_verified": true,
|
||||
})
|
||||
})
|
||||
|
||||
s := makeTestServer(t, func(cfg *Config) {
|
||||
cfg.GoogleAuthEnabled = true
|
||||
cfg.GoogleClientID = "client-id"
|
||||
cfg.GoogleClientSecret = "client-secret"
|
||||
cfg.GoogleAuthURL = provider.URL + "/auth"
|
||||
cfg.GoogleTokenURL = provider.URL + "/token"
|
||||
cfg.GoogleUserInfoURL = provider.URL + "/userinfo"
|
||||
})
|
||||
|
||||
user, err := s.createUser("alice", "password123", "dracula", "auto")
|
||||
if err != nil {
|
||||
t.Fatalf("create user failed: %v", err)
|
||||
}
|
||||
|
||||
loginReq := httptest.NewRequest(http.MethodGet, "/api/auth/login-link", nil)
|
||||
loginRec := httptest.NewRecorder()
|
||||
if err := s.issueUserSession(loginRec, loginReq, user.ID); err != nil {
|
||||
t.Fatalf("issueUserSession failed: %v", err)
|
||||
}
|
||||
accessCookie := cookieByName(loginRec.Result().Cookies(), "access_token")
|
||||
if accessCookie == nil || accessCookie.Value == "" {
|
||||
t.Fatal("missing access cookie")
|
||||
}
|
||||
|
||||
startReq := httptest.NewRequest(http.MethodGet, "/api/user/google/link/start", nil)
|
||||
startReq.Host = "file.example.com"
|
||||
startReq.AddCookie(accessCookie)
|
||||
startRec := httptest.NewRecorder()
|
||||
s.handleGoogleLinkStart(startRec, startReq)
|
||||
|
||||
if startRec.Code != http.StatusFound {
|
||||
t.Fatalf("link start status = %d, want %d", startRec.Code, http.StatusFound)
|
||||
}
|
||||
stateCookie := cookieByName(startRec.Result().Cookies(), googleOAuthLinkStateCookie)
|
||||
if stateCookie == nil || stateCookie.Value == "" {
|
||||
t.Fatal("missing link state cookie")
|
||||
}
|
||||
|
||||
cbReq := httptest.NewRequest(http.MethodGet, "/api/auth/google/callback?code=ok-code&state="+url.QueryEscape(stateCookie.Value), nil)
|
||||
cbReq.Host = "file.example.com"
|
||||
cbReq.AddCookie(stateCookie)
|
||||
cbReq.AddCookie(accessCookie)
|
||||
cbRec := httptest.NewRecorder()
|
||||
s.handleGoogleAuthCallback(cbRec, cbReq)
|
||||
|
||||
if cbRec.Code != http.StatusFound {
|
||||
t.Fatalf("callback status = %d, want %d", cbRec.Code, http.StatusFound)
|
||||
}
|
||||
if got := cbRec.Header().Get("Location"); got != "/drive" {
|
||||
t.Fatalf("callback redirect = %q, want %q", got, "/drive")
|
||||
}
|
||||
|
||||
var googleSub string
|
||||
err = s.db.QueryRow(`SELECT COALESCE(google_sub, '') FROM users WHERE id = ?`, user.ID).Scan(&googleSub)
|
||||
if err != nil {
|
||||
t.Fatalf("query user google_sub failed: %v", err)
|
||||
}
|
||||
if strings.TrimSpace(googleSub) != "google-sub-link-1" {
|
||||
t.Fatalf("google_sub = %q, want %q", googleSub, "google-sub-link-1")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user