Add Google OAuth, German locale, and ORM-backed user access

This commit is contained in:
mixa
2026-03-04 18:51:33 +03:00
parent 2fab944351
commit 1bdeddb2ff
26 changed files with 2488 additions and 583 deletions

249
backend/api_ftp_test.go Normal file
View File

@@ -0,0 +1,249 @@
package main
import (
"bytes"
"database/sql"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"testing"
"time"
ftpserver "goftp.io/server/v2"
)
func makeTestServer(t *testing.T, mutate func(*Config)) *Server {
t.Helper()
root := t.TempDir()
cfg := Config{
Addr: ":0",
DBPath: filepath.Join(root, "app.db"),
StorageRoot: filepath.Join(root, "users"),
AppDomain: "file.example.com",
AllowedHost: "file.example.com",
CORSOrigin: "https://file.example.com",
CookieSecure: false,
MaxBodyBytes: 8 * 1024 * 1024,
RateLimitPerMin: 1000,
AuthRateLimitPerMin: 1000,
JWTSecret: "test-jwt-secret-very-long-value-1234567890",
AccessTTL: 15 * time.Minute,
RefreshTTL: 24 * time.Hour,
ShareDefaultTTL: 24 * time.Hour,
AdminSessionTTL: 12 * time.Hour,
AdminLogin: "admin",
AdminPasswordHash: "sha256:dummy",
}
if mutate != nil {
mutate(&cfg)
}
db, err := openDB(cfg.DBPath)
if err != nil {
t.Fatalf("openDB failed: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
if err := migrate(db); err != nil {
t.Fatalf("migrate failed: %v", err)
}
storage, err := buildStorage(cfg)
if err != nil {
t.Fatalf("buildStorage failed: %v", err)
}
orm, err := newORMRepo(cfg.DBPath)
if err != nil {
t.Fatalf("newORMRepo failed: %v", err)
}
return &Server{db: db, orm: orm, config: cfg, storage: storage, limiter: newRateLimiter()}
}
func decodeJSONBody[T any](t *testing.T, res *http.Response, out *T) {
t.Helper()
defer res.Body.Close()
if err := json.NewDecoder(res.Body).Decode(out); err != nil {
t.Fatalf("decode json failed: %v", err)
}
}
func cookieByName(cookies []*http.Cookie, name string) *http.Cookie {
for _, c := range cookies {
if c.Name == name {
return c
}
}
return nil
}
func TestAPILoginRefreshAndMe(t *testing.T) {
t.Parallel()
s := makeTestServer(t, nil)
user, err := s.createUser("alice", "password123", "dracula", "auto")
if err != nil {
t.Fatalf("createUser failed: %v", err)
}
loginReq := httptest.NewRequest(http.MethodPost, "/api/auth/login", strings.NewReader(`{"username":"alice","password":"password123"}`))
loginReq.Header.Set("Content-Type", "application/json")
loginRec := httptest.NewRecorder()
s.handleLogin(loginRec, loginReq)
loginRes := loginRec.Result()
if loginRes.StatusCode != http.StatusOK {
t.Fatalf("login status = %d, want %d", loginRes.StatusCode, http.StatusOK)
}
if cookieByName(loginRes.Cookies(), "access_token") == nil {
t.Fatal("login did not set access_token cookie")
}
refreshCookie := cookieByName(loginRes.Cookies(), "refresh_token")
if refreshCookie == nil {
t.Fatal("login did not set refresh_token cookie")
}
var meResp User
decodeJSONBody(t, loginRes, &meResp)
if meResp.ID != user.ID || meResp.Username != user.Username {
t.Fatalf("login response user mismatch: got %+v want id=%d username=%q", meResp, user.ID, user.Username)
}
refreshReq := httptest.NewRequest(http.MethodPost, "/api/auth/refresh", nil)
refreshReq.AddCookie(refreshCookie)
refreshRec := httptest.NewRecorder()
s.handleRefresh(refreshRec, refreshReq)
refreshRes := refreshRec.Result()
if refreshRes.StatusCode != http.StatusOK {
t.Fatalf("refresh status = %d, want %d", refreshRes.StatusCode, http.StatusOK)
}
newAccess := cookieByName(refreshRes.Cookies(), "access_token")
if newAccess == nil {
t.Fatal("refresh did not rotate access_token cookie")
}
meReq := httptest.NewRequest(http.MethodGet, "/api/auth/me", nil)
meReq.AddCookie(newAccess)
meRec := httptest.NewRecorder()
s.authMiddleware(http.HandlerFunc(s.handleMe)).ServeHTTP(meRec, meReq)
if meRec.Code != http.StatusOK {
t.Fatalf("/api/auth/me status = %d, want %d", meRec.Code, http.StatusOK)
}
var meAfter User
if err := json.NewDecoder(meRec.Body).Decode(&meAfter); err != nil {
t.Fatalf("decode /api/auth/me failed: %v", err)
}
if meAfter.Username != "alice" {
t.Fatalf("/api/auth/me username = %q, want %q", meAfter.Username, "alice")
}
}
func TestAPIUserProtocolsFTPS(t *testing.T) {
t.Parallel()
s := makeTestServer(t, func(cfg *Config) {
cfg.FTPSEnabled = true
cfg.FTPSHost = "0.0.0.0"
cfg.FTPSPort = 2990
cfg.FTPSPublicIP = "198.51.100.10"
cfg.FTPSExplicit = true
cfg.FTPSForceTLS = true
})
if _, err := s.createUser("bob", "password123", "dracula", "auto"); err != nil {
t.Fatalf("createUser failed: %v", err)
}
loginReq := httptest.NewRequest(http.MethodPost, "/api/auth/login", strings.NewReader(`{"username":"bob","password":"password123"}`))
loginReq.Header.Set("Content-Type", "application/json")
loginRec := httptest.NewRecorder()
s.handleLogin(loginRec, loginReq)
if loginRec.Code != http.StatusOK {
t.Fatalf("login status = %d", loginRec.Code)
}
access := cookieByName(loginRec.Result().Cookies(), "access_token")
if access == nil {
t.Fatal("missing access token cookie")
}
req := httptest.NewRequest(http.MethodGet, "/api/user/protocols", nil)
req.AddCookie(access)
rec := httptest.NewRecorder()
s.authMiddleware(http.HandlerFunc(s.handleUserProtocols)).ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("protocols status = %d, want %d", rec.Code, http.StatusOK)
}
var out userProtocolsResponse
if err := json.NewDecoder(rec.Body).Decode(&out); err != nil {
t.Fatalf("decode response failed: %v", err)
}
if out.FTPS == nil {
t.Fatal("expected FTPS profile in response")
}
if out.FTPS.Username != "bob" {
t.Fatalf("ftps username = %q, want %q", out.FTPS.Username, "bob")
}
if out.FTPS.Host != "198.51.100.10" || out.FTPS.Port != 2990 {
t.Fatalf("ftps endpoint mismatch: got %s:%d", out.FTPS.Host, out.FTPS.Port)
}
if !out.FTPS.ExplicitTLS || !out.FTPS.ForceTLS {
t.Fatal("expected explicit/forced TLS flags to be true")
}
}
func insertUserWithHash(t *testing.T, db *sql.DB, username, hash string) int64 {
t.Helper()
res, err := db.Exec(`INSERT INTO users(email, password_hash, theme, color_mode, archive_format) VALUES (?, ?, 'dracula', 'auto', 'zip')`, username, hash)
if err != nil {
t.Fatalf("insert user failed: %v", err)
}
id, err := res.LastInsertId()
if err != nil {
t.Fatalf("LastInsertId failed: %v", err)
}
return id
}
func testFTPContextWithUserID(userID int64) *ftpserver.Context {
return &ftpserver.Context{Sess: &ftpserver.Session{Data: map[string]interface{}{"filez_user_id": userID}}}
}
func TestFTPDriverPlainTransfer(t *testing.T) {
t.Parallel()
s := makeTestServer(t, nil)
hash, err := hashPasswordArgon2ID("password123")
if err != nil {
t.Fatalf("hash password failed: %v", err)
}
uid := insertUserWithHash(t, s.db, "dave", hash)
drv := &ftpUserDriver{db: s.db, root: s.config.StorageRoot}
ctx := testFTPContextWithUserID(uid)
plain := []byte("plain ftp payload")
if _, err := drv.PutFile(ctx, "/plain.txt", bytes.NewReader(plain), 0); err != nil {
t.Fatalf("PutFile failed: %v", err)
}
size, rc, err := drv.GetFile(ctx, "/plain.txt", 0)
if err != nil {
t.Fatalf("GetFile failed: %v", err)
}
defer rc.Close()
got, err := io.ReadAll(rc)
if err != nil {
t.Fatalf("ReadAll failed: %v", err)
}
if size != int64(len(plain)) || !bytes.Equal(got, plain) {
t.Fatalf("plain transfer mismatch: size=%d got=%q", size, string(got))
}
}

218
backend/config.go Normal file
View File

@@ -0,0 +1,218 @@
package main
import (
"log"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/joho/godotenv"
)
type Config struct {
Addr string
DBPath string
StorageRoot string
AppDomain string
AllowedHost string
CORSOrigin string
CookieSecure bool
MaxBodyBytes int64
GoogleAuthEnabled bool
GoogleClientID string
GoogleClientSecret string
GoogleRedirectURL string
GoogleAuthURL string
GoogleTokenURL string
GoogleUserInfoURL string
RateLimitPerMin int
AuthRateLimitPerMin int
JWTSecret string
AccessTTL time.Duration
RefreshTTL time.Duration
ShareDefaultTTL time.Duration
AdminSessionTTL time.Duration
AdminLogin string
AdminPasswordHash string
FTPEnabled bool
FTPHost string
FTPPort int
FTPPublicIP string
FTPPassivePorts string
FTPSEnabled bool
FTPSHost string
FTPSPort int
FTPSPublicIP string
FTPSPassivePorts string
FTPSCertFile string
FTPSKeyFile string
FTPSLEDomain string
FTPSLEDir string
FTPSExplicit bool
FTPSForceTLS bool
SFTPEnabled bool
SFTPHost string
SFTPPort int
SFTPHostKeyPath string
}
func loadConfig() Config {
_ = godotenv.Load(".env", "../.env")
dbPath := getEnv("DB_PATH", "./app.db")
storageRoot := getEnv("STORAGE_ROOT", "./users")
if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil {
log.Fatalf("failed to create db path parent: %v", err)
}
if err := os.MkdirAll(storageRoot, 0o755); err != nil {
log.Fatalf("failed to create storage root: %v", err)
}
cfg := Config{
Addr: getEnv("ADDR", ":8080"),
DBPath: dbPath,
StorageRoot: storageRoot,
AppDomain: strings.ToLower(strings.TrimSpace(getEnv("APP_DOMAIN", "file.example.com"))),
AllowedHost: strings.ToLower(getEnv("ALLOWED_HOST", "")),
CORSOrigin: getEnv("CORS_ALLOWED_ORIGIN", ""),
CookieSecure: getEnv("COOKIE_SECURE", "false") == "true",
MaxBodyBytes: int64(getEnvInt("MAX_BODY_MB", 8)) * 1024 * 1024,
GoogleAuthEnabled: getEnv("GOOGLE_AUTH_ENABLED", "false") == "true",
GoogleClientID: getEnv("GOOGLE_CLIENT_ID", ""),
GoogleClientSecret: getEnv("GOOGLE_CLIENT_SECRET", ""),
GoogleRedirectURL: getEnv("GOOGLE_REDIRECT_URL", ""),
GoogleAuthURL: getEnv("GOOGLE_AUTH_URL", "https://accounts.google.com/o/oauth2/v2/auth"),
GoogleTokenURL: getEnv("GOOGLE_TOKEN_URL", "https://oauth2.googleapis.com/token"),
GoogleUserInfoURL: getEnv("GOOGLE_USERINFO_URL", "https://openidconnect.googleapis.com/v1/userinfo"),
RateLimitPerMin: getEnvInt("RATE_LIMIT_PER_MIN", 240),
AuthRateLimitPerMin: getEnvInt("AUTH_RATE_LIMIT_PER_MIN", 30),
JWTSecret: getEnv("JWT_SECRET", "dev-change-me-immediately"),
AccessTTL: 15 * time.Minute,
RefreshTTL: 30 * 24 * time.Hour,
ShareDefaultTTL: 24 * time.Hour,
AdminSessionTTL: 12 * time.Hour,
AdminLogin: getEnv("ADMIN_LOGIN", "admin"),
AdminPasswordHash: getEnv("ADMIN_PASSWORD_HASH", ""),
FTPEnabled: getEnv("FTP_ENABLED", "false") == "true",
FTPHost: getEnv("FTP_HOST", "0.0.0.0"),
FTPPort: getEnvInt("FTP_PORT", 2121),
FTPPublicIP: getEnv("FTP_PUBLIC_IP", ""),
FTPPassivePorts: getEnv("FTP_PASSIVE_PORTS", ""),
FTPSEnabled: getEnv("FTPS_ENABLED", "false") == "true",
FTPSHost: getEnv("FTPS_HOST", "0.0.0.0"),
FTPSPort: getEnvInt("FTPS_PORT", 2990),
FTPSPublicIP: getEnv("FTPS_PUBLIC_IP", ""),
FTPSPassivePorts: getEnv("FTPS_PASSIVE_PORTS", ""),
FTPSCertFile: getEnv("FTPS_CERT_FILE", ""),
FTPSKeyFile: getEnv("FTPS_KEY_FILE", ""),
FTPSLEDomain: strings.ToLower(strings.TrimSpace(getEnv("FTPS_LETSENCRYPT_DOMAIN", ""))),
FTPSLEDir: getEnv("FTPS_LETSENCRYPT_DIR", "/etc/letsencrypt/live"),
FTPSExplicit: getEnv("FTPS_EXPLICIT", "true") != "false",
FTPSForceTLS: getEnv("FTPS_FORCE_TLS", "true") != "false",
SFTPEnabled: getEnv("SFTP_ENABLED", "false") == "true",
SFTPHost: getEnv("SFTP_HOST", "0.0.0.0"),
SFTPPort: getEnvInt("SFTP_PORT", 2022),
SFTPHostKeyPath: getEnv("SFTP_HOST_KEY_PATH", "./sftp_host_ed25519"),
}
if cfg.AllowedHost == "" {
cfg.AllowedHost = cfg.AppDomain
}
if cfg.CORSOrigin == "" {
cfg.CORSOrigin = "https://" + cfg.AppDomain
}
if cfg.JWTSecret == "dev-change-me-immediately" {
log.Println("warning: JWT_SECRET is using default development value")
}
if strings.TrimSpace(cfg.AdminPasswordHash) == "" {
log.Fatal("ADMIN_PASSWORD_HASH is required. Generate one with: go run . hash-admin <password>")
}
if cfg.GoogleAuthEnabled {
if strings.TrimSpace(cfg.GoogleClientID) == "" || strings.TrimSpace(cfg.GoogleClientSecret) == "" {
log.Fatal("GOOGLE_AUTH_ENABLED=true requires GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET")
}
}
if cfg.FTPEnabled {
if cfg.FTPPort < 1 || cfg.FTPPort > 65535 {
log.Fatal("FTP_PORT must be in range 1..65535")
}
}
if cfg.FTPSEnabled {
if cfg.FTPSPort < 1 || cfg.FTPSPort > 65535 {
log.Fatal("FTPS_PORT must be in range 1..65535")
}
applyFTPSLetsEncryptDefaults(&cfg)
if strings.TrimSpace(cfg.FTPSCertFile) == "" || strings.TrimSpace(cfg.FTPSKeyFile) == "" {
log.Fatal("FTPS_ENABLED=true requires FTPS_CERT_FILE/FTPS_KEY_FILE or FTPS_LETSENCRYPT_DOMAIN")
}
if _, err := os.Stat(cfg.FTPSCertFile); err != nil {
log.Fatalf("FTPS_CERT_FILE is invalid: %v", err)
}
if _, err := os.Stat(cfg.FTPSKeyFile); err != nil {
log.Fatalf("FTPS_KEY_FILE is invalid: %v", err)
}
}
if cfg.FTPEnabled && cfg.FTPSEnabled {
if cfg.FTPPort == cfg.FTPSPort && strings.EqualFold(cfg.FTPHost, cfg.FTPSHost) {
log.Fatal("FTP and FTPS cannot share the same host:port")
}
}
if cfg.SFTPEnabled {
if cfg.SFTPPort < 1 || cfg.SFTPPort > 65535 {
log.Fatal("SFTP_PORT must be in range 1..65535")
}
}
return cfg
}
func applyFTPSLetsEncryptDefaults(cfg *Config) {
if cfg == nil {
return
}
if strings.TrimSpace(cfg.FTPSLEDomain) == "" {
return
}
base := strings.TrimSpace(cfg.FTPSLEDir)
if base == "" {
base = "/etc/letsencrypt/live"
}
domainDir := filepath.Join(base, cfg.FTPSLEDomain)
if strings.TrimSpace(cfg.FTPSCertFile) == "" {
cfg.FTPSCertFile = filepath.Join(domainDir, "fullchain.pem")
}
if strings.TrimSpace(cfg.FTPSKeyFile) == "" {
cfg.FTPSKeyFile = filepath.Join(domainDir, "privkey.pem")
}
}
func getEnv(key, fallback string) string {
v := strings.TrimSpace(os.Getenv(key))
if v == "" {
return fallback
}
return v
}
func getEnvInt(key string, fallback int) int {
v := strings.TrimSpace(os.Getenv(key))
if v == "" {
return fallback
}
n, err := strconv.Atoi(v)
if err != nil {
return fallback
}
return n
}

33
backend/config_test.go Normal file
View File

@@ -0,0 +1,33 @@
package main
import "testing"
func TestApplyFTPSLetsEncryptDefaults(t *testing.T) {
t.Parallel()
cfg := Config{FTPSLEDomain: "files.example.com"}
applyFTPSLetsEncryptDefaults(&cfg)
if cfg.FTPSCertFile != "/etc/letsencrypt/live/files.example.com/fullchain.pem" {
t.Fatalf("unexpected cert path: %q", cfg.FTPSCertFile)
}
if cfg.FTPSKeyFile != "/etc/letsencrypt/live/files.example.com/privkey.pem" {
t.Fatalf("unexpected key path: %q", cfg.FTPSKeyFile)
}
}
func TestApplyFTPSLetsEncryptDefaultsCustomDirAndPreserveManual(t *testing.T) {
t.Parallel()
cfg := Config{
FTPSLEDomain: "files.example.com",
FTPSLEDir: "/var/lib/acme/live",
FTPSCertFile: "/custom/cert.pem",
}
applyFTPSLetsEncryptDefaults(&cfg)
if cfg.FTPSCertFile != "/custom/cert.pem" {
t.Fatalf("manual cert should be preserved, got %q", cfg.FTPSCertFile)
}
if cfg.FTPSKeyFile != "/var/lib/acme/live/files.example.com/privkey.pem" {
t.Fatalf("unexpected key path: %q", cfg.FTPSKeyFile)
}
}

View File

@@ -3,24 +3,29 @@ module driveclone
go 1.25.0
require (
github.com/glebarez/sqlite v1.11.0
github.com/golang-jwt/jwt/v5 v5.3.1
github.com/gorilla/mux v1.8.1
github.com/hirochachacha/go-smb2 v1.1.0
github.com/joho/godotenv v1.5.1
goftp.io/server/v2 v2.0.2
golang.org/x/crypto v0.48.0
modernc.org/sqlite v1.46.1
gorm.io/gorm v1.31.1
)
require (
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/geoffgarside/ber v1.1.0 // indirect
github.com/glebarez/go-sqlite v1.21.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
modernc.org/libc v1.67.6 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
modernc.org/sqlite v1.23.1 // indirect
)

View File

@@ -1,19 +1,23 @@
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/geoffgarside/ber v1.1.0 h1:qTmFG4jJbwiSzSXoNJeHcOprVzZ8Ulde2Rrrifu5U9w=
github.com/geoffgarside/ber v1.1.0/go.mod h1:jVPKeCbj6MvQZhwLYsGwaGI52oUorHoHKNecGT85ZCc=
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/hirochachacha/go-smb2 v1.1.0 h1:b6hs9qKIql9eVXAiN0M2wSFY5xnhbHAQoCwRKbaRTZI=
github.com/hirochachacha/go-smb2 v1.1.0/go.mod h1:8F1A4d5EZzrGu5R7PU163UcMRDJQl4FtcxjBfsY8TZE=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
@@ -22,25 +26,25 @@ github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOF
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
goftp.io/server/v2 v2.0.2 h1:tkZpqyXys+vC15W5yGMi8Kzmbv1QSgeKr8qJXBnJbm8=
goftp.io/server/v2 v2.0.2/go.mod h1:Fl1WdcV7fx1pjOWx7jEHb7tsJ8VwE7+xHu6bVJ6r2qg=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
@@ -63,8 +67,8 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU=
modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
modernc.org/sqlite v1.23.1 h1:nrSBg4aRQQwq59JpvGEQ15tNxoO5pX/kUjcRNwSAGQM=
modernc.org/sqlite v1.23.1/go.mod h1:OrDj17Mggn6MhE+iPbBNf7RGKODDE9NFT0f3EwDzJqk=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=

View File

@@ -33,48 +33,16 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/gorilla/mux"
"github.com/hirochachacha/go-smb2"
"github.com/joho/godotenv"
"golang.org/x/crypto/argon2"
"golang.org/x/crypto/bcrypt"
_ "modernc.org/sqlite"
)
type Config struct {
Addr string
DBPath string
StorageRoot string
AppDomain string
AllowedHost string
CORSOrigin string
CookieSecure bool
MaxBodyBytes int64
RateLimitPerMin int
AuthRateLimitPerMin int
JWTSecret string
AccessTTL time.Duration
RefreshTTL time.Duration
ShareDefaultTTL time.Duration
AdminSessionTTL time.Duration
AdminLogin string
AdminPasswordHash string
StorageBackend string
SMBHost string
SMBShare string
SMBUser string
SMBPass string
SMBDomain string
SMBBasePath string
SMBConnectTimout time.Duration
}
//go:embed web/dist
var embeddedWeb embed.FS
type Server struct {
db *sql.DB
orm *ormRepo
config Config
storage Storage
limiter *rateLimiter
@@ -198,7 +166,16 @@ func main() {
log.Fatalf("storage init failed: %v", err)
}
s := &Server{db: db, config: cfg, storage: storage, limiter: newRateLimiter()}
orm, err := newORMRepo(cfg.DBPath)
if err != nil {
log.Fatalf("orm init failed: %v", err)
}
if err := startProtocolServers(cfg, db); err != nil {
log.Fatalf("protocol init failed: %v", err)
}
s := &Server{db: db, orm: orm, config: cfg, storage: storage, limiter: newRateLimiter()}
r := mux.NewRouter()
r.Use(s.recoverMiddleware)
r.Use(s.securityHeadersMiddleware)
@@ -212,6 +189,8 @@ func main() {
r.HandleFunc("/api/auth/register", s.handleRegisterDisabled).Methods(http.MethodPost)
r.HandleFunc("/api/auth/login", s.handleLogin).Methods(http.MethodPost)
r.HandleFunc("/api/auth/google/start", s.handleGoogleAuthStart).Methods(http.MethodGet)
r.HandleFunc("/api/auth/google/callback", s.handleGoogleAuthCallback).Methods(http.MethodGet)
r.HandleFunc("/api/auth/refresh", s.handleRefresh).Methods(http.MethodPost)
r.HandleFunc("/api/auth/logout", s.handleLogout).Methods(http.MethodPost)
@@ -224,6 +203,7 @@ func main() {
protected.Use(s.authMiddleware)
protected.HandleFunc("/auth/me", s.handleMe).Methods(http.MethodGet)
protected.HandleFunc("/user/preferences", s.handleSetPreferences).Methods(http.MethodPost)
protected.HandleFunc("/user/protocols", s.handleUserProtocols).Methods(http.MethodGet)
protected.HandleFunc("/files", s.handleListFiles).Methods(http.MethodGet)
protected.HandleFunc("/files/upload", s.handleUpload).Methods(http.MethodPost)
protected.HandleFunc("/files/download", s.handleDownload).Methods(http.MethodGet, http.MethodHead)
@@ -232,6 +212,7 @@ func main() {
protected.HandleFunc("/files/preview", s.handlePreview).Methods(http.MethodGet, http.MethodHead)
protected.HandleFunc("/files/text", s.handleReadTextFile).Methods(http.MethodGet)
protected.HandleFunc("/files/text", s.handleWriteTextFile).Methods(http.MethodPut)
protected.HandleFunc("/files/rename", s.handleRename).Methods(http.MethodPost)
protected.HandleFunc("/files", s.handleDelete).Methods(http.MethodDelete)
protected.HandleFunc("/files/folder", s.handleCreateFolder).Methods(http.MethodPost)
protected.HandleFunc("/files/share", s.handleCreateShareLink).Methods(http.MethodPost)
@@ -260,62 +241,6 @@ func main() {
}
}
func loadConfig() Config {
_ = godotenv.Load(".env", "../.env")
dbPath := getEnv("DB_PATH", "./app.db")
storageRoot := getEnv("STORAGE_ROOT", "./users")
if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil {
log.Fatalf("failed to create db path parent: %v", err)
}
if err := os.MkdirAll(storageRoot, 0o755); err != nil {
log.Fatalf("failed to create storage root: %v", err)
}
cfg := Config{
Addr: getEnv("ADDR", ":8080"),
DBPath: dbPath,
StorageRoot: storageRoot,
AppDomain: strings.ToLower(strings.TrimSpace(getEnv("APP_DOMAIN", "file.example.com"))),
AllowedHost: strings.ToLower(getEnv("ALLOWED_HOST", "")),
CORSOrigin: getEnv("CORS_ALLOWED_ORIGIN", ""),
CookieSecure: getEnv("COOKIE_SECURE", "false") == "true",
MaxBodyBytes: int64(getEnvInt("MAX_BODY_MB", 8)) * 1024 * 1024,
RateLimitPerMin: getEnvInt("RATE_LIMIT_PER_MIN", 240),
AuthRateLimitPerMin: getEnvInt("AUTH_RATE_LIMIT_PER_MIN", 30),
JWTSecret: getEnv("JWT_SECRET", "dev-change-me-immediately"),
AccessTTL: 15 * time.Minute,
RefreshTTL: 30 * 24 * time.Hour,
ShareDefaultTTL: 24 * time.Hour,
AdminSessionTTL: 12 * time.Hour,
AdminLogin: getEnv("ADMIN_LOGIN", "admin"),
AdminPasswordHash: getEnv("ADMIN_PASSWORD_HASH", ""),
StorageBackend: getEnv("STORAGE_BACKEND", "local"),
SMBHost: getEnv("SMB_HOST", ""),
SMBShare: getEnv("SMB_SHARE", ""),
SMBUser: getEnv("SMB_USER", ""),
SMBPass: getEnv("SMB_PASS", ""),
SMBDomain: getEnv("SMB_DOMAIN", ""),
SMBBasePath: getEnv("SMB_BASE_PATH", "driveflow"),
SMBConnectTimout: 5 * time.Second,
}
if cfg.AllowedHost == "" {
cfg.AllowedHost = cfg.AppDomain
}
if cfg.CORSOrigin == "" {
cfg.CORSOrigin = "https://" + cfg.AppDomain
}
if cfg.JWTSecret == "dev-change-me-immediately" {
log.Println("warning: JWT_SECRET is using default development value")
}
if strings.TrimSpace(cfg.AdminPasswordHash) == "" {
log.Fatal("ADMIN_PASSWORD_HASH is required. Generate one with: go run . hash-admin <password>")
}
return cfg
}
func openDB(path string) (*sql.DB, error) {
db, err := sql.Open("sqlite", path)
if err != nil {
@@ -392,18 +317,19 @@ func migrate(db *sql.DB) error {
return err
}
}
if _, err := db.Exec(`ALTER TABLE users ADD COLUMN google_sub TEXT`); err != nil {
if !strings.Contains(strings.ToLower(err.Error()), "duplicate column") {
return err
}
}
if _, err := db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_users_google_sub ON users(google_sub) WHERE google_sub IS NOT NULL`); err != nil {
return err
}
return nil
}
func buildStorage(cfg Config) (Storage, error) {
if strings.EqualFold(cfg.StorageBackend, "smb") {
if cfg.SMBHost == "" || cfg.SMBShare == "" || cfg.SMBUser == "" {
return nil, fmt.Errorf("SMB_HOST, SMB_SHARE, SMB_USER must be set for smb backend")
}
return &SMBStorage{cfg: cfg}, nil
}
root := cfg.StorageRoot
if err := os.MkdirAll(root, 0o755); err != nil {
return nil, err
@@ -1686,6 +1612,67 @@ func (s *Server) handleCreateFolder(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusCreated, map[string]string{"status": "created"})
}
func (s *Server) handleRename(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
var in struct {
Path string `json:"path"`
Name string `json:"name"`
}
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
writeErr(w, http.StatusBadRequest, "invalid payload")
return
}
src := normalizePath(in.Path)
if src == "/" {
writeErr(w, http.StatusBadRequest, "invalid path")
return
}
meta, err := s.storage.Stat(uid, src)
if err != nil {
writeErr(w, http.StatusBadRequest, "file not found")
return
}
name := path.Base(strings.TrimSpace(in.Name))
if name == "" || name == "." || name == ".." {
writeErr(w, http.StatusBadRequest, "invalid name")
return
}
dir := path.Dir(src)
if dir == "." {
dir = "/"
}
dst := normalizePath(path.Join(dir, name))
if dst == src {
writeJSON(w, http.StatusOK, map[string]any{"status": "renamed", "path": src})
return
}
if _, err := s.storage.Stat(uid, dst); err == nil {
writeErr(w, http.StatusBadRequest, "target already exists")
return
}
if meta.IsDir {
srcPrefix := strings.TrimSuffix(src, "/") + "/"
if dst == src || strings.HasPrefix(dst, srcPrefix) {
writeErr(w, http.StatusBadRequest, "cannot rename folder into itself")
return
}
}
if err := s.copyPath(uid, src, dst); err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
if err := s.storage.Delete(uid, src); err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
s.moveTags(uid, src, dst)
writeJSON(w, http.StatusOK, map[string]any{"status": "renamed", "path": dst})
}
type shareInput struct {
Path string `json:"path"`
ExpiresMinutes int `json:"expiresMinutes"`
@@ -1804,23 +1791,11 @@ func (s *Server) handleAdminMe(w http.ResponseWriter, _ *http.Request) {
}
func (s *Server) handleAdminUsersList(w http.ResponseWriter, _ *http.Request) {
rows, err := s.db.Query(`SELECT id, email, theme, color_mode, archive_format FROM users ORDER BY id ASC`)
users, err := s.orm.listUsers()
if err != nil {
writeErr(w, http.StatusInternalServerError, "failed to load users")
return
}
defer rows.Close()
users := make([]User, 0)
for rows.Next() {
var u User
if err := rows.Scan(&u.ID, &u.Username, &u.Theme, &u.ColorMode, &u.Archive); err != nil {
continue
}
u.Theme = normalizeTheme(u.Theme)
u.ColorMode = normalizeColorMode(u.ColorMode)
users = append(users, u)
}
writeJSON(w, http.StatusOK, map[string]any{"users": users})
}
@@ -1974,15 +1949,13 @@ func (s *Server) createUser(username, password, theme, colorMode string) (User,
return User{}, fmt.Errorf("failed to hash password")
}
res, err := s.db.Exec(`INSERT INTO users(email, password_hash, theme, color_mode, archive_format) VALUES (?, ?, ?, ?, ?)`, username, hash, normalizeTheme(theme), normalizeColorMode(colorMode), "zip")
id, err := s.orm.createUser(username, hash, normalizeTheme(theme), normalizeColorMode(colorMode), "zip", nil)
if err != nil {
if strings.Contains(strings.ToLower(err.Error()), "unique") {
return User{}, fmt.Errorf("account already exists")
}
return User{}, err
}
id, _ := res.LastInsertId()
if err := s.storage.Mkdir(id, "/"); err != nil {
return User{}, fmt.Errorf("failed to provision user storage: %w", err)
}
@@ -1991,26 +1964,11 @@ func (s *Server) createUser(username, password, theme, colorMode string) (User,
}
func (s *Server) findUserWithHash(username string) (User, string, error) {
username = strings.ToLower(strings.TrimSpace(username))
var u User
var hash string
err := s.db.QueryRow(`SELECT id, email, password_hash, theme, color_mode, archive_format FROM users WHERE email = ?`, username).
Scan(&u.ID, &u.Username, &hash, &u.Theme, &u.ColorMode, &u.Archive)
if err != nil {
return User{}, "", err
}
u.Theme = normalizeTheme(u.Theme)
u.ColorMode = normalizeColorMode(u.ColorMode)
return u, hash, nil
return s.orm.findUserWithHashByEmail(username)
}
func (s *Server) findUser(id int64) (User, error) {
var u User
err := s.db.QueryRow(`SELECT id, email, theme, color_mode, archive_format FROM users WHERE id = ?`, id).
Scan(&u.ID, &u.Username, &u.Theme, &u.ColorMode, &u.Archive)
u.Theme = normalizeTheme(u.Theme)
u.ColorMode = normalizeColorMode(u.ColorMode)
return u, err
return s.orm.findUserByID(id)
}
func (s *Server) issueUserSession(w http.ResponseWriter, r *http.Request, userID int64) error {
@@ -2290,197 +2248,6 @@ func (l *LocalStorage) OpenReadSeeker(userID int64, rel string) (ReadSeekCloser,
return os.Open(full)
}
type SMBStorage struct {
cfg Config
}
type smbConn struct {
conn net.Conn
session *smb2.Session
share *smb2.Share
}
func (c *smbConn) Close() {
if c.share != nil {
_ = c.share.Umount()
}
if c.session != nil {
_ = c.session.Logoff()
}
if c.conn != nil {
_ = c.conn.Close()
}
}
type smbReadSeekCloser struct {
file *smb2.File
conn *smbConn
}
func (s *smbReadSeekCloser) Read(p []byte) (int, error) {
return s.file.Read(p)
}
func (s *smbReadSeekCloser) Seek(offset int64, whence int) (int64, error) {
return s.file.Seek(offset, whence)
}
func (s *smbReadSeekCloser) Close() error {
_ = s.file.Close()
s.conn.Close()
return nil
}
func (smbs *SMBStorage) openConnection() (*smbConn, error) {
conn, err := net.DialTimeout("tcp", smbs.cfg.SMBHost, smbs.cfg.SMBConnectTimout)
if err != nil {
return nil, err
}
dialer := &smb2.Dialer{Initiator: &smb2.NTLMInitiator{
User: smbs.cfg.SMBUser,
Password: smbs.cfg.SMBPass,
Domain: smbs.cfg.SMBDomain,
}}
session, err := dialer.Dial(conn)
if err != nil {
_ = conn.Close()
return nil, err
}
share, err := session.Mount(smbs.cfg.SMBShare)
if err != nil {
_ = session.Logoff()
_ = conn.Close()
return nil, err
}
return &smbConn{conn: conn, session: session, share: share}, nil
}
func (smbs *SMBStorage) withShare(fn func(share *smb2.Share) error) error {
conn, err := smbs.openConnection()
if err != nil {
return err
}
defer conn.Close()
return fn(conn.share)
}
func (smbs *SMBStorage) userPath(userID int64, rel string) string {
base := path.Join(smbs.cfg.SMBBasePath, strconv.FormatInt(userID, 10))
clean := strings.TrimPrefix(normalizePath(rel), "/")
if clean == "" {
return base
}
return path.Join(base, clean)
}
func (smbs *SMBStorage) List(userID int64, rel string) ([]FileEntry, error) {
out := make([]FileEntry, 0)
err := smbs.withShare(func(share *smb2.Share) error {
target := smbs.userPath(userID, rel)
if err := share.MkdirAll(target, 0o755); err != nil {
return err
}
entries, err := share.ReadDir(target)
if err != nil {
return err
}
for _, e := range entries {
out = append(out, FileEntry{
Name: e.Name(),
Path: path.Join(normalizePath(rel), e.Name()),
IsDir: e.IsDir(),
Size: e.Size(),
ModTime: e.ModTime(),
})
}
return nil
})
return out, err
}
func (smbs *SMBStorage) Mkdir(userID int64, rel string) error {
return smbs.withShare(func(share *smb2.Share) error {
return share.MkdirAll(smbs.userPath(userID, rel), 0o755)
})
}
func (smbs *SMBStorage) Save(userID int64, rel string, src multipart.File) error {
return smbs.withShare(func(share *smb2.Share) error {
target := smbs.userPath(userID, rel)
if err := share.MkdirAll(path.Dir(target), 0o755); err != nil {
return err
}
f, err := share.Create(target)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(f, src)
return err
})
}
func (smbs *SMBStorage) SaveBytes(userID int64, rel string, data []byte) error {
return smbs.withShare(func(share *smb2.Share) error {
target := smbs.userPath(userID, rel)
if err := share.MkdirAll(path.Dir(target), 0o755); err != nil {
return err
}
f, err := share.Create(target)
if err != nil {
return err
}
defer f.Close()
_, err = f.Write(data)
return err
})
}
func (smbs *SMBStorage) Delete(userID int64, rel string) error {
return smbs.withShare(func(share *smb2.Share) error {
target := smbs.userPath(userID, rel)
if normalizePath(rel) == "/" {
return share.RemoveAll(target)
}
if err := share.Remove(target); err == nil {
return nil
}
return share.RemoveAll(target)
})
}
func (smbs *SMBStorage) Stat(userID int64, rel string) (FileMeta, error) {
var out FileMeta
err := smbs.withShare(func(share *smb2.Share) error {
st, err := share.Stat(smbs.userPath(userID, rel))
if err != nil {
return err
}
out = FileMeta{Name: st.Name(), Size: st.Size(), ModTime: st.ModTime(), IsDir: st.IsDir()}
return nil
})
return out, err
}
func (smbs *SMBStorage) OpenReadSeeker(userID int64, rel string) (ReadSeekCloser, error) {
conn, err := smbs.openConnection()
if err != nil {
return nil, err
}
f, err := conn.share.Open(smbs.userPath(userID, rel))
if err != nil {
conn.Close()
return nil, err
}
return &smbReadSeekCloser{file: f, conn: conn}, nil
}
func normalizePath(rel string) string {
clean := path.Clean("/" + strings.TrimSpace(rel))
if clean == "." {
@@ -2792,26 +2559,6 @@ func writeErr(w http.ResponseWriter, status int, msg string) {
writeJSON(w, status, map[string]string{"error": msg})
}
func getEnv(key, fallback string) string {
v := strings.TrimSpace(os.Getenv(key))
if v == "" {
return fallback
}
return v
}
func getEnvInt(key string, fallback int) int {
v := strings.TrimSpace(os.Getenv(key))
if v == "" {
return fallback
}
n, err := strconv.Atoi(v)
if err != nil {
return fallback
}
return n
}
func limitString(s string, max int) string {
if len(s) <= max {
return s

247
backend/oauth_google.go Normal file
View File

@@ -0,0 +1,247 @@
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"strings"
)
const googleOAuthStateCookie = "google_oauth_state"
type googleTokenResponse struct {
AccessToken string `json:"access_token"`
}
type googleUserInfo struct {
Sub string `json:"sub"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
}
func (s *Server) handleGoogleAuthStart(w http.ResponseWriter, r *http.Request) {
if !s.config.GoogleAuthEnabled {
writeErr(w, http.StatusNotFound, "google auth is disabled")
return
}
state, err := randomToken()
if err != nil {
writeErr(w, http.StatusInternalServerError, "failed to initialize oauth")
return
}
setCookie(w, googleOAuthStateCookie, state, 600, s.config.CookieSecure)
u, err := url.Parse(s.config.GoogleAuthURL)
if err != nil {
writeErr(w, http.StatusInternalServerError, "invalid google auth config")
return
}
q := u.Query()
q.Set("client_id", strings.TrimSpace(s.config.GoogleClientID))
q.Set("redirect_uri", s.googleRedirectURL(r))
q.Set("response_type", "code")
q.Set("scope", "openid email profile")
q.Set("state", state)
q.Set("prompt", "select_account")
u.RawQuery = q.Encode()
http.Redirect(w, r, u.String(), http.StatusFound)
}
func (s *Server) handleGoogleAuthCallback(w http.ResponseWriter, r *http.Request) {
if !s.config.GoogleAuthEnabled {
writeErr(w, http.StatusNotFound, "google auth is disabled")
return
}
if oauthErr := strings.TrimSpace(r.URL.Query().Get("error")); oauthErr != "" {
writeErr(w, http.StatusUnauthorized, "google login was denied")
return
}
code := strings.TrimSpace(r.URL.Query().Get("code"))
if code == "" {
writeErr(w, http.StatusBadRequest, "missing oauth code")
return
}
state := strings.TrimSpace(r.URL.Query().Get("state"))
stateCookie, err := r.Cookie(googleOAuthStateCookie)
if err != nil || stateCookie == nil || stateCookie.Value == "" || subtleConstantTimeEq(stateCookie.Value, state) == 0 {
writeErr(w, http.StatusUnauthorized, "invalid oauth state")
return
}
clearCookie(w, googleOAuthStateCookie, s.config.CookieSecure)
token, err := s.exchangeGoogleCode(r.Context(), code, s.googleRedirectURL(r))
if err != nil {
log.Printf("auth.google.failed ip=%q reason=%q", clientIP(r), "token_exchange_failed")
writeErr(w, http.StatusUnauthorized, "google auth failed")
return
}
info, err := s.fetchGoogleUserInfo(r.Context(), token)
if err != nil {
log.Printf("auth.google.failed ip=%q reason=%q", clientIP(r), "userinfo_failed")
writeErr(w, http.StatusUnauthorized, "google auth failed")
return
}
user, err := s.findOrCreateGoogleUser(info.Sub, info.Email)
if err != nil {
log.Printf("auth.google.failed ip=%q reason=%q", clientIP(r), "user_provision_failed")
writeErr(w, http.StatusUnauthorized, "google auth failed")
return
}
if err := s.issueUserSession(w, r, user.ID); err != nil {
log.Printf("auth.google.failed ip=%q reason=%q", clientIP(r), "session_issue_failed")
writeErr(w, http.StatusInternalServerError, "failed to create session")
return
}
log.Printf("auth.google.success user_id=%d username=%q ip=%q", user.ID, user.Username, clientIP(r))
http.Redirect(w, r, "/drive", http.StatusFound)
}
func (s *Server) googleRedirectURL(r *http.Request) string {
if v := strings.TrimSpace(s.config.GoogleRedirectURL); v != "" {
return v
}
return fmt.Sprintf("%s://%s/api/auth/google/callback", schemeOf(r), r.Host)
}
func (s *Server) exchangeGoogleCode(ctx context.Context, code, redirectURI string) (string, error) {
values := url.Values{}
values.Set("code", code)
values.Set("client_id", strings.TrimSpace(s.config.GoogleClientID))
values.Set("client_secret", s.config.GoogleClientSecret)
values.Set("redirect_uri", redirectURI)
values.Set("grant_type", "authorization_code")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.config.GoogleTokenURL, strings.NewReader(values.Encode()))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("google token endpoint returned %d", resp.StatusCode)
}
var out googleTokenResponse
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return "", err
}
if strings.TrimSpace(out.AccessToken) == "" {
return "", fmt.Errorf("google token response missing access_token")
}
return out.AccessToken, nil
}
func (s *Server) fetchGoogleUserInfo(ctx context.Context, accessToken string) (googleUserInfo, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.config.GoogleUserInfoURL, nil)
if err != nil {
return googleUserInfo{}, err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return googleUserInfo{}, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return googleUserInfo{}, fmt.Errorf("google userinfo endpoint returned %d", resp.StatusCode)
}
var out googleUserInfo
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return googleUserInfo{}, err
}
out.Sub = strings.TrimSpace(out.Sub)
out.Email = strings.ToLower(strings.TrimSpace(out.Email))
if out.Sub == "" || out.Email == "" || !out.EmailVerified {
return googleUserInfo{}, fmt.Errorf("google account data is incomplete")
}
return out, nil
}
func (s *Server) findOrCreateGoogleUser(googleSub, email string) (User, error) {
googleSub = strings.TrimSpace(googleSub)
email = strings.ToLower(strings.TrimSpace(email))
if googleSub == "" || email == "" {
return User{}, fmt.Errorf("invalid google identity")
}
if user, err := s.findUserByGoogleSub(googleSub); err == nil {
return user, nil
} else if !isNoRows(err) {
return User{}, err
}
if user, err := s.findUserByEmail(email); err == nil {
if err := s.orm.updateGoogleSub(user.ID, googleSub); err != nil {
return User{}, err
}
return user, nil
} else if !isNoRows(err) {
return User{}, err
}
hash, err := hashPasswordArgon2ID(mustRandomPassword())
if err != nil {
return User{}, err
}
id, err := s.orm.createUser(email, hash, "dracula", "auto", "zip", &googleSub)
if err != nil {
if strings.Contains(strings.ToLower(err.Error()), "unique") {
if user, findErr := s.findUserByEmail(email); findErr == nil {
if linkErr := s.orm.updateGoogleSub(user.ID, googleSub); linkErr != nil {
return User{}, linkErr
}
return s.findUser(user.ID)
}
}
return User{}, err
}
if err := s.storage.Mkdir(id, "/"); err != nil {
return User{}, fmt.Errorf("failed to provision user storage: %w", err)
}
return User{ID: id, Username: email, Theme: "dracula", ColorMode: "auto", Archive: "zip"}, nil
}
func (s *Server) findUserByGoogleSub(googleSub string) (User, error) {
return s.orm.findUserByGoogleSub(googleSub)
}
func (s *Server) findUserByEmail(email string) (User, error) {
return s.orm.findUserByEmail(email)
}
func mustRandomPassword() string {
tok, err := randomToken()
if err != nil {
return "google-oauth-password-fallback"
}
if len(tok) < 16 {
return tok + "-google-oauth"
}
return tok
}
func isNoRows(err error) bool {
return isORMNotFound(err)
}

View File

@@ -0,0 +1,108 @@
package main
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func TestGoogleOAuthCallbackCreatesSessionAndUser(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
provider := httptest.NewServer(mux)
defer provider.Close()
mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Fatalf("parse token form failed: %v", err)
}
if got := r.Form.Get("code"); got != "ok-code" {
t.Fatalf("code = %q, want %q", got, "ok-code")
}
if got := r.Form.Get("client_id"); got != "client-id" {
t.Fatalf("client_id = %q, want %q", got, "client-id")
}
if got := r.Form.Get("client_secret"); got != "client-secret" {
t.Fatalf("client_secret = %q, want %q", got, "client-secret")
}
_ = json.NewEncoder(w).Encode(map[string]any{"access_token": "google-access-token"})
})
mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "Bearer google-access-token" {
t.Fatalf("authorization = %q, want bearer token", got)
}
_ = json.NewEncoder(w).Encode(map[string]any{
"sub": "google-sub-1",
"email": "alice@example.com",
"email_verified": true,
})
})
s := makeTestServer(t, func(cfg *Config) {
cfg.GoogleAuthEnabled = true
cfg.GoogleClientID = "client-id"
cfg.GoogleClientSecret = "client-secret"
cfg.GoogleAuthURL = provider.URL + "/auth"
cfg.GoogleTokenURL = provider.URL + "/token"
cfg.GoogleUserInfoURL = provider.URL + "/userinfo"
})
startReq := httptest.NewRequest(http.MethodGet, "/api/auth/google/start", nil)
startReq.Host = "file.example.com"
startRec := httptest.NewRecorder()
s.handleGoogleAuthStart(startRec, startReq)
if startRec.Code != http.StatusFound {
t.Fatalf("start status = %d, want %d", startRec.Code, http.StatusFound)
}
stateCookie := cookieByName(startRec.Result().Cookies(), googleOAuthStateCookie)
if stateCookie == nil || stateCookie.Value == "" {
t.Fatal("missing oauth state cookie")
}
redir := startRec.Result().Header.Get("Location")
parsed, err := url.Parse(redir)
if err != nil {
t.Fatalf("parse redirect url failed: %v", err)
}
if parsed.Query().Get("state") != stateCookie.Value {
t.Fatalf("redirect state mismatch")
}
cbReq := httptest.NewRequest(http.MethodGet, "/api/auth/google/callback?code=ok-code&state="+url.QueryEscape(stateCookie.Value), nil)
cbReq.Host = "file.example.com"
cbReq.AddCookie(stateCookie)
cbRec := httptest.NewRecorder()
s.handleGoogleAuthCallback(cbRec, cbReq)
if cbRec.Code != http.StatusFound {
t.Fatalf("callback status = %d, want %d", cbRec.Code, http.StatusFound)
}
if got := cbRec.Header().Get("Location"); got != "/drive" {
t.Fatalf("callback redirect = %q, want %q", got, "/drive")
}
if cookieByName(cbRec.Result().Cookies(), "access_token") == nil {
t.Fatal("callback missing access_token cookie")
}
if cookieByName(cbRec.Result().Cookies(), "refresh_token") == nil {
t.Fatal("callback missing refresh_token cookie")
}
var count int
var googleSub string
err = s.db.QueryRow(`SELECT COUNT(*), COALESCE(MAX(google_sub), '') FROM users WHERE email = ?`, "alice@example.com").Scan(&count, &googleSub)
if err != nil {
t.Fatalf("query user failed: %v", err)
}
if count != 1 {
t.Fatalf("users with google email = %d, want 1", count)
}
if strings.TrimSpace(googleSub) != "google-sub-1" {
t.Fatalf("google_sub = %q, want %q", googleSub, "google-sub-1")
}
}

122
backend/orm.go Normal file
View File

@@ -0,0 +1,122 @@
package main
import (
"errors"
"strings"
"github.com/glebarez/sqlite"
"gorm.io/gorm"
)
type ormUser struct {
ID int64 `gorm:"column:id;primaryKey;autoIncrement"`
Email string `gorm:"column:email"`
PasswordHash string `gorm:"column:password_hash"`
Theme string `gorm:"column:theme"`
ColorMode string `gorm:"column:color_mode"`
Archive string `gorm:"column:archive_format"`
GoogleSub *string `gorm:"column:google_sub"`
}
func (ormUser) TableName() string {
return "users"
}
type ormRepo struct {
db *gorm.DB
}
func newORMRepo(dbPath string) (*ormRepo, error) {
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
if err != nil {
return nil, err
}
return &ormRepo{db: db}, nil
}
func (o *ormRepo) listUsers() ([]User, error) {
var rows []ormUser
if err := o.db.Order("id asc").Find(&rows).Error; err != nil {
return nil, err
}
users := make([]User, 0, len(rows))
for _, row := range rows {
users = append(users, row.toPublicUser())
}
return users, nil
}
func (o *ormRepo) createUser(email, passwordHash, theme, colorMode, archive string, googleSub *string) (int64, error) {
row := ormUser{
Email: strings.ToLower(strings.TrimSpace(email)),
PasswordHash: passwordHash,
Theme: normalizeTheme(theme),
ColorMode: normalizeColorMode(colorMode),
Archive: archive,
GoogleSub: googleSub,
}
if err := o.db.Create(&row).Error; err != nil {
return 0, err
}
return row.ID, nil
}
func (o *ormRepo) findUserWithHashByEmail(email string) (User, string, error) {
var row ormUser
err := o.db.Where("email = ?", strings.ToLower(strings.TrimSpace(email))).First(&row).Error
if err != nil {
return User{}, "", err
}
user := row.toPublicUser()
return user, row.PasswordHash, nil
}
func (o *ormRepo) findUserByID(id int64) (User, error) {
var row ormUser
err := o.db.Where("id = ?", id).First(&row).Error
if err != nil {
return User{}, err
}
return row.toPublicUser(), nil
}
func (o *ormRepo) findUserByEmail(email string) (User, error) {
var row ormUser
err := o.db.Where("email = ?", strings.ToLower(strings.TrimSpace(email))).First(&row).Error
if err != nil {
return User{}, err
}
return row.toPublicUser(), nil
}
func (o *ormRepo) findUserByGoogleSub(googleSub string) (User, error) {
var row ormUser
err := o.db.Where("google_sub = ?", strings.TrimSpace(googleSub)).First(&row).Error
if err != nil {
return User{}, err
}
return row.toPublicUser(), nil
}
func (o *ormRepo) updateGoogleSub(userID int64, googleSub string) error {
return o.db.Model(&ormUser{}).Where("id = ?", userID).Update("google_sub", strings.TrimSpace(googleSub)).Error
}
func (u ormUser) toPublicUser() User {
archive := normalizeArchiveFormat(u.Archive)
if archive == "" {
archive = "zip"
}
return User{
ID: u.ID,
Username: strings.ToLower(strings.TrimSpace(u.Email)),
Theme: normalizeTheme(u.Theme),
ColorMode: normalizeColorMode(u.ColorMode),
Archive: archive,
}
}
func isORMNotFound(err error) bool {
return errors.Is(err, gorm.ErrRecordNotFound)
}

356
backend/protocol_ftp.go Normal file
View File

@@ -0,0 +1,356 @@
package main
import (
"database/sql"
"errors"
"fmt"
"io"
"log"
"os"
"path/filepath"
"strconv"
"strings"
ftpserver "goftp.io/server/v2"
)
func startProtocolServers(cfg Config, db *sql.DB) error {
if cfg.SFTPEnabled {
return fmt.Errorf("SFTP_ENABLED=true is configured, but SFTP server is not implemented")
}
if cfg.FTPEnabled {
ftpSrv, err := buildFTPServer(cfg, db)
if err != nil {
return fmt.Errorf("build ftp server: %w", err)
}
go runFTPServer("FTP", ftpSrv)
}
if cfg.FTPSEnabled {
ftpsSrv, err := buildFTPSServer(cfg, db)
if err != nil {
return fmt.Errorf("build ftps server: %w", err)
}
go runFTPServer("FTPS", ftpsSrv)
}
return nil
}
func runFTPServer(name string, srv *ftpserver.Server) {
log.Printf("%s listening on %s:%d", name, srv.Hostname, srv.Port)
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, ftpserver.ErrServerClosed) {
log.Printf("%s server stopped: %v", name, err)
}
}
type ftpUserDriver struct {
db *sql.DB
root string
}
func buildFTPServer(cfg Config, db *sql.DB) (*ftpserver.Server, error) {
return buildFileZFTPServer(ftpServerOptions{
name: "FileZ FTP",
host: cfg.FTPHost,
port: cfg.FTPPort,
publicIP: cfg.FTPPublicIP,
passivePorts: cfg.FTPPassivePorts,
tls: false,
}, db, cfg.StorageRoot)
}
func buildFTPSServer(cfg Config, db *sql.DB) (*ftpserver.Server, error) {
return buildFileZFTPServer(ftpServerOptions{
name: "FileZ FTPS",
host: cfg.FTPSHost,
port: cfg.FTPSPort,
publicIP: cfg.FTPSPublicIP,
passivePorts: cfg.FTPSPassivePorts,
tls: true,
certFile: cfg.FTPSCertFile,
keyFile: cfg.FTPSKeyFile,
explicitFTPS: cfg.FTPSExplicit,
forceTLS: cfg.FTPSForceTLS,
}, db, cfg.StorageRoot)
}
type ftpServerOptions struct {
name string
host string
port int
publicIP string
passivePorts string
tls bool
certFile string
keyFile string
explicitFTPS bool
forceTLS bool
}
func buildFileZFTPServer(opts ftpServerOptions, db *sql.DB, storageRoot string) (*ftpserver.Server, error) {
drv := &ftpUserDriver{db: db, root: storageRoot}
serverOpts := &ftpserver.Options{
Name: opts.name,
Hostname: opts.host,
Port: opts.port,
PublicIP: opts.publicIP,
PassivePorts: opts.passivePorts,
WelcomeMessage: opts.name + " ready",
Driver: drv,
Auth: drv,
Perm: ftpserver.NewSimplePerm("filez", "filez"),
TLS: opts.tls,
CertFile: opts.certFile,
KeyFile: opts.keyFile,
ExplicitFTPS: opts.explicitFTPS,
ForceTLS: opts.forceTLS,
}
return ftpserver.NewServer(serverOpts)
}
func (d *ftpUserDriver) CheckPasswd(ctx *ftpserver.Context, username, password string) (bool, error) {
username = strings.TrimSpace(username)
if username == "" {
return false, nil
}
var userID int64
var hash string
err := d.db.QueryRow(`SELECT id, password_hash FROM users WHERE email = ?`, username).Scan(&userID, &hash)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
return false, err
}
if !verifyPasswordHash(hash, password) {
return false, nil
}
if err := os.MkdirAll(d.userRoot(userID), 0o755); err != nil {
return false, err
}
d.setUserID(ctx, userID)
return true, nil
}
func (d *ftpUserDriver) userRoot(userID int64) string {
return filepath.Join(d.root, strconv.FormatInt(userID, 10))
}
func (d *ftpUserDriver) setUserID(ctx *ftpserver.Context, userID int64) {
if ctx == nil || ctx.Sess == nil {
return
}
if ctx.Sess.Data == nil {
ctx.Sess.Data = make(map[string]interface{})
}
ctx.Sess.Data["filez_user_id"] = userID
}
func (d *ftpUserDriver) userIDFromCtx(ctx *ftpserver.Context) (int64, error) {
if ctx == nil || ctx.Sess == nil {
return 0, fmt.Errorf("missing ftp session")
}
if v, ok := ctx.Sess.Data["filez_user_id"]; ok {
switch id := v.(type) {
case int64:
if id > 0 {
return id, nil
}
case int:
if id > 0 {
return int64(id), nil
}
}
}
username := strings.TrimSpace(ctx.Sess.LoginUser())
if username == "" {
return 0, fmt.Errorf("missing ftp username")
}
var userID int64
if err := d.db.QueryRow(`SELECT id FROM users WHERE email = ?`, username).Scan(&userID); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return 0, fmt.Errorf("user not found")
}
return 0, err
}
d.setUserID(ctx, userID)
return userID, nil
}
func (d *ftpUserDriver) fullPath(ctx *ftpserver.Context, rel string) (string, error) {
userID, err := d.userIDFromCtx(ctx)
if err != nil {
return "", err
}
root := d.userRoot(userID)
if err := os.MkdirAll(root, 0o755); err != nil {
return "", err
}
clean := filepath.FromSlash(strings.TrimPrefix(normalizePath(rel), "/"))
full := filepath.Clean(filepath.Join(root, clean))
if full != root && !strings.HasPrefix(full, root+string(os.PathSeparator)) {
return "", fmt.Errorf("invalid path")
}
return full, nil
}
func (d *ftpUserDriver) Stat(ctx *ftpserver.Context, rel string) (os.FileInfo, error) {
full, err := d.fullPath(ctx, rel)
if err != nil {
return nil, err
}
return os.Stat(full)
}
func (d *ftpUserDriver) ListDir(ctx *ftpserver.Context, rel string, callback func(os.FileInfo) error) error {
full, err := d.fullPath(ctx, rel)
if err != nil {
return err
}
entries, err := os.ReadDir(full)
if err != nil {
return err
}
for _, entry := range entries {
info, err := entry.Info()
if err != nil {
continue
}
if err := callback(info); err != nil {
return err
}
}
return nil
}
func (d *ftpUserDriver) DeleteDir(ctx *ftpserver.Context, rel string) error {
if normalizePath(rel) == "/" {
return fmt.Errorf("cannot remove root directory")
}
full, err := d.fullPath(ctx, rel)
if err != nil {
return err
}
st, err := os.Stat(full)
if err != nil {
return err
}
if !st.IsDir() {
return fmt.Errorf("not a directory")
}
return os.RemoveAll(full)
}
func (d *ftpUserDriver) DeleteFile(ctx *ftpserver.Context, rel string) error {
full, err := d.fullPath(ctx, rel)
if err != nil {
return err
}
st, err := os.Stat(full)
if err != nil {
return err
}
if st.IsDir() {
return fmt.Errorf("not a file")
}
return os.Remove(full)
}
func (d *ftpUserDriver) Rename(ctx *ftpserver.Context, fromPath, toPath string) error {
if normalizePath(fromPath) == "/" {
return fmt.Errorf("cannot rename root directory")
}
oldPath, err := d.fullPath(ctx, fromPath)
if err != nil {
return err
}
newPath, err := d.fullPath(ctx, toPath)
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(newPath), 0o755); err != nil {
return err
}
return os.Rename(oldPath, newPath)
}
func (d *ftpUserDriver) MakeDir(ctx *ftpserver.Context, rel string) error {
full, err := d.fullPath(ctx, rel)
if err != nil {
return err
}
return os.MkdirAll(full, 0o755)
}
func (d *ftpUserDriver) GetFile(ctx *ftpserver.Context, rel string, offset int64) (int64, io.ReadCloser, error) {
full, err := d.fullPath(ctx, rel)
if err != nil {
return 0, nil, err
}
f, err := os.Open(full)
if err != nil {
return 0, nil, err
}
defer func() {
if err != nil {
_ = f.Close()
}
}()
st, err := f.Stat()
if err != nil {
return 0, nil, err
}
if offset < 0 {
offset = 0
}
if _, err := f.Seek(offset, io.SeekStart); err != nil {
return 0, nil, err
}
sz := st.Size() - offset
if sz < 0 {
sz = 0
}
return sz, f, nil
}
func (d *ftpUserDriver) PutFile(ctx *ftpserver.Context, rel string, data io.Reader, offset int64) (int64, error) {
full, err := d.fullPath(ctx, rel)
if err != nil {
return 0, err
}
if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil {
return 0, err
}
if offset < 0 {
f, err := os.Create(full)
if err != nil {
return 0, err
}
defer f.Close()
return io.Copy(f, data)
}
flags := os.O_CREATE | os.O_WRONLY
if offset == 0 {
flags |= os.O_TRUNC
}
f, err := os.OpenFile(full, flags, 0o644)
if err != nil {
return 0, err
}
defer f.Close()
if _, err := f.Seek(offset, io.SeekStart); err != nil {
return 0, err
}
return io.Copy(f, data)
}

71
backend/protocol_info.go Normal file
View File

@@ -0,0 +1,71 @@
package main
import (
"net/http"
"strings"
)
type protocolProfile struct {
Host string `json:"host"`
Port int `json:"port"`
Username string `json:"username"`
PublicIP string `json:"publicIP,omitempty"`
PassivePorts string `json:"passivePorts,omitempty"`
ExplicitTLS bool `json:"explicitTLS,omitempty"`
ForceTLS bool `json:"forceTLS,omitempty"`
}
type userProtocolsResponse struct {
FTP *protocolProfile `json:"ftp,omitempty"`
FTPS *protocolProfile `json:"ftps,omitempty"`
}
func (s *Server) handleUserProtocols(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
user, err := s.findUser(uid)
if err != nil {
writeErr(w, http.StatusNotFound, "user not found")
return
}
out := userProtocolsResponse{}
if s.config.FTPEnabled {
out.FTP = &protocolProfile{
Host: protocolHostForClient(s.config.FTPHost, s.config.FTPPublicIP, s.config.AppDomain, r),
Port: s.config.FTPPort,
Username: user.Username,
PublicIP: strings.TrimSpace(s.config.FTPPublicIP),
PassivePorts: strings.TrimSpace(s.config.FTPPassivePorts),
}
}
if s.config.FTPSEnabled {
out.FTPS = &protocolProfile{
Host: protocolHostForClient(s.config.FTPSHost, s.config.FTPSPublicIP, s.config.AppDomain, r),
Port: s.config.FTPSPort,
Username: user.Username,
PublicIP: strings.TrimSpace(s.config.FTPSPublicIP),
PassivePorts: strings.TrimSpace(s.config.FTPSPassivePorts),
ExplicitTLS: s.config.FTPSExplicit,
ForceTLS: s.config.FTPSForceTLS,
}
}
writeJSON(w, http.StatusOK, out)
}
func protocolHostForClient(bindHost, publicIP, appDomain string, r *http.Request) string {
if v := strings.TrimSpace(publicIP); v != "" {
return v
}
if v := strings.TrimSpace(appDomain); v != "" {
return v
}
if v := strings.TrimSpace(hostOnly(r.Host)); v != "" {
return v
}
v := strings.TrimSpace(bindHost)
if v == "" || v == "0.0.0.0" || v == "::" || v == "[::]" {
return "localhost"
}
return v
}