2821 lines
77 KiB
Go
2821 lines
77 KiB
Go
package main
|
|
|
|
import (
|
|
"archive/tar"
|
|
"archive/zip"
|
|
"compress/gzip"
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"crypto/subtle"
|
|
"database/sql"
|
|
"embed"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"log"
|
|
"mime"
|
|
"mime/multipart"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"path"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/gorilla/mux"
|
|
"github.com/hirochachacha/go-smb2"
|
|
"github.com/joho/godotenv"
|
|
"golang.org/x/crypto/argon2"
|
|
"golang.org/x/crypto/bcrypt"
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
type Config struct {
|
|
Addr string
|
|
DBPath string
|
|
StorageRoot string
|
|
AppDomain string
|
|
AllowedHost string
|
|
CORSOrigin string
|
|
CookieSecure bool
|
|
MaxBodyBytes int64
|
|
|
|
RateLimitPerMin int
|
|
AuthRateLimitPerMin int
|
|
|
|
JWTSecret string
|
|
AccessTTL time.Duration
|
|
RefreshTTL time.Duration
|
|
ShareDefaultTTL time.Duration
|
|
AdminSessionTTL time.Duration
|
|
AdminLogin string
|
|
AdminPasswordHash string
|
|
StorageBackend string
|
|
SMBHost string
|
|
SMBShare string
|
|
SMBUser string
|
|
SMBPass string
|
|
SMBDomain string
|
|
SMBBasePath string
|
|
SMBConnectTimout time.Duration
|
|
}
|
|
|
|
//go:embed web/dist
|
|
var embeddedWeb embed.FS
|
|
|
|
type Server struct {
|
|
db *sql.DB
|
|
config Config
|
|
storage Storage
|
|
limiter *rateLimiter
|
|
}
|
|
|
|
type rateLimiter struct {
|
|
mu sync.Mutex
|
|
entries map[string]*rateEntry
|
|
}
|
|
|
|
type rateEntry struct {
|
|
Count int
|
|
WindowEnds time.Time
|
|
}
|
|
|
|
func newRateLimiter() *rateLimiter {
|
|
return &rateLimiter{entries: make(map[string]*rateEntry)}
|
|
}
|
|
|
|
func (rl *rateLimiter) allow(key string, limit int, now time.Time) bool {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
entry, ok := rl.entries[key]
|
|
if !ok || now.After(entry.WindowEnds) {
|
|
rl.entries[key] = &rateEntry{Count: 1, WindowEnds: now.Add(time.Minute)}
|
|
return true
|
|
}
|
|
|
|
if entry.Count >= limit {
|
|
return false
|
|
}
|
|
|
|
entry.Count++
|
|
return true
|
|
}
|
|
|
|
type statusRecorder struct {
|
|
http.ResponseWriter
|
|
status int
|
|
}
|
|
|
|
func (r *statusRecorder) WriteHeader(code int) {
|
|
r.status = code
|
|
r.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
type key int
|
|
|
|
const (
|
|
userIDKey key = 1
|
|
)
|
|
|
|
type AccessClaims struct {
|
|
UserID int64 `json:"uid"`
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
type AdminClaims struct {
|
|
Login string `json:"login"`
|
|
Role string `json:"role"`
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
type User struct {
|
|
ID int64 `json:"id"`
|
|
Username string `json:"username"`
|
|
Theme string `json:"theme"`
|
|
ColorMode string `json:"colorMode"`
|
|
Archive string `json:"archiveFormat"`
|
|
}
|
|
|
|
type FileEntry struct {
|
|
Name string `json:"name"`
|
|
Path string `json:"path"`
|
|
IsDir bool `json:"isDir"`
|
|
Size int64 `json:"size"`
|
|
ModTime time.Time `json:"modTime"`
|
|
Tags []string `json:"tags,omitempty"`
|
|
}
|
|
|
|
type FileMeta struct {
|
|
Name string
|
|
Size int64
|
|
ModTime time.Time
|
|
IsDir bool
|
|
}
|
|
|
|
type ReadSeekCloser interface {
|
|
io.ReadSeeker
|
|
io.Closer
|
|
}
|
|
|
|
type Storage interface {
|
|
List(userID int64, rel string) ([]FileEntry, error)
|
|
Mkdir(userID int64, rel string) error
|
|
Save(userID int64, rel string, src multipart.File) error
|
|
SaveBytes(userID int64, rel string, data []byte) error
|
|
Delete(userID int64, rel string) error
|
|
Stat(userID int64, rel string) (FileMeta, error)
|
|
OpenReadSeeker(userID int64, rel string) (ReadSeekCloser, error)
|
|
}
|
|
|
|
func main() {
|
|
if maybeRunHashCommand() {
|
|
return
|
|
}
|
|
|
|
cfg := loadConfig()
|
|
db, err := openDB(cfg.DBPath)
|
|
if err != nil {
|
|
log.Fatalf("db open failed: %v", err)
|
|
}
|
|
|
|
if err := migrate(db); err != nil {
|
|
log.Fatalf("migrate failed: %v", err)
|
|
}
|
|
|
|
storage, err := buildStorage(cfg)
|
|
if err != nil {
|
|
log.Fatalf("storage init failed: %v", err)
|
|
}
|
|
|
|
s := &Server{db: db, config: cfg, storage: storage, limiter: newRateLimiter()}
|
|
r := mux.NewRouter()
|
|
r.Use(s.recoverMiddleware)
|
|
r.Use(s.securityHeadersMiddleware)
|
|
r.Use(s.hostGuardMiddleware)
|
|
r.Use(s.corsMiddleware)
|
|
r.Use(s.bodyLimitMiddleware)
|
|
r.Use(s.rateLimitMiddleware)
|
|
r.Use(s.requestLogMiddleware)
|
|
|
|
r.HandleFunc("/api/health", s.handleHealth).Methods(http.MethodGet)
|
|
|
|
r.HandleFunc("/api/auth/register", s.handleRegisterDisabled).Methods(http.MethodPost)
|
|
r.HandleFunc("/api/auth/login", s.handleLogin).Methods(http.MethodPost)
|
|
r.HandleFunc("/api/auth/refresh", s.handleRefresh).Methods(http.MethodPost)
|
|
r.HandleFunc("/api/auth/logout", s.handleLogout).Methods(http.MethodPost)
|
|
|
|
r.HandleFunc("/api/admin/login", s.handleAdminLogin).Methods(http.MethodPost)
|
|
r.HandleFunc("/api/admin/logout", s.handleAdminLogout).Methods(http.MethodPost)
|
|
|
|
r.HandleFunc("/api/share/{token}", s.handleSharedDownload).Methods(http.MethodGet, http.MethodHead)
|
|
|
|
protected := r.PathPrefix("/api").Subrouter()
|
|
protected.Use(s.authMiddleware)
|
|
protected.HandleFunc("/auth/me", s.handleMe).Methods(http.MethodGet)
|
|
protected.HandleFunc("/user/preferences", s.handleSetPreferences).Methods(http.MethodPost)
|
|
protected.HandleFunc("/files", s.handleListFiles).Methods(http.MethodGet)
|
|
protected.HandleFunc("/files/upload", s.handleUpload).Methods(http.MethodPost)
|
|
protected.HandleFunc("/files/download", s.handleDownload).Methods(http.MethodGet, http.MethodHead)
|
|
protected.HandleFunc("/files/download-batch", s.handleBatchDownload).Methods(http.MethodPost)
|
|
protected.HandleFunc("/files/move-batch", s.handleBatchMove).Methods(http.MethodPost)
|
|
protected.HandleFunc("/files/preview", s.handlePreview).Methods(http.MethodGet, http.MethodHead)
|
|
protected.HandleFunc("/files/text", s.handleReadTextFile).Methods(http.MethodGet)
|
|
protected.HandleFunc("/files/text", s.handleWriteTextFile).Methods(http.MethodPut)
|
|
protected.HandleFunc("/files", s.handleDelete).Methods(http.MethodDelete)
|
|
protected.HandleFunc("/files/folder", s.handleCreateFolder).Methods(http.MethodPost)
|
|
protected.HandleFunc("/files/share", s.handleCreateShareLink).Methods(http.MethodPost)
|
|
protected.HandleFunc("/files/tags", s.handleListFileTags).Methods(http.MethodGet)
|
|
protected.HandleFunc("/files/tags", s.handleAddFileTag).Methods(http.MethodPost)
|
|
protected.HandleFunc("/files/tags", s.handleDeleteFileTag).Methods(http.MethodDelete)
|
|
|
|
admin := r.PathPrefix("/api/admin").Subrouter()
|
|
admin.Use(s.adminMiddleware)
|
|
admin.HandleFunc("/me", s.handleAdminMe).Methods(http.MethodGet)
|
|
admin.HandleFunc("/users", s.handleAdminUsersList).Methods(http.MethodGet)
|
|
admin.HandleFunc("/users", s.handleAdminUserCreate).Methods(http.MethodPost)
|
|
admin.HandleFunc("/users/{id}", s.handleAdminUserDelete).Methods(http.MethodDelete)
|
|
|
|
r.PathPrefix("/").Handler(s.staticHandler())
|
|
|
|
server := &http.Server{
|
|
Addr: cfg.Addr,
|
|
Handler: r,
|
|
ReadHeaderTimeout: 10 * time.Second,
|
|
}
|
|
|
|
log.Printf("listening on %s", cfg.Addr)
|
|
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
log.Fatalf("server failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func loadConfig() Config {
|
|
_ = godotenv.Load(".env", "../.env")
|
|
|
|
dbPath := getEnv("DB_PATH", "./app.db")
|
|
storageRoot := getEnv("STORAGE_ROOT", "./users")
|
|
if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil {
|
|
log.Fatalf("failed to create db path parent: %v", err)
|
|
}
|
|
if err := os.MkdirAll(storageRoot, 0o755); err != nil {
|
|
log.Fatalf("failed to create storage root: %v", err)
|
|
}
|
|
|
|
cfg := Config{
|
|
Addr: getEnv("ADDR", ":8080"),
|
|
DBPath: dbPath,
|
|
StorageRoot: storageRoot,
|
|
AppDomain: strings.ToLower(strings.TrimSpace(getEnv("APP_DOMAIN", "file.example.com"))),
|
|
AllowedHost: strings.ToLower(getEnv("ALLOWED_HOST", "")),
|
|
CORSOrigin: getEnv("CORS_ALLOWED_ORIGIN", ""),
|
|
CookieSecure: getEnv("COOKIE_SECURE", "false") == "true",
|
|
MaxBodyBytes: int64(getEnvInt("MAX_BODY_MB", 8)) * 1024 * 1024,
|
|
RateLimitPerMin: getEnvInt("RATE_LIMIT_PER_MIN", 240),
|
|
AuthRateLimitPerMin: getEnvInt("AUTH_RATE_LIMIT_PER_MIN", 30),
|
|
JWTSecret: getEnv("JWT_SECRET", "dev-change-me-immediately"),
|
|
AccessTTL: 15 * time.Minute,
|
|
RefreshTTL: 30 * 24 * time.Hour,
|
|
ShareDefaultTTL: 24 * time.Hour,
|
|
AdminSessionTTL: 12 * time.Hour,
|
|
AdminLogin: getEnv("ADMIN_LOGIN", "admin"),
|
|
AdminPasswordHash: getEnv("ADMIN_PASSWORD_HASH", ""),
|
|
StorageBackend: getEnv("STORAGE_BACKEND", "local"),
|
|
SMBHost: getEnv("SMB_HOST", ""),
|
|
SMBShare: getEnv("SMB_SHARE", ""),
|
|
SMBUser: getEnv("SMB_USER", ""),
|
|
SMBPass: getEnv("SMB_PASS", ""),
|
|
SMBDomain: getEnv("SMB_DOMAIN", ""),
|
|
SMBBasePath: getEnv("SMB_BASE_PATH", "driveflow"),
|
|
SMBConnectTimout: 5 * time.Second,
|
|
}
|
|
|
|
if cfg.AllowedHost == "" {
|
|
cfg.AllowedHost = cfg.AppDomain
|
|
}
|
|
if cfg.CORSOrigin == "" {
|
|
cfg.CORSOrigin = "https://" + cfg.AppDomain
|
|
}
|
|
if cfg.JWTSecret == "dev-change-me-immediately" {
|
|
log.Println("warning: JWT_SECRET is using default development value")
|
|
}
|
|
if strings.TrimSpace(cfg.AdminPasswordHash) == "" {
|
|
log.Fatal("ADMIN_PASSWORD_HASH is required. Generate one with: go run . hash-admin <password>")
|
|
}
|
|
|
|
return cfg
|
|
}
|
|
|
|
func openDB(path string) (*sql.DB, error) {
|
|
db, err := sql.Open("sqlite", path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := db.Ping(); err != nil {
|
|
return nil, err
|
|
}
|
|
return db, nil
|
|
}
|
|
|
|
func migrate(db *sql.DB) error {
|
|
stmts := []string{
|
|
`CREATE TABLE IF NOT EXISTS users (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
email TEXT NOT NULL UNIQUE,
|
|
password_hash TEXT NOT NULL,
|
|
theme TEXT NOT NULL DEFAULT 'dracula',
|
|
color_mode TEXT NOT NULL DEFAULT 'auto',
|
|
archive_format TEXT NOT NULL DEFAULT 'zip',
|
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
|
);`,
|
|
`CREATE TABLE IF NOT EXISTS refresh_tokens (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_id INTEGER NOT NULL,
|
|
token_hash TEXT NOT NULL,
|
|
expires_at TIMESTAMP NOT NULL,
|
|
revoked_at TIMESTAMP,
|
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
user_agent TEXT,
|
|
ip TEXT,
|
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
);`,
|
|
`CREATE INDEX IF NOT EXISTS idx_refresh_tokens_hash ON refresh_tokens(token_hash);`,
|
|
`CREATE TABLE IF NOT EXISTS share_links (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_id INTEGER NOT NULL,
|
|
rel_path TEXT NOT NULL,
|
|
token_hash TEXT NOT NULL UNIQUE,
|
|
expires_at TIMESTAMP NOT NULL,
|
|
max_downloads INTEGER,
|
|
download_count INTEGER NOT NULL DEFAULT 0,
|
|
revoked_at TIMESTAMP,
|
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
);`,
|
|
`CREATE INDEX IF NOT EXISTS idx_share_links_hash ON share_links(token_hash);`,
|
|
`CREATE TABLE IF NOT EXISTS file_tags (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_id INTEGER NOT NULL,
|
|
rel_path TEXT NOT NULL,
|
|
tag TEXT NOT NULL,
|
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
UNIQUE(user_id, rel_path, tag),
|
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
);`,
|
|
`CREATE INDEX IF NOT EXISTS idx_file_tags_user_path ON file_tags(user_id, rel_path);`,
|
|
`CREATE INDEX IF NOT EXISTS idx_file_tags_user_tag ON file_tags(user_id, tag);`,
|
|
}
|
|
|
|
for _, stmt := range stmts {
|
|
if _, err := db.Exec(stmt); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if _, err := db.Exec(`ALTER TABLE users ADD COLUMN color_mode TEXT NOT NULL DEFAULT 'auto'`); err != nil {
|
|
if !strings.Contains(strings.ToLower(err.Error()), "duplicate column") {
|
|
return err
|
|
}
|
|
}
|
|
if _, err := db.Exec(`ALTER TABLE users ADD COLUMN archive_format TEXT NOT NULL DEFAULT 'zip'`); err != nil {
|
|
if !strings.Contains(strings.ToLower(err.Error()), "duplicate column") {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func buildStorage(cfg Config) (Storage, error) {
|
|
if strings.EqualFold(cfg.StorageBackend, "smb") {
|
|
if cfg.SMBHost == "" || cfg.SMBShare == "" || cfg.SMBUser == "" {
|
|
return nil, fmt.Errorf("SMB_HOST, SMB_SHARE, SMB_USER must be set for smb backend")
|
|
}
|
|
return &SMBStorage{cfg: cfg}, nil
|
|
}
|
|
|
|
root := cfg.StorageRoot
|
|
if err := os.MkdirAll(root, 0o755); err != nil {
|
|
return nil, err
|
|
}
|
|
return &LocalStorage{root: root}, nil
|
|
}
|
|
|
|
func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request) {
|
|
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
|
|
}
|
|
|
|
func (s *Server) staticHandler() http.Handler {
|
|
webRoot, err := fs.Sub(embeddedWeb, "web/dist")
|
|
if err != nil {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
http.Error(w, "web assets unavailable", http.StatusServiceUnavailable)
|
|
})
|
|
}
|
|
|
|
fileServer := http.FileServer(http.FS(webRoot))
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if strings.HasPrefix(r.URL.Path, "/api/") {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
|
|
requestPath := strings.TrimPrefix(path.Clean(r.URL.Path), "/")
|
|
if requestPath == "." || requestPath == "" {
|
|
requestPath = "index.html"
|
|
}
|
|
|
|
if _, statErr := fs.Stat(webRoot, requestPath); statErr == nil {
|
|
fileServer.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
index, readErr := fs.ReadFile(webRoot, "index.html")
|
|
if readErr != nil {
|
|
http.Error(w, "web app is not bundled", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write(index)
|
|
})
|
|
}
|
|
|
|
func (s *Server) handleRegisterDisabled(w http.ResponseWriter, _ *http.Request) {
|
|
writeErr(w, http.StatusForbidden, "public registration is disabled; ask an administrator")
|
|
}
|
|
|
|
type authInput struct {
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
|
|
var in authInput
|
|
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
|
|
writeErr(w, http.StatusBadRequest, "invalid payload")
|
|
return
|
|
}
|
|
|
|
user, hash, err := s.findUserWithHash(in.Username)
|
|
if err != nil {
|
|
log.Printf("auth.login.failed username=%q ip=%q reason=%q", in.Username, clientIP(r), "user_not_found")
|
|
writeErr(w, http.StatusUnauthorized, "invalid credentials")
|
|
return
|
|
}
|
|
|
|
if !verifyPasswordHash(hash, in.Password) {
|
|
log.Printf("auth.login.failed username=%q ip=%q reason=%q", in.Username, clientIP(r), "invalid_password")
|
|
writeErr(w, http.StatusUnauthorized, "invalid credentials")
|
|
return
|
|
}
|
|
|
|
if err := s.issueUserSession(w, r, user.ID); err != nil {
|
|
log.Printf("auth.login.failed username=%q user_id=%d ip=%q reason=%q", user.Username, user.ID, clientIP(r), "session_issue_failed")
|
|
writeErr(w, http.StatusInternalServerError, "failed to create session")
|
|
return
|
|
}
|
|
|
|
log.Printf("auth.login.success user_id=%d username=%q ip=%q", user.ID, user.Username, clientIP(r))
|
|
|
|
writeJSON(w, http.StatusOK, user)
|
|
}
|
|
|
|
func (s *Server) handleRefresh(w http.ResponseWriter, r *http.Request) {
|
|
rt, err := r.Cookie("refresh_token")
|
|
if err != nil || rt.Value == "" {
|
|
writeErr(w, http.StatusUnauthorized, "missing refresh token")
|
|
return
|
|
}
|
|
|
|
uid, err := s.consumeRefreshToken(rt.Value)
|
|
if err != nil {
|
|
writeErr(w, http.StatusUnauthorized, "invalid refresh token")
|
|
return
|
|
}
|
|
|
|
if err := s.issueUserSession(w, r, uid); err != nil {
|
|
writeErr(w, http.StatusInternalServerError, "failed to refresh session")
|
|
return
|
|
}
|
|
|
|
writeJSON(w, http.StatusOK, map[string]string{"status": "refreshed"})
|
|
}
|
|
|
|
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
|
|
if cookie, err := r.Cookie("access_token"); err == nil && cookie.Value != "" {
|
|
claims := &AccessClaims{}
|
|
if tkn, parseErr := jwt.ParseWithClaims(cookie.Value, claims, func(token *jwt.Token) (any, error) {
|
|
return []byte(s.config.JWTSecret), nil
|
|
}); parseErr == nil && tkn.Valid {
|
|
log.Printf("auth.logout user_id=%d ip=%q", claims.UserID, clientIP(r))
|
|
}
|
|
}
|
|
|
|
rt, _ := r.Cookie("refresh_token")
|
|
if rt != nil && rt.Value != "" {
|
|
_, _ = s.db.Exec(`UPDATE refresh_tokens SET revoked_at = CURRENT_TIMESTAMP WHERE token_hash = ?`, hashToken(rt.Value))
|
|
}
|
|
|
|
clearCookie(w, "access_token", s.config.CookieSecure)
|
|
clearCookie(w, "refresh_token", s.config.CookieSecure)
|
|
writeJSON(w, http.StatusOK, map[string]string{"status": "logged_out"})
|
|
}
|
|
|
|
type adminLoginInput struct {
|
|
Login string `json:"login"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
func (s *Server) handleAdminLogin(w http.ResponseWriter, r *http.Request) {
|
|
var in adminLoginInput
|
|
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
|
|
writeErr(w, http.StatusBadRequest, "invalid payload")
|
|
return
|
|
}
|
|
|
|
if subtleConstantTimeEq(strings.TrimSpace(in.Login), s.config.AdminLogin) == 0 || !verifyAdminPasswordHash(s.config.AdminPasswordHash, in.Password) {
|
|
log.Printf("auth.admin_login.failed login=%q ip=%q", strings.TrimSpace(in.Login), clientIP(r))
|
|
writeErr(w, http.StatusUnauthorized, "invalid admin credentials")
|
|
return
|
|
}
|
|
|
|
now := time.Now()
|
|
claims := AdminClaims{
|
|
Login: s.config.AdminLogin,
|
|
Role: "admin",
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(s.config.AdminSessionTTL)),
|
|
Subject: s.config.AdminLogin,
|
|
},
|
|
}
|
|
|
|
token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(s.config.JWTSecret))
|
|
if err != nil {
|
|
writeErr(w, http.StatusInternalServerError, "failed to issue admin session")
|
|
return
|
|
}
|
|
|
|
setCookie(w, "admin_token", token, int(s.config.AdminSessionTTL.Seconds()), s.config.CookieSecure)
|
|
log.Printf("auth.admin_login.success login=%q ip=%q", s.config.AdminLogin, clientIP(r))
|
|
writeJSON(w, http.StatusOK, map[string]string{"login": s.config.AdminLogin})
|
|
}
|
|
|
|
func (s *Server) handleAdminLogout(w http.ResponseWriter, _ *http.Request) {
|
|
log.Printf("auth.admin_logout")
|
|
clearCookie(w, "admin_token", s.config.CookieSecure)
|
|
writeJSON(w, http.StatusOK, map[string]string{"status": "logged_out"})
|
|
}
|
|
|
|
func (s *Server) handleMe(w http.ResponseWriter, r *http.Request) {
|
|
uid := userIDFromContext(r.Context())
|
|
user, err := s.findUser(uid)
|
|
if err != nil {
|
|
writeErr(w, http.StatusNotFound, "user not found")
|
|
return
|
|
}
|
|
writeJSON(w, http.StatusOK, user)
|
|
}
|
|
|
|
type userPrefInput struct {
|
|
Theme string `json:"theme"`
|
|
ColorMode string `json:"colorMode"`
|
|
ArchiveFmt string `json:"archiveFormat"`
|
|
}
|
|
|
|
func (s *Server) handleSetPreferences(w http.ResponseWriter, r *http.Request) {
|
|
uid := userIDFromContext(r.Context())
|
|
var in userPrefInput
|
|
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
|
|
writeErr(w, http.StatusBadRequest, "invalid payload")
|
|
return
|
|
}
|
|
|
|
in.Theme = normalizeTheme(in.Theme)
|
|
in.ColorMode = normalizeColorMode(in.ColorMode)
|
|
in.ArchiveFmt = normalizeArchiveFormat(in.ArchiveFmt)
|
|
if in.ArchiveFmt == "" {
|
|
in.ArchiveFmt = "zip"
|
|
}
|
|
|
|
if _, err := s.db.Exec(`UPDATE users SET theme = ?, color_mode = ?, archive_format = ? WHERE id = ?`, in.Theme, in.ColorMode, in.ArchiveFmt, uid); err != nil {
|
|
writeErr(w, http.StatusInternalServerError, "failed to update preferences")
|
|
return
|
|
}
|
|
|
|
writeJSON(w, http.StatusOK, map[string]string{"theme": in.Theme, "colorMode": in.ColorMode, "archiveFormat": in.ArchiveFmt})
|
|
}
|
|
|
|
func (s *Server) handleListFiles(w http.ResponseWriter, r *http.Request) {
|
|
uid := userIDFromContext(r.Context())
|
|
rel := r.URL.Query().Get("path")
|
|
log.Printf("file.list user_id=%d path=%q", uid, normalizePath(rel))
|
|
entries, err := s.storage.List(uid, rel)
|
|
if err != nil {
|
|
writeErr(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
if len(entries) > 0 {
|
|
paths := make([]string, 0, len(entries))
|
|
for _, e := range entries {
|
|
paths = append(paths, normalizePath(e.Path))
|
|
}
|
|
tagsByPath, tagErr := s.fileTagsForPaths(uid, paths)
|
|
if tagErr == nil {
|
|
for i := range entries {
|
|
entries[i].Tags = tagsByPath[normalizePath(entries[i].Path)]
|
|
}
|
|
}
|
|
}
|
|
writeJSON(w, http.StatusOK, map[string]any{"path": normalizePath(rel), "entries": entries})
|
|
}
|
|
|
|
func (s *Server) handleUpload(w http.ResponseWriter, r *http.Request) {
|
|
uid := userIDFromContext(r.Context())
|
|
relDir := r.URL.Query().Get("path")
|
|
|
|
if err := r.ParseMultipartForm(256 << 20); err != nil {
|
|
writeErr(w, http.StatusBadRequest, "failed to parse upload")
|
|
return
|
|
}
|
|
|
|
files := r.MultipartForm.File["file"]
|
|
if len(files) == 0 {
|
|
writeErr(w, http.StatusBadRequest, "no files provided")
|
|
return
|
|
}
|
|
|
|
for _, fh := range files {
|
|
log.Printf("file.upload user_id=%d dir=%q name=%q size=%d", uid, normalizePath(relDir), fh.Filename, fh.Size)
|
|
src, err := fh.Open()
|
|
if err != nil {
|
|
writeErr(w, http.StatusBadRequest, fmt.Sprintf("cannot open %s", fh.Filename))
|
|
return
|
|
}
|
|
relName, err := normalizeUploadRelativePath(fh.Filename)
|
|
if err != nil {
|
|
_ = src.Close()
|
|
writeErr(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
target := path.Join(normalizePath(relDir), relName)
|
|
err = s.storage.Save(uid, target, src)
|
|
_ = src.Close()
|
|
if err != nil {
|
|
writeErr(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
}
|
|
|
|
writeJSON(w, http.StatusCreated, map[string]string{"status": "uploaded"})
|
|
}
|
|
|
|
func (s *Server) handleDownload(w http.ResponseWriter, r *http.Request) {
|
|
uid := userIDFromContext(r.Context())
|
|
rel := r.URL.Query().Get("path")
|
|
log.Printf("file.download user_id=%d path=%q", uid, normalizePath(rel))
|
|
if err := s.serveFile(w, r, uid, rel, false, r.URL.Query().Get("archive")); err != nil {
|
|
writeErr(w, http.StatusBadRequest, err.Error())
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleBatchDownload(w http.ResponseWriter, r *http.Request) {
|
|
uid := userIDFromContext(r.Context())
|
|
var in struct {
|
|
Paths []string `json:"paths"`
|
|
Archive string `json:"archive"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
|
|
writeErr(w, http.StatusBadRequest, "invalid payload")
|
|
return
|
|
}
|
|
if len(in.Paths) == 0 {
|
|
writeErr(w, http.StatusBadRequest, "no paths selected")
|
|
return
|
|
}
|
|
if len(in.Paths) > 200 {
|
|
writeErr(w, http.StatusBadRequest, "too many selected paths")
|
|
return
|
|
}
|
|
|
|
format, err := s.resolveArchiveFormat(uid, in.Archive)
|
|
if err != nil {
|
|
writeErr(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
archivePath, name, ctype, err := s.createBatchArchiveTemp(uid, in.Paths, format)
|
|
if err != nil {
|
|
writeErr(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
defer os.Remove(archivePath)
|
|
|
|
f, err := os.Open(archivePath)
|
|
if err != nil {
|
|
writeErr(w, http.StatusInternalServerError, "failed to open archive")
|
|
return
|
|
}
|
|
defer f.Close()
|
|
st, err := f.Stat()
|
|
if err != nil {
|
|
writeErr(w, http.StatusInternalServerError, "failed to stat archive")
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Accept-Ranges", "bytes")
|
|
w.Header().Set("Content-Type", ctype)
|
|
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name))
|
|
http.ServeContent(w, r, name, st.ModTime(), f)
|
|
}
|
|
|
|
func (s *Server) handleBatchMove(w http.ResponseWriter, r *http.Request) {
|
|
uid := userIDFromContext(r.Context())
|
|
var in struct {
|
|
Paths []string `json:"paths"`
|
|
Destination string `json:"destination"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
|
|
writeErr(w, http.StatusBadRequest, "invalid payload")
|
|
return
|
|
}
|
|
if len(in.Paths) == 0 {
|
|
writeErr(w, http.StatusBadRequest, "no paths selected")
|
|
return
|
|
}
|
|
destDir := normalizePath(in.Destination)
|
|
if err := s.storage.Mkdir(uid, destDir); err != nil {
|
|
writeErr(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
moved := 0
|
|
for _, p := range in.Paths {
|
|
src := normalizePath(p)
|
|
if src == "/" {
|
|
continue
|
|
}
|
|
meta, err := s.storage.Stat(uid, src)
|
|
if err != nil {
|
|
writeErr(w, http.StatusBadRequest, "invalid path")
|
|
return
|
|
}
|
|
if meta.IsDir {
|
|
srcPrefix := strings.TrimSuffix(src, "/") + "/"
|
|
if destDir == src || strings.HasPrefix(destDir, srcPrefix) {
|
|
writeErr(w, http.StatusBadRequest, "cannot move a folder into itself")
|
|
return
|
|
}
|
|
}
|
|
base := path.Base(src)
|
|
dst := path.Join(destDir, base)
|
|
dst, err = s.nextMoveTarget(uid, dst)
|
|
if err != nil {
|
|
writeErr(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
if err := s.copyPath(uid, src, dst); err != nil {
|
|
writeErr(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
if err := s.storage.Delete(uid, src); err != nil {
|
|
writeErr(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
s.moveTags(uid, src, dst)
|
|
moved++
|
|
}
|
|
|
|
writeJSON(w, http.StatusOK, map[string]any{"status": "moved", "count": moved})
|
|
}
|
|
|
|
func (s *Server) handlePreview(w http.ResponseWriter, r *http.Request) {
|
|
uid := userIDFromContext(r.Context())
|
|
rel := r.URL.Query().Get("path")
|
|
log.Printf("file.preview user_id=%d path=%q", uid, normalizePath(rel))
|
|
if err := s.serveFile(w, r, uid, rel, true, ""); err != nil {
|
|
writeErr(w, http.StatusBadRequest, err.Error())
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleReadTextFile(w http.ResponseWriter, r *http.Request) {
|
|
uid := userIDFromContext(r.Context())
|
|
rel := normalizePath(r.URL.Query().Get("path"))
|
|
if rel == "/" {
|
|
writeErr(w, http.StatusBadRequest, "path is required")
|
|
return
|
|
}
|
|
if !isMarkdownExtension(path.Ext(rel)) {
|
|
writeErr(w, http.StatusBadRequest, "only markdown files are editable")
|
|
return
|
|
}
|
|
meta, err := s.storage.Stat(uid, rel)
|
|
if err != nil || meta.IsDir {
|
|
writeErr(w, http.StatusBadRequest, "file not found")
|
|
return
|
|
}
|
|
rc, err := s.storage.OpenReadSeeker(uid, rel)
|
|
if err != nil {
|
|
writeErr(w, http.StatusBadRequest, "file not found")
|
|
return
|
|
}
|
|
defer rc.Close()
|
|
|
|
const maxTextBytes = 2 << 20
|
|
data, err := io.ReadAll(io.LimitReader(rc, maxTextBytes+1))
|
|
if err != nil {
|
|
writeErr(w, http.StatusInternalServerError, "failed to read file")
|
|
return
|
|
}
|
|
if len(data) > maxTextBytes {
|
|
writeErr(w, http.StatusBadRequest, "markdown file is too large (max 2MB)")
|
|
return
|
|
}
|
|
writeJSON(w, http.StatusOK, map[string]any{"path": rel, "content": string(data), "size": len(data)})
|
|
}
|
|
|
|
func (s *Server) handleWriteTextFile(w http.ResponseWriter, r *http.Request) {
|
|
uid := userIDFromContext(r.Context())
|
|
var in struct {
|
|
Path string `json:"path"`
|
|
Content string `json:"content"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
|
|
writeErr(w, http.StatusBadRequest, "invalid payload")
|
|
return
|
|
}
|
|
rel := normalizePath(in.Path)
|
|
if rel == "/" {
|
|
writeErr(w, http.StatusBadRequest, "path is required")
|
|
return
|
|
}
|
|
if !isMarkdownExtension(path.Ext(rel)) {
|
|
writeErr(w, http.StatusBadRequest, "only markdown files are editable")
|
|
return
|
|
}
|
|
if len(in.Content) > 2<<20 {
|
|
writeErr(w, http.StatusBadRequest, "markdown file is too large (max 2MB)")
|
|
return
|
|
}
|
|
if err := s.storage.SaveBytes(uid, rel, []byte(in.Content)); err != nil {
|
|
writeErr(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
writeJSON(w, http.StatusOK, map[string]any{"path": rel, "status": "saved"})
|
|
}
|
|
|
|
func (s *Server) handleListFileTags(w http.ResponseWriter, r *http.Request) {
|
|
uid := userIDFromContext(r.Context())
|
|
rel := normalizePath(r.URL.Query().Get("path"))
|
|
rows, err := s.db.Query(`SELECT tag FROM file_tags WHERE user_id = ? AND rel_path = ? ORDER BY tag ASC`, uid, rel)
|
|
if err != nil {
|
|
writeErr(w, http.StatusInternalServerError, "failed to load tags")
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
tags := make([]string, 0)
|
|
for rows.Next() {
|
|
var tag string
|
|
if rows.Scan(&tag) == nil {
|
|
tags = append(tags, tag)
|
|
}
|
|
}
|
|
writeJSON(w, http.StatusOK, map[string]any{"path": rel, "tags": tags})
|
|
}
|
|
|
|
func (s *Server) handleAddFileTag(w http.ResponseWriter, r *http.Request) {
|
|
uid := userIDFromContext(r.Context())
|
|
var in struct {
|
|
Path string `json:"path"`
|
|
Tag string `json:"tag"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
|
|
writeErr(w, http.StatusBadRequest, "invalid payload")
|
|
return
|
|
}
|
|
rel := normalizePath(in.Path)
|
|
tag, ok := normalizeTag(in.Tag)
|
|
if !ok {
|
|
writeErr(w, http.StatusBadRequest, "tag must be 1-24 chars: a-z 0-9 dash underscore")
|
|
return
|
|
}
|
|
if _, err := s.storage.Stat(uid, rel); err != nil {
|
|
writeErr(w, http.StatusBadRequest, "file not found")
|
|
return
|
|
}
|
|
if _, err := s.db.Exec(`INSERT OR IGNORE INTO file_tags(user_id, rel_path, tag) VALUES (?, ?, ?)`, uid, rel, tag); err != nil {
|
|
writeErr(w, http.StatusInternalServerError, "failed to save tag")
|
|
return
|
|
}
|
|
writeJSON(w, http.StatusCreated, map[string]any{"path": rel, "tag": tag})
|
|
}
|
|
|
|
func (s *Server) handleDeleteFileTag(w http.ResponseWriter, r *http.Request) {
|
|
uid := userIDFromContext(r.Context())
|
|
rel := normalizePath(r.URL.Query().Get("path"))
|
|
tag, ok := normalizeTag(r.URL.Query().Get("tag"))
|
|
if !ok {
|
|
writeErr(w, http.StatusBadRequest, "invalid tag")
|
|
return
|
|
}
|
|
if _, err := s.db.Exec(`DELETE FROM file_tags WHERE user_id = ? AND rel_path = ? AND tag = ?`, uid, rel, tag); err != nil {
|
|
writeErr(w, http.StatusInternalServerError, "failed to delete tag")
|
|
return
|
|
}
|
|
writeJSON(w, http.StatusOK, map[string]any{"path": rel, "tag": tag, "status": "deleted"})
|
|
}
|
|
|
|
func (s *Server) serveFile(w http.ResponseWriter, r *http.Request, uid int64, rel string, inline bool, archiveQuery string) error {
|
|
meta, err := s.storage.Stat(uid, rel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if meta.IsDir {
|
|
if inline {
|
|
return fmt.Errorf("cannot preview a directory")
|
|
}
|
|
format, err := s.resolveArchiveFormat(uid, archiveQuery)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.serveDirectoryArchive(w, r, uid, rel, format)
|
|
}
|
|
|
|
rc, err := s.storage.OpenReadSeeker(uid, rel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rc.Close()
|
|
|
|
name := path.Base(normalizePath(rel))
|
|
if name == "." || name == "/" || name == "" {
|
|
name = "download"
|
|
}
|
|
ctype := mime.TypeByExtension(strings.ToLower(filepath.Ext(name)))
|
|
if ctype == "" {
|
|
ctype = "application/octet-stream"
|
|
}
|
|
dispositionType := "attachment"
|
|
if inline {
|
|
dispositionType = "inline"
|
|
}
|
|
|
|
w.Header().Set("Accept-Ranges", "bytes")
|
|
w.Header().Set("Content-Type", ctype)
|
|
w.Header().Set("Content-Disposition", fmt.Sprintf("%s; filename=%q", dispositionType, name))
|
|
http.ServeContent(w, r, name, meta.ModTime, rc)
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) resolveArchiveFormat(uid int64, archiveQuery string) (string, error) {
|
|
requested := normalizeArchiveFormat(archiveQuery)
|
|
if requested != "" {
|
|
return requested, nil
|
|
}
|
|
var fmtPref string
|
|
err := s.db.QueryRow(`SELECT archive_format FROM users WHERE id = ?`, uid).Scan(&fmtPref)
|
|
if err != nil {
|
|
return "zip", nil
|
|
}
|
|
fmtPref = normalizeArchiveFormat(fmtPref)
|
|
if fmtPref == "" {
|
|
return "zip", nil
|
|
}
|
|
return fmtPref, nil
|
|
}
|
|
|
|
func (s *Server) serveDirectoryArchive(w http.ResponseWriter, r *http.Request, uid int64, rel, format string) error {
|
|
base := path.Base(normalizePath(rel))
|
|
if base == "/" || base == "." || base == "" {
|
|
base = "folder"
|
|
}
|
|
|
|
archivePath, downloadName, ctype, err := s.createArchiveTemp(uid, rel, base, format)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer os.Remove(archivePath)
|
|
|
|
f, err := os.Open(archivePath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
st, err := f.Stat()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
w.Header().Set("Accept-Ranges", "bytes")
|
|
w.Header().Set("Content-Type", ctype)
|
|
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", downloadName))
|
|
http.ServeContent(w, r, downloadName, st.ModTime(), f)
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) createArchiveTemp(uid int64, rel, base, format string) (string, string, string, error) {
|
|
switch format {
|
|
case "rar":
|
|
return s.createRarTemp(uid, rel, base)
|
|
case "tar.gz":
|
|
return s.createTarGzTemp(uid, rel, base)
|
|
case "lz4":
|
|
return s.createTarLz4Temp(uid, rel, base)
|
|
default:
|
|
return s.createZipTemp(uid, rel, base)
|
|
}
|
|
}
|
|
|
|
func (s *Server) createBatchArchiveTemp(uid int64, paths []string, format string) (string, string, string, error) {
|
|
workdir, err := os.MkdirTemp("", "filez-batch-*")
|
|
if err != nil {
|
|
return "", "", "", err
|
|
}
|
|
defer os.RemoveAll(workdir)
|
|
|
|
root := filepath.Join(workdir, "selection")
|
|
if err := os.MkdirAll(root, 0o755); err != nil {
|
|
return "", "", "", err
|
|
}
|
|
|
|
used := map[string]int{}
|
|
for _, raw := range paths {
|
|
rel := normalizePath(raw)
|
|
if rel == "/" {
|
|
continue
|
|
}
|
|
meta, err := s.storage.Stat(uid, rel)
|
|
if err != nil {
|
|
return "", "", "", fmt.Errorf("invalid path: %s", rel)
|
|
}
|
|
base := path.Base(rel)
|
|
if base == "." || base == "/" || base == "" {
|
|
base = "item"
|
|
}
|
|
targetName := uniqueArchiveName(base, used)
|
|
target := filepath.Join(root, targetName)
|
|
if meta.IsDir {
|
|
if err := os.MkdirAll(target, 0o755); err != nil {
|
|
return "", "", "", err
|
|
}
|
|
}
|
|
if err := s.materializePath(uid, rel, target); err != nil {
|
|
return "", "", "", err
|
|
}
|
|
}
|
|
|
|
baseName := "files"
|
|
switch format {
|
|
case "tar.gz":
|
|
return createTarGzFromLocalDir(root, baseName)
|
|
case "lz4":
|
|
return createTarLz4FromLocalDir(root, baseName)
|
|
case "rar":
|
|
return createRarFromLocalDir(root, baseName)
|
|
default:
|
|
return createZipFromLocalDir(root, baseName)
|
|
}
|
|
}
|
|
|
|
func (s *Server) createZipTemp(uid int64, rel, base string) (string, string, string, error) {
|
|
tmp, err := os.CreateTemp("", "filez-*.zip")
|
|
if err != nil {
|
|
return "", "", "", err
|
|
}
|
|
defer tmp.Close()
|
|
|
|
zw := zip.NewWriter(tmp)
|
|
if err := s.addPathToZip(zw, uid, rel, base); err != nil {
|
|
zw.Close()
|
|
return "", "", "", err
|
|
}
|
|
if err := zw.Close(); err != nil {
|
|
return "", "", "", err
|
|
}
|
|
return tmp.Name(), base + ".zip", "application/zip", nil
|
|
}
|
|
|
|
func (s *Server) addPathToZip(zw *zip.Writer, uid int64, rel string, zipBase string) error {
|
|
meta, err := s.storage.Stat(uid, rel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if meta.IsDir {
|
|
if zipBase != "" {
|
|
hdr := &zip.FileHeader{Name: strings.TrimPrefix(zipBase, "/") + "/", Method: zip.Store}
|
|
hdr.Modified = meta.ModTime
|
|
if _, err := zw.CreateHeader(hdr); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
entries, err := s.storage.List(uid, rel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, e := range entries {
|
|
childRel := path.Join(normalizePath(rel), e.Name)
|
|
childBase := path.Join(zipBase, e.Name)
|
|
if err := s.addPathToZip(zw, uid, childRel, childBase); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
rc, err := s.storage.OpenReadSeeker(uid, rel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rc.Close()
|
|
|
|
hdr := &zip.FileHeader{Name: strings.TrimPrefix(zipBase, "/"), Method: zip.Deflate}
|
|
hdr.Modified = meta.ModTime
|
|
writer, err := zw.CreateHeader(hdr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = io.Copy(writer, rc)
|
|
return err
|
|
}
|
|
|
|
func (s *Server) createRarTemp(uid int64, rel, base string) (string, string, string, error) {
|
|
if _, err := exec.LookPath("rar"); err != nil {
|
|
return "", "", "", fmt.Errorf("rar format requires 'rar' binary on server; choose zip or tar.gz in settings")
|
|
}
|
|
|
|
workdir, err := os.MkdirTemp("", "filez-rar-*")
|
|
if err != nil {
|
|
return "", "", "", err
|
|
}
|
|
|
|
root := filepath.Join(workdir, base)
|
|
if err := os.MkdirAll(root, 0o755); err != nil {
|
|
os.RemoveAll(workdir)
|
|
return "", "", "", err
|
|
}
|
|
|
|
if err := s.materializePath(uid, rel, root); err != nil {
|
|
os.RemoveAll(workdir)
|
|
return "", "", "", err
|
|
}
|
|
|
|
outTmp, err := os.CreateTemp("", "filez-*.rar")
|
|
if err != nil {
|
|
os.RemoveAll(workdir)
|
|
return "", "", "", err
|
|
}
|
|
archivePath := outTmp.Name()
|
|
outTmp.Close()
|
|
|
|
cmd := exec.Command("rar", "a", "-idq", archivePath, root)
|
|
if out, err := cmd.CombinedOutput(); err != nil {
|
|
os.Remove(archivePath)
|
|
os.RemoveAll(workdir)
|
|
return "", "", "", fmt.Errorf("rar creation failed: %s", strings.TrimSpace(string(out)))
|
|
}
|
|
os.RemoveAll(workdir)
|
|
|
|
return archivePath, base + ".rar", "application/vnd.rar", nil
|
|
}
|
|
|
|
func (s *Server) createTarGzTemp(uid int64, rel, base string) (string, string, string, error) {
|
|
tmp, err := os.CreateTemp("", "filez-*.tar.gz")
|
|
if err != nil {
|
|
return "", "", "", err
|
|
}
|
|
defer tmp.Close()
|
|
|
|
gzw := gzip.NewWriter(tmp)
|
|
tw := tar.NewWriter(gzw)
|
|
if err := s.addPathToTar(tw, uid, rel, base); err != nil {
|
|
tw.Close()
|
|
gzw.Close()
|
|
return "", "", "", err
|
|
}
|
|
if err := tw.Close(); err != nil {
|
|
gzw.Close()
|
|
return "", "", "", err
|
|
}
|
|
if err := gzw.Close(); err != nil {
|
|
return "", "", "", err
|
|
}
|
|
return tmp.Name(), base + ".tar.gz", "application/gzip", nil
|
|
}
|
|
|
|
func (s *Server) createTarLz4Temp(uid int64, rel, base string) (string, string, string, error) {
|
|
if _, err := exec.LookPath("lz4"); err != nil {
|
|
return "", "", "", fmt.Errorf("lz4 format requires 'lz4' binary on server; choose zip or tar.gz in settings")
|
|
}
|
|
tarPath, err := s.createTarTemp(uid, rel, base)
|
|
if err != nil {
|
|
return "", "", "", err
|
|
}
|
|
defer os.Remove(tarPath)
|
|
|
|
outTmp, err := os.CreateTemp("", "filez-*.tar.lz4")
|
|
if err != nil {
|
|
return "", "", "", err
|
|
}
|
|
outPath := outTmp.Name()
|
|
outTmp.Close()
|
|
|
|
cmd := exec.Command("lz4", "-z", "-q", tarPath, outPath)
|
|
if out, err := cmd.CombinedOutput(); err != nil {
|
|
os.Remove(outPath)
|
|
return "", "", "", fmt.Errorf("lz4 creation failed: %s", strings.TrimSpace(string(out)))
|
|
}
|
|
|
|
return outPath, base + ".tar.lz4", "application/x-lz4", nil
|
|
}
|
|
|
|
func (s *Server) createTarTemp(uid int64, rel, base string) (string, error) {
|
|
tmp, err := os.CreateTemp("", "filez-*.tar")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer tmp.Close()
|
|
|
|
tw := tar.NewWriter(tmp)
|
|
if err := s.addPathToTar(tw, uid, rel, base); err != nil {
|
|
tw.Close()
|
|
return "", err
|
|
}
|
|
if err := tw.Close(); err != nil {
|
|
return "", err
|
|
}
|
|
return tmp.Name(), nil
|
|
}
|
|
|
|
func (s *Server) addPathToTar(tw *tar.Writer, uid int64, rel string, tarBase string) error {
|
|
meta, err := s.storage.Stat(uid, rel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if meta.IsDir {
|
|
if tarBase != "" {
|
|
hdr := &tar.Header{Name: strings.TrimPrefix(tarBase, "/") + "/", Mode: 0o755, ModTime: meta.ModTime, Typeflag: tar.TypeDir}
|
|
if err := tw.WriteHeader(hdr); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
entries, err := s.storage.List(uid, rel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, e := range entries {
|
|
childRel := path.Join(normalizePath(rel), e.Name)
|
|
childBase := path.Join(tarBase, e.Name)
|
|
if err := s.addPathToTar(tw, uid, childRel, childBase); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
rc, err := s.storage.OpenReadSeeker(uid, rel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rc.Close()
|
|
|
|
hdr := &tar.Header{Name: strings.TrimPrefix(tarBase, "/"), Mode: 0o644, Size: meta.Size, ModTime: meta.ModTime, Typeflag: tar.TypeReg}
|
|
if err := tw.WriteHeader(hdr); err != nil {
|
|
return err
|
|
}
|
|
_, err = io.Copy(tw, rc)
|
|
return err
|
|
}
|
|
|
|
func (s *Server) materializePath(uid int64, rel string, localPath string) error {
|
|
meta, err := s.storage.Stat(uid, rel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if meta.IsDir {
|
|
entries, err := s.storage.List(uid, rel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, e := range entries {
|
|
childRel := path.Join(normalizePath(rel), e.Name)
|
|
childLocal := filepath.Join(localPath, e.Name)
|
|
if e.IsDir {
|
|
if err := os.MkdirAll(childLocal, 0o755); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if err := s.materializePath(uid, childRel, childLocal); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
rc, err := s.storage.OpenReadSeeker(uid, rel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rc.Close()
|
|
|
|
out, err := os.Create(localPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer out.Close()
|
|
_, err = io.Copy(out, rc)
|
|
return err
|
|
}
|
|
|
|
func (s *Server) copyPath(uid int64, src, dst string) error {
|
|
meta, err := s.storage.Stat(uid, src)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if meta.IsDir {
|
|
if err := s.storage.Mkdir(uid, dst); err != nil {
|
|
return err
|
|
}
|
|
entries, err := s.storage.List(uid, src)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, e := range entries {
|
|
childSrc := path.Join(src, e.Name)
|
|
childDst := path.Join(dst, e.Name)
|
|
if err := s.copyPath(uid, childSrc, childDst); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
rc, err := s.storage.OpenReadSeeker(uid, src)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rc.Close()
|
|
data, err := io.ReadAll(rc)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.storage.SaveBytes(uid, dst, data)
|
|
}
|
|
|
|
func (s *Server) nextMoveTarget(uid int64, dst string) (string, error) {
|
|
norm := normalizePath(dst)
|
|
if _, err := s.storage.Stat(uid, norm); err != nil {
|
|
return norm, nil
|
|
}
|
|
ext := path.Ext(norm)
|
|
base := strings.TrimSuffix(path.Base(norm), ext)
|
|
dir := path.Dir(norm)
|
|
for i := 2; i <= 9999; i++ {
|
|
candidate := path.Join(dir, fmt.Sprintf("%s-%d%s", base, i, ext))
|
|
if _, err := s.storage.Stat(uid, candidate); err != nil {
|
|
return candidate, nil
|
|
}
|
|
}
|
|
return "", fmt.Errorf("cannot allocate destination name")
|
|
}
|
|
|
|
func (s *Server) moveTags(uid int64, src, dst string) {
|
|
_, _ = s.db.Exec(`UPDATE file_tags SET rel_path = ? WHERE user_id = ? AND rel_path = ?`, dst, uid, src)
|
|
prefixFrom := strings.TrimSuffix(src, "/") + "/"
|
|
prefixTo := strings.TrimSuffix(dst, "/") + "/"
|
|
rows, err := s.db.Query(`SELECT rel_path, tag FROM file_tags WHERE user_id = ? AND rel_path LIKE ?`, uid, prefixFrom+"%")
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
type rowItem struct{ p, t string }
|
|
items := make([]rowItem, 0)
|
|
for rows.Next() {
|
|
var rp, tg string
|
|
if rows.Scan(&rp, &tg) == nil {
|
|
items = append(items, rowItem{p: rp, t: tg})
|
|
}
|
|
}
|
|
for _, it := range items {
|
|
next := prefixTo + strings.TrimPrefix(it.p, prefixFrom)
|
|
_, _ = s.db.Exec(`DELETE FROM file_tags WHERE user_id = ? AND rel_path = ? AND tag = ?`, uid, it.p, it.t)
|
|
_, _ = s.db.Exec(`INSERT OR IGNORE INTO file_tags(user_id, rel_path, tag) VALUES (?, ?, ?)`, uid, next, it.t)
|
|
}
|
|
}
|
|
|
|
func uniqueArchiveName(base string, used map[string]int) string {
|
|
base = strings.TrimSpace(base)
|
|
if base == "" {
|
|
base = "item"
|
|
}
|
|
count := used[base]
|
|
used[base] = count + 1
|
|
if count == 0 {
|
|
return base
|
|
}
|
|
ext := path.Ext(base)
|
|
name := strings.TrimSuffix(base, ext)
|
|
return fmt.Sprintf("%s-%d%s", name, count+1, ext)
|
|
}
|
|
|
|
func createZipFromLocalDir(root, baseName string) (string, string, string, error) {
|
|
tmp, err := os.CreateTemp("", "filez-batch-*.zip")
|
|
if err != nil {
|
|
return "", "", "", err
|
|
}
|
|
defer tmp.Close()
|
|
|
|
zw := zip.NewWriter(tmp)
|
|
err = filepath.WalkDir(root, func(p string, d fs.DirEntry, walkErr error) error {
|
|
if walkErr != nil {
|
|
return walkErr
|
|
}
|
|
rel, err := filepath.Rel(root, p)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if rel == "." {
|
|
return nil
|
|
}
|
|
rel = filepath.ToSlash(rel)
|
|
info, err := d.Info()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if d.IsDir() {
|
|
hdr := &zip.FileHeader{Name: rel + "/", Method: zip.Store}
|
|
hdr.Modified = info.ModTime()
|
|
_, err = zw.CreateHeader(hdr)
|
|
return err
|
|
}
|
|
hdr, err := zip.FileInfoHeader(info)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
hdr.Name = rel
|
|
hdr.Method = zip.Deflate
|
|
w, err := zw.CreateHeader(hdr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
f, err := os.Open(p)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
_, err = io.Copy(w, f)
|
|
return err
|
|
})
|
|
if err != nil {
|
|
zw.Close()
|
|
return "", "", "", err
|
|
}
|
|
if err := zw.Close(); err != nil {
|
|
return "", "", "", err
|
|
}
|
|
return tmp.Name(), baseName + ".zip", "application/zip", nil
|
|
}
|
|
|
|
func createTarGzFromLocalDir(root, baseName string) (string, string, string, error) {
|
|
tmp, err := os.CreateTemp("", "filez-batch-*.tar.gz")
|
|
if err != nil {
|
|
return "", "", "", err
|
|
}
|
|
defer tmp.Close()
|
|
gzw := gzip.NewWriter(tmp)
|
|
tw := tar.NewWriter(gzw)
|
|
err = filepath.WalkDir(root, func(p string, d fs.DirEntry, walkErr error) error {
|
|
if walkErr != nil {
|
|
return walkErr
|
|
}
|
|
rel, err := filepath.Rel(root, p)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if rel == "." {
|
|
return nil
|
|
}
|
|
rel = filepath.ToSlash(rel)
|
|
info, err := d.Info()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
hdr, err := tar.FileInfoHeader(info, "")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if d.IsDir() {
|
|
hdr.Name = rel + "/"
|
|
} else {
|
|
hdr.Name = rel
|
|
}
|
|
if err := tw.WriteHeader(hdr); err != nil {
|
|
return err
|
|
}
|
|
if d.IsDir() {
|
|
return nil
|
|
}
|
|
f, err := os.Open(p)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
_, err = io.Copy(tw, f)
|
|
return err
|
|
})
|
|
if err != nil {
|
|
tw.Close()
|
|
gzw.Close()
|
|
return "", "", "", err
|
|
}
|
|
if err := tw.Close(); err != nil {
|
|
gzw.Close()
|
|
return "", "", "", err
|
|
}
|
|
if err := gzw.Close(); err != nil {
|
|
return "", "", "", err
|
|
}
|
|
return tmp.Name(), baseName + ".tar.gz", "application/gzip", nil
|
|
}
|
|
|
|
func createTarFromLocalDir(root string) (string, error) {
|
|
tmp, err := os.CreateTemp("", "filez-batch-*.tar")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer tmp.Close()
|
|
tw := tar.NewWriter(tmp)
|
|
err = filepath.WalkDir(root, func(p string, d fs.DirEntry, walkErr error) error {
|
|
if walkErr != nil {
|
|
return walkErr
|
|
}
|
|
rel, err := filepath.Rel(root, p)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if rel == "." {
|
|
return nil
|
|
}
|
|
rel = filepath.ToSlash(rel)
|
|
info, err := d.Info()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
hdr, err := tar.FileInfoHeader(info, "")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if d.IsDir() {
|
|
hdr.Name = rel + "/"
|
|
} else {
|
|
hdr.Name = rel
|
|
}
|
|
if err := tw.WriteHeader(hdr); err != nil {
|
|
return err
|
|
}
|
|
if d.IsDir() {
|
|
return nil
|
|
}
|
|
f, err := os.Open(p)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
_, err = io.Copy(tw, f)
|
|
return err
|
|
})
|
|
if err != nil {
|
|
tw.Close()
|
|
return "", err
|
|
}
|
|
if err := tw.Close(); err != nil {
|
|
return "", err
|
|
}
|
|
return tmp.Name(), nil
|
|
}
|
|
|
|
func createTarLz4FromLocalDir(root, baseName string) (string, string, string, error) {
|
|
if _, err := exec.LookPath("lz4"); err != nil {
|
|
return "", "", "", fmt.Errorf("lz4 format requires 'lz4' binary on server; choose zip or tar.gz in settings")
|
|
}
|
|
tarPath, err := createTarFromLocalDir(root)
|
|
if err != nil {
|
|
return "", "", "", err
|
|
}
|
|
defer os.Remove(tarPath)
|
|
|
|
outTmp, err := os.CreateTemp("", "filez-batch-*.tar.lz4")
|
|
if err != nil {
|
|
return "", "", "", err
|
|
}
|
|
outPath := outTmp.Name()
|
|
outTmp.Close()
|
|
cmd := exec.Command("lz4", "-z", "-q", tarPath, outPath)
|
|
if out, err := cmd.CombinedOutput(); err != nil {
|
|
os.Remove(outPath)
|
|
return "", "", "", fmt.Errorf("lz4 creation failed: %s", strings.TrimSpace(string(out)))
|
|
}
|
|
return outPath, baseName + ".tar.lz4", "application/x-lz4", nil
|
|
}
|
|
|
|
func createRarFromLocalDir(root, baseName string) (string, string, string, error) {
|
|
if _, err := exec.LookPath("rar"); err != nil {
|
|
return "", "", "", fmt.Errorf("rar format requires 'rar' binary on server; choose zip or tar.gz in settings")
|
|
}
|
|
outTmp, err := os.CreateTemp("", "filez-batch-*.rar")
|
|
if err != nil {
|
|
return "", "", "", err
|
|
}
|
|
outPath := outTmp.Name()
|
|
outTmp.Close()
|
|
cmd := exec.Command("rar", "a", "-idq", outPath, root)
|
|
if out, err := cmd.CombinedOutput(); err != nil {
|
|
os.Remove(outPath)
|
|
return "", "", "", fmt.Errorf("rar creation failed: %s", strings.TrimSpace(string(out)))
|
|
}
|
|
return outPath, baseName + ".rar", "application/vnd.rar", nil
|
|
}
|
|
|
|
func (s *Server) handleDelete(w http.ResponseWriter, r *http.Request) {
|
|
uid := userIDFromContext(r.Context())
|
|
rel := r.URL.Query().Get("path")
|
|
norm := normalizePath(rel)
|
|
log.Printf("file.delete user_id=%d path=%q", uid, norm)
|
|
if err := s.storage.Delete(uid, norm); err != nil {
|
|
writeErr(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
_, _ = s.db.Exec(`DELETE FROM file_tags WHERE user_id = ? AND (rel_path = ? OR rel_path LIKE ?)`, uid, norm, strings.TrimSuffix(norm, "/")+"/%")
|
|
writeJSON(w, http.StatusOK, map[string]string{"status": "deleted"})
|
|
}
|
|
|
|
func (s *Server) handleCreateFolder(w http.ResponseWriter, r *http.Request) {
|
|
type payload struct {
|
|
Path string `json:"path"`
|
|
Name string `json:"name"`
|
|
}
|
|
var in payload
|
|
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
|
|
writeErr(w, http.StatusBadRequest, "invalid payload")
|
|
return
|
|
}
|
|
|
|
uid := userIDFromContext(r.Context())
|
|
rel := path.Join(normalizePath(in.Path), path.Base(strings.TrimSpace(in.Name)))
|
|
log.Printf("file.mkdir user_id=%d path=%q", uid, normalizePath(rel))
|
|
if rel == "/" || rel == "." {
|
|
writeErr(w, http.StatusBadRequest, "invalid folder name")
|
|
return
|
|
}
|
|
if err := s.storage.Mkdir(uid, rel); err != nil {
|
|
writeErr(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
writeJSON(w, http.StatusCreated, map[string]string{"status": "created"})
|
|
}
|
|
|
|
type shareInput struct {
|
|
Path string `json:"path"`
|
|
ExpiresMinutes int `json:"expiresMinutes"`
|
|
MaxDownloads *int `json:"maxDownloads"`
|
|
AllowPreview bool `json:"allowPreview"`
|
|
PreferredInline bool `json:"preferredInline"`
|
|
}
|
|
|
|
func (s *Server) handleCreateShareLink(w http.ResponseWriter, r *http.Request) {
|
|
uid := userIDFromContext(r.Context())
|
|
var in shareInput
|
|
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
|
|
writeErr(w, http.StatusBadRequest, "invalid payload")
|
|
return
|
|
}
|
|
|
|
rel := normalizePath(in.Path)
|
|
log.Printf("file.share.create user_id=%d path=%q", uid, rel)
|
|
meta, err := s.storage.Stat(uid, rel)
|
|
if err != nil {
|
|
writeErr(w, http.StatusBadRequest, "file not found")
|
|
return
|
|
}
|
|
if meta.IsDir {
|
|
writeErr(w, http.StatusBadRequest, "cannot create share for folder")
|
|
return
|
|
}
|
|
|
|
ttl := s.config.ShareDefaultTTL
|
|
if in.ExpiresMinutes > 0 {
|
|
ttl = time.Duration(in.ExpiresMinutes) * time.Minute
|
|
}
|
|
if ttl < 5*time.Minute {
|
|
ttl = 5 * time.Minute
|
|
}
|
|
if ttl > 30*24*time.Hour {
|
|
ttl = 30 * 24 * time.Hour
|
|
}
|
|
|
|
token, err := randomToken()
|
|
if err != nil {
|
|
writeErr(w, http.StatusInternalServerError, "failed to create token")
|
|
return
|
|
}
|
|
|
|
var maxDownloads any
|
|
if in.MaxDownloads != nil && *in.MaxDownloads > 0 {
|
|
maxDownloads = *in.MaxDownloads
|
|
}
|
|
|
|
expiresAt := time.Now().Add(ttl)
|
|
_, err = s.db.Exec(`INSERT INTO share_links(user_id, rel_path, token_hash, expires_at, max_downloads) VALUES (?, ?, ?, ?, ?)`,
|
|
uid,
|
|
rel,
|
|
hashToken(token),
|
|
expiresAt,
|
|
maxDownloads,
|
|
)
|
|
if err != nil {
|
|
writeErr(w, http.StatusInternalServerError, "failed to create share")
|
|
return
|
|
}
|
|
|
|
shareURL := fmt.Sprintf("%s://%s/api/share/%s", schemeOf(r), r.Host, token)
|
|
writeJSON(w, http.StatusCreated, map[string]any{
|
|
"url": shareURL,
|
|
"token": token,
|
|
"path": rel,
|
|
"expiresAt": expiresAt,
|
|
})
|
|
}
|
|
|
|
func (s *Server) handleSharedDownload(w http.ResponseWriter, r *http.Request) {
|
|
token := mux.Vars(r)["token"]
|
|
if strings.TrimSpace(token) == "" {
|
|
writeErr(w, http.StatusBadRequest, "missing token")
|
|
return
|
|
}
|
|
|
|
var uid int64
|
|
var rel string
|
|
var expiresAt time.Time
|
|
var revokedAt sql.NullTime
|
|
var maxDownloads sql.NullInt64
|
|
var downloadCount int64
|
|
err := s.db.QueryRow(`SELECT user_id, rel_path, expires_at, revoked_at, max_downloads, download_count FROM share_links WHERE token_hash = ?`, hashToken(token)).
|
|
Scan(&uid, &rel, &expiresAt, &revokedAt, &maxDownloads, &downloadCount)
|
|
if err != nil {
|
|
writeErr(w, http.StatusNotFound, "share link not found")
|
|
return
|
|
}
|
|
|
|
if revokedAt.Valid || expiresAt.Before(time.Now()) {
|
|
writeErr(w, http.StatusGone, "share link expired")
|
|
return
|
|
}
|
|
if maxDownloads.Valid && downloadCount >= maxDownloads.Int64 {
|
|
writeErr(w, http.StatusGone, "share link download limit reached")
|
|
return
|
|
}
|
|
|
|
if _, err := s.db.Exec(`UPDATE share_links SET download_count = download_count + 1 WHERE token_hash = ?`, hashToken(token)); err != nil {
|
|
writeErr(w, http.StatusInternalServerError, "failed to track download")
|
|
return
|
|
}
|
|
|
|
log.Printf("file.share.download user_id=%d path=%q ip=%q", uid, normalizePath(rel), clientIP(r))
|
|
|
|
if err := s.serveFile(w, r, uid, rel, true, ""); err != nil {
|
|
writeErr(w, http.StatusBadRequest, err.Error())
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleAdminMe(w http.ResponseWriter, _ *http.Request) {
|
|
writeJSON(w, http.StatusOK, map[string]string{"login": s.config.AdminLogin})
|
|
}
|
|
|
|
func (s *Server) handleAdminUsersList(w http.ResponseWriter, _ *http.Request) {
|
|
rows, err := s.db.Query(`SELECT id, email, theme, color_mode, archive_format FROM users ORDER BY id ASC`)
|
|
if err != nil {
|
|
writeErr(w, http.StatusInternalServerError, "failed to load users")
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
users := make([]User, 0)
|
|
for rows.Next() {
|
|
var u User
|
|
if err := rows.Scan(&u.ID, &u.Username, &u.Theme, &u.ColorMode, &u.Archive); err != nil {
|
|
continue
|
|
}
|
|
u.Theme = normalizeTheme(u.Theme)
|
|
u.ColorMode = normalizeColorMode(u.ColorMode)
|
|
users = append(users, u)
|
|
}
|
|
|
|
writeJSON(w, http.StatusOK, map[string]any{"users": users})
|
|
}
|
|
|
|
type adminCreateUserInput struct {
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
Theme string `json:"theme"`
|
|
ColorMode string `json:"colorMode"`
|
|
}
|
|
|
|
func (s *Server) handleAdminUserCreate(w http.ResponseWriter, r *http.Request) {
|
|
var in adminCreateUserInput
|
|
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
|
|
writeErr(w, http.StatusBadRequest, "invalid payload")
|
|
return
|
|
}
|
|
|
|
in.Theme = normalizeTheme(in.Theme)
|
|
in.ColorMode = normalizeColorMode(in.ColorMode)
|
|
user, err := s.createUser(in.Username, in.Password, in.Theme, in.ColorMode)
|
|
if err != nil {
|
|
if strings.Contains(strings.ToLower(err.Error()), "exists") {
|
|
writeErr(w, http.StatusConflict, err.Error())
|
|
return
|
|
}
|
|
writeErr(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
writeJSON(w, http.StatusCreated, user)
|
|
log.Printf("admin.user.create user_id=%d username=%q", user.ID, user.Username)
|
|
}
|
|
|
|
func (s *Server) handleAdminUserDelete(w http.ResponseWriter, r *http.Request) {
|
|
idStr := mux.Vars(r)["id"]
|
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
|
if err != nil || id <= 0 {
|
|
writeErr(w, http.StatusBadRequest, "invalid user id")
|
|
return
|
|
}
|
|
|
|
if _, err := s.db.Exec(`DELETE FROM users WHERE id = ?`, id); err != nil {
|
|
writeErr(w, http.StatusInternalServerError, "failed to delete user")
|
|
return
|
|
}
|
|
_, _ = s.db.Exec(`DELETE FROM refresh_tokens WHERE user_id = ?`, id)
|
|
_, _ = s.db.Exec(`UPDATE share_links SET revoked_at = CURRENT_TIMESTAMP WHERE user_id = ?`, id)
|
|
_ = s.storage.Delete(id, "/")
|
|
|
|
writeJSON(w, http.StatusOK, map[string]string{"status": "deleted"})
|
|
log.Printf("admin.user.delete user_id=%d", id)
|
|
}
|
|
|
|
func (s *Server) recoverMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
defer func() {
|
|
if v := recover(); v != nil {
|
|
log.Printf("panic recovered path=%q err=%v", r.URL.Path, v)
|
|
writeErr(w, http.StatusInternalServerError, "internal server error")
|
|
}
|
|
}()
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func (s *Server) securityHeadersMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
w.Header().Set("X-Frame-Options", "DENY")
|
|
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
|
w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()")
|
|
w.Header().Set("Content-Security-Policy", "default-src 'self'; connect-src 'self'; img-src 'self' data: blob:; style-src 'self' 'unsafe-inline'; script-src 'self'; frame-ancestors 'none'; base-uri 'self'; form-action 'self'")
|
|
if r.TLS != nil {
|
|
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func (s *Server) bodyLimitMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if strings.HasPrefix(r.URL.Path, "/api/files/upload") {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
if strings.HasPrefix(r.URL.Path, "/api/") {
|
|
r.Body = http.MaxBytesReader(w, r.Body, s.config.MaxBodyBytes)
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func (s *Server) rateLimitMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if !strings.HasPrefix(r.URL.Path, "/api/") {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
ip := clientIP(r)
|
|
limit := s.config.RateLimitPerMin
|
|
if strings.Contains(r.URL.Path, "/auth/") || strings.Contains(r.URL.Path, "/admin/login") {
|
|
limit = s.config.AuthRateLimitPerMin
|
|
}
|
|
|
|
if !s.limiter.allow(ip+":"+r.URL.Path, limit, time.Now()) {
|
|
writeErr(w, http.StatusTooManyRequests, "rate limit exceeded")
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func (s *Server) requestLogMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK}
|
|
next.ServeHTTP(rec, r)
|
|
if strings.HasPrefix(r.URL.Path, "/api/") {
|
|
log.Printf(
|
|
"api.request method=%s path=%q status=%d duration_ms=%d ip=%q ua=%q",
|
|
r.Method,
|
|
r.URL.Path,
|
|
rec.status,
|
|
time.Since(start).Milliseconds(),
|
|
clientIP(r),
|
|
limitString(r.UserAgent(), 120),
|
|
)
|
|
}
|
|
})
|
|
}
|
|
|
|
func (s *Server) createUser(username, password, theme, colorMode string) (User, error) {
|
|
username = strings.ToLower(strings.TrimSpace(username))
|
|
if len(username) < 3 || len(username) > 32 {
|
|
return User{}, fmt.Errorf("username must be 3-32 characters")
|
|
}
|
|
for _, ch := range username {
|
|
if (ch < 'a' || ch > 'z') && (ch < '0' || ch > '9') && ch != '_' && ch != '-' && ch != '.' {
|
|
return User{}, fmt.Errorf("username can contain only a-z, 0-9, dot, dash, underscore")
|
|
}
|
|
}
|
|
if len(password) < 10 {
|
|
return User{}, fmt.Errorf("password must be at least 10 characters")
|
|
}
|
|
|
|
hash, err := hashPasswordArgon2ID(password)
|
|
if err != nil {
|
|
return User{}, fmt.Errorf("failed to hash password")
|
|
}
|
|
|
|
res, err := s.db.Exec(`INSERT INTO users(email, password_hash, theme, color_mode, archive_format) VALUES (?, ?, ?, ?, ?)`, username, hash, normalizeTheme(theme), normalizeColorMode(colorMode), "zip")
|
|
if err != nil {
|
|
if strings.Contains(strings.ToLower(err.Error()), "unique") {
|
|
return User{}, fmt.Errorf("account already exists")
|
|
}
|
|
return User{}, err
|
|
}
|
|
|
|
id, _ := res.LastInsertId()
|
|
if err := s.storage.Mkdir(id, "/"); err != nil {
|
|
return User{}, fmt.Errorf("failed to provision user storage: %w", err)
|
|
}
|
|
|
|
return User{ID: id, Username: username, Theme: normalizeTheme(theme), ColorMode: normalizeColorMode(colorMode), Archive: "zip"}, nil
|
|
}
|
|
|
|
func (s *Server) findUserWithHash(username string) (User, string, error) {
|
|
username = strings.ToLower(strings.TrimSpace(username))
|
|
var u User
|
|
var hash string
|
|
err := s.db.QueryRow(`SELECT id, email, password_hash, theme, color_mode, archive_format FROM users WHERE email = ?`, username).
|
|
Scan(&u.ID, &u.Username, &hash, &u.Theme, &u.ColorMode, &u.Archive)
|
|
if err != nil {
|
|
return User{}, "", err
|
|
}
|
|
u.Theme = normalizeTheme(u.Theme)
|
|
u.ColorMode = normalizeColorMode(u.ColorMode)
|
|
return u, hash, nil
|
|
}
|
|
|
|
func (s *Server) findUser(id int64) (User, error) {
|
|
var u User
|
|
err := s.db.QueryRow(`SELECT id, email, theme, color_mode, archive_format FROM users WHERE id = ?`, id).
|
|
Scan(&u.ID, &u.Username, &u.Theme, &u.ColorMode, &u.Archive)
|
|
u.Theme = normalizeTheme(u.Theme)
|
|
u.ColorMode = normalizeColorMode(u.ColorMode)
|
|
return u, err
|
|
}
|
|
|
|
func (s *Server) issueUserSession(w http.ResponseWriter, r *http.Request, userID int64) error {
|
|
now := time.Now()
|
|
claims := AccessClaims{
|
|
UserID: userID,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: strconv.FormatInt(userID, 10),
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(s.config.AccessTTL)),
|
|
},
|
|
}
|
|
|
|
access, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(s.config.JWTSecret))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
refresh, err := randomToken()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err := s.db.Exec(`INSERT INTO refresh_tokens(user_id, token_hash, expires_at, user_agent, ip) VALUES (?, ?, ?, ?, ?)`,
|
|
userID,
|
|
hashToken(refresh),
|
|
now.Add(s.config.RefreshTTL),
|
|
limitString(r.UserAgent(), 255),
|
|
limitString(clientIP(r), 64),
|
|
); err != nil {
|
|
return err
|
|
}
|
|
|
|
setCookie(w, "access_token", access, int(s.config.AccessTTL.Seconds()), s.config.CookieSecure)
|
|
setCookie(w, "refresh_token", refresh, int(s.config.RefreshTTL.Seconds()), s.config.CookieSecure)
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) consumeRefreshToken(token string) (int64, error) {
|
|
var id int64
|
|
var uid int64
|
|
var expiresAt time.Time
|
|
var revokedAt sql.NullTime
|
|
err := s.db.QueryRow(`SELECT id, user_id, expires_at, revoked_at FROM refresh_tokens WHERE token_hash = ?`, hashToken(token)).
|
|
Scan(&id, &uid, &expiresAt, &revokedAt)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if revokedAt.Valid || expiresAt.Before(time.Now()) {
|
|
return 0, fmt.Errorf("refresh token expired or revoked")
|
|
}
|
|
|
|
if _, err := s.db.Exec(`UPDATE refresh_tokens SET revoked_at = CURRENT_TIMESTAMP WHERE id = ?`, id); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return uid, nil
|
|
}
|
|
|
|
func (s *Server) authMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
cookie, err := r.Cookie("access_token")
|
|
if err != nil || cookie.Value == "" {
|
|
writeErr(w, http.StatusUnauthorized, "missing access token")
|
|
return
|
|
}
|
|
|
|
claims := &AccessClaims{}
|
|
tkn, err := jwt.ParseWithClaims(cookie.Value, claims, func(token *jwt.Token) (any, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method")
|
|
}
|
|
return []byte(s.config.JWTSecret), nil
|
|
})
|
|
if err != nil || !tkn.Valid {
|
|
writeErr(w, http.StatusUnauthorized, "invalid access token")
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), userIDKey, claims.UserID)))
|
|
})
|
|
}
|
|
|
|
func (s *Server) adminMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
cookie, err := r.Cookie("admin_token")
|
|
if err != nil || cookie.Value == "" {
|
|
writeErr(w, http.StatusUnauthorized, "missing admin token")
|
|
return
|
|
}
|
|
|
|
claims := &AdminClaims{}
|
|
tkn, err := jwt.ParseWithClaims(cookie.Value, claims, func(token *jwt.Token) (any, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method")
|
|
}
|
|
return []byte(s.config.JWTSecret), nil
|
|
})
|
|
if err != nil || !tkn.Valid || claims.Role != "admin" || claims.Login != s.config.AdminLogin {
|
|
writeErr(w, http.StatusUnauthorized, "invalid admin session")
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func (s *Server) corsMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
origin := strings.TrimSpace(r.Header.Get("Origin"))
|
|
if origin != "" && s.config.CORSOrigin != "" && origin != s.config.CORSOrigin {
|
|
writeErr(w, http.StatusForbidden, "cors origin denied")
|
|
return
|
|
}
|
|
|
|
if s.config.CORSOrigin != "" {
|
|
w.Header().Set("Access-Control-Allow-Origin", s.config.CORSOrigin)
|
|
w.Header().Set("Vary", "Origin")
|
|
}
|
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET,POST,DELETE,OPTIONS,HEAD")
|
|
if r.Method == http.MethodOptions {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func (s *Server) hostGuardMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if strings.TrimSpace(s.config.AllowedHost) == "" {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
host := strings.ToLower(hostOnly(r.Host))
|
|
if host == strings.ToLower(s.config.AllowedHost) {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
if strings.HasPrefix(r.URL.Path, "/api/") {
|
|
writeErr(w, http.StatusForbidden, fmt.Sprintf("access denied: host must be %s", s.config.AllowedHost))
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
|
w.WriteHeader(http.StatusForbidden)
|
|
_, _ = w.Write([]byte(fmt.Sprintf("<!doctype html><html><body><h1>Access denied</h1><p>This service is only available at <b>%s</b>.</p></body></html>", s.config.AllowedHost)))
|
|
})
|
|
}
|
|
|
|
type LocalStorage struct {
|
|
root string
|
|
}
|
|
|
|
func (l *LocalStorage) userRoot(userID int64) string {
|
|
return filepath.Join(l.root, strconv.FormatInt(userID, 10))
|
|
}
|
|
|
|
func (l *LocalStorage) fullPath(userID int64, rel string) (string, error) {
|
|
root := l.userRoot(userID)
|
|
if err := os.MkdirAll(root, 0o755); err != nil {
|
|
return "", err
|
|
}
|
|
clean := filepath.FromSlash(strings.TrimPrefix(normalizePath(rel), "/"))
|
|
full := filepath.Clean(filepath.Join(root, clean))
|
|
if full != root && !strings.HasPrefix(full, root+string(os.PathSeparator)) {
|
|
return "", fmt.Errorf("invalid path")
|
|
}
|
|
return full, nil
|
|
}
|
|
|
|
func (l *LocalStorage) List(userID int64, rel string) ([]FileEntry, error) {
|
|
full, err := l.fullPath(userID, rel)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
entries, err := os.ReadDir(full)
|
|
if err != nil {
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
return []FileEntry{}, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
out := make([]FileEntry, 0, len(entries))
|
|
for _, e := range entries {
|
|
info, err := e.Info()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
out = append(out, FileEntry{
|
|
Name: e.Name(),
|
|
Path: path.Join(normalizePath(rel), e.Name()),
|
|
IsDir: e.IsDir(),
|
|
Size: info.Size(),
|
|
ModTime: info.ModTime(),
|
|
})
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (l *LocalStorage) Mkdir(userID int64, rel string) error {
|
|
full, err := l.fullPath(userID, rel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return os.MkdirAll(full, 0o755)
|
|
}
|
|
|
|
func (l *LocalStorage) Save(userID int64, rel string, src multipart.File) error {
|
|
full, err := l.fullPath(userID, rel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil {
|
|
return err
|
|
}
|
|
dst, err := os.Create(full)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer dst.Close()
|
|
_, err = io.Copy(dst, src)
|
|
return err
|
|
}
|
|
|
|
func (l *LocalStorage) SaveBytes(userID int64, rel string, data []byte) error {
|
|
full, err := l.fullPath(userID, rel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil {
|
|
return err
|
|
}
|
|
return os.WriteFile(full, data, 0o644)
|
|
}
|
|
|
|
func (l *LocalStorage) Delete(userID int64, rel string) error {
|
|
if normalizePath(rel) == "/" {
|
|
full, err := l.fullPath(userID, rel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return os.RemoveAll(full)
|
|
}
|
|
|
|
full, err := l.fullPath(userID, rel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return os.RemoveAll(full)
|
|
}
|
|
|
|
func (l *LocalStorage) Stat(userID int64, rel string) (FileMeta, error) {
|
|
full, err := l.fullPath(userID, rel)
|
|
if err != nil {
|
|
return FileMeta{}, err
|
|
}
|
|
st, err := os.Stat(full)
|
|
if err != nil {
|
|
return FileMeta{}, err
|
|
}
|
|
return FileMeta{Name: st.Name(), Size: st.Size(), ModTime: st.ModTime(), IsDir: st.IsDir()}, nil
|
|
}
|
|
|
|
func (l *LocalStorage) OpenReadSeeker(userID int64, rel string) (ReadSeekCloser, error) {
|
|
full, err := l.fullPath(userID, rel)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return os.Open(full)
|
|
}
|
|
|
|
type SMBStorage struct {
|
|
cfg Config
|
|
}
|
|
|
|
type smbConn struct {
|
|
conn net.Conn
|
|
session *smb2.Session
|
|
share *smb2.Share
|
|
}
|
|
|
|
func (c *smbConn) Close() {
|
|
if c.share != nil {
|
|
_ = c.share.Umount()
|
|
}
|
|
if c.session != nil {
|
|
_ = c.session.Logoff()
|
|
}
|
|
if c.conn != nil {
|
|
_ = c.conn.Close()
|
|
}
|
|
}
|
|
|
|
type smbReadSeekCloser struct {
|
|
file *smb2.File
|
|
conn *smbConn
|
|
}
|
|
|
|
func (s *smbReadSeekCloser) Read(p []byte) (int, error) {
|
|
return s.file.Read(p)
|
|
}
|
|
|
|
func (s *smbReadSeekCloser) Seek(offset int64, whence int) (int64, error) {
|
|
return s.file.Seek(offset, whence)
|
|
}
|
|
|
|
func (s *smbReadSeekCloser) Close() error {
|
|
_ = s.file.Close()
|
|
s.conn.Close()
|
|
return nil
|
|
}
|
|
|
|
func (smbs *SMBStorage) openConnection() (*smbConn, error) {
|
|
conn, err := net.DialTimeout("tcp", smbs.cfg.SMBHost, smbs.cfg.SMBConnectTimout)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
dialer := &smb2.Dialer{Initiator: &smb2.NTLMInitiator{
|
|
User: smbs.cfg.SMBUser,
|
|
Password: smbs.cfg.SMBPass,
|
|
Domain: smbs.cfg.SMBDomain,
|
|
}}
|
|
|
|
session, err := dialer.Dial(conn)
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
return nil, err
|
|
}
|
|
|
|
share, err := session.Mount(smbs.cfg.SMBShare)
|
|
if err != nil {
|
|
_ = session.Logoff()
|
|
_ = conn.Close()
|
|
return nil, err
|
|
}
|
|
|
|
return &smbConn{conn: conn, session: session, share: share}, nil
|
|
}
|
|
|
|
func (smbs *SMBStorage) withShare(fn func(share *smb2.Share) error) error {
|
|
conn, err := smbs.openConnection()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer conn.Close()
|
|
return fn(conn.share)
|
|
}
|
|
|
|
func (smbs *SMBStorage) userPath(userID int64, rel string) string {
|
|
base := path.Join(smbs.cfg.SMBBasePath, strconv.FormatInt(userID, 10))
|
|
clean := strings.TrimPrefix(normalizePath(rel), "/")
|
|
if clean == "" {
|
|
return base
|
|
}
|
|
return path.Join(base, clean)
|
|
}
|
|
|
|
func (smbs *SMBStorage) List(userID int64, rel string) ([]FileEntry, error) {
|
|
out := make([]FileEntry, 0)
|
|
err := smbs.withShare(func(share *smb2.Share) error {
|
|
target := smbs.userPath(userID, rel)
|
|
if err := share.MkdirAll(target, 0o755); err != nil {
|
|
return err
|
|
}
|
|
entries, err := share.ReadDir(target)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, e := range entries {
|
|
out = append(out, FileEntry{
|
|
Name: e.Name(),
|
|
Path: path.Join(normalizePath(rel), e.Name()),
|
|
IsDir: e.IsDir(),
|
|
Size: e.Size(),
|
|
ModTime: e.ModTime(),
|
|
})
|
|
}
|
|
return nil
|
|
})
|
|
return out, err
|
|
}
|
|
|
|
func (smbs *SMBStorage) Mkdir(userID int64, rel string) error {
|
|
return smbs.withShare(func(share *smb2.Share) error {
|
|
return share.MkdirAll(smbs.userPath(userID, rel), 0o755)
|
|
})
|
|
}
|
|
|
|
func (smbs *SMBStorage) Save(userID int64, rel string, src multipart.File) error {
|
|
return smbs.withShare(func(share *smb2.Share) error {
|
|
target := smbs.userPath(userID, rel)
|
|
if err := share.MkdirAll(path.Dir(target), 0o755); err != nil {
|
|
return err
|
|
}
|
|
f, err := share.Create(target)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
_, err = io.Copy(f, src)
|
|
return err
|
|
})
|
|
}
|
|
|
|
func (smbs *SMBStorage) SaveBytes(userID int64, rel string, data []byte) error {
|
|
return smbs.withShare(func(share *smb2.Share) error {
|
|
target := smbs.userPath(userID, rel)
|
|
if err := share.MkdirAll(path.Dir(target), 0o755); err != nil {
|
|
return err
|
|
}
|
|
f, err := share.Create(target)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
_, err = f.Write(data)
|
|
return err
|
|
})
|
|
}
|
|
|
|
func (smbs *SMBStorage) Delete(userID int64, rel string) error {
|
|
return smbs.withShare(func(share *smb2.Share) error {
|
|
target := smbs.userPath(userID, rel)
|
|
if normalizePath(rel) == "/" {
|
|
return share.RemoveAll(target)
|
|
}
|
|
if err := share.Remove(target); err == nil {
|
|
return nil
|
|
}
|
|
return share.RemoveAll(target)
|
|
})
|
|
}
|
|
|
|
func (smbs *SMBStorage) Stat(userID int64, rel string) (FileMeta, error) {
|
|
var out FileMeta
|
|
err := smbs.withShare(func(share *smb2.Share) error {
|
|
st, err := share.Stat(smbs.userPath(userID, rel))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
out = FileMeta{Name: st.Name(), Size: st.Size(), ModTime: st.ModTime(), IsDir: st.IsDir()}
|
|
return nil
|
|
})
|
|
return out, err
|
|
}
|
|
|
|
func (smbs *SMBStorage) OpenReadSeeker(userID int64, rel string) (ReadSeekCloser, error) {
|
|
conn, err := smbs.openConnection()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
f, err := conn.share.Open(smbs.userPath(userID, rel))
|
|
if err != nil {
|
|
conn.Close()
|
|
return nil, err
|
|
}
|
|
|
|
return &smbReadSeekCloser{file: f, conn: conn}, nil
|
|
}
|
|
|
|
func normalizePath(rel string) string {
|
|
clean := path.Clean("/" + strings.TrimSpace(rel))
|
|
if clean == "." {
|
|
return "/"
|
|
}
|
|
return clean
|
|
}
|
|
|
|
func normalizeUploadRelativePath(v string) (string, error) {
|
|
v = strings.ReplaceAll(strings.TrimSpace(v), "\\", "/")
|
|
v = strings.TrimPrefix(v, "/")
|
|
v = path.Clean(v)
|
|
if v == "." || v == "" {
|
|
return "", fmt.Errorf("invalid upload path")
|
|
}
|
|
if strings.HasPrefix(v, "../") || strings.Contains(v, "/../") || strings.HasPrefix(v, "/") {
|
|
return "", fmt.Errorf("invalid upload path")
|
|
}
|
|
return v, nil
|
|
}
|
|
|
|
func normalizeTheme(v string) string {
|
|
v = strings.TrimSpace(strings.ToLower(v))
|
|
switch v {
|
|
case "dracula", "nord", "monokai", "solarized", "github":
|
|
return v
|
|
case "material", "glass", "desktop", "auto":
|
|
return "dracula"
|
|
default:
|
|
return "dracula"
|
|
}
|
|
}
|
|
|
|
func normalizeColorMode(v string) string {
|
|
v = strings.TrimSpace(strings.ToLower(v))
|
|
switch v {
|
|
case "light", "dark", "auto":
|
|
return v
|
|
default:
|
|
return "auto"
|
|
}
|
|
}
|
|
|
|
func normalizeArchiveFormat(v string) string {
|
|
v = strings.TrimSpace(strings.ToLower(v))
|
|
switch v {
|
|
case "zip", "rar", "tar.gz", "lz4":
|
|
return v
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func normalizeTag(v string) (string, bool) {
|
|
v = strings.ToLower(strings.TrimSpace(v))
|
|
if len(v) < 1 || len(v) > 24 {
|
|
return "", false
|
|
}
|
|
for _, ch := range v {
|
|
if (ch < 'a' || ch > 'z') && (ch < '0' || ch > '9') && ch != '-' && ch != '_' {
|
|
return "", false
|
|
}
|
|
}
|
|
return v, true
|
|
}
|
|
|
|
func isMarkdownExtension(ext string) bool {
|
|
ext = strings.ToLower(strings.TrimPrefix(ext, "."))
|
|
switch ext {
|
|
case "md", "markdown":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (s *Server) fileTagsForPaths(uid int64, paths []string) (map[string][]string, error) {
|
|
out := make(map[string][]string, len(paths))
|
|
if len(paths) == 0 {
|
|
return out, nil
|
|
}
|
|
args := make([]any, 0, len(paths)+1)
|
|
args = append(args, uid)
|
|
pl := make([]string, len(paths))
|
|
for i, p := range paths {
|
|
norm := normalizePath(p)
|
|
pl[i] = "?"
|
|
args = append(args, norm)
|
|
}
|
|
q := `SELECT rel_path, tag FROM file_tags WHERE user_id = ? AND rel_path IN (` + strings.Join(pl, ",") + `) ORDER BY rel_path, tag`
|
|
rows, err := s.db.Query(q, args...)
|
|
if err != nil {
|
|
return out, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var rel string
|
|
var tag string
|
|
if rows.Scan(&rel, &tag) == nil {
|
|
out[rel] = append(out[rel], tag)
|
|
}
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func setCookie(w http.ResponseWriter, name, value string, maxAge int, secure bool) {
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: name,
|
|
Value: value,
|
|
Path: "/",
|
|
HttpOnly: true,
|
|
Secure: secure,
|
|
SameSite: http.SameSiteLaxMode,
|
|
MaxAge: maxAge,
|
|
})
|
|
}
|
|
|
|
func clearCookie(w http.ResponseWriter, name string, secure bool) {
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: name,
|
|
Value: "",
|
|
Path: "/",
|
|
HttpOnly: true,
|
|
Secure: secure,
|
|
SameSite: http.SameSiteLaxMode,
|
|
MaxAge: -1,
|
|
})
|
|
}
|
|
|
|
func randomToken() (string, error) {
|
|
b := make([]byte, 32)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
|
}
|
|
|
|
func hashToken(token string) string {
|
|
h := sha256.Sum256([]byte(token))
|
|
return base64.RawURLEncoding.EncodeToString(h[:])
|
|
}
|
|
|
|
func subtleConstantTimeEq(a, b string) int {
|
|
if len(a) != len(b) {
|
|
return 0
|
|
}
|
|
var out byte
|
|
for i := 0; i < len(a); i++ {
|
|
out |= a[i] ^ b[i]
|
|
}
|
|
if out == 0 {
|
|
return 1
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func maybeRunHashCommand() bool {
|
|
if len(os.Args) < 2 {
|
|
return false
|
|
}
|
|
if os.Args[1] != "hash-admin" {
|
|
return false
|
|
}
|
|
if len(os.Args) < 3 {
|
|
fmt.Println("usage: go run . hash-admin <password>")
|
|
os.Exit(1)
|
|
}
|
|
h, err := hashPasswordArgon2ID(os.Args[2])
|
|
if err != nil {
|
|
fmt.Println("failed to hash password:", err)
|
|
os.Exit(1)
|
|
}
|
|
fmt.Println(h)
|
|
return true
|
|
}
|
|
|
|
func hashPasswordArgon2ID(password string) (string, error) {
|
|
salt := make([]byte, 16)
|
|
if _, err := rand.Read(salt); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
timeCost := uint32(3)
|
|
memoryCost := uint32(64 * 1024)
|
|
threads := uint8(2)
|
|
keyLen := uint32(32)
|
|
derived := argon2.IDKey([]byte(password), salt, timeCost, memoryCost, threads, keyLen)
|
|
|
|
saltB64 := base64.RawStdEncoding.EncodeToString(salt)
|
|
hashB64 := base64.RawStdEncoding.EncodeToString(derived)
|
|
encoded := fmt.Sprintf("argon2id:v=19,m=%d,t=%d,p=%d:%s:%s", memoryCost, timeCost, threads, saltB64, hashB64)
|
|
return encoded, nil
|
|
}
|
|
|
|
func verifyPasswordHash(storedHash, password string) bool {
|
|
if strings.HasPrefix(storedHash, "argon2id:") {
|
|
parts := strings.Split(storedHash, ":")
|
|
if len(parts) != 4 {
|
|
return false
|
|
}
|
|
var mem, timeCost uint32
|
|
var threads uint8
|
|
if _, err := fmt.Sscanf(parts[1], "v=19,m=%d,t=%d,p=%d", &mem, &timeCost, &threads); err != nil {
|
|
return false
|
|
}
|
|
salt, err := base64.RawStdEncoding.DecodeString(parts[2])
|
|
if err != nil {
|
|
return false
|
|
}
|
|
expected, err := base64.RawStdEncoding.DecodeString(parts[3])
|
|
if err != nil {
|
|
return false
|
|
}
|
|
actual := argon2.IDKey([]byte(password), salt, timeCost, mem, threads, uint32(len(expected)))
|
|
return subtle.ConstantTimeCompare(expected, actual) == 1
|
|
}
|
|
|
|
if strings.HasPrefix(storedHash, "argon2id$") {
|
|
parts := strings.Split(storedHash, "$")
|
|
if len(parts) == 6 {
|
|
var mem, timeCost uint32
|
|
var threads uint8
|
|
if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &mem, &timeCost, &threads); err == nil {
|
|
salt, errSalt := base64.RawStdEncoding.DecodeString(parts[4])
|
|
expected, errExpected := base64.RawStdEncoding.DecodeString(parts[5])
|
|
if errSalt == nil && errExpected == nil {
|
|
actual := argon2.IDKey([]byte(password), salt, timeCost, mem, threads, uint32(len(expected)))
|
|
return subtle.ConstantTimeCompare(expected, actual) == 1
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if strings.HasPrefix(storedHash, "sha256:") {
|
|
expected := strings.TrimPrefix(storedHash, "sha256:")
|
|
sum := sha256.Sum256([]byte(password))
|
|
actual := hex.EncodeToString(sum[:])
|
|
return subtle.ConstantTimeCompare([]byte(expected), []byte(actual)) == 1
|
|
}
|
|
|
|
bcryptHash := storedHash
|
|
if strings.HasPrefix(storedHash, "bcrypt:") {
|
|
bcryptHash = strings.TrimPrefix(storedHash, "bcrypt:")
|
|
}
|
|
if strings.HasPrefix(bcryptHash, "$2a$") || strings.HasPrefix(bcryptHash, "$2b$") || strings.HasPrefix(bcryptHash, "$2y$") {
|
|
return bcrypt.CompareHashAndPassword([]byte(bcryptHash), []byte(password)) == nil
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func verifyAdminPasswordHash(storedHash, password string) bool {
|
|
return verifyPasswordHash(storedHash, password)
|
|
}
|
|
|
|
func userIDFromContext(ctx context.Context) int64 {
|
|
v, _ := ctx.Value(userIDKey).(int64)
|
|
return v
|
|
}
|
|
|
|
func schemeOf(r *http.Request) string {
|
|
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
|
|
return proto
|
|
}
|
|
if r.TLS != nil {
|
|
return "https"
|
|
}
|
|
return "http"
|
|
}
|
|
|
|
func hostOnly(hostport string) string {
|
|
h := strings.TrimSpace(hostport)
|
|
if h == "" {
|
|
return ""
|
|
}
|
|
if strings.Contains(h, ":") {
|
|
if host, _, err := net.SplitHostPort(h); err == nil {
|
|
return host
|
|
}
|
|
}
|
|
return h
|
|
}
|
|
|
|
func clientIP(r *http.Request) string {
|
|
for _, h := range []string{"CF-Connecting-IP", "X-Forwarded-For", "X-Real-IP"} {
|
|
v := strings.TrimSpace(r.Header.Get(h))
|
|
if v != "" {
|
|
if strings.Contains(v, ",") {
|
|
return strings.TrimSpace(strings.Split(v, ",")[0])
|
|
}
|
|
return v
|
|
}
|
|
}
|
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err != nil {
|
|
return r.RemoteAddr
|
|
}
|
|
return host
|
|
}
|
|
|
|
func writeJSON(w http.ResponseWriter, status int, payload any) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
_ = json.NewEncoder(w).Encode(payload)
|
|
}
|
|
|
|
func writeErr(w http.ResponseWriter, status int, msg string) {
|
|
writeJSON(w, status, map[string]string{"error": msg})
|
|
}
|
|
|
|
func getEnv(key, fallback string) string {
|
|
v := strings.TrimSpace(os.Getenv(key))
|
|
if v == "" {
|
|
return fallback
|
|
}
|
|
return v
|
|
}
|
|
|
|
func getEnvInt(key string, fallback int) int {
|
|
v := strings.TrimSpace(os.Getenv(key))
|
|
if v == "" {
|
|
return fallback
|
|
}
|
|
n, err := strconv.Atoi(v)
|
|
if err != nil {
|
|
return fallback
|
|
}
|
|
return n
|
|
}
|
|
|
|
func limitString(s string, max int) string {
|
|
if len(s) <= max {
|
|
return s
|
|
}
|
|
return s[:max]
|
|
}
|