Add Google OAuth, German locale, and ORM-backed user access
This commit is contained in:
249
backend/api_ftp_test.go
Normal file
249
backend/api_ftp_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
ftpserver "goftp.io/server/v2"
|
||||
)
|
||||
|
||||
func makeTestServer(t *testing.T, mutate func(*Config)) *Server {
|
||||
t.Helper()
|
||||
|
||||
root := t.TempDir()
|
||||
cfg := Config{
|
||||
Addr: ":0",
|
||||
DBPath: filepath.Join(root, "app.db"),
|
||||
StorageRoot: filepath.Join(root, "users"),
|
||||
AppDomain: "file.example.com",
|
||||
AllowedHost: "file.example.com",
|
||||
CORSOrigin: "https://file.example.com",
|
||||
CookieSecure: false,
|
||||
MaxBodyBytes: 8 * 1024 * 1024,
|
||||
RateLimitPerMin: 1000,
|
||||
AuthRateLimitPerMin: 1000,
|
||||
JWTSecret: "test-jwt-secret-very-long-value-1234567890",
|
||||
AccessTTL: 15 * time.Minute,
|
||||
RefreshTTL: 24 * time.Hour,
|
||||
ShareDefaultTTL: 24 * time.Hour,
|
||||
AdminSessionTTL: 12 * time.Hour,
|
||||
AdminLogin: "admin",
|
||||
AdminPasswordHash: "sha256:dummy",
|
||||
}
|
||||
if mutate != nil {
|
||||
mutate(&cfg)
|
||||
}
|
||||
|
||||
db, err := openDB(cfg.DBPath)
|
||||
if err != nil {
|
||||
t.Fatalf("openDB failed: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
if err := migrate(db); err != nil {
|
||||
t.Fatalf("migrate failed: %v", err)
|
||||
}
|
||||
|
||||
storage, err := buildStorage(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("buildStorage failed: %v", err)
|
||||
}
|
||||
orm, err := newORMRepo(cfg.DBPath)
|
||||
if err != nil {
|
||||
t.Fatalf("newORMRepo failed: %v", err)
|
||||
}
|
||||
|
||||
return &Server{db: db, orm: orm, config: cfg, storage: storage, limiter: newRateLimiter()}
|
||||
}
|
||||
|
||||
func decodeJSONBody[T any](t *testing.T, res *http.Response, out *T) {
|
||||
t.Helper()
|
||||
defer res.Body.Close()
|
||||
if err := json.NewDecoder(res.Body).Decode(out); err != nil {
|
||||
t.Fatalf("decode json failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func cookieByName(cookies []*http.Cookie, name string) *http.Cookie {
|
||||
for _, c := range cookies {
|
||||
if c.Name == name {
|
||||
return c
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAPILoginRefreshAndMe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := makeTestServer(t, nil)
|
||||
user, err := s.createUser("alice", "password123", "dracula", "auto")
|
||||
if err != nil {
|
||||
t.Fatalf("createUser failed: %v", err)
|
||||
}
|
||||
|
||||
loginReq := httptest.NewRequest(http.MethodPost, "/api/auth/login", strings.NewReader(`{"username":"alice","password":"password123"}`))
|
||||
loginReq.Header.Set("Content-Type", "application/json")
|
||||
loginRec := httptest.NewRecorder()
|
||||
s.handleLogin(loginRec, loginReq)
|
||||
|
||||
loginRes := loginRec.Result()
|
||||
if loginRes.StatusCode != http.StatusOK {
|
||||
t.Fatalf("login status = %d, want %d", loginRes.StatusCode, http.StatusOK)
|
||||
}
|
||||
if cookieByName(loginRes.Cookies(), "access_token") == nil {
|
||||
t.Fatal("login did not set access_token cookie")
|
||||
}
|
||||
refreshCookie := cookieByName(loginRes.Cookies(), "refresh_token")
|
||||
if refreshCookie == nil {
|
||||
t.Fatal("login did not set refresh_token cookie")
|
||||
}
|
||||
|
||||
var meResp User
|
||||
decodeJSONBody(t, loginRes, &meResp)
|
||||
if meResp.ID != user.ID || meResp.Username != user.Username {
|
||||
t.Fatalf("login response user mismatch: got %+v want id=%d username=%q", meResp, user.ID, user.Username)
|
||||
}
|
||||
|
||||
refreshReq := httptest.NewRequest(http.MethodPost, "/api/auth/refresh", nil)
|
||||
refreshReq.AddCookie(refreshCookie)
|
||||
refreshRec := httptest.NewRecorder()
|
||||
s.handleRefresh(refreshRec, refreshReq)
|
||||
|
||||
refreshRes := refreshRec.Result()
|
||||
if refreshRes.StatusCode != http.StatusOK {
|
||||
t.Fatalf("refresh status = %d, want %d", refreshRes.StatusCode, http.StatusOK)
|
||||
}
|
||||
newAccess := cookieByName(refreshRes.Cookies(), "access_token")
|
||||
if newAccess == nil {
|
||||
t.Fatal("refresh did not rotate access_token cookie")
|
||||
}
|
||||
|
||||
meReq := httptest.NewRequest(http.MethodGet, "/api/auth/me", nil)
|
||||
meReq.AddCookie(newAccess)
|
||||
meRec := httptest.NewRecorder()
|
||||
s.authMiddleware(http.HandlerFunc(s.handleMe)).ServeHTTP(meRec, meReq)
|
||||
|
||||
if meRec.Code != http.StatusOK {
|
||||
t.Fatalf("/api/auth/me status = %d, want %d", meRec.Code, http.StatusOK)
|
||||
}
|
||||
var meAfter User
|
||||
if err := json.NewDecoder(meRec.Body).Decode(&meAfter); err != nil {
|
||||
t.Fatalf("decode /api/auth/me failed: %v", err)
|
||||
}
|
||||
if meAfter.Username != "alice" {
|
||||
t.Fatalf("/api/auth/me username = %q, want %q", meAfter.Username, "alice")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIUserProtocolsFTPS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := makeTestServer(t, func(cfg *Config) {
|
||||
cfg.FTPSEnabled = true
|
||||
cfg.FTPSHost = "0.0.0.0"
|
||||
cfg.FTPSPort = 2990
|
||||
cfg.FTPSPublicIP = "198.51.100.10"
|
||||
cfg.FTPSExplicit = true
|
||||
cfg.FTPSForceTLS = true
|
||||
})
|
||||
|
||||
if _, err := s.createUser("bob", "password123", "dracula", "auto"); err != nil {
|
||||
t.Fatalf("createUser failed: %v", err)
|
||||
}
|
||||
|
||||
loginReq := httptest.NewRequest(http.MethodPost, "/api/auth/login", strings.NewReader(`{"username":"bob","password":"password123"}`))
|
||||
loginReq.Header.Set("Content-Type", "application/json")
|
||||
loginRec := httptest.NewRecorder()
|
||||
s.handleLogin(loginRec, loginReq)
|
||||
if loginRec.Code != http.StatusOK {
|
||||
t.Fatalf("login status = %d", loginRec.Code)
|
||||
}
|
||||
access := cookieByName(loginRec.Result().Cookies(), "access_token")
|
||||
if access == nil {
|
||||
t.Fatal("missing access token cookie")
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/user/protocols", nil)
|
||||
req.AddCookie(access)
|
||||
rec := httptest.NewRecorder()
|
||||
s.authMiddleware(http.HandlerFunc(s.handleUserProtocols)).ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("protocols status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var out userProtocolsResponse
|
||||
if err := json.NewDecoder(rec.Body).Decode(&out); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
if out.FTPS == nil {
|
||||
t.Fatal("expected FTPS profile in response")
|
||||
}
|
||||
if out.FTPS.Username != "bob" {
|
||||
t.Fatalf("ftps username = %q, want %q", out.FTPS.Username, "bob")
|
||||
}
|
||||
if out.FTPS.Host != "198.51.100.10" || out.FTPS.Port != 2990 {
|
||||
t.Fatalf("ftps endpoint mismatch: got %s:%d", out.FTPS.Host, out.FTPS.Port)
|
||||
}
|
||||
if !out.FTPS.ExplicitTLS || !out.FTPS.ForceTLS {
|
||||
t.Fatal("expected explicit/forced TLS flags to be true")
|
||||
}
|
||||
}
|
||||
|
||||
func insertUserWithHash(t *testing.T, db *sql.DB, username, hash string) int64 {
|
||||
t.Helper()
|
||||
res, err := db.Exec(`INSERT INTO users(email, password_hash, theme, color_mode, archive_format) VALUES (?, ?, 'dracula', 'auto', 'zip')`, username, hash)
|
||||
if err != nil {
|
||||
t.Fatalf("insert user failed: %v", err)
|
||||
}
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
t.Fatalf("LastInsertId failed: %v", err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func testFTPContextWithUserID(userID int64) *ftpserver.Context {
|
||||
return &ftpserver.Context{Sess: &ftpserver.Session{Data: map[string]interface{}{"filez_user_id": userID}}}
|
||||
}
|
||||
|
||||
func TestFTPDriverPlainTransfer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := makeTestServer(t, nil)
|
||||
hash, err := hashPasswordArgon2ID("password123")
|
||||
if err != nil {
|
||||
t.Fatalf("hash password failed: %v", err)
|
||||
}
|
||||
uid := insertUserWithHash(t, s.db, "dave", hash)
|
||||
|
||||
drv := &ftpUserDriver{db: s.db, root: s.config.StorageRoot}
|
||||
ctx := testFTPContextWithUserID(uid)
|
||||
|
||||
plain := []byte("plain ftp payload")
|
||||
if _, err := drv.PutFile(ctx, "/plain.txt", bytes.NewReader(plain), 0); err != nil {
|
||||
t.Fatalf("PutFile failed: %v", err)
|
||||
}
|
||||
|
||||
size, rc, err := drv.GetFile(ctx, "/plain.txt", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("GetFile failed: %v", err)
|
||||
}
|
||||
defer rc.Close()
|
||||
got, err := io.ReadAll(rc)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll failed: %v", err)
|
||||
}
|
||||
if size != int64(len(plain)) || !bytes.Equal(got, plain) {
|
||||
t.Fatalf("plain transfer mismatch: size=%d got=%q", size, string(got))
|
||||
}
|
||||
}
|
||||
218
backend/config.go
Normal file
218
backend/config.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Addr string
|
||||
DBPath string
|
||||
StorageRoot string
|
||||
AppDomain string
|
||||
AllowedHost string
|
||||
CORSOrigin string
|
||||
CookieSecure bool
|
||||
MaxBodyBytes int64
|
||||
|
||||
GoogleAuthEnabled bool
|
||||
GoogleClientID string
|
||||
GoogleClientSecret string
|
||||
GoogleRedirectURL string
|
||||
GoogleAuthURL string
|
||||
GoogleTokenURL string
|
||||
GoogleUserInfoURL string
|
||||
|
||||
RateLimitPerMin int
|
||||
AuthRateLimitPerMin int
|
||||
|
||||
JWTSecret string
|
||||
AccessTTL time.Duration
|
||||
RefreshTTL time.Duration
|
||||
ShareDefaultTTL time.Duration
|
||||
AdminSessionTTL time.Duration
|
||||
AdminLogin string
|
||||
AdminPasswordHash string
|
||||
|
||||
FTPEnabled bool
|
||||
FTPHost string
|
||||
FTPPort int
|
||||
FTPPublicIP string
|
||||
FTPPassivePorts string
|
||||
|
||||
FTPSEnabled bool
|
||||
FTPSHost string
|
||||
FTPSPort int
|
||||
FTPSPublicIP string
|
||||
FTPSPassivePorts string
|
||||
FTPSCertFile string
|
||||
FTPSKeyFile string
|
||||
FTPSLEDomain string
|
||||
FTPSLEDir string
|
||||
FTPSExplicit bool
|
||||
FTPSForceTLS bool
|
||||
|
||||
SFTPEnabled bool
|
||||
SFTPHost string
|
||||
SFTPPort int
|
||||
SFTPHostKeyPath string
|
||||
}
|
||||
|
||||
func loadConfig() Config {
|
||||
_ = godotenv.Load(".env", "../.env")
|
||||
|
||||
dbPath := getEnv("DB_PATH", "./app.db")
|
||||
storageRoot := getEnv("STORAGE_ROOT", "./users")
|
||||
if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil {
|
||||
log.Fatalf("failed to create db path parent: %v", err)
|
||||
}
|
||||
if err := os.MkdirAll(storageRoot, 0o755); err != nil {
|
||||
log.Fatalf("failed to create storage root: %v", err)
|
||||
}
|
||||
|
||||
cfg := Config{
|
||||
Addr: getEnv("ADDR", ":8080"),
|
||||
DBPath: dbPath,
|
||||
StorageRoot: storageRoot,
|
||||
AppDomain: strings.ToLower(strings.TrimSpace(getEnv("APP_DOMAIN", "file.example.com"))),
|
||||
AllowedHost: strings.ToLower(getEnv("ALLOWED_HOST", "")),
|
||||
CORSOrigin: getEnv("CORS_ALLOWED_ORIGIN", ""),
|
||||
CookieSecure: getEnv("COOKIE_SECURE", "false") == "true",
|
||||
MaxBodyBytes: int64(getEnvInt("MAX_BODY_MB", 8)) * 1024 * 1024,
|
||||
GoogleAuthEnabled: getEnv("GOOGLE_AUTH_ENABLED", "false") == "true",
|
||||
GoogleClientID: getEnv("GOOGLE_CLIENT_ID", ""),
|
||||
GoogleClientSecret: getEnv("GOOGLE_CLIENT_SECRET", ""),
|
||||
GoogleRedirectURL: getEnv("GOOGLE_REDIRECT_URL", ""),
|
||||
GoogleAuthURL: getEnv("GOOGLE_AUTH_URL", "https://accounts.google.com/o/oauth2/v2/auth"),
|
||||
GoogleTokenURL: getEnv("GOOGLE_TOKEN_URL", "https://oauth2.googleapis.com/token"),
|
||||
GoogleUserInfoURL: getEnv("GOOGLE_USERINFO_URL", "https://openidconnect.googleapis.com/v1/userinfo"),
|
||||
RateLimitPerMin: getEnvInt("RATE_LIMIT_PER_MIN", 240),
|
||||
AuthRateLimitPerMin: getEnvInt("AUTH_RATE_LIMIT_PER_MIN", 30),
|
||||
JWTSecret: getEnv("JWT_SECRET", "dev-change-me-immediately"),
|
||||
AccessTTL: 15 * time.Minute,
|
||||
RefreshTTL: 30 * 24 * time.Hour,
|
||||
ShareDefaultTTL: 24 * time.Hour,
|
||||
AdminSessionTTL: 12 * time.Hour,
|
||||
AdminLogin: getEnv("ADMIN_LOGIN", "admin"),
|
||||
AdminPasswordHash: getEnv("ADMIN_PASSWORD_HASH", ""),
|
||||
FTPEnabled: getEnv("FTP_ENABLED", "false") == "true",
|
||||
FTPHost: getEnv("FTP_HOST", "0.0.0.0"),
|
||||
FTPPort: getEnvInt("FTP_PORT", 2121),
|
||||
FTPPublicIP: getEnv("FTP_PUBLIC_IP", ""),
|
||||
FTPPassivePorts: getEnv("FTP_PASSIVE_PORTS", ""),
|
||||
FTPSEnabled: getEnv("FTPS_ENABLED", "false") == "true",
|
||||
FTPSHost: getEnv("FTPS_HOST", "0.0.0.0"),
|
||||
FTPSPort: getEnvInt("FTPS_PORT", 2990),
|
||||
FTPSPublicIP: getEnv("FTPS_PUBLIC_IP", ""),
|
||||
FTPSPassivePorts: getEnv("FTPS_PASSIVE_PORTS", ""),
|
||||
FTPSCertFile: getEnv("FTPS_CERT_FILE", ""),
|
||||
FTPSKeyFile: getEnv("FTPS_KEY_FILE", ""),
|
||||
FTPSLEDomain: strings.ToLower(strings.TrimSpace(getEnv("FTPS_LETSENCRYPT_DOMAIN", ""))),
|
||||
FTPSLEDir: getEnv("FTPS_LETSENCRYPT_DIR", "/etc/letsencrypt/live"),
|
||||
FTPSExplicit: getEnv("FTPS_EXPLICIT", "true") != "false",
|
||||
FTPSForceTLS: getEnv("FTPS_FORCE_TLS", "true") != "false",
|
||||
SFTPEnabled: getEnv("SFTP_ENABLED", "false") == "true",
|
||||
SFTPHost: getEnv("SFTP_HOST", "0.0.0.0"),
|
||||
SFTPPort: getEnvInt("SFTP_PORT", 2022),
|
||||
SFTPHostKeyPath: getEnv("SFTP_HOST_KEY_PATH", "./sftp_host_ed25519"),
|
||||
}
|
||||
|
||||
if cfg.AllowedHost == "" {
|
||||
cfg.AllowedHost = cfg.AppDomain
|
||||
}
|
||||
if cfg.CORSOrigin == "" {
|
||||
cfg.CORSOrigin = "https://" + cfg.AppDomain
|
||||
}
|
||||
if cfg.JWTSecret == "dev-change-me-immediately" {
|
||||
log.Println("warning: JWT_SECRET is using default development value")
|
||||
}
|
||||
if strings.TrimSpace(cfg.AdminPasswordHash) == "" {
|
||||
log.Fatal("ADMIN_PASSWORD_HASH is required. Generate one with: go run . hash-admin <password>")
|
||||
}
|
||||
if cfg.GoogleAuthEnabled {
|
||||
if strings.TrimSpace(cfg.GoogleClientID) == "" || strings.TrimSpace(cfg.GoogleClientSecret) == "" {
|
||||
log.Fatal("GOOGLE_AUTH_ENABLED=true requires GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET")
|
||||
}
|
||||
}
|
||||
if cfg.FTPEnabled {
|
||||
if cfg.FTPPort < 1 || cfg.FTPPort > 65535 {
|
||||
log.Fatal("FTP_PORT must be in range 1..65535")
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.FTPSEnabled {
|
||||
if cfg.FTPSPort < 1 || cfg.FTPSPort > 65535 {
|
||||
log.Fatal("FTPS_PORT must be in range 1..65535")
|
||||
}
|
||||
applyFTPSLetsEncryptDefaults(&cfg)
|
||||
if strings.TrimSpace(cfg.FTPSCertFile) == "" || strings.TrimSpace(cfg.FTPSKeyFile) == "" {
|
||||
log.Fatal("FTPS_ENABLED=true requires FTPS_CERT_FILE/FTPS_KEY_FILE or FTPS_LETSENCRYPT_DOMAIN")
|
||||
}
|
||||
if _, err := os.Stat(cfg.FTPSCertFile); err != nil {
|
||||
log.Fatalf("FTPS_CERT_FILE is invalid: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(cfg.FTPSKeyFile); err != nil {
|
||||
log.Fatalf("FTPS_KEY_FILE is invalid: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.FTPEnabled && cfg.FTPSEnabled {
|
||||
if cfg.FTPPort == cfg.FTPSPort && strings.EqualFold(cfg.FTPHost, cfg.FTPSHost) {
|
||||
log.Fatal("FTP and FTPS cannot share the same host:port")
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.SFTPEnabled {
|
||||
if cfg.SFTPPort < 1 || cfg.SFTPPort > 65535 {
|
||||
log.Fatal("SFTP_PORT must be in range 1..65535")
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func applyFTPSLetsEncryptDefaults(cfg *Config) {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(cfg.FTPSLEDomain) == "" {
|
||||
return
|
||||
}
|
||||
base := strings.TrimSpace(cfg.FTPSLEDir)
|
||||
if base == "" {
|
||||
base = "/etc/letsencrypt/live"
|
||||
}
|
||||
domainDir := filepath.Join(base, cfg.FTPSLEDomain)
|
||||
if strings.TrimSpace(cfg.FTPSCertFile) == "" {
|
||||
cfg.FTPSCertFile = filepath.Join(domainDir, "fullchain.pem")
|
||||
}
|
||||
if strings.TrimSpace(cfg.FTPSKeyFile) == "" {
|
||||
cfg.FTPSKeyFile = filepath.Join(domainDir, "privkey.pem")
|
||||
}
|
||||
}
|
||||
|
||||
func getEnv(key, fallback string) string {
|
||||
v := strings.TrimSpace(os.Getenv(key))
|
||||
if v == "" {
|
||||
return fallback
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func getEnvInt(key string, fallback int) int {
|
||||
v := strings.TrimSpace(os.Getenv(key))
|
||||
if v == "" {
|
||||
return fallback
|
||||
}
|
||||
n, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
return n
|
||||
}
|
||||
33
backend/config_test.go
Normal file
33
backend/config_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package main
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestApplyFTPSLetsEncryptDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := Config{FTPSLEDomain: "files.example.com"}
|
||||
applyFTPSLetsEncryptDefaults(&cfg)
|
||||
if cfg.FTPSCertFile != "/etc/letsencrypt/live/files.example.com/fullchain.pem" {
|
||||
t.Fatalf("unexpected cert path: %q", cfg.FTPSCertFile)
|
||||
}
|
||||
if cfg.FTPSKeyFile != "/etc/letsencrypt/live/files.example.com/privkey.pem" {
|
||||
t.Fatalf("unexpected key path: %q", cfg.FTPSKeyFile)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyFTPSLetsEncryptDefaultsCustomDirAndPreserveManual(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := Config{
|
||||
FTPSLEDomain: "files.example.com",
|
||||
FTPSLEDir: "/var/lib/acme/live",
|
||||
FTPSCertFile: "/custom/cert.pem",
|
||||
}
|
||||
applyFTPSLetsEncryptDefaults(&cfg)
|
||||
if cfg.FTPSCertFile != "/custom/cert.pem" {
|
||||
t.Fatalf("manual cert should be preserved, got %q", cfg.FTPSCertFile)
|
||||
}
|
||||
if cfg.FTPSKeyFile != "/var/lib/acme/live/files.example.com/privkey.pem" {
|
||||
t.Fatalf("unexpected key path: %q", cfg.FTPSKeyFile)
|
||||
}
|
||||
}
|
||||
@@ -3,24 +3,29 @@ module driveclone
|
||||
go 1.25.0
|
||||
|
||||
require (
|
||||
github.com/glebarez/sqlite v1.11.0
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1
|
||||
github.com/gorilla/mux v1.8.1
|
||||
github.com/hirochachacha/go-smb2 v1.1.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
goftp.io/server/v2 v2.0.2
|
||||
golang.org/x/crypto v0.48.0
|
||||
modernc.org/sqlite v1.46.1
|
||||
gorm.io/gorm v1.31.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/geoffgarside/ber v1.1.0 // indirect
|
||||
github.com/glebarez/go-sqlite v1.21.2 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
modernc.org/libc v1.67.6 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
modernc.org/sqlite v1.23.1 // indirect
|
||||
)
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/geoffgarside/ber v1.1.0 h1:qTmFG4jJbwiSzSXoNJeHcOprVzZ8Ulde2Rrrifu5U9w=
|
||||
github.com/geoffgarside/ber v1.1.0/go.mod h1:jVPKeCbj6MvQZhwLYsGwaGI52oUorHoHKNecGT85ZCc=
|
||||
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
|
||||
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
|
||||
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
|
||||
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
||||
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/hirochachacha/go-smb2 v1.1.0 h1:b6hs9qKIql9eVXAiN0M2wSFY5xnhbHAQoCwRKbaRTZI=
|
||||
github.com/hirochachacha/go-smb2 v1.1.0/go.mod h1:8F1A4d5EZzrGu5R7PU163UcMRDJQl4FtcxjBfsY8TZE=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
@@ -22,25 +26,25 @@ github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOF
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
goftp.io/server/v2 v2.0.2 h1:tkZpqyXys+vC15W5yGMi8Kzmbv1QSgeKr8qJXBnJbm8=
|
||||
goftp.io/server/v2 v2.0.2/go.mod h1:Fl1WdcV7fx1pjOWx7jEHb7tsJ8VwE7+xHu6bVJ6r2qg=
|
||||
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
|
||||
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
|
||||
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
|
||||
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
|
||||
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
|
||||
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
|
||||
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
|
||||
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
||||
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
|
||||
@@ -63,8 +67,8 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU=
|
||||
modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
|
||||
modernc.org/sqlite v1.23.1 h1:nrSBg4aRQQwq59JpvGEQ15tNxoO5pX/kUjcRNwSAGQM=
|
||||
modernc.org/sqlite v1.23.1/go.mod h1:OrDj17Mggn6MhE+iPbBNf7RGKODDE9NFT0f3EwDzJqk=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
|
||||
429
backend/main.go
429
backend/main.go
@@ -33,48 +33,16 @@ import (
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/hirochachacha/go-smb2"
|
||||
"github.com/joho/godotenv"
|
||||
"golang.org/x/crypto/argon2"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Addr string
|
||||
DBPath string
|
||||
StorageRoot string
|
||||
AppDomain string
|
||||
AllowedHost string
|
||||
CORSOrigin string
|
||||
CookieSecure bool
|
||||
MaxBodyBytes int64
|
||||
|
||||
RateLimitPerMin int
|
||||
AuthRateLimitPerMin int
|
||||
|
||||
JWTSecret string
|
||||
AccessTTL time.Duration
|
||||
RefreshTTL time.Duration
|
||||
ShareDefaultTTL time.Duration
|
||||
AdminSessionTTL time.Duration
|
||||
AdminLogin string
|
||||
AdminPasswordHash string
|
||||
StorageBackend string
|
||||
SMBHost string
|
||||
SMBShare string
|
||||
SMBUser string
|
||||
SMBPass string
|
||||
SMBDomain string
|
||||
SMBBasePath string
|
||||
SMBConnectTimout time.Duration
|
||||
}
|
||||
|
||||
//go:embed web/dist
|
||||
var embeddedWeb embed.FS
|
||||
|
||||
type Server struct {
|
||||
db *sql.DB
|
||||
orm *ormRepo
|
||||
config Config
|
||||
storage Storage
|
||||
limiter *rateLimiter
|
||||
@@ -198,7 +166,16 @@ func main() {
|
||||
log.Fatalf("storage init failed: %v", err)
|
||||
}
|
||||
|
||||
s := &Server{db: db, config: cfg, storage: storage, limiter: newRateLimiter()}
|
||||
orm, err := newORMRepo(cfg.DBPath)
|
||||
if err != nil {
|
||||
log.Fatalf("orm init failed: %v", err)
|
||||
}
|
||||
|
||||
if err := startProtocolServers(cfg, db); err != nil {
|
||||
log.Fatalf("protocol init failed: %v", err)
|
||||
}
|
||||
|
||||
s := &Server{db: db, orm: orm, config: cfg, storage: storage, limiter: newRateLimiter()}
|
||||
r := mux.NewRouter()
|
||||
r.Use(s.recoverMiddleware)
|
||||
r.Use(s.securityHeadersMiddleware)
|
||||
@@ -212,6 +189,8 @@ func main() {
|
||||
|
||||
r.HandleFunc("/api/auth/register", s.handleRegisterDisabled).Methods(http.MethodPost)
|
||||
r.HandleFunc("/api/auth/login", s.handleLogin).Methods(http.MethodPost)
|
||||
r.HandleFunc("/api/auth/google/start", s.handleGoogleAuthStart).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/auth/google/callback", s.handleGoogleAuthCallback).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/auth/refresh", s.handleRefresh).Methods(http.MethodPost)
|
||||
r.HandleFunc("/api/auth/logout", s.handleLogout).Methods(http.MethodPost)
|
||||
|
||||
@@ -224,6 +203,7 @@ func main() {
|
||||
protected.Use(s.authMiddleware)
|
||||
protected.HandleFunc("/auth/me", s.handleMe).Methods(http.MethodGet)
|
||||
protected.HandleFunc("/user/preferences", s.handleSetPreferences).Methods(http.MethodPost)
|
||||
protected.HandleFunc("/user/protocols", s.handleUserProtocols).Methods(http.MethodGet)
|
||||
protected.HandleFunc("/files", s.handleListFiles).Methods(http.MethodGet)
|
||||
protected.HandleFunc("/files/upload", s.handleUpload).Methods(http.MethodPost)
|
||||
protected.HandleFunc("/files/download", s.handleDownload).Methods(http.MethodGet, http.MethodHead)
|
||||
@@ -232,6 +212,7 @@ func main() {
|
||||
protected.HandleFunc("/files/preview", s.handlePreview).Methods(http.MethodGet, http.MethodHead)
|
||||
protected.HandleFunc("/files/text", s.handleReadTextFile).Methods(http.MethodGet)
|
||||
protected.HandleFunc("/files/text", s.handleWriteTextFile).Methods(http.MethodPut)
|
||||
protected.HandleFunc("/files/rename", s.handleRename).Methods(http.MethodPost)
|
||||
protected.HandleFunc("/files", s.handleDelete).Methods(http.MethodDelete)
|
||||
protected.HandleFunc("/files/folder", s.handleCreateFolder).Methods(http.MethodPost)
|
||||
protected.HandleFunc("/files/share", s.handleCreateShareLink).Methods(http.MethodPost)
|
||||
@@ -260,62 +241,6 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
func loadConfig() Config {
|
||||
_ = godotenv.Load(".env", "../.env")
|
||||
|
||||
dbPath := getEnv("DB_PATH", "./app.db")
|
||||
storageRoot := getEnv("STORAGE_ROOT", "./users")
|
||||
if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil {
|
||||
log.Fatalf("failed to create db path parent: %v", err)
|
||||
}
|
||||
if err := os.MkdirAll(storageRoot, 0o755); err != nil {
|
||||
log.Fatalf("failed to create storage root: %v", err)
|
||||
}
|
||||
|
||||
cfg := Config{
|
||||
Addr: getEnv("ADDR", ":8080"),
|
||||
DBPath: dbPath,
|
||||
StorageRoot: storageRoot,
|
||||
AppDomain: strings.ToLower(strings.TrimSpace(getEnv("APP_DOMAIN", "file.example.com"))),
|
||||
AllowedHost: strings.ToLower(getEnv("ALLOWED_HOST", "")),
|
||||
CORSOrigin: getEnv("CORS_ALLOWED_ORIGIN", ""),
|
||||
CookieSecure: getEnv("COOKIE_SECURE", "false") == "true",
|
||||
MaxBodyBytes: int64(getEnvInt("MAX_BODY_MB", 8)) * 1024 * 1024,
|
||||
RateLimitPerMin: getEnvInt("RATE_LIMIT_PER_MIN", 240),
|
||||
AuthRateLimitPerMin: getEnvInt("AUTH_RATE_LIMIT_PER_MIN", 30),
|
||||
JWTSecret: getEnv("JWT_SECRET", "dev-change-me-immediately"),
|
||||
AccessTTL: 15 * time.Minute,
|
||||
RefreshTTL: 30 * 24 * time.Hour,
|
||||
ShareDefaultTTL: 24 * time.Hour,
|
||||
AdminSessionTTL: 12 * time.Hour,
|
||||
AdminLogin: getEnv("ADMIN_LOGIN", "admin"),
|
||||
AdminPasswordHash: getEnv("ADMIN_PASSWORD_HASH", ""),
|
||||
StorageBackend: getEnv("STORAGE_BACKEND", "local"),
|
||||
SMBHost: getEnv("SMB_HOST", ""),
|
||||
SMBShare: getEnv("SMB_SHARE", ""),
|
||||
SMBUser: getEnv("SMB_USER", ""),
|
||||
SMBPass: getEnv("SMB_PASS", ""),
|
||||
SMBDomain: getEnv("SMB_DOMAIN", ""),
|
||||
SMBBasePath: getEnv("SMB_BASE_PATH", "driveflow"),
|
||||
SMBConnectTimout: 5 * time.Second,
|
||||
}
|
||||
|
||||
if cfg.AllowedHost == "" {
|
||||
cfg.AllowedHost = cfg.AppDomain
|
||||
}
|
||||
if cfg.CORSOrigin == "" {
|
||||
cfg.CORSOrigin = "https://" + cfg.AppDomain
|
||||
}
|
||||
if cfg.JWTSecret == "dev-change-me-immediately" {
|
||||
log.Println("warning: JWT_SECRET is using default development value")
|
||||
}
|
||||
if strings.TrimSpace(cfg.AdminPasswordHash) == "" {
|
||||
log.Fatal("ADMIN_PASSWORD_HASH is required. Generate one with: go run . hash-admin <password>")
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func openDB(path string) (*sql.DB, error) {
|
||||
db, err := sql.Open("sqlite", path)
|
||||
if err != nil {
|
||||
@@ -392,18 +317,19 @@ func migrate(db *sql.DB) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := db.Exec(`ALTER TABLE users ADD COLUMN google_sub TEXT`); err != nil {
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "duplicate column") {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_users_google_sub ON users(google_sub) WHERE google_sub IS NOT NULL`); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildStorage(cfg Config) (Storage, error) {
|
||||
if strings.EqualFold(cfg.StorageBackend, "smb") {
|
||||
if cfg.SMBHost == "" || cfg.SMBShare == "" || cfg.SMBUser == "" {
|
||||
return nil, fmt.Errorf("SMB_HOST, SMB_SHARE, SMB_USER must be set for smb backend")
|
||||
}
|
||||
return &SMBStorage{cfg: cfg}, nil
|
||||
}
|
||||
|
||||
root := cfg.StorageRoot
|
||||
if err := os.MkdirAll(root, 0o755); err != nil {
|
||||
return nil, err
|
||||
@@ -1686,6 +1612,67 @@ func (s *Server) handleCreateFolder(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusCreated, map[string]string{"status": "created"})
|
||||
}
|
||||
|
||||
func (s *Server) handleRename(w http.ResponseWriter, r *http.Request) {
|
||||
uid := userIDFromContext(r.Context())
|
||||
var in struct {
|
||||
Path string `json:"path"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "invalid payload")
|
||||
return
|
||||
}
|
||||
|
||||
src := normalizePath(in.Path)
|
||||
if src == "/" {
|
||||
writeErr(w, http.StatusBadRequest, "invalid path")
|
||||
return
|
||||
}
|
||||
meta, err := s.storage.Stat(uid, src)
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusBadRequest, "file not found")
|
||||
return
|
||||
}
|
||||
|
||||
name := path.Base(strings.TrimSpace(in.Name))
|
||||
if name == "" || name == "." || name == ".." {
|
||||
writeErr(w, http.StatusBadRequest, "invalid name")
|
||||
return
|
||||
}
|
||||
|
||||
dir := path.Dir(src)
|
||||
if dir == "." {
|
||||
dir = "/"
|
||||
}
|
||||
dst := normalizePath(path.Join(dir, name))
|
||||
if dst == src {
|
||||
writeJSON(w, http.StatusOK, map[string]any{"status": "renamed", "path": src})
|
||||
return
|
||||
}
|
||||
if _, err := s.storage.Stat(uid, dst); err == nil {
|
||||
writeErr(w, http.StatusBadRequest, "target already exists")
|
||||
return
|
||||
}
|
||||
if meta.IsDir {
|
||||
srcPrefix := strings.TrimSuffix(src, "/") + "/"
|
||||
if dst == src || strings.HasPrefix(dst, srcPrefix) {
|
||||
writeErr(w, http.StatusBadRequest, "cannot rename folder into itself")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.copyPath(uid, src, dst); err != nil {
|
||||
writeErr(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
if err := s.storage.Delete(uid, src); err != nil {
|
||||
writeErr(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
s.moveTags(uid, src, dst)
|
||||
writeJSON(w, http.StatusOK, map[string]any{"status": "renamed", "path": dst})
|
||||
}
|
||||
|
||||
type shareInput struct {
|
||||
Path string `json:"path"`
|
||||
ExpiresMinutes int `json:"expiresMinutes"`
|
||||
@@ -1804,23 +1791,11 @@ func (s *Server) handleAdminMe(w http.ResponseWriter, _ *http.Request) {
|
||||
}
|
||||
|
||||
func (s *Server) handleAdminUsersList(w http.ResponseWriter, _ *http.Request) {
|
||||
rows, err := s.db.Query(`SELECT id, email, theme, color_mode, archive_format FROM users ORDER BY id ASC`)
|
||||
users, err := s.orm.listUsers()
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusInternalServerError, "failed to load users")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
users := make([]User, 0)
|
||||
for rows.Next() {
|
||||
var u User
|
||||
if err := rows.Scan(&u.ID, &u.Username, &u.Theme, &u.ColorMode, &u.Archive); err != nil {
|
||||
continue
|
||||
}
|
||||
u.Theme = normalizeTheme(u.Theme)
|
||||
u.ColorMode = normalizeColorMode(u.ColorMode)
|
||||
users = append(users, u)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]any{"users": users})
|
||||
}
|
||||
@@ -1974,15 +1949,13 @@ func (s *Server) createUser(username, password, theme, colorMode string) (User,
|
||||
return User{}, fmt.Errorf("failed to hash password")
|
||||
}
|
||||
|
||||
res, err := s.db.Exec(`INSERT INTO users(email, password_hash, theme, color_mode, archive_format) VALUES (?, ?, ?, ?, ?)`, username, hash, normalizeTheme(theme), normalizeColorMode(colorMode), "zip")
|
||||
id, err := s.orm.createUser(username, hash, normalizeTheme(theme), normalizeColorMode(colorMode), "zip", nil)
|
||||
if err != nil {
|
||||
if strings.Contains(strings.ToLower(err.Error()), "unique") {
|
||||
return User{}, fmt.Errorf("account already exists")
|
||||
}
|
||||
return User{}, err
|
||||
}
|
||||
|
||||
id, _ := res.LastInsertId()
|
||||
if err := s.storage.Mkdir(id, "/"); err != nil {
|
||||
return User{}, fmt.Errorf("failed to provision user storage: %w", err)
|
||||
}
|
||||
@@ -1991,26 +1964,11 @@ func (s *Server) createUser(username, password, theme, colorMode string) (User,
|
||||
}
|
||||
|
||||
func (s *Server) findUserWithHash(username string) (User, string, error) {
|
||||
username = strings.ToLower(strings.TrimSpace(username))
|
||||
var u User
|
||||
var hash string
|
||||
err := s.db.QueryRow(`SELECT id, email, password_hash, theme, color_mode, archive_format FROM users WHERE email = ?`, username).
|
||||
Scan(&u.ID, &u.Username, &hash, &u.Theme, &u.ColorMode, &u.Archive)
|
||||
if err != nil {
|
||||
return User{}, "", err
|
||||
}
|
||||
u.Theme = normalizeTheme(u.Theme)
|
||||
u.ColorMode = normalizeColorMode(u.ColorMode)
|
||||
return u, hash, nil
|
||||
return s.orm.findUserWithHashByEmail(username)
|
||||
}
|
||||
|
||||
func (s *Server) findUser(id int64) (User, error) {
|
||||
var u User
|
||||
err := s.db.QueryRow(`SELECT id, email, theme, color_mode, archive_format FROM users WHERE id = ?`, id).
|
||||
Scan(&u.ID, &u.Username, &u.Theme, &u.ColorMode, &u.Archive)
|
||||
u.Theme = normalizeTheme(u.Theme)
|
||||
u.ColorMode = normalizeColorMode(u.ColorMode)
|
||||
return u, err
|
||||
return s.orm.findUserByID(id)
|
||||
}
|
||||
|
||||
func (s *Server) issueUserSession(w http.ResponseWriter, r *http.Request, userID int64) error {
|
||||
@@ -2290,197 +2248,6 @@ func (l *LocalStorage) OpenReadSeeker(userID int64, rel string) (ReadSeekCloser,
|
||||
return os.Open(full)
|
||||
}
|
||||
|
||||
type SMBStorage struct {
|
||||
cfg Config
|
||||
}
|
||||
|
||||
type smbConn struct {
|
||||
conn net.Conn
|
||||
session *smb2.Session
|
||||
share *smb2.Share
|
||||
}
|
||||
|
||||
func (c *smbConn) Close() {
|
||||
if c.share != nil {
|
||||
_ = c.share.Umount()
|
||||
}
|
||||
if c.session != nil {
|
||||
_ = c.session.Logoff()
|
||||
}
|
||||
if c.conn != nil {
|
||||
_ = c.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
type smbReadSeekCloser struct {
|
||||
file *smb2.File
|
||||
conn *smbConn
|
||||
}
|
||||
|
||||
func (s *smbReadSeekCloser) Read(p []byte) (int, error) {
|
||||
return s.file.Read(p)
|
||||
}
|
||||
|
||||
func (s *smbReadSeekCloser) Seek(offset int64, whence int) (int64, error) {
|
||||
return s.file.Seek(offset, whence)
|
||||
}
|
||||
|
||||
func (s *smbReadSeekCloser) Close() error {
|
||||
_ = s.file.Close()
|
||||
s.conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (smbs *SMBStorage) openConnection() (*smbConn, error) {
|
||||
conn, err := net.DialTimeout("tcp", smbs.cfg.SMBHost, smbs.cfg.SMBConnectTimout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialer := &smb2.Dialer{Initiator: &smb2.NTLMInitiator{
|
||||
User: smbs.cfg.SMBUser,
|
||||
Password: smbs.cfg.SMBPass,
|
||||
Domain: smbs.cfg.SMBDomain,
|
||||
}}
|
||||
|
||||
session, err := dialer.Dial(conn)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
share, err := session.Mount(smbs.cfg.SMBShare)
|
||||
if err != nil {
|
||||
_ = session.Logoff()
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &smbConn{conn: conn, session: session, share: share}, nil
|
||||
}
|
||||
|
||||
func (smbs *SMBStorage) withShare(fn func(share *smb2.Share) error) error {
|
||||
conn, err := smbs.openConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
return fn(conn.share)
|
||||
}
|
||||
|
||||
func (smbs *SMBStorage) userPath(userID int64, rel string) string {
|
||||
base := path.Join(smbs.cfg.SMBBasePath, strconv.FormatInt(userID, 10))
|
||||
clean := strings.TrimPrefix(normalizePath(rel), "/")
|
||||
if clean == "" {
|
||||
return base
|
||||
}
|
||||
return path.Join(base, clean)
|
||||
}
|
||||
|
||||
func (smbs *SMBStorage) List(userID int64, rel string) ([]FileEntry, error) {
|
||||
out := make([]FileEntry, 0)
|
||||
err := smbs.withShare(func(share *smb2.Share) error {
|
||||
target := smbs.userPath(userID, rel)
|
||||
if err := share.MkdirAll(target, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
entries, err := share.ReadDir(target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, e := range entries {
|
||||
out = append(out, FileEntry{
|
||||
Name: e.Name(),
|
||||
Path: path.Join(normalizePath(rel), e.Name()),
|
||||
IsDir: e.IsDir(),
|
||||
Size: e.Size(),
|
||||
ModTime: e.ModTime(),
|
||||
})
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return out, err
|
||||
}
|
||||
|
||||
func (smbs *SMBStorage) Mkdir(userID int64, rel string) error {
|
||||
return smbs.withShare(func(share *smb2.Share) error {
|
||||
return share.MkdirAll(smbs.userPath(userID, rel), 0o755)
|
||||
})
|
||||
}
|
||||
|
||||
func (smbs *SMBStorage) Save(userID int64, rel string, src multipart.File) error {
|
||||
return smbs.withShare(func(share *smb2.Share) error {
|
||||
target := smbs.userPath(userID, rel)
|
||||
if err := share.MkdirAll(path.Dir(target), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
f, err := share.Create(target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
_, err = io.Copy(f, src)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (smbs *SMBStorage) SaveBytes(userID int64, rel string, data []byte) error {
|
||||
return smbs.withShare(func(share *smb2.Share) error {
|
||||
target := smbs.userPath(userID, rel)
|
||||
if err := share.MkdirAll(path.Dir(target), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
f, err := share.Create(target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
_, err = f.Write(data)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (smbs *SMBStorage) Delete(userID int64, rel string) error {
|
||||
return smbs.withShare(func(share *smb2.Share) error {
|
||||
target := smbs.userPath(userID, rel)
|
||||
if normalizePath(rel) == "/" {
|
||||
return share.RemoveAll(target)
|
||||
}
|
||||
if err := share.Remove(target); err == nil {
|
||||
return nil
|
||||
}
|
||||
return share.RemoveAll(target)
|
||||
})
|
||||
}
|
||||
|
||||
func (smbs *SMBStorage) Stat(userID int64, rel string) (FileMeta, error) {
|
||||
var out FileMeta
|
||||
err := smbs.withShare(func(share *smb2.Share) error {
|
||||
st, err := share.Stat(smbs.userPath(userID, rel))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
out = FileMeta{Name: st.Name(), Size: st.Size(), ModTime: st.ModTime(), IsDir: st.IsDir()}
|
||||
return nil
|
||||
})
|
||||
return out, err
|
||||
}
|
||||
|
||||
func (smbs *SMBStorage) OpenReadSeeker(userID int64, rel string) (ReadSeekCloser, error) {
|
||||
conn, err := smbs.openConnection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := conn.share.Open(smbs.userPath(userID, rel))
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &smbReadSeekCloser{file: f, conn: conn}, nil
|
||||
}
|
||||
|
||||
func normalizePath(rel string) string {
|
||||
clean := path.Clean("/" + strings.TrimSpace(rel))
|
||||
if clean == "." {
|
||||
@@ -2792,26 +2559,6 @@ func writeErr(w http.ResponseWriter, status int, msg string) {
|
||||
writeJSON(w, status, map[string]string{"error": msg})
|
||||
}
|
||||
|
||||
func getEnv(key, fallback string) string {
|
||||
v := strings.TrimSpace(os.Getenv(key))
|
||||
if v == "" {
|
||||
return fallback
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func getEnvInt(key string, fallback int) int {
|
||||
v := strings.TrimSpace(os.Getenv(key))
|
||||
if v == "" {
|
||||
return fallback
|
||||
}
|
||||
n, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func limitString(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
|
||||
247
backend/oauth_google.go
Normal file
247
backend/oauth_google.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const googleOAuthStateCookie = "google_oauth_state"
|
||||
|
||||
type googleTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
|
||||
type googleUserInfo struct {
|
||||
Sub string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
}
|
||||
|
||||
func (s *Server) handleGoogleAuthStart(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.config.GoogleAuthEnabled {
|
||||
writeErr(w, http.StatusNotFound, "google auth is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
state, err := randomToken()
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusInternalServerError, "failed to initialize oauth")
|
||||
return
|
||||
}
|
||||
setCookie(w, googleOAuthStateCookie, state, 600, s.config.CookieSecure)
|
||||
|
||||
u, err := url.Parse(s.config.GoogleAuthURL)
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusInternalServerError, "invalid google auth config")
|
||||
return
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("client_id", strings.TrimSpace(s.config.GoogleClientID))
|
||||
q.Set("redirect_uri", s.googleRedirectURL(r))
|
||||
q.Set("response_type", "code")
|
||||
q.Set("scope", "openid email profile")
|
||||
q.Set("state", state)
|
||||
q.Set("prompt", "select_account")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
http.Redirect(w, r, u.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
func (s *Server) handleGoogleAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.config.GoogleAuthEnabled {
|
||||
writeErr(w, http.StatusNotFound, "google auth is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
if oauthErr := strings.TrimSpace(r.URL.Query().Get("error")); oauthErr != "" {
|
||||
writeErr(w, http.StatusUnauthorized, "google login was denied")
|
||||
return
|
||||
}
|
||||
|
||||
code := strings.TrimSpace(r.URL.Query().Get("code"))
|
||||
if code == "" {
|
||||
writeErr(w, http.StatusBadRequest, "missing oauth code")
|
||||
return
|
||||
}
|
||||
|
||||
state := strings.TrimSpace(r.URL.Query().Get("state"))
|
||||
stateCookie, err := r.Cookie(googleOAuthStateCookie)
|
||||
if err != nil || stateCookie == nil || stateCookie.Value == "" || subtleConstantTimeEq(stateCookie.Value, state) == 0 {
|
||||
writeErr(w, http.StatusUnauthorized, "invalid oauth state")
|
||||
return
|
||||
}
|
||||
clearCookie(w, googleOAuthStateCookie, s.config.CookieSecure)
|
||||
|
||||
token, err := s.exchangeGoogleCode(r.Context(), code, s.googleRedirectURL(r))
|
||||
if err != nil {
|
||||
log.Printf("auth.google.failed ip=%q reason=%q", clientIP(r), "token_exchange_failed")
|
||||
writeErr(w, http.StatusUnauthorized, "google auth failed")
|
||||
return
|
||||
}
|
||||
|
||||
info, err := s.fetchGoogleUserInfo(r.Context(), token)
|
||||
if err != nil {
|
||||
log.Printf("auth.google.failed ip=%q reason=%q", clientIP(r), "userinfo_failed")
|
||||
writeErr(w, http.StatusUnauthorized, "google auth failed")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := s.findOrCreateGoogleUser(info.Sub, info.Email)
|
||||
if err != nil {
|
||||
log.Printf("auth.google.failed ip=%q reason=%q", clientIP(r), "user_provision_failed")
|
||||
writeErr(w, http.StatusUnauthorized, "google auth failed")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.issueUserSession(w, r, user.ID); err != nil {
|
||||
log.Printf("auth.google.failed ip=%q reason=%q", clientIP(r), "session_issue_failed")
|
||||
writeErr(w, http.StatusInternalServerError, "failed to create session")
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("auth.google.success user_id=%d username=%q ip=%q", user.ID, user.Username, clientIP(r))
|
||||
http.Redirect(w, r, "/drive", http.StatusFound)
|
||||
}
|
||||
|
||||
func (s *Server) googleRedirectURL(r *http.Request) string {
|
||||
if v := strings.TrimSpace(s.config.GoogleRedirectURL); v != "" {
|
||||
return v
|
||||
}
|
||||
return fmt.Sprintf("%s://%s/api/auth/google/callback", schemeOf(r), r.Host)
|
||||
}
|
||||
|
||||
func (s *Server) exchangeGoogleCode(ctx context.Context, code, redirectURI string) (string, error) {
|
||||
values := url.Values{}
|
||||
values.Set("code", code)
|
||||
values.Set("client_id", strings.TrimSpace(s.config.GoogleClientID))
|
||||
values.Set("client_secret", s.config.GoogleClientSecret)
|
||||
values.Set("redirect_uri", redirectURI)
|
||||
values.Set("grant_type", "authorization_code")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.config.GoogleTokenURL, strings.NewReader(values.Encode()))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("google token endpoint returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var out googleTokenResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.TrimSpace(out.AccessToken) == "" {
|
||||
return "", fmt.Errorf("google token response missing access_token")
|
||||
}
|
||||
return out.AccessToken, nil
|
||||
}
|
||||
|
||||
func (s *Server) fetchGoogleUserInfo(ctx context.Context, accessToken string) (googleUserInfo, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.config.GoogleUserInfoURL, nil)
|
||||
if err != nil {
|
||||
return googleUserInfo{}, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return googleUserInfo{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return googleUserInfo{}, fmt.Errorf("google userinfo endpoint returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var out googleUserInfo
|
||||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||||
return googleUserInfo{}, err
|
||||
}
|
||||
out.Sub = strings.TrimSpace(out.Sub)
|
||||
out.Email = strings.ToLower(strings.TrimSpace(out.Email))
|
||||
if out.Sub == "" || out.Email == "" || !out.EmailVerified {
|
||||
return googleUserInfo{}, fmt.Errorf("google account data is incomplete")
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *Server) findOrCreateGoogleUser(googleSub, email string) (User, error) {
|
||||
googleSub = strings.TrimSpace(googleSub)
|
||||
email = strings.ToLower(strings.TrimSpace(email))
|
||||
if googleSub == "" || email == "" {
|
||||
return User{}, fmt.Errorf("invalid google identity")
|
||||
}
|
||||
|
||||
if user, err := s.findUserByGoogleSub(googleSub); err == nil {
|
||||
return user, nil
|
||||
} else if !isNoRows(err) {
|
||||
return User{}, err
|
||||
}
|
||||
|
||||
if user, err := s.findUserByEmail(email); err == nil {
|
||||
if err := s.orm.updateGoogleSub(user.ID, googleSub); err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
return user, nil
|
||||
} else if !isNoRows(err) {
|
||||
return User{}, err
|
||||
}
|
||||
|
||||
hash, err := hashPasswordArgon2ID(mustRandomPassword())
|
||||
if err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
|
||||
id, err := s.orm.createUser(email, hash, "dracula", "auto", "zip", &googleSub)
|
||||
if err != nil {
|
||||
if strings.Contains(strings.ToLower(err.Error()), "unique") {
|
||||
if user, findErr := s.findUserByEmail(email); findErr == nil {
|
||||
if linkErr := s.orm.updateGoogleSub(user.ID, googleSub); linkErr != nil {
|
||||
return User{}, linkErr
|
||||
}
|
||||
return s.findUser(user.ID)
|
||||
}
|
||||
}
|
||||
return User{}, err
|
||||
}
|
||||
if err := s.storage.Mkdir(id, "/"); err != nil {
|
||||
return User{}, fmt.Errorf("failed to provision user storage: %w", err)
|
||||
}
|
||||
return User{ID: id, Username: email, Theme: "dracula", ColorMode: "auto", Archive: "zip"}, nil
|
||||
}
|
||||
|
||||
func (s *Server) findUserByGoogleSub(googleSub string) (User, error) {
|
||||
return s.orm.findUserByGoogleSub(googleSub)
|
||||
}
|
||||
|
||||
func (s *Server) findUserByEmail(email string) (User, error) {
|
||||
return s.orm.findUserByEmail(email)
|
||||
}
|
||||
|
||||
func mustRandomPassword() string {
|
||||
tok, err := randomToken()
|
||||
if err != nil {
|
||||
return "google-oauth-password-fallback"
|
||||
}
|
||||
if len(tok) < 16 {
|
||||
return tok + "-google-oauth"
|
||||
}
|
||||
return tok
|
||||
}
|
||||
|
||||
func isNoRows(err error) bool {
|
||||
return isORMNotFound(err)
|
||||
}
|
||||
108
backend/oauth_google_test.go
Normal file
108
backend/oauth_google_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGoogleOAuthCallbackCreatesSessionAndUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
provider := httptest.NewServer(mux)
|
||||
defer provider.Close()
|
||||
|
||||
mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Fatalf("parse token form failed: %v", err)
|
||||
}
|
||||
if got := r.Form.Get("code"); got != "ok-code" {
|
||||
t.Fatalf("code = %q, want %q", got, "ok-code")
|
||||
}
|
||||
if got := r.Form.Get("client_id"); got != "client-id" {
|
||||
t.Fatalf("client_id = %q, want %q", got, "client-id")
|
||||
}
|
||||
if got := r.Form.Get("client_secret"); got != "client-secret" {
|
||||
t.Fatalf("client_secret = %q, want %q", got, "client-secret")
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"access_token": "google-access-token"})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer google-access-token" {
|
||||
t.Fatalf("authorization = %q, want bearer token", got)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"sub": "google-sub-1",
|
||||
"email": "alice@example.com",
|
||||
"email_verified": true,
|
||||
})
|
||||
})
|
||||
|
||||
s := makeTestServer(t, func(cfg *Config) {
|
||||
cfg.GoogleAuthEnabled = true
|
||||
cfg.GoogleClientID = "client-id"
|
||||
cfg.GoogleClientSecret = "client-secret"
|
||||
cfg.GoogleAuthURL = provider.URL + "/auth"
|
||||
cfg.GoogleTokenURL = provider.URL + "/token"
|
||||
cfg.GoogleUserInfoURL = provider.URL + "/userinfo"
|
||||
})
|
||||
|
||||
startReq := httptest.NewRequest(http.MethodGet, "/api/auth/google/start", nil)
|
||||
startReq.Host = "file.example.com"
|
||||
startRec := httptest.NewRecorder()
|
||||
s.handleGoogleAuthStart(startRec, startReq)
|
||||
|
||||
if startRec.Code != http.StatusFound {
|
||||
t.Fatalf("start status = %d, want %d", startRec.Code, http.StatusFound)
|
||||
}
|
||||
stateCookie := cookieByName(startRec.Result().Cookies(), googleOAuthStateCookie)
|
||||
if stateCookie == nil || stateCookie.Value == "" {
|
||||
t.Fatal("missing oauth state cookie")
|
||||
}
|
||||
|
||||
redir := startRec.Result().Header.Get("Location")
|
||||
parsed, err := url.Parse(redir)
|
||||
if err != nil {
|
||||
t.Fatalf("parse redirect url failed: %v", err)
|
||||
}
|
||||
if parsed.Query().Get("state") != stateCookie.Value {
|
||||
t.Fatalf("redirect state mismatch")
|
||||
}
|
||||
|
||||
cbReq := httptest.NewRequest(http.MethodGet, "/api/auth/google/callback?code=ok-code&state="+url.QueryEscape(stateCookie.Value), nil)
|
||||
cbReq.Host = "file.example.com"
|
||||
cbReq.AddCookie(stateCookie)
|
||||
cbRec := httptest.NewRecorder()
|
||||
s.handleGoogleAuthCallback(cbRec, cbReq)
|
||||
|
||||
if cbRec.Code != http.StatusFound {
|
||||
t.Fatalf("callback status = %d, want %d", cbRec.Code, http.StatusFound)
|
||||
}
|
||||
if got := cbRec.Header().Get("Location"); got != "/drive" {
|
||||
t.Fatalf("callback redirect = %q, want %q", got, "/drive")
|
||||
}
|
||||
if cookieByName(cbRec.Result().Cookies(), "access_token") == nil {
|
||||
t.Fatal("callback missing access_token cookie")
|
||||
}
|
||||
if cookieByName(cbRec.Result().Cookies(), "refresh_token") == nil {
|
||||
t.Fatal("callback missing refresh_token cookie")
|
||||
}
|
||||
|
||||
var count int
|
||||
var googleSub string
|
||||
err = s.db.QueryRow(`SELECT COUNT(*), COALESCE(MAX(google_sub), '') FROM users WHERE email = ?`, "alice@example.com").Scan(&count, &googleSub)
|
||||
if err != nil {
|
||||
t.Fatalf("query user failed: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatalf("users with google email = %d, want 1", count)
|
||||
}
|
||||
if strings.TrimSpace(googleSub) != "google-sub-1" {
|
||||
t.Fatalf("google_sub = %q, want %q", googleSub, "google-sub-1")
|
||||
}
|
||||
}
|
||||
122
backend/orm.go
Normal file
122
backend/orm.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ormUser struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
Email string `gorm:"column:email"`
|
||||
PasswordHash string `gorm:"column:password_hash"`
|
||||
Theme string `gorm:"column:theme"`
|
||||
ColorMode string `gorm:"column:color_mode"`
|
||||
Archive string `gorm:"column:archive_format"`
|
||||
GoogleSub *string `gorm:"column:google_sub"`
|
||||
}
|
||||
|
||||
func (ormUser) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
|
||||
type ormRepo struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func newORMRepo(dbPath string) (*ormRepo, error) {
|
||||
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ormRepo{db: db}, nil
|
||||
}
|
||||
|
||||
func (o *ormRepo) listUsers() ([]User, error) {
|
||||
var rows []ormUser
|
||||
if err := o.db.Order("id asc").Find(&rows).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users := make([]User, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
users = append(users, row.toPublicUser())
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (o *ormRepo) createUser(email, passwordHash, theme, colorMode, archive string, googleSub *string) (int64, error) {
|
||||
row := ormUser{
|
||||
Email: strings.ToLower(strings.TrimSpace(email)),
|
||||
PasswordHash: passwordHash,
|
||||
Theme: normalizeTheme(theme),
|
||||
ColorMode: normalizeColorMode(colorMode),
|
||||
Archive: archive,
|
||||
GoogleSub: googleSub,
|
||||
}
|
||||
if err := o.db.Create(&row).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return row.ID, nil
|
||||
}
|
||||
|
||||
func (o *ormRepo) findUserWithHashByEmail(email string) (User, string, error) {
|
||||
var row ormUser
|
||||
err := o.db.Where("email = ?", strings.ToLower(strings.TrimSpace(email))).First(&row).Error
|
||||
if err != nil {
|
||||
return User{}, "", err
|
||||
}
|
||||
user := row.toPublicUser()
|
||||
return user, row.PasswordHash, nil
|
||||
}
|
||||
|
||||
func (o *ormRepo) findUserByID(id int64) (User, error) {
|
||||
var row ormUser
|
||||
err := o.db.Where("id = ?", id).First(&row).Error
|
||||
if err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
return row.toPublicUser(), nil
|
||||
}
|
||||
|
||||
func (o *ormRepo) findUserByEmail(email string) (User, error) {
|
||||
var row ormUser
|
||||
err := o.db.Where("email = ?", strings.ToLower(strings.TrimSpace(email))).First(&row).Error
|
||||
if err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
return row.toPublicUser(), nil
|
||||
}
|
||||
|
||||
func (o *ormRepo) findUserByGoogleSub(googleSub string) (User, error) {
|
||||
var row ormUser
|
||||
err := o.db.Where("google_sub = ?", strings.TrimSpace(googleSub)).First(&row).Error
|
||||
if err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
return row.toPublicUser(), nil
|
||||
}
|
||||
|
||||
func (o *ormRepo) updateGoogleSub(userID int64, googleSub string) error {
|
||||
return o.db.Model(&ormUser{}).Where("id = ?", userID).Update("google_sub", strings.TrimSpace(googleSub)).Error
|
||||
}
|
||||
|
||||
func (u ormUser) toPublicUser() User {
|
||||
archive := normalizeArchiveFormat(u.Archive)
|
||||
if archive == "" {
|
||||
archive = "zip"
|
||||
}
|
||||
|
||||
return User{
|
||||
ID: u.ID,
|
||||
Username: strings.ToLower(strings.TrimSpace(u.Email)),
|
||||
Theme: normalizeTheme(u.Theme),
|
||||
ColorMode: normalizeColorMode(u.ColorMode),
|
||||
Archive: archive,
|
||||
}
|
||||
}
|
||||
|
||||
func isORMNotFound(err error) bool {
|
||||
return errors.Is(err, gorm.ErrRecordNotFound)
|
||||
}
|
||||
356
backend/protocol_ftp.go
Normal file
356
backend/protocol_ftp.go
Normal file
@@ -0,0 +1,356 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
ftpserver "goftp.io/server/v2"
|
||||
)
|
||||
|
||||
func startProtocolServers(cfg Config, db *sql.DB) error {
|
||||
if cfg.SFTPEnabled {
|
||||
return fmt.Errorf("SFTP_ENABLED=true is configured, but SFTP server is not implemented")
|
||||
}
|
||||
|
||||
if cfg.FTPEnabled {
|
||||
ftpSrv, err := buildFTPServer(cfg, db)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build ftp server: %w", err)
|
||||
}
|
||||
go runFTPServer("FTP", ftpSrv)
|
||||
}
|
||||
|
||||
if cfg.FTPSEnabled {
|
||||
ftpsSrv, err := buildFTPSServer(cfg, db)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build ftps server: %w", err)
|
||||
}
|
||||
go runFTPServer("FTPS", ftpsSrv)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runFTPServer(name string, srv *ftpserver.Server) {
|
||||
log.Printf("%s listening on %s:%d", name, srv.Hostname, srv.Port)
|
||||
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, ftpserver.ErrServerClosed) {
|
||||
log.Printf("%s server stopped: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
type ftpUserDriver struct {
|
||||
db *sql.DB
|
||||
root string
|
||||
}
|
||||
|
||||
func buildFTPServer(cfg Config, db *sql.DB) (*ftpserver.Server, error) {
|
||||
return buildFileZFTPServer(ftpServerOptions{
|
||||
name: "FileZ FTP",
|
||||
host: cfg.FTPHost,
|
||||
port: cfg.FTPPort,
|
||||
publicIP: cfg.FTPPublicIP,
|
||||
passivePorts: cfg.FTPPassivePorts,
|
||||
tls: false,
|
||||
}, db, cfg.StorageRoot)
|
||||
}
|
||||
|
||||
func buildFTPSServer(cfg Config, db *sql.DB) (*ftpserver.Server, error) {
|
||||
return buildFileZFTPServer(ftpServerOptions{
|
||||
name: "FileZ FTPS",
|
||||
host: cfg.FTPSHost,
|
||||
port: cfg.FTPSPort,
|
||||
publicIP: cfg.FTPSPublicIP,
|
||||
passivePorts: cfg.FTPSPassivePorts,
|
||||
tls: true,
|
||||
certFile: cfg.FTPSCertFile,
|
||||
keyFile: cfg.FTPSKeyFile,
|
||||
explicitFTPS: cfg.FTPSExplicit,
|
||||
forceTLS: cfg.FTPSForceTLS,
|
||||
}, db, cfg.StorageRoot)
|
||||
}
|
||||
|
||||
type ftpServerOptions struct {
|
||||
name string
|
||||
host string
|
||||
port int
|
||||
publicIP string
|
||||
passivePorts string
|
||||
tls bool
|
||||
certFile string
|
||||
keyFile string
|
||||
explicitFTPS bool
|
||||
forceTLS bool
|
||||
}
|
||||
|
||||
func buildFileZFTPServer(opts ftpServerOptions, db *sql.DB, storageRoot string) (*ftpserver.Server, error) {
|
||||
drv := &ftpUserDriver{db: db, root: storageRoot}
|
||||
serverOpts := &ftpserver.Options{
|
||||
Name: opts.name,
|
||||
Hostname: opts.host,
|
||||
Port: opts.port,
|
||||
PublicIP: opts.publicIP,
|
||||
PassivePorts: opts.passivePorts,
|
||||
WelcomeMessage: opts.name + " ready",
|
||||
Driver: drv,
|
||||
Auth: drv,
|
||||
Perm: ftpserver.NewSimplePerm("filez", "filez"),
|
||||
TLS: opts.tls,
|
||||
CertFile: opts.certFile,
|
||||
KeyFile: opts.keyFile,
|
||||
ExplicitFTPS: opts.explicitFTPS,
|
||||
ForceTLS: opts.forceTLS,
|
||||
}
|
||||
return ftpserver.NewServer(serverOpts)
|
||||
}
|
||||
|
||||
func (d *ftpUserDriver) CheckPasswd(ctx *ftpserver.Context, username, password string) (bool, error) {
|
||||
username = strings.TrimSpace(username)
|
||||
if username == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
var userID int64
|
||||
var hash string
|
||||
err := d.db.QueryRow(`SELECT id, password_hash FROM users WHERE email = ?`, username).Scan(&userID, &hash)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
if !verifyPasswordHash(hash, password) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(d.userRoot(userID), 0o755); err != nil {
|
||||
return false, err
|
||||
}
|
||||
d.setUserID(ctx, userID)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (d *ftpUserDriver) userRoot(userID int64) string {
|
||||
return filepath.Join(d.root, strconv.FormatInt(userID, 10))
|
||||
}
|
||||
|
||||
func (d *ftpUserDriver) setUserID(ctx *ftpserver.Context, userID int64) {
|
||||
if ctx == nil || ctx.Sess == nil {
|
||||
return
|
||||
}
|
||||
if ctx.Sess.Data == nil {
|
||||
ctx.Sess.Data = make(map[string]interface{})
|
||||
}
|
||||
ctx.Sess.Data["filez_user_id"] = userID
|
||||
}
|
||||
|
||||
func (d *ftpUserDriver) userIDFromCtx(ctx *ftpserver.Context) (int64, error) {
|
||||
if ctx == nil || ctx.Sess == nil {
|
||||
return 0, fmt.Errorf("missing ftp session")
|
||||
}
|
||||
|
||||
if v, ok := ctx.Sess.Data["filez_user_id"]; ok {
|
||||
switch id := v.(type) {
|
||||
case int64:
|
||||
if id > 0 {
|
||||
return id, nil
|
||||
}
|
||||
case int:
|
||||
if id > 0 {
|
||||
return int64(id), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
username := strings.TrimSpace(ctx.Sess.LoginUser())
|
||||
if username == "" {
|
||||
return 0, fmt.Errorf("missing ftp username")
|
||||
}
|
||||
|
||||
var userID int64
|
||||
if err := d.db.QueryRow(`SELECT id FROM users WHERE email = ?`, username).Scan(&userID); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return 0, fmt.Errorf("user not found")
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
d.setUserID(ctx, userID)
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
func (d *ftpUserDriver) fullPath(ctx *ftpserver.Context, rel string) (string, error) {
|
||||
userID, err := d.userIDFromCtx(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
root := d.userRoot(userID)
|
||||
if err := os.MkdirAll(root, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
clean := filepath.FromSlash(strings.TrimPrefix(normalizePath(rel), "/"))
|
||||
full := filepath.Clean(filepath.Join(root, clean))
|
||||
if full != root && !strings.HasPrefix(full, root+string(os.PathSeparator)) {
|
||||
return "", fmt.Errorf("invalid path")
|
||||
}
|
||||
return full, nil
|
||||
}
|
||||
|
||||
func (d *ftpUserDriver) Stat(ctx *ftpserver.Context, rel string) (os.FileInfo, error) {
|
||||
full, err := d.fullPath(ctx, rel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return os.Stat(full)
|
||||
}
|
||||
|
||||
func (d *ftpUserDriver) ListDir(ctx *ftpserver.Context, rel string, callback func(os.FileInfo) error) error {
|
||||
full, err := d.fullPath(ctx, rel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
entries, err := os.ReadDir(full)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, entry := range entries {
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if err := callback(info); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *ftpUserDriver) DeleteDir(ctx *ftpserver.Context, rel string) error {
|
||||
if normalizePath(rel) == "/" {
|
||||
return fmt.Errorf("cannot remove root directory")
|
||||
}
|
||||
full, err := d.fullPath(ctx, rel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
st, err := os.Stat(full)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !st.IsDir() {
|
||||
return fmt.Errorf("not a directory")
|
||||
}
|
||||
return os.RemoveAll(full)
|
||||
}
|
||||
|
||||
func (d *ftpUserDriver) DeleteFile(ctx *ftpserver.Context, rel string) error {
|
||||
full, err := d.fullPath(ctx, rel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
st, err := os.Stat(full)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if st.IsDir() {
|
||||
return fmt.Errorf("not a file")
|
||||
}
|
||||
return os.Remove(full)
|
||||
}
|
||||
|
||||
func (d *ftpUserDriver) Rename(ctx *ftpserver.Context, fromPath, toPath string) error {
|
||||
if normalizePath(fromPath) == "/" {
|
||||
return fmt.Errorf("cannot rename root directory")
|
||||
}
|
||||
oldPath, err := d.fullPath(ctx, fromPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newPath, err := d.fullPath(ctx, toPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(newPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Rename(oldPath, newPath)
|
||||
}
|
||||
|
||||
func (d *ftpUserDriver) MakeDir(ctx *ftpserver.Context, rel string) error {
|
||||
full, err := d.fullPath(ctx, rel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.MkdirAll(full, 0o755)
|
||||
}
|
||||
|
||||
func (d *ftpUserDriver) GetFile(ctx *ftpserver.Context, rel string, offset int64) (int64, io.ReadCloser, error) {
|
||||
full, err := d.fullPath(ctx, rel)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
f, err := os.Open(full)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = f.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
st, err := f.Stat()
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
if _, err := f.Seek(offset, io.SeekStart); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
sz := st.Size() - offset
|
||||
if sz < 0 {
|
||||
sz = 0
|
||||
}
|
||||
return sz, f, nil
|
||||
}
|
||||
|
||||
func (d *ftpUserDriver) PutFile(ctx *ftpserver.Context, rel string, data io.Reader, offset int64) (int64, error) {
|
||||
full, err := d.fullPath(ctx, rel)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if offset < 0 {
|
||||
f, err := os.Create(full)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
return io.Copy(f, data)
|
||||
}
|
||||
|
||||
flags := os.O_CREATE | os.O_WRONLY
|
||||
if offset == 0 {
|
||||
flags |= os.O_TRUNC
|
||||
}
|
||||
f, err := os.OpenFile(full, flags, 0o644)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
if _, err := f.Seek(offset, io.SeekStart); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return io.Copy(f, data)
|
||||
}
|
||||
71
backend/protocol_info.go
Normal file
71
backend/protocol_info.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type protocolProfile struct {
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Username string `json:"username"`
|
||||
PublicIP string `json:"publicIP,omitempty"`
|
||||
PassivePorts string `json:"passivePorts,omitempty"`
|
||||
ExplicitTLS bool `json:"explicitTLS,omitempty"`
|
||||
ForceTLS bool `json:"forceTLS,omitempty"`
|
||||
}
|
||||
|
||||
type userProtocolsResponse struct {
|
||||
FTP *protocolProfile `json:"ftp,omitempty"`
|
||||
FTPS *protocolProfile `json:"ftps,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Server) handleUserProtocols(w http.ResponseWriter, r *http.Request) {
|
||||
uid := userIDFromContext(r.Context())
|
||||
user, err := s.findUser(uid)
|
||||
if err != nil {
|
||||
writeErr(w, http.StatusNotFound, "user not found")
|
||||
return
|
||||
}
|
||||
|
||||
out := userProtocolsResponse{}
|
||||
if s.config.FTPEnabled {
|
||||
out.FTP = &protocolProfile{
|
||||
Host: protocolHostForClient(s.config.FTPHost, s.config.FTPPublicIP, s.config.AppDomain, r),
|
||||
Port: s.config.FTPPort,
|
||||
Username: user.Username,
|
||||
PublicIP: strings.TrimSpace(s.config.FTPPublicIP),
|
||||
PassivePorts: strings.TrimSpace(s.config.FTPPassivePorts),
|
||||
}
|
||||
}
|
||||
if s.config.FTPSEnabled {
|
||||
out.FTPS = &protocolProfile{
|
||||
Host: protocolHostForClient(s.config.FTPSHost, s.config.FTPSPublicIP, s.config.AppDomain, r),
|
||||
Port: s.config.FTPSPort,
|
||||
Username: user.Username,
|
||||
PublicIP: strings.TrimSpace(s.config.FTPSPublicIP),
|
||||
PassivePorts: strings.TrimSpace(s.config.FTPSPassivePorts),
|
||||
ExplicitTLS: s.config.FTPSExplicit,
|
||||
ForceTLS: s.config.FTPSForceTLS,
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, out)
|
||||
}
|
||||
|
||||
func protocolHostForClient(bindHost, publicIP, appDomain string, r *http.Request) string {
|
||||
if v := strings.TrimSpace(publicIP); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(appDomain); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(hostOnly(r.Host)); v != "" {
|
||||
return v
|
||||
}
|
||||
v := strings.TrimSpace(bindHost)
|
||||
if v == "" || v == "0.0.0.0" || v == "::" || v == "[::]" {
|
||||
return "localhost"
|
||||
}
|
||||
return v
|
||||
}
|
||||
Reference in New Issue
Block a user