Files
ZFile/backend/oauth_google_test.go

109 lines
3.4 KiB
Go

package main
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func TestGoogleOAuthCallbackCreatesSessionAndUser(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
provider := httptest.NewServer(mux)
defer provider.Close()
mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Fatalf("parse token form failed: %v", err)
}
if got := r.Form.Get("code"); got != "ok-code" {
t.Fatalf("code = %q, want %q", got, "ok-code")
}
if got := r.Form.Get("client_id"); got != "client-id" {
t.Fatalf("client_id = %q, want %q", got, "client-id")
}
if got := r.Form.Get("client_secret"); got != "client-secret" {
t.Fatalf("client_secret = %q, want %q", got, "client-secret")
}
_ = json.NewEncoder(w).Encode(map[string]any{"access_token": "google-access-token"})
})
mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "Bearer google-access-token" {
t.Fatalf("authorization = %q, want bearer token", got)
}
_ = json.NewEncoder(w).Encode(map[string]any{
"sub": "google-sub-1",
"email": "alice@example.com",
"email_verified": true,
})
})
s := makeTestServer(t, func(cfg *Config) {
cfg.GoogleAuthEnabled = true
cfg.GoogleClientID = "client-id"
cfg.GoogleClientSecret = "client-secret"
cfg.GoogleAuthURL = provider.URL + "/auth"
cfg.GoogleTokenURL = provider.URL + "/token"
cfg.GoogleUserInfoURL = provider.URL + "/userinfo"
})
startReq := httptest.NewRequest(http.MethodGet, "/api/auth/google/start", nil)
startReq.Host = "file.example.com"
startRec := httptest.NewRecorder()
s.handleGoogleAuthStart(startRec, startReq)
if startRec.Code != http.StatusFound {
t.Fatalf("start status = %d, want %d", startRec.Code, http.StatusFound)
}
stateCookie := cookieByName(startRec.Result().Cookies(), googleOAuthStateCookie)
if stateCookie == nil || stateCookie.Value == "" {
t.Fatal("missing oauth state cookie")
}
redir := startRec.Result().Header.Get("Location")
parsed, err := url.Parse(redir)
if err != nil {
t.Fatalf("parse redirect url failed: %v", err)
}
if parsed.Query().Get("state") != stateCookie.Value {
t.Fatalf("redirect state mismatch")
}
cbReq := httptest.NewRequest(http.MethodGet, "/api/auth/google/callback?code=ok-code&state="+url.QueryEscape(stateCookie.Value), nil)
cbReq.Host = "file.example.com"
cbReq.AddCookie(stateCookie)
cbRec := httptest.NewRecorder()
s.handleGoogleAuthCallback(cbRec, cbReq)
if cbRec.Code != http.StatusFound {
t.Fatalf("callback status = %d, want %d", cbRec.Code, http.StatusFound)
}
if got := cbRec.Header().Get("Location"); got != "/drive" {
t.Fatalf("callback redirect = %q, want %q", got, "/drive")
}
if cookieByName(cbRec.Result().Cookies(), "access_token") == nil {
t.Fatal("callback missing access_token cookie")
}
if cookieByName(cbRec.Result().Cookies(), "refresh_token") == nil {
t.Fatal("callback missing refresh_token cookie")
}
var count int
var googleSub string
err = s.db.QueryRow(`SELECT COUNT(*), COALESCE(MAX(google_sub), '') FROM users WHERE email = ?`, "alice@example.com").Scan(&count, &googleSub)
if err != nil {
t.Fatalf("query user failed: %v", err)
}
if count != 1 {
t.Fatalf("users with google email = %d, want 1", count)
}
if strings.TrimSpace(googleSub) != "google-sub-1" {
t.Fatalf("google_sub = %q, want %q", googleSub, "google-sub-1")
}
}