Files
ZFile/backend/main.go
2026-03-02 22:32:46 +03:00

2821 lines
77 KiB
Go

package main
import (
"archive/tar"
"archive/zip"
"compress/gzip"
"context"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"database/sql"
"embed"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log"
"mime"
"mime/multipart"
"net"
"net/http"
"os"
"os/exec"
"path"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/gorilla/mux"
"github.com/hirochachacha/go-smb2"
"github.com/joho/godotenv"
"golang.org/x/crypto/argon2"
"golang.org/x/crypto/bcrypt"
_ "modernc.org/sqlite"
)
type Config struct {
Addr string
DBPath string
StorageRoot string
AppDomain string
AllowedHost string
CORSOrigin string
CookieSecure bool
MaxBodyBytes int64
RateLimitPerMin int
AuthRateLimitPerMin int
JWTSecret string
AccessTTL time.Duration
RefreshTTL time.Duration
ShareDefaultTTL time.Duration
AdminSessionTTL time.Duration
AdminLogin string
AdminPasswordHash string
StorageBackend string
SMBHost string
SMBShare string
SMBUser string
SMBPass string
SMBDomain string
SMBBasePath string
SMBConnectTimout time.Duration
}
//go:embed web/dist
var embeddedWeb embed.FS
type Server struct {
db *sql.DB
config Config
storage Storage
limiter *rateLimiter
}
type rateLimiter struct {
mu sync.Mutex
entries map[string]*rateEntry
}
type rateEntry struct {
Count int
WindowEnds time.Time
}
func newRateLimiter() *rateLimiter {
return &rateLimiter{entries: make(map[string]*rateEntry)}
}
func (rl *rateLimiter) allow(key string, limit int, now time.Time) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
entry, ok := rl.entries[key]
if !ok || now.After(entry.WindowEnds) {
rl.entries[key] = &rateEntry{Count: 1, WindowEnds: now.Add(time.Minute)}
return true
}
if entry.Count >= limit {
return false
}
entry.Count++
return true
}
type statusRecorder struct {
http.ResponseWriter
status int
}
func (r *statusRecorder) WriteHeader(code int) {
r.status = code
r.ResponseWriter.WriteHeader(code)
}
type key int
const (
userIDKey key = 1
)
type AccessClaims struct {
UserID int64 `json:"uid"`
jwt.RegisteredClaims
}
type AdminClaims struct {
Login string `json:"login"`
Role string `json:"role"`
jwt.RegisteredClaims
}
type User struct {
ID int64 `json:"id"`
Username string `json:"username"`
Theme string `json:"theme"`
ColorMode string `json:"colorMode"`
Archive string `json:"archiveFormat"`
}
type FileEntry struct {
Name string `json:"name"`
Path string `json:"path"`
IsDir bool `json:"isDir"`
Size int64 `json:"size"`
ModTime time.Time `json:"modTime"`
Tags []string `json:"tags,omitempty"`
}
type FileMeta struct {
Name string
Size int64
ModTime time.Time
IsDir bool
}
type ReadSeekCloser interface {
io.ReadSeeker
io.Closer
}
type Storage interface {
List(userID int64, rel string) ([]FileEntry, error)
Mkdir(userID int64, rel string) error
Save(userID int64, rel string, src multipart.File) error
SaveBytes(userID int64, rel string, data []byte) error
Delete(userID int64, rel string) error
Stat(userID int64, rel string) (FileMeta, error)
OpenReadSeeker(userID int64, rel string) (ReadSeekCloser, error)
}
func main() {
if maybeRunHashCommand() {
return
}
cfg := loadConfig()
db, err := openDB(cfg.DBPath)
if err != nil {
log.Fatalf("db open failed: %v", err)
}
if err := migrate(db); err != nil {
log.Fatalf("migrate failed: %v", err)
}
storage, err := buildStorage(cfg)
if err != nil {
log.Fatalf("storage init failed: %v", err)
}
s := &Server{db: db, config: cfg, storage: storage, limiter: newRateLimiter()}
r := mux.NewRouter()
r.Use(s.recoverMiddleware)
r.Use(s.securityHeadersMiddleware)
r.Use(s.hostGuardMiddleware)
r.Use(s.corsMiddleware)
r.Use(s.bodyLimitMiddleware)
r.Use(s.rateLimitMiddleware)
r.Use(s.requestLogMiddleware)
r.HandleFunc("/api/health", s.handleHealth).Methods(http.MethodGet)
r.HandleFunc("/api/auth/register", s.handleRegisterDisabled).Methods(http.MethodPost)
r.HandleFunc("/api/auth/login", s.handleLogin).Methods(http.MethodPost)
r.HandleFunc("/api/auth/refresh", s.handleRefresh).Methods(http.MethodPost)
r.HandleFunc("/api/auth/logout", s.handleLogout).Methods(http.MethodPost)
r.HandleFunc("/api/admin/login", s.handleAdminLogin).Methods(http.MethodPost)
r.HandleFunc("/api/admin/logout", s.handleAdminLogout).Methods(http.MethodPost)
r.HandleFunc("/api/share/{token}", s.handleSharedDownload).Methods(http.MethodGet, http.MethodHead)
protected := r.PathPrefix("/api").Subrouter()
protected.Use(s.authMiddleware)
protected.HandleFunc("/auth/me", s.handleMe).Methods(http.MethodGet)
protected.HandleFunc("/user/preferences", s.handleSetPreferences).Methods(http.MethodPost)
protected.HandleFunc("/files", s.handleListFiles).Methods(http.MethodGet)
protected.HandleFunc("/files/upload", s.handleUpload).Methods(http.MethodPost)
protected.HandleFunc("/files/download", s.handleDownload).Methods(http.MethodGet, http.MethodHead)
protected.HandleFunc("/files/download-batch", s.handleBatchDownload).Methods(http.MethodPost)
protected.HandleFunc("/files/move-batch", s.handleBatchMove).Methods(http.MethodPost)
protected.HandleFunc("/files/preview", s.handlePreview).Methods(http.MethodGet, http.MethodHead)
protected.HandleFunc("/files/text", s.handleReadTextFile).Methods(http.MethodGet)
protected.HandleFunc("/files/text", s.handleWriteTextFile).Methods(http.MethodPut)
protected.HandleFunc("/files", s.handleDelete).Methods(http.MethodDelete)
protected.HandleFunc("/files/folder", s.handleCreateFolder).Methods(http.MethodPost)
protected.HandleFunc("/files/share", s.handleCreateShareLink).Methods(http.MethodPost)
protected.HandleFunc("/files/tags", s.handleListFileTags).Methods(http.MethodGet)
protected.HandleFunc("/files/tags", s.handleAddFileTag).Methods(http.MethodPost)
protected.HandleFunc("/files/tags", s.handleDeleteFileTag).Methods(http.MethodDelete)
admin := r.PathPrefix("/api/admin").Subrouter()
admin.Use(s.adminMiddleware)
admin.HandleFunc("/me", s.handleAdminMe).Methods(http.MethodGet)
admin.HandleFunc("/users", s.handleAdminUsersList).Methods(http.MethodGet)
admin.HandleFunc("/users", s.handleAdminUserCreate).Methods(http.MethodPost)
admin.HandleFunc("/users/{id}", s.handleAdminUserDelete).Methods(http.MethodDelete)
r.PathPrefix("/").Handler(s.staticHandler())
server := &http.Server{
Addr: cfg.Addr,
Handler: r,
ReadHeaderTimeout: 10 * time.Second,
}
log.Printf("listening on %s", cfg.Addr)
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("server failed: %v", err)
}
}
func loadConfig() Config {
_ = godotenv.Load(".env", "../.env")
dbPath := getEnv("DB_PATH", "./app.db")
storageRoot := getEnv("STORAGE_ROOT", "./users")
if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil {
log.Fatalf("failed to create db path parent: %v", err)
}
if err := os.MkdirAll(storageRoot, 0o755); err != nil {
log.Fatalf("failed to create storage root: %v", err)
}
cfg := Config{
Addr: getEnv("ADDR", ":8080"),
DBPath: dbPath,
StorageRoot: storageRoot,
AppDomain: strings.ToLower(strings.TrimSpace(getEnv("APP_DOMAIN", "file.example.com"))),
AllowedHost: strings.ToLower(getEnv("ALLOWED_HOST", "")),
CORSOrigin: getEnv("CORS_ALLOWED_ORIGIN", ""),
CookieSecure: getEnv("COOKIE_SECURE", "false") == "true",
MaxBodyBytes: int64(getEnvInt("MAX_BODY_MB", 8)) * 1024 * 1024,
RateLimitPerMin: getEnvInt("RATE_LIMIT_PER_MIN", 240),
AuthRateLimitPerMin: getEnvInt("AUTH_RATE_LIMIT_PER_MIN", 30),
JWTSecret: getEnv("JWT_SECRET", "dev-change-me-immediately"),
AccessTTL: 15 * time.Minute,
RefreshTTL: 30 * 24 * time.Hour,
ShareDefaultTTL: 24 * time.Hour,
AdminSessionTTL: 12 * time.Hour,
AdminLogin: getEnv("ADMIN_LOGIN", "admin"),
AdminPasswordHash: getEnv("ADMIN_PASSWORD_HASH", ""),
StorageBackend: getEnv("STORAGE_BACKEND", "local"),
SMBHost: getEnv("SMB_HOST", ""),
SMBShare: getEnv("SMB_SHARE", ""),
SMBUser: getEnv("SMB_USER", ""),
SMBPass: getEnv("SMB_PASS", ""),
SMBDomain: getEnv("SMB_DOMAIN", ""),
SMBBasePath: getEnv("SMB_BASE_PATH", "driveflow"),
SMBConnectTimout: 5 * time.Second,
}
if cfg.AllowedHost == "" {
cfg.AllowedHost = cfg.AppDomain
}
if cfg.CORSOrigin == "" {
cfg.CORSOrigin = "https://" + cfg.AppDomain
}
if cfg.JWTSecret == "dev-change-me-immediately" {
log.Println("warning: JWT_SECRET is using default development value")
}
if strings.TrimSpace(cfg.AdminPasswordHash) == "" {
log.Fatal("ADMIN_PASSWORD_HASH is required. Generate one with: go run . hash-admin <password>")
}
return cfg
}
func openDB(path string) (*sql.DB, error) {
db, err := sql.Open("sqlite", path)
if err != nil {
return nil, err
}
if err := db.Ping(); err != nil {
return nil, err
}
return db, nil
}
func migrate(db *sql.DB) error {
stmts := []string{
`CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
email TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
theme TEXT NOT NULL DEFAULT 'dracula',
color_mode TEXT NOT NULL DEFAULT 'auto',
archive_format TEXT NOT NULL DEFAULT 'zip',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);`,
`CREATE TABLE IF NOT EXISTS refresh_tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
token_hash TEXT NOT NULL,
expires_at TIMESTAMP NOT NULL,
revoked_at TIMESTAMP,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
user_agent TEXT,
ip TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
);`,
`CREATE INDEX IF NOT EXISTS idx_refresh_tokens_hash ON refresh_tokens(token_hash);`,
`CREATE TABLE IF NOT EXISTS share_links (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
rel_path TEXT NOT NULL,
token_hash TEXT NOT NULL UNIQUE,
expires_at TIMESTAMP NOT NULL,
max_downloads INTEGER,
download_count INTEGER NOT NULL DEFAULT 0,
revoked_at TIMESTAMP,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(user_id) REFERENCES users(id)
);`,
`CREATE INDEX IF NOT EXISTS idx_share_links_hash ON share_links(token_hash);`,
`CREATE TABLE IF NOT EXISTS file_tags (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
rel_path TEXT NOT NULL,
tag TEXT NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
UNIQUE(user_id, rel_path, tag),
FOREIGN KEY(user_id) REFERENCES users(id)
);`,
`CREATE INDEX IF NOT EXISTS idx_file_tags_user_path ON file_tags(user_id, rel_path);`,
`CREATE INDEX IF NOT EXISTS idx_file_tags_user_tag ON file_tags(user_id, tag);`,
}
for _, stmt := range stmts {
if _, err := db.Exec(stmt); err != nil {
return err
}
}
if _, err := db.Exec(`ALTER TABLE users ADD COLUMN color_mode TEXT NOT NULL DEFAULT 'auto'`); err != nil {
if !strings.Contains(strings.ToLower(err.Error()), "duplicate column") {
return err
}
}
if _, err := db.Exec(`ALTER TABLE users ADD COLUMN archive_format TEXT NOT NULL DEFAULT 'zip'`); err != nil {
if !strings.Contains(strings.ToLower(err.Error()), "duplicate column") {
return err
}
}
return nil
}
func buildStorage(cfg Config) (Storage, error) {
if strings.EqualFold(cfg.StorageBackend, "smb") {
if cfg.SMBHost == "" || cfg.SMBShare == "" || cfg.SMBUser == "" {
return nil, fmt.Errorf("SMB_HOST, SMB_SHARE, SMB_USER must be set for smb backend")
}
return &SMBStorage{cfg: cfg}, nil
}
root := cfg.StorageRoot
if err := os.MkdirAll(root, 0o755); err != nil {
return nil, err
}
return &LocalStorage{root: root}, nil
}
func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
}
func (s *Server) staticHandler() http.Handler {
webRoot, err := fs.Sub(embeddedWeb, "web/dist")
if err != nil {
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "web assets unavailable", http.StatusServiceUnavailable)
})
}
fileServer := http.FileServer(http.FS(webRoot))
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/api/") {
http.NotFound(w, r)
return
}
requestPath := strings.TrimPrefix(path.Clean(r.URL.Path), "/")
if requestPath == "." || requestPath == "" {
requestPath = "index.html"
}
if _, statErr := fs.Stat(webRoot, requestPath); statErr == nil {
fileServer.ServeHTTP(w, r)
return
}
index, readErr := fs.ReadFile(webRoot, "index.html")
if readErr != nil {
http.Error(w, "web app is not bundled", http.StatusServiceUnavailable)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(index)
})
}
func (s *Server) handleRegisterDisabled(w http.ResponseWriter, _ *http.Request) {
writeErr(w, http.StatusForbidden, "public registration is disabled; ask an administrator")
}
type authInput struct {
Username string `json:"username"`
Password string `json:"password"`
}
func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
var in authInput
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
writeErr(w, http.StatusBadRequest, "invalid payload")
return
}
user, hash, err := s.findUserWithHash(in.Username)
if err != nil {
log.Printf("auth.login.failed username=%q ip=%q reason=%q", in.Username, clientIP(r), "user_not_found")
writeErr(w, http.StatusUnauthorized, "invalid credentials")
return
}
if !verifyPasswordHash(hash, in.Password) {
log.Printf("auth.login.failed username=%q ip=%q reason=%q", in.Username, clientIP(r), "invalid_password")
writeErr(w, http.StatusUnauthorized, "invalid credentials")
return
}
if err := s.issueUserSession(w, r, user.ID); err != nil {
log.Printf("auth.login.failed username=%q user_id=%d ip=%q reason=%q", user.Username, user.ID, clientIP(r), "session_issue_failed")
writeErr(w, http.StatusInternalServerError, "failed to create session")
return
}
log.Printf("auth.login.success user_id=%d username=%q ip=%q", user.ID, user.Username, clientIP(r))
writeJSON(w, http.StatusOK, user)
}
func (s *Server) handleRefresh(w http.ResponseWriter, r *http.Request) {
rt, err := r.Cookie("refresh_token")
if err != nil || rt.Value == "" {
writeErr(w, http.StatusUnauthorized, "missing refresh token")
return
}
uid, err := s.consumeRefreshToken(rt.Value)
if err != nil {
writeErr(w, http.StatusUnauthorized, "invalid refresh token")
return
}
if err := s.issueUserSession(w, r, uid); err != nil {
writeErr(w, http.StatusInternalServerError, "failed to refresh session")
return
}
writeJSON(w, http.StatusOK, map[string]string{"status": "refreshed"})
}
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
if cookie, err := r.Cookie("access_token"); err == nil && cookie.Value != "" {
claims := &AccessClaims{}
if tkn, parseErr := jwt.ParseWithClaims(cookie.Value, claims, func(token *jwt.Token) (any, error) {
return []byte(s.config.JWTSecret), nil
}); parseErr == nil && tkn.Valid {
log.Printf("auth.logout user_id=%d ip=%q", claims.UserID, clientIP(r))
}
}
rt, _ := r.Cookie("refresh_token")
if rt != nil && rt.Value != "" {
_, _ = s.db.Exec(`UPDATE refresh_tokens SET revoked_at = CURRENT_TIMESTAMP WHERE token_hash = ?`, hashToken(rt.Value))
}
clearCookie(w, "access_token", s.config.CookieSecure)
clearCookie(w, "refresh_token", s.config.CookieSecure)
writeJSON(w, http.StatusOK, map[string]string{"status": "logged_out"})
}
type adminLoginInput struct {
Login string `json:"login"`
Password string `json:"password"`
}
func (s *Server) handleAdminLogin(w http.ResponseWriter, r *http.Request) {
var in adminLoginInput
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
writeErr(w, http.StatusBadRequest, "invalid payload")
return
}
if subtleConstantTimeEq(strings.TrimSpace(in.Login), s.config.AdminLogin) == 0 || !verifyAdminPasswordHash(s.config.AdminPasswordHash, in.Password) {
log.Printf("auth.admin_login.failed login=%q ip=%q", strings.TrimSpace(in.Login), clientIP(r))
writeErr(w, http.StatusUnauthorized, "invalid admin credentials")
return
}
now := time.Now()
claims := AdminClaims{
Login: s.config.AdminLogin,
Role: "admin",
RegisteredClaims: jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(s.config.AdminSessionTTL)),
Subject: s.config.AdminLogin,
},
}
token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(s.config.JWTSecret))
if err != nil {
writeErr(w, http.StatusInternalServerError, "failed to issue admin session")
return
}
setCookie(w, "admin_token", token, int(s.config.AdminSessionTTL.Seconds()), s.config.CookieSecure)
log.Printf("auth.admin_login.success login=%q ip=%q", s.config.AdminLogin, clientIP(r))
writeJSON(w, http.StatusOK, map[string]string{"login": s.config.AdminLogin})
}
func (s *Server) handleAdminLogout(w http.ResponseWriter, _ *http.Request) {
log.Printf("auth.admin_logout")
clearCookie(w, "admin_token", s.config.CookieSecure)
writeJSON(w, http.StatusOK, map[string]string{"status": "logged_out"})
}
func (s *Server) handleMe(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
user, err := s.findUser(uid)
if err != nil {
writeErr(w, http.StatusNotFound, "user not found")
return
}
writeJSON(w, http.StatusOK, user)
}
type userPrefInput struct {
Theme string `json:"theme"`
ColorMode string `json:"colorMode"`
ArchiveFmt string `json:"archiveFormat"`
}
func (s *Server) handleSetPreferences(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
var in userPrefInput
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
writeErr(w, http.StatusBadRequest, "invalid payload")
return
}
in.Theme = normalizeTheme(in.Theme)
in.ColorMode = normalizeColorMode(in.ColorMode)
in.ArchiveFmt = normalizeArchiveFormat(in.ArchiveFmt)
if in.ArchiveFmt == "" {
in.ArchiveFmt = "zip"
}
if _, err := s.db.Exec(`UPDATE users SET theme = ?, color_mode = ?, archive_format = ? WHERE id = ?`, in.Theme, in.ColorMode, in.ArchiveFmt, uid); err != nil {
writeErr(w, http.StatusInternalServerError, "failed to update preferences")
return
}
writeJSON(w, http.StatusOK, map[string]string{"theme": in.Theme, "colorMode": in.ColorMode, "archiveFormat": in.ArchiveFmt})
}
func (s *Server) handleListFiles(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
rel := r.URL.Query().Get("path")
log.Printf("file.list user_id=%d path=%q", uid, normalizePath(rel))
entries, err := s.storage.List(uid, rel)
if err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
if len(entries) > 0 {
paths := make([]string, 0, len(entries))
for _, e := range entries {
paths = append(paths, normalizePath(e.Path))
}
tagsByPath, tagErr := s.fileTagsForPaths(uid, paths)
if tagErr == nil {
for i := range entries {
entries[i].Tags = tagsByPath[normalizePath(entries[i].Path)]
}
}
}
writeJSON(w, http.StatusOK, map[string]any{"path": normalizePath(rel), "entries": entries})
}
func (s *Server) handleUpload(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
relDir := r.URL.Query().Get("path")
if err := r.ParseMultipartForm(256 << 20); err != nil {
writeErr(w, http.StatusBadRequest, "failed to parse upload")
return
}
files := r.MultipartForm.File["file"]
if len(files) == 0 {
writeErr(w, http.StatusBadRequest, "no files provided")
return
}
for _, fh := range files {
log.Printf("file.upload user_id=%d dir=%q name=%q size=%d", uid, normalizePath(relDir), fh.Filename, fh.Size)
src, err := fh.Open()
if err != nil {
writeErr(w, http.StatusBadRequest, fmt.Sprintf("cannot open %s", fh.Filename))
return
}
relName, err := normalizeUploadRelativePath(fh.Filename)
if err != nil {
_ = src.Close()
writeErr(w, http.StatusBadRequest, err.Error())
return
}
target := path.Join(normalizePath(relDir), relName)
err = s.storage.Save(uid, target, src)
_ = src.Close()
if err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
}
writeJSON(w, http.StatusCreated, map[string]string{"status": "uploaded"})
}
func (s *Server) handleDownload(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
rel := r.URL.Query().Get("path")
log.Printf("file.download user_id=%d path=%q", uid, normalizePath(rel))
if err := s.serveFile(w, r, uid, rel, false, r.URL.Query().Get("archive")); err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
}
}
func (s *Server) handleBatchDownload(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
var in struct {
Paths []string `json:"paths"`
Archive string `json:"archive"`
}
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
writeErr(w, http.StatusBadRequest, "invalid payload")
return
}
if len(in.Paths) == 0 {
writeErr(w, http.StatusBadRequest, "no paths selected")
return
}
if len(in.Paths) > 200 {
writeErr(w, http.StatusBadRequest, "too many selected paths")
return
}
format, err := s.resolveArchiveFormat(uid, in.Archive)
if err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
archivePath, name, ctype, err := s.createBatchArchiveTemp(uid, in.Paths, format)
if err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
defer os.Remove(archivePath)
f, err := os.Open(archivePath)
if err != nil {
writeErr(w, http.StatusInternalServerError, "failed to open archive")
return
}
defer f.Close()
st, err := f.Stat()
if err != nil {
writeErr(w, http.StatusInternalServerError, "failed to stat archive")
return
}
w.Header().Set("Accept-Ranges", "bytes")
w.Header().Set("Content-Type", ctype)
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name))
http.ServeContent(w, r, name, st.ModTime(), f)
}
func (s *Server) handleBatchMove(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
var in struct {
Paths []string `json:"paths"`
Destination string `json:"destination"`
}
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
writeErr(w, http.StatusBadRequest, "invalid payload")
return
}
if len(in.Paths) == 0 {
writeErr(w, http.StatusBadRequest, "no paths selected")
return
}
destDir := normalizePath(in.Destination)
if err := s.storage.Mkdir(uid, destDir); err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
moved := 0
for _, p := range in.Paths {
src := normalizePath(p)
if src == "/" {
continue
}
meta, err := s.storage.Stat(uid, src)
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid path")
return
}
if meta.IsDir {
srcPrefix := strings.TrimSuffix(src, "/") + "/"
if destDir == src || strings.HasPrefix(destDir, srcPrefix) {
writeErr(w, http.StatusBadRequest, "cannot move a folder into itself")
return
}
}
base := path.Base(src)
dst := path.Join(destDir, base)
dst, err = s.nextMoveTarget(uid, dst)
if err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
if err := s.copyPath(uid, src, dst); err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
if err := s.storage.Delete(uid, src); err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
s.moveTags(uid, src, dst)
moved++
}
writeJSON(w, http.StatusOK, map[string]any{"status": "moved", "count": moved})
}
func (s *Server) handlePreview(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
rel := r.URL.Query().Get("path")
log.Printf("file.preview user_id=%d path=%q", uid, normalizePath(rel))
if err := s.serveFile(w, r, uid, rel, true, ""); err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
}
}
func (s *Server) handleReadTextFile(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
rel := normalizePath(r.URL.Query().Get("path"))
if rel == "/" {
writeErr(w, http.StatusBadRequest, "path is required")
return
}
if !isMarkdownExtension(path.Ext(rel)) {
writeErr(w, http.StatusBadRequest, "only markdown files are editable")
return
}
meta, err := s.storage.Stat(uid, rel)
if err != nil || meta.IsDir {
writeErr(w, http.StatusBadRequest, "file not found")
return
}
rc, err := s.storage.OpenReadSeeker(uid, rel)
if err != nil {
writeErr(w, http.StatusBadRequest, "file not found")
return
}
defer rc.Close()
const maxTextBytes = 2 << 20
data, err := io.ReadAll(io.LimitReader(rc, maxTextBytes+1))
if err != nil {
writeErr(w, http.StatusInternalServerError, "failed to read file")
return
}
if len(data) > maxTextBytes {
writeErr(w, http.StatusBadRequest, "markdown file is too large (max 2MB)")
return
}
writeJSON(w, http.StatusOK, map[string]any{"path": rel, "content": string(data), "size": len(data)})
}
func (s *Server) handleWriteTextFile(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
var in struct {
Path string `json:"path"`
Content string `json:"content"`
}
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
writeErr(w, http.StatusBadRequest, "invalid payload")
return
}
rel := normalizePath(in.Path)
if rel == "/" {
writeErr(w, http.StatusBadRequest, "path is required")
return
}
if !isMarkdownExtension(path.Ext(rel)) {
writeErr(w, http.StatusBadRequest, "only markdown files are editable")
return
}
if len(in.Content) > 2<<20 {
writeErr(w, http.StatusBadRequest, "markdown file is too large (max 2MB)")
return
}
if err := s.storage.SaveBytes(uid, rel, []byte(in.Content)); err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]any{"path": rel, "status": "saved"})
}
func (s *Server) handleListFileTags(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
rel := normalizePath(r.URL.Query().Get("path"))
rows, err := s.db.Query(`SELECT tag FROM file_tags WHERE user_id = ? AND rel_path = ? ORDER BY tag ASC`, uid, rel)
if err != nil {
writeErr(w, http.StatusInternalServerError, "failed to load tags")
return
}
defer rows.Close()
tags := make([]string, 0)
for rows.Next() {
var tag string
if rows.Scan(&tag) == nil {
tags = append(tags, tag)
}
}
writeJSON(w, http.StatusOK, map[string]any{"path": rel, "tags": tags})
}
func (s *Server) handleAddFileTag(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
var in struct {
Path string `json:"path"`
Tag string `json:"tag"`
}
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
writeErr(w, http.StatusBadRequest, "invalid payload")
return
}
rel := normalizePath(in.Path)
tag, ok := normalizeTag(in.Tag)
if !ok {
writeErr(w, http.StatusBadRequest, "tag must be 1-24 chars: a-z 0-9 dash underscore")
return
}
if _, err := s.storage.Stat(uid, rel); err != nil {
writeErr(w, http.StatusBadRequest, "file not found")
return
}
if _, err := s.db.Exec(`INSERT OR IGNORE INTO file_tags(user_id, rel_path, tag) VALUES (?, ?, ?)`, uid, rel, tag); err != nil {
writeErr(w, http.StatusInternalServerError, "failed to save tag")
return
}
writeJSON(w, http.StatusCreated, map[string]any{"path": rel, "tag": tag})
}
func (s *Server) handleDeleteFileTag(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
rel := normalizePath(r.URL.Query().Get("path"))
tag, ok := normalizeTag(r.URL.Query().Get("tag"))
if !ok {
writeErr(w, http.StatusBadRequest, "invalid tag")
return
}
if _, err := s.db.Exec(`DELETE FROM file_tags WHERE user_id = ? AND rel_path = ? AND tag = ?`, uid, rel, tag); err != nil {
writeErr(w, http.StatusInternalServerError, "failed to delete tag")
return
}
writeJSON(w, http.StatusOK, map[string]any{"path": rel, "tag": tag, "status": "deleted"})
}
func (s *Server) serveFile(w http.ResponseWriter, r *http.Request, uid int64, rel string, inline bool, archiveQuery string) error {
meta, err := s.storage.Stat(uid, rel)
if err != nil {
return err
}
if meta.IsDir {
if inline {
return fmt.Errorf("cannot preview a directory")
}
format, err := s.resolveArchiveFormat(uid, archiveQuery)
if err != nil {
return err
}
return s.serveDirectoryArchive(w, r, uid, rel, format)
}
rc, err := s.storage.OpenReadSeeker(uid, rel)
if err != nil {
return err
}
defer rc.Close()
name := path.Base(normalizePath(rel))
if name == "." || name == "/" || name == "" {
name = "download"
}
ctype := mime.TypeByExtension(strings.ToLower(filepath.Ext(name)))
if ctype == "" {
ctype = "application/octet-stream"
}
dispositionType := "attachment"
if inline {
dispositionType = "inline"
}
w.Header().Set("Accept-Ranges", "bytes")
w.Header().Set("Content-Type", ctype)
w.Header().Set("Content-Disposition", fmt.Sprintf("%s; filename=%q", dispositionType, name))
http.ServeContent(w, r, name, meta.ModTime, rc)
return nil
}
func (s *Server) resolveArchiveFormat(uid int64, archiveQuery string) (string, error) {
requested := normalizeArchiveFormat(archiveQuery)
if requested != "" {
return requested, nil
}
var fmtPref string
err := s.db.QueryRow(`SELECT archive_format FROM users WHERE id = ?`, uid).Scan(&fmtPref)
if err != nil {
return "zip", nil
}
fmtPref = normalizeArchiveFormat(fmtPref)
if fmtPref == "" {
return "zip", nil
}
return fmtPref, nil
}
func (s *Server) serveDirectoryArchive(w http.ResponseWriter, r *http.Request, uid int64, rel, format string) error {
base := path.Base(normalizePath(rel))
if base == "/" || base == "." || base == "" {
base = "folder"
}
archivePath, downloadName, ctype, err := s.createArchiveTemp(uid, rel, base, format)
if err != nil {
return err
}
defer os.Remove(archivePath)
f, err := os.Open(archivePath)
if err != nil {
return err
}
defer f.Close()
st, err := f.Stat()
if err != nil {
return err
}
w.Header().Set("Accept-Ranges", "bytes")
w.Header().Set("Content-Type", ctype)
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", downloadName))
http.ServeContent(w, r, downloadName, st.ModTime(), f)
return nil
}
func (s *Server) createArchiveTemp(uid int64, rel, base, format string) (string, string, string, error) {
switch format {
case "rar":
return s.createRarTemp(uid, rel, base)
case "tar.gz":
return s.createTarGzTemp(uid, rel, base)
case "lz4":
return s.createTarLz4Temp(uid, rel, base)
default:
return s.createZipTemp(uid, rel, base)
}
}
func (s *Server) createBatchArchiveTemp(uid int64, paths []string, format string) (string, string, string, error) {
workdir, err := os.MkdirTemp("", "filez-batch-*")
if err != nil {
return "", "", "", err
}
defer os.RemoveAll(workdir)
root := filepath.Join(workdir, "selection")
if err := os.MkdirAll(root, 0o755); err != nil {
return "", "", "", err
}
used := map[string]int{}
for _, raw := range paths {
rel := normalizePath(raw)
if rel == "/" {
continue
}
meta, err := s.storage.Stat(uid, rel)
if err != nil {
return "", "", "", fmt.Errorf("invalid path: %s", rel)
}
base := path.Base(rel)
if base == "." || base == "/" || base == "" {
base = "item"
}
targetName := uniqueArchiveName(base, used)
target := filepath.Join(root, targetName)
if meta.IsDir {
if err := os.MkdirAll(target, 0o755); err != nil {
return "", "", "", err
}
}
if err := s.materializePath(uid, rel, target); err != nil {
return "", "", "", err
}
}
baseName := "files"
switch format {
case "tar.gz":
return createTarGzFromLocalDir(root, baseName)
case "lz4":
return createTarLz4FromLocalDir(root, baseName)
case "rar":
return createRarFromLocalDir(root, baseName)
default:
return createZipFromLocalDir(root, baseName)
}
}
func (s *Server) createZipTemp(uid int64, rel, base string) (string, string, string, error) {
tmp, err := os.CreateTemp("", "filez-*.zip")
if err != nil {
return "", "", "", err
}
defer tmp.Close()
zw := zip.NewWriter(tmp)
if err := s.addPathToZip(zw, uid, rel, base); err != nil {
zw.Close()
return "", "", "", err
}
if err := zw.Close(); err != nil {
return "", "", "", err
}
return tmp.Name(), base + ".zip", "application/zip", nil
}
func (s *Server) addPathToZip(zw *zip.Writer, uid int64, rel string, zipBase string) error {
meta, err := s.storage.Stat(uid, rel)
if err != nil {
return err
}
if meta.IsDir {
if zipBase != "" {
hdr := &zip.FileHeader{Name: strings.TrimPrefix(zipBase, "/") + "/", Method: zip.Store}
hdr.Modified = meta.ModTime
if _, err := zw.CreateHeader(hdr); err != nil {
return err
}
}
entries, err := s.storage.List(uid, rel)
if err != nil {
return err
}
for _, e := range entries {
childRel := path.Join(normalizePath(rel), e.Name)
childBase := path.Join(zipBase, e.Name)
if err := s.addPathToZip(zw, uid, childRel, childBase); err != nil {
return err
}
}
return nil
}
rc, err := s.storage.OpenReadSeeker(uid, rel)
if err != nil {
return err
}
defer rc.Close()
hdr := &zip.FileHeader{Name: strings.TrimPrefix(zipBase, "/"), Method: zip.Deflate}
hdr.Modified = meta.ModTime
writer, err := zw.CreateHeader(hdr)
if err != nil {
return err
}
_, err = io.Copy(writer, rc)
return err
}
func (s *Server) createRarTemp(uid int64, rel, base string) (string, string, string, error) {
if _, err := exec.LookPath("rar"); err != nil {
return "", "", "", fmt.Errorf("rar format requires 'rar' binary on server; choose zip or tar.gz in settings")
}
workdir, err := os.MkdirTemp("", "filez-rar-*")
if err != nil {
return "", "", "", err
}
root := filepath.Join(workdir, base)
if err := os.MkdirAll(root, 0o755); err != nil {
os.RemoveAll(workdir)
return "", "", "", err
}
if err := s.materializePath(uid, rel, root); err != nil {
os.RemoveAll(workdir)
return "", "", "", err
}
outTmp, err := os.CreateTemp("", "filez-*.rar")
if err != nil {
os.RemoveAll(workdir)
return "", "", "", err
}
archivePath := outTmp.Name()
outTmp.Close()
cmd := exec.Command("rar", "a", "-idq", archivePath, root)
if out, err := cmd.CombinedOutput(); err != nil {
os.Remove(archivePath)
os.RemoveAll(workdir)
return "", "", "", fmt.Errorf("rar creation failed: %s", strings.TrimSpace(string(out)))
}
os.RemoveAll(workdir)
return archivePath, base + ".rar", "application/vnd.rar", nil
}
func (s *Server) createTarGzTemp(uid int64, rel, base string) (string, string, string, error) {
tmp, err := os.CreateTemp("", "filez-*.tar.gz")
if err != nil {
return "", "", "", err
}
defer tmp.Close()
gzw := gzip.NewWriter(tmp)
tw := tar.NewWriter(gzw)
if err := s.addPathToTar(tw, uid, rel, base); err != nil {
tw.Close()
gzw.Close()
return "", "", "", err
}
if err := tw.Close(); err != nil {
gzw.Close()
return "", "", "", err
}
if err := gzw.Close(); err != nil {
return "", "", "", err
}
return tmp.Name(), base + ".tar.gz", "application/gzip", nil
}
func (s *Server) createTarLz4Temp(uid int64, rel, base string) (string, string, string, error) {
if _, err := exec.LookPath("lz4"); err != nil {
return "", "", "", fmt.Errorf("lz4 format requires 'lz4' binary on server; choose zip or tar.gz in settings")
}
tarPath, err := s.createTarTemp(uid, rel, base)
if err != nil {
return "", "", "", err
}
defer os.Remove(tarPath)
outTmp, err := os.CreateTemp("", "filez-*.tar.lz4")
if err != nil {
return "", "", "", err
}
outPath := outTmp.Name()
outTmp.Close()
cmd := exec.Command("lz4", "-z", "-q", tarPath, outPath)
if out, err := cmd.CombinedOutput(); err != nil {
os.Remove(outPath)
return "", "", "", fmt.Errorf("lz4 creation failed: %s", strings.TrimSpace(string(out)))
}
return outPath, base + ".tar.lz4", "application/x-lz4", nil
}
func (s *Server) createTarTemp(uid int64, rel, base string) (string, error) {
tmp, err := os.CreateTemp("", "filez-*.tar")
if err != nil {
return "", err
}
defer tmp.Close()
tw := tar.NewWriter(tmp)
if err := s.addPathToTar(tw, uid, rel, base); err != nil {
tw.Close()
return "", err
}
if err := tw.Close(); err != nil {
return "", err
}
return tmp.Name(), nil
}
func (s *Server) addPathToTar(tw *tar.Writer, uid int64, rel string, tarBase string) error {
meta, err := s.storage.Stat(uid, rel)
if err != nil {
return err
}
if meta.IsDir {
if tarBase != "" {
hdr := &tar.Header{Name: strings.TrimPrefix(tarBase, "/") + "/", Mode: 0o755, ModTime: meta.ModTime, Typeflag: tar.TypeDir}
if err := tw.WriteHeader(hdr); err != nil {
return err
}
}
entries, err := s.storage.List(uid, rel)
if err != nil {
return err
}
for _, e := range entries {
childRel := path.Join(normalizePath(rel), e.Name)
childBase := path.Join(tarBase, e.Name)
if err := s.addPathToTar(tw, uid, childRel, childBase); err != nil {
return err
}
}
return nil
}
rc, err := s.storage.OpenReadSeeker(uid, rel)
if err != nil {
return err
}
defer rc.Close()
hdr := &tar.Header{Name: strings.TrimPrefix(tarBase, "/"), Mode: 0o644, Size: meta.Size, ModTime: meta.ModTime, Typeflag: tar.TypeReg}
if err := tw.WriteHeader(hdr); err != nil {
return err
}
_, err = io.Copy(tw, rc)
return err
}
func (s *Server) materializePath(uid int64, rel string, localPath string) error {
meta, err := s.storage.Stat(uid, rel)
if err != nil {
return err
}
if meta.IsDir {
entries, err := s.storage.List(uid, rel)
if err != nil {
return err
}
for _, e := range entries {
childRel := path.Join(normalizePath(rel), e.Name)
childLocal := filepath.Join(localPath, e.Name)
if e.IsDir {
if err := os.MkdirAll(childLocal, 0o755); err != nil {
return err
}
}
if err := s.materializePath(uid, childRel, childLocal); err != nil {
return err
}
}
return nil
}
rc, err := s.storage.OpenReadSeeker(uid, rel)
if err != nil {
return err
}
defer rc.Close()
out, err := os.Create(localPath)
if err != nil {
return err
}
defer out.Close()
_, err = io.Copy(out, rc)
return err
}
func (s *Server) copyPath(uid int64, src, dst string) error {
meta, err := s.storage.Stat(uid, src)
if err != nil {
return err
}
if meta.IsDir {
if err := s.storage.Mkdir(uid, dst); err != nil {
return err
}
entries, err := s.storage.List(uid, src)
if err != nil {
return err
}
for _, e := range entries {
childSrc := path.Join(src, e.Name)
childDst := path.Join(dst, e.Name)
if err := s.copyPath(uid, childSrc, childDst); err != nil {
return err
}
}
return nil
}
rc, err := s.storage.OpenReadSeeker(uid, src)
if err != nil {
return err
}
defer rc.Close()
data, err := io.ReadAll(rc)
if err != nil {
return err
}
return s.storage.SaveBytes(uid, dst, data)
}
func (s *Server) nextMoveTarget(uid int64, dst string) (string, error) {
norm := normalizePath(dst)
if _, err := s.storage.Stat(uid, norm); err != nil {
return norm, nil
}
ext := path.Ext(norm)
base := strings.TrimSuffix(path.Base(norm), ext)
dir := path.Dir(norm)
for i := 2; i <= 9999; i++ {
candidate := path.Join(dir, fmt.Sprintf("%s-%d%s", base, i, ext))
if _, err := s.storage.Stat(uid, candidate); err != nil {
return candidate, nil
}
}
return "", fmt.Errorf("cannot allocate destination name")
}
func (s *Server) moveTags(uid int64, src, dst string) {
_, _ = s.db.Exec(`UPDATE file_tags SET rel_path = ? WHERE user_id = ? AND rel_path = ?`, dst, uid, src)
prefixFrom := strings.TrimSuffix(src, "/") + "/"
prefixTo := strings.TrimSuffix(dst, "/") + "/"
rows, err := s.db.Query(`SELECT rel_path, tag FROM file_tags WHERE user_id = ? AND rel_path LIKE ?`, uid, prefixFrom+"%")
if err != nil {
return
}
defer rows.Close()
type rowItem struct{ p, t string }
items := make([]rowItem, 0)
for rows.Next() {
var rp, tg string
if rows.Scan(&rp, &tg) == nil {
items = append(items, rowItem{p: rp, t: tg})
}
}
for _, it := range items {
next := prefixTo + strings.TrimPrefix(it.p, prefixFrom)
_, _ = s.db.Exec(`DELETE FROM file_tags WHERE user_id = ? AND rel_path = ? AND tag = ?`, uid, it.p, it.t)
_, _ = s.db.Exec(`INSERT OR IGNORE INTO file_tags(user_id, rel_path, tag) VALUES (?, ?, ?)`, uid, next, it.t)
}
}
func uniqueArchiveName(base string, used map[string]int) string {
base = strings.TrimSpace(base)
if base == "" {
base = "item"
}
count := used[base]
used[base] = count + 1
if count == 0 {
return base
}
ext := path.Ext(base)
name := strings.TrimSuffix(base, ext)
return fmt.Sprintf("%s-%d%s", name, count+1, ext)
}
func createZipFromLocalDir(root, baseName string) (string, string, string, error) {
tmp, err := os.CreateTemp("", "filez-batch-*.zip")
if err != nil {
return "", "", "", err
}
defer tmp.Close()
zw := zip.NewWriter(tmp)
err = filepath.WalkDir(root, func(p string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
rel, err := filepath.Rel(root, p)
if err != nil {
return err
}
if rel == "." {
return nil
}
rel = filepath.ToSlash(rel)
info, err := d.Info()
if err != nil {
return err
}
if d.IsDir() {
hdr := &zip.FileHeader{Name: rel + "/", Method: zip.Store}
hdr.Modified = info.ModTime()
_, err = zw.CreateHeader(hdr)
return err
}
hdr, err := zip.FileInfoHeader(info)
if err != nil {
return err
}
hdr.Name = rel
hdr.Method = zip.Deflate
w, err := zw.CreateHeader(hdr)
if err != nil {
return err
}
f, err := os.Open(p)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(w, f)
return err
})
if err != nil {
zw.Close()
return "", "", "", err
}
if err := zw.Close(); err != nil {
return "", "", "", err
}
return tmp.Name(), baseName + ".zip", "application/zip", nil
}
func createTarGzFromLocalDir(root, baseName string) (string, string, string, error) {
tmp, err := os.CreateTemp("", "filez-batch-*.tar.gz")
if err != nil {
return "", "", "", err
}
defer tmp.Close()
gzw := gzip.NewWriter(tmp)
tw := tar.NewWriter(gzw)
err = filepath.WalkDir(root, func(p string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
rel, err := filepath.Rel(root, p)
if err != nil {
return err
}
if rel == "." {
return nil
}
rel = filepath.ToSlash(rel)
info, err := d.Info()
if err != nil {
return err
}
hdr, err := tar.FileInfoHeader(info, "")
if err != nil {
return err
}
if d.IsDir() {
hdr.Name = rel + "/"
} else {
hdr.Name = rel
}
if err := tw.WriteHeader(hdr); err != nil {
return err
}
if d.IsDir() {
return nil
}
f, err := os.Open(p)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(tw, f)
return err
})
if err != nil {
tw.Close()
gzw.Close()
return "", "", "", err
}
if err := tw.Close(); err != nil {
gzw.Close()
return "", "", "", err
}
if err := gzw.Close(); err != nil {
return "", "", "", err
}
return tmp.Name(), baseName + ".tar.gz", "application/gzip", nil
}
func createTarFromLocalDir(root string) (string, error) {
tmp, err := os.CreateTemp("", "filez-batch-*.tar")
if err != nil {
return "", err
}
defer tmp.Close()
tw := tar.NewWriter(tmp)
err = filepath.WalkDir(root, func(p string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
rel, err := filepath.Rel(root, p)
if err != nil {
return err
}
if rel == "." {
return nil
}
rel = filepath.ToSlash(rel)
info, err := d.Info()
if err != nil {
return err
}
hdr, err := tar.FileInfoHeader(info, "")
if err != nil {
return err
}
if d.IsDir() {
hdr.Name = rel + "/"
} else {
hdr.Name = rel
}
if err := tw.WriteHeader(hdr); err != nil {
return err
}
if d.IsDir() {
return nil
}
f, err := os.Open(p)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(tw, f)
return err
})
if err != nil {
tw.Close()
return "", err
}
if err := tw.Close(); err != nil {
return "", err
}
return tmp.Name(), nil
}
func createTarLz4FromLocalDir(root, baseName string) (string, string, string, error) {
if _, err := exec.LookPath("lz4"); err != nil {
return "", "", "", fmt.Errorf("lz4 format requires 'lz4' binary on server; choose zip or tar.gz in settings")
}
tarPath, err := createTarFromLocalDir(root)
if err != nil {
return "", "", "", err
}
defer os.Remove(tarPath)
outTmp, err := os.CreateTemp("", "filez-batch-*.tar.lz4")
if err != nil {
return "", "", "", err
}
outPath := outTmp.Name()
outTmp.Close()
cmd := exec.Command("lz4", "-z", "-q", tarPath, outPath)
if out, err := cmd.CombinedOutput(); err != nil {
os.Remove(outPath)
return "", "", "", fmt.Errorf("lz4 creation failed: %s", strings.TrimSpace(string(out)))
}
return outPath, baseName + ".tar.lz4", "application/x-lz4", nil
}
func createRarFromLocalDir(root, baseName string) (string, string, string, error) {
if _, err := exec.LookPath("rar"); err != nil {
return "", "", "", fmt.Errorf("rar format requires 'rar' binary on server; choose zip or tar.gz in settings")
}
outTmp, err := os.CreateTemp("", "filez-batch-*.rar")
if err != nil {
return "", "", "", err
}
outPath := outTmp.Name()
outTmp.Close()
cmd := exec.Command("rar", "a", "-idq", outPath, root)
if out, err := cmd.CombinedOutput(); err != nil {
os.Remove(outPath)
return "", "", "", fmt.Errorf("rar creation failed: %s", strings.TrimSpace(string(out)))
}
return outPath, baseName + ".rar", "application/vnd.rar", nil
}
func (s *Server) handleDelete(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
rel := r.URL.Query().Get("path")
norm := normalizePath(rel)
log.Printf("file.delete user_id=%d path=%q", uid, norm)
if err := s.storage.Delete(uid, norm); err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
_, _ = s.db.Exec(`DELETE FROM file_tags WHERE user_id = ? AND (rel_path = ? OR rel_path LIKE ?)`, uid, norm, strings.TrimSuffix(norm, "/")+"/%")
writeJSON(w, http.StatusOK, map[string]string{"status": "deleted"})
}
func (s *Server) handleCreateFolder(w http.ResponseWriter, r *http.Request) {
type payload struct {
Path string `json:"path"`
Name string `json:"name"`
}
var in payload
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
writeErr(w, http.StatusBadRequest, "invalid payload")
return
}
uid := userIDFromContext(r.Context())
rel := path.Join(normalizePath(in.Path), path.Base(strings.TrimSpace(in.Name)))
log.Printf("file.mkdir user_id=%d path=%q", uid, normalizePath(rel))
if rel == "/" || rel == "." {
writeErr(w, http.StatusBadRequest, "invalid folder name")
return
}
if err := s.storage.Mkdir(uid, rel); err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
writeJSON(w, http.StatusCreated, map[string]string{"status": "created"})
}
type shareInput struct {
Path string `json:"path"`
ExpiresMinutes int `json:"expiresMinutes"`
MaxDownloads *int `json:"maxDownloads"`
AllowPreview bool `json:"allowPreview"`
PreferredInline bool `json:"preferredInline"`
}
func (s *Server) handleCreateShareLink(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
var in shareInput
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
writeErr(w, http.StatusBadRequest, "invalid payload")
return
}
rel := normalizePath(in.Path)
log.Printf("file.share.create user_id=%d path=%q", uid, rel)
meta, err := s.storage.Stat(uid, rel)
if err != nil {
writeErr(w, http.StatusBadRequest, "file not found")
return
}
if meta.IsDir {
writeErr(w, http.StatusBadRequest, "cannot create share for folder")
return
}
ttl := s.config.ShareDefaultTTL
if in.ExpiresMinutes > 0 {
ttl = time.Duration(in.ExpiresMinutes) * time.Minute
}
if ttl < 5*time.Minute {
ttl = 5 * time.Minute
}
if ttl > 30*24*time.Hour {
ttl = 30 * 24 * time.Hour
}
token, err := randomToken()
if err != nil {
writeErr(w, http.StatusInternalServerError, "failed to create token")
return
}
var maxDownloads any
if in.MaxDownloads != nil && *in.MaxDownloads > 0 {
maxDownloads = *in.MaxDownloads
}
expiresAt := time.Now().Add(ttl)
_, err = s.db.Exec(`INSERT INTO share_links(user_id, rel_path, token_hash, expires_at, max_downloads) VALUES (?, ?, ?, ?, ?)`,
uid,
rel,
hashToken(token),
expiresAt,
maxDownloads,
)
if err != nil {
writeErr(w, http.StatusInternalServerError, "failed to create share")
return
}
shareURL := fmt.Sprintf("%s://%s/api/share/%s", schemeOf(r), r.Host, token)
writeJSON(w, http.StatusCreated, map[string]any{
"url": shareURL,
"token": token,
"path": rel,
"expiresAt": expiresAt,
})
}
func (s *Server) handleSharedDownload(w http.ResponseWriter, r *http.Request) {
token := mux.Vars(r)["token"]
if strings.TrimSpace(token) == "" {
writeErr(w, http.StatusBadRequest, "missing token")
return
}
var uid int64
var rel string
var expiresAt time.Time
var revokedAt sql.NullTime
var maxDownloads sql.NullInt64
var downloadCount int64
err := s.db.QueryRow(`SELECT user_id, rel_path, expires_at, revoked_at, max_downloads, download_count FROM share_links WHERE token_hash = ?`, hashToken(token)).
Scan(&uid, &rel, &expiresAt, &revokedAt, &maxDownloads, &downloadCount)
if err != nil {
writeErr(w, http.StatusNotFound, "share link not found")
return
}
if revokedAt.Valid || expiresAt.Before(time.Now()) {
writeErr(w, http.StatusGone, "share link expired")
return
}
if maxDownloads.Valid && downloadCount >= maxDownloads.Int64 {
writeErr(w, http.StatusGone, "share link download limit reached")
return
}
if _, err := s.db.Exec(`UPDATE share_links SET download_count = download_count + 1 WHERE token_hash = ?`, hashToken(token)); err != nil {
writeErr(w, http.StatusInternalServerError, "failed to track download")
return
}
log.Printf("file.share.download user_id=%d path=%q ip=%q", uid, normalizePath(rel), clientIP(r))
if err := s.serveFile(w, r, uid, rel, true, ""); err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
}
}
func (s *Server) handleAdminMe(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, http.StatusOK, map[string]string{"login": s.config.AdminLogin})
}
func (s *Server) handleAdminUsersList(w http.ResponseWriter, _ *http.Request) {
rows, err := s.db.Query(`SELECT id, email, theme, color_mode, archive_format FROM users ORDER BY id ASC`)
if err != nil {
writeErr(w, http.StatusInternalServerError, "failed to load users")
return
}
defer rows.Close()
users := make([]User, 0)
for rows.Next() {
var u User
if err := rows.Scan(&u.ID, &u.Username, &u.Theme, &u.ColorMode, &u.Archive); err != nil {
continue
}
u.Theme = normalizeTheme(u.Theme)
u.ColorMode = normalizeColorMode(u.ColorMode)
users = append(users, u)
}
writeJSON(w, http.StatusOK, map[string]any{"users": users})
}
type adminCreateUserInput struct {
Username string `json:"username"`
Password string `json:"password"`
Theme string `json:"theme"`
ColorMode string `json:"colorMode"`
}
func (s *Server) handleAdminUserCreate(w http.ResponseWriter, r *http.Request) {
var in adminCreateUserInput
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
writeErr(w, http.StatusBadRequest, "invalid payload")
return
}
in.Theme = normalizeTheme(in.Theme)
in.ColorMode = normalizeColorMode(in.ColorMode)
user, err := s.createUser(in.Username, in.Password, in.Theme, in.ColorMode)
if err != nil {
if strings.Contains(strings.ToLower(err.Error()), "exists") {
writeErr(w, http.StatusConflict, err.Error())
return
}
writeErr(w, http.StatusBadRequest, err.Error())
return
}
writeJSON(w, http.StatusCreated, user)
log.Printf("admin.user.create user_id=%d username=%q", user.ID, user.Username)
}
func (s *Server) handleAdminUserDelete(w http.ResponseWriter, r *http.Request) {
idStr := mux.Vars(r)["id"]
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
writeErr(w, http.StatusBadRequest, "invalid user id")
return
}
if _, err := s.db.Exec(`DELETE FROM users WHERE id = ?`, id); err != nil {
writeErr(w, http.StatusInternalServerError, "failed to delete user")
return
}
_, _ = s.db.Exec(`DELETE FROM refresh_tokens WHERE user_id = ?`, id)
_, _ = s.db.Exec(`UPDATE share_links SET revoked_at = CURRENT_TIMESTAMP WHERE user_id = ?`, id)
_ = s.storage.Delete(id, "/")
writeJSON(w, http.StatusOK, map[string]string{"status": "deleted"})
log.Printf("admin.user.delete user_id=%d", id)
}
func (s *Server) recoverMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if v := recover(); v != nil {
log.Printf("panic recovered path=%q err=%v", r.URL.Path, v)
writeErr(w, http.StatusInternalServerError, "internal server error")
}
}()
next.ServeHTTP(w, r)
})
}
func (s *Server) securityHeadersMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()")
w.Header().Set("Content-Security-Policy", "default-src 'self'; connect-src 'self'; img-src 'self' data: blob:; style-src 'self' 'unsafe-inline'; script-src 'self'; frame-ancestors 'none'; base-uri 'self'; form-action 'self'")
if r.TLS != nil {
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
}
next.ServeHTTP(w, r)
})
}
func (s *Server) bodyLimitMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/api/files/upload") {
next.ServeHTTP(w, r)
return
}
if strings.HasPrefix(r.URL.Path, "/api/") {
r.Body = http.MaxBytesReader(w, r.Body, s.config.MaxBodyBytes)
}
next.ServeHTTP(w, r)
})
}
func (s *Server) rateLimitMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.Path, "/api/") {
next.ServeHTTP(w, r)
return
}
ip := clientIP(r)
limit := s.config.RateLimitPerMin
if strings.Contains(r.URL.Path, "/auth/") || strings.Contains(r.URL.Path, "/admin/login") {
limit = s.config.AuthRateLimitPerMin
}
if !s.limiter.allow(ip+":"+r.URL.Path, limit, time.Now()) {
writeErr(w, http.StatusTooManyRequests, "rate limit exceeded")
return
}
next.ServeHTTP(w, r)
})
}
func (s *Server) requestLogMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK}
next.ServeHTTP(rec, r)
if strings.HasPrefix(r.URL.Path, "/api/") {
log.Printf(
"api.request method=%s path=%q status=%d duration_ms=%d ip=%q ua=%q",
r.Method,
r.URL.Path,
rec.status,
time.Since(start).Milliseconds(),
clientIP(r),
limitString(r.UserAgent(), 120),
)
}
})
}
func (s *Server) createUser(username, password, theme, colorMode string) (User, error) {
username = strings.ToLower(strings.TrimSpace(username))
if len(username) < 3 || len(username) > 32 {
return User{}, fmt.Errorf("username must be 3-32 characters")
}
for _, ch := range username {
if (ch < 'a' || ch > 'z') && (ch < '0' || ch > '9') && ch != '_' && ch != '-' && ch != '.' {
return User{}, fmt.Errorf("username can contain only a-z, 0-9, dot, dash, underscore")
}
}
if len(password) < 10 {
return User{}, fmt.Errorf("password must be at least 10 characters")
}
hash, err := hashPasswordArgon2ID(password)
if err != nil {
return User{}, fmt.Errorf("failed to hash password")
}
res, err := s.db.Exec(`INSERT INTO users(email, password_hash, theme, color_mode, archive_format) VALUES (?, ?, ?, ?, ?)`, username, hash, normalizeTheme(theme), normalizeColorMode(colorMode), "zip")
if err != nil {
if strings.Contains(strings.ToLower(err.Error()), "unique") {
return User{}, fmt.Errorf("account already exists")
}
return User{}, err
}
id, _ := res.LastInsertId()
if err := s.storage.Mkdir(id, "/"); err != nil {
return User{}, fmt.Errorf("failed to provision user storage: %w", err)
}
return User{ID: id, Username: username, Theme: normalizeTheme(theme), ColorMode: normalizeColorMode(colorMode), Archive: "zip"}, nil
}
func (s *Server) findUserWithHash(username string) (User, string, error) {
username = strings.ToLower(strings.TrimSpace(username))
var u User
var hash string
err := s.db.QueryRow(`SELECT id, email, password_hash, theme, color_mode, archive_format FROM users WHERE email = ?`, username).
Scan(&u.ID, &u.Username, &hash, &u.Theme, &u.ColorMode, &u.Archive)
if err != nil {
return User{}, "", err
}
u.Theme = normalizeTheme(u.Theme)
u.ColorMode = normalizeColorMode(u.ColorMode)
return u, hash, nil
}
func (s *Server) findUser(id int64) (User, error) {
var u User
err := s.db.QueryRow(`SELECT id, email, theme, color_mode, archive_format FROM users WHERE id = ?`, id).
Scan(&u.ID, &u.Username, &u.Theme, &u.ColorMode, &u.Archive)
u.Theme = normalizeTheme(u.Theme)
u.ColorMode = normalizeColorMode(u.ColorMode)
return u, err
}
func (s *Server) issueUserSession(w http.ResponseWriter, r *http.Request, userID int64) error {
now := time.Now()
claims := AccessClaims{
UserID: userID,
RegisteredClaims: jwt.RegisteredClaims{
Subject: strconv.FormatInt(userID, 10),
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(s.config.AccessTTL)),
},
}
access, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(s.config.JWTSecret))
if err != nil {
return err
}
refresh, err := randomToken()
if err != nil {
return err
}
if _, err := s.db.Exec(`INSERT INTO refresh_tokens(user_id, token_hash, expires_at, user_agent, ip) VALUES (?, ?, ?, ?, ?)`,
userID,
hashToken(refresh),
now.Add(s.config.RefreshTTL),
limitString(r.UserAgent(), 255),
limitString(clientIP(r), 64),
); err != nil {
return err
}
setCookie(w, "access_token", access, int(s.config.AccessTTL.Seconds()), s.config.CookieSecure)
setCookie(w, "refresh_token", refresh, int(s.config.RefreshTTL.Seconds()), s.config.CookieSecure)
return nil
}
func (s *Server) consumeRefreshToken(token string) (int64, error) {
var id int64
var uid int64
var expiresAt time.Time
var revokedAt sql.NullTime
err := s.db.QueryRow(`SELECT id, user_id, expires_at, revoked_at FROM refresh_tokens WHERE token_hash = ?`, hashToken(token)).
Scan(&id, &uid, &expiresAt, &revokedAt)
if err != nil {
return 0, err
}
if revokedAt.Valid || expiresAt.Before(time.Now()) {
return 0, fmt.Errorf("refresh token expired or revoked")
}
if _, err := s.db.Exec(`UPDATE refresh_tokens SET revoked_at = CURRENT_TIMESTAMP WHERE id = ?`, id); err != nil {
return 0, err
}
return uid, nil
}
func (s *Server) authMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("access_token")
if err != nil || cookie.Value == "" {
writeErr(w, http.StatusUnauthorized, "missing access token")
return
}
claims := &AccessClaims{}
tkn, err := jwt.ParseWithClaims(cookie.Value, claims, func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method")
}
return []byte(s.config.JWTSecret), nil
})
if err != nil || !tkn.Valid {
writeErr(w, http.StatusUnauthorized, "invalid access token")
return
}
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), userIDKey, claims.UserID)))
})
}
func (s *Server) adminMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("admin_token")
if err != nil || cookie.Value == "" {
writeErr(w, http.StatusUnauthorized, "missing admin token")
return
}
claims := &AdminClaims{}
tkn, err := jwt.ParseWithClaims(cookie.Value, claims, func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method")
}
return []byte(s.config.JWTSecret), nil
})
if err != nil || !tkn.Valid || claims.Role != "admin" || claims.Login != s.config.AdminLogin {
writeErr(w, http.StatusUnauthorized, "invalid admin session")
return
}
next.ServeHTTP(w, r)
})
}
func (s *Server) corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := strings.TrimSpace(r.Header.Get("Origin"))
if origin != "" && s.config.CORSOrigin != "" && origin != s.config.CORSOrigin {
writeErr(w, http.StatusForbidden, "cors origin denied")
return
}
if s.config.CORSOrigin != "" {
w.Header().Set("Access-Control-Allow-Origin", s.config.CORSOrigin)
w.Header().Set("Vary", "Origin")
}
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
w.Header().Set("Access-Control-Allow-Methods", "GET,POST,DELETE,OPTIONS,HEAD")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}
func (s *Server) hostGuardMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.TrimSpace(s.config.AllowedHost) == "" {
next.ServeHTTP(w, r)
return
}
host := strings.ToLower(hostOnly(r.Host))
if host == strings.ToLower(s.config.AllowedHost) {
next.ServeHTTP(w, r)
return
}
if strings.HasPrefix(r.URL.Path, "/api/") {
writeErr(w, http.StatusForbidden, fmt.Sprintf("access denied: host must be %s", s.config.AllowedHost))
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(fmt.Sprintf("<!doctype html><html><body><h1>Access denied</h1><p>This service is only available at <b>%s</b>.</p></body></html>", s.config.AllowedHost)))
})
}
type LocalStorage struct {
root string
}
func (l *LocalStorage) userRoot(userID int64) string {
return filepath.Join(l.root, strconv.FormatInt(userID, 10))
}
func (l *LocalStorage) fullPath(userID int64, rel string) (string, error) {
root := l.userRoot(userID)
if err := os.MkdirAll(root, 0o755); err != nil {
return "", err
}
clean := filepath.FromSlash(strings.TrimPrefix(normalizePath(rel), "/"))
full := filepath.Clean(filepath.Join(root, clean))
if full != root && !strings.HasPrefix(full, root+string(os.PathSeparator)) {
return "", fmt.Errorf("invalid path")
}
return full, nil
}
func (l *LocalStorage) List(userID int64, rel string) ([]FileEntry, error) {
full, err := l.fullPath(userID, rel)
if err != nil {
return nil, err
}
entries, err := os.ReadDir(full)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return []FileEntry{}, nil
}
return nil, err
}
out := make([]FileEntry, 0, len(entries))
for _, e := range entries {
info, err := e.Info()
if err != nil {
continue
}
out = append(out, FileEntry{
Name: e.Name(),
Path: path.Join(normalizePath(rel), e.Name()),
IsDir: e.IsDir(),
Size: info.Size(),
ModTime: info.ModTime(),
})
}
return out, nil
}
func (l *LocalStorage) Mkdir(userID int64, rel string) error {
full, err := l.fullPath(userID, rel)
if err != nil {
return err
}
return os.MkdirAll(full, 0o755)
}
func (l *LocalStorage) Save(userID int64, rel string, src multipart.File) error {
full, err := l.fullPath(userID, rel)
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil {
return err
}
dst, err := os.Create(full)
if err != nil {
return err
}
defer dst.Close()
_, err = io.Copy(dst, src)
return err
}
func (l *LocalStorage) SaveBytes(userID int64, rel string, data []byte) error {
full, err := l.fullPath(userID, rel)
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil {
return err
}
return os.WriteFile(full, data, 0o644)
}
func (l *LocalStorage) Delete(userID int64, rel string) error {
if normalizePath(rel) == "/" {
full, err := l.fullPath(userID, rel)
if err != nil {
return err
}
return os.RemoveAll(full)
}
full, err := l.fullPath(userID, rel)
if err != nil {
return err
}
return os.RemoveAll(full)
}
func (l *LocalStorage) Stat(userID int64, rel string) (FileMeta, error) {
full, err := l.fullPath(userID, rel)
if err != nil {
return FileMeta{}, err
}
st, err := os.Stat(full)
if err != nil {
return FileMeta{}, err
}
return FileMeta{Name: st.Name(), Size: st.Size(), ModTime: st.ModTime(), IsDir: st.IsDir()}, nil
}
func (l *LocalStorage) OpenReadSeeker(userID int64, rel string) (ReadSeekCloser, error) {
full, err := l.fullPath(userID, rel)
if err != nil {
return nil, err
}
return os.Open(full)
}
type SMBStorage struct {
cfg Config
}
type smbConn struct {
conn net.Conn
session *smb2.Session
share *smb2.Share
}
func (c *smbConn) Close() {
if c.share != nil {
_ = c.share.Umount()
}
if c.session != nil {
_ = c.session.Logoff()
}
if c.conn != nil {
_ = c.conn.Close()
}
}
type smbReadSeekCloser struct {
file *smb2.File
conn *smbConn
}
func (s *smbReadSeekCloser) Read(p []byte) (int, error) {
return s.file.Read(p)
}
func (s *smbReadSeekCloser) Seek(offset int64, whence int) (int64, error) {
return s.file.Seek(offset, whence)
}
func (s *smbReadSeekCloser) Close() error {
_ = s.file.Close()
s.conn.Close()
return nil
}
func (smbs *SMBStorage) openConnection() (*smbConn, error) {
conn, err := net.DialTimeout("tcp", smbs.cfg.SMBHost, smbs.cfg.SMBConnectTimout)
if err != nil {
return nil, err
}
dialer := &smb2.Dialer{Initiator: &smb2.NTLMInitiator{
User: smbs.cfg.SMBUser,
Password: smbs.cfg.SMBPass,
Domain: smbs.cfg.SMBDomain,
}}
session, err := dialer.Dial(conn)
if err != nil {
_ = conn.Close()
return nil, err
}
share, err := session.Mount(smbs.cfg.SMBShare)
if err != nil {
_ = session.Logoff()
_ = conn.Close()
return nil, err
}
return &smbConn{conn: conn, session: session, share: share}, nil
}
func (smbs *SMBStorage) withShare(fn func(share *smb2.Share) error) error {
conn, err := smbs.openConnection()
if err != nil {
return err
}
defer conn.Close()
return fn(conn.share)
}
func (smbs *SMBStorage) userPath(userID int64, rel string) string {
base := path.Join(smbs.cfg.SMBBasePath, strconv.FormatInt(userID, 10))
clean := strings.TrimPrefix(normalizePath(rel), "/")
if clean == "" {
return base
}
return path.Join(base, clean)
}
func (smbs *SMBStorage) List(userID int64, rel string) ([]FileEntry, error) {
out := make([]FileEntry, 0)
err := smbs.withShare(func(share *smb2.Share) error {
target := smbs.userPath(userID, rel)
if err := share.MkdirAll(target, 0o755); err != nil {
return err
}
entries, err := share.ReadDir(target)
if err != nil {
return err
}
for _, e := range entries {
out = append(out, FileEntry{
Name: e.Name(),
Path: path.Join(normalizePath(rel), e.Name()),
IsDir: e.IsDir(),
Size: e.Size(),
ModTime: e.ModTime(),
})
}
return nil
})
return out, err
}
func (smbs *SMBStorage) Mkdir(userID int64, rel string) error {
return smbs.withShare(func(share *smb2.Share) error {
return share.MkdirAll(smbs.userPath(userID, rel), 0o755)
})
}
func (smbs *SMBStorage) Save(userID int64, rel string, src multipart.File) error {
return smbs.withShare(func(share *smb2.Share) error {
target := smbs.userPath(userID, rel)
if err := share.MkdirAll(path.Dir(target), 0o755); err != nil {
return err
}
f, err := share.Create(target)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(f, src)
return err
})
}
func (smbs *SMBStorage) SaveBytes(userID int64, rel string, data []byte) error {
return smbs.withShare(func(share *smb2.Share) error {
target := smbs.userPath(userID, rel)
if err := share.MkdirAll(path.Dir(target), 0o755); err != nil {
return err
}
f, err := share.Create(target)
if err != nil {
return err
}
defer f.Close()
_, err = f.Write(data)
return err
})
}
func (smbs *SMBStorage) Delete(userID int64, rel string) error {
return smbs.withShare(func(share *smb2.Share) error {
target := smbs.userPath(userID, rel)
if normalizePath(rel) == "/" {
return share.RemoveAll(target)
}
if err := share.Remove(target); err == nil {
return nil
}
return share.RemoveAll(target)
})
}
func (smbs *SMBStorage) Stat(userID int64, rel string) (FileMeta, error) {
var out FileMeta
err := smbs.withShare(func(share *smb2.Share) error {
st, err := share.Stat(smbs.userPath(userID, rel))
if err != nil {
return err
}
out = FileMeta{Name: st.Name(), Size: st.Size(), ModTime: st.ModTime(), IsDir: st.IsDir()}
return nil
})
return out, err
}
func (smbs *SMBStorage) OpenReadSeeker(userID int64, rel string) (ReadSeekCloser, error) {
conn, err := smbs.openConnection()
if err != nil {
return nil, err
}
f, err := conn.share.Open(smbs.userPath(userID, rel))
if err != nil {
conn.Close()
return nil, err
}
return &smbReadSeekCloser{file: f, conn: conn}, nil
}
func normalizePath(rel string) string {
clean := path.Clean("/" + strings.TrimSpace(rel))
if clean == "." {
return "/"
}
return clean
}
func normalizeUploadRelativePath(v string) (string, error) {
v = strings.ReplaceAll(strings.TrimSpace(v), "\\", "/")
v = strings.TrimPrefix(v, "/")
v = path.Clean(v)
if v == "." || v == "" {
return "", fmt.Errorf("invalid upload path")
}
if strings.HasPrefix(v, "../") || strings.Contains(v, "/../") || strings.HasPrefix(v, "/") {
return "", fmt.Errorf("invalid upload path")
}
return v, nil
}
func normalizeTheme(v string) string {
v = strings.TrimSpace(strings.ToLower(v))
switch v {
case "dracula", "nord", "monokai", "solarized", "github":
return v
case "material", "glass", "desktop", "auto":
return "dracula"
default:
return "dracula"
}
}
func normalizeColorMode(v string) string {
v = strings.TrimSpace(strings.ToLower(v))
switch v {
case "light", "dark", "auto":
return v
default:
return "auto"
}
}
func normalizeArchiveFormat(v string) string {
v = strings.TrimSpace(strings.ToLower(v))
switch v {
case "zip", "rar", "tar.gz", "lz4":
return v
default:
return ""
}
}
func normalizeTag(v string) (string, bool) {
v = strings.ToLower(strings.TrimSpace(v))
if len(v) < 1 || len(v) > 24 {
return "", false
}
for _, ch := range v {
if (ch < 'a' || ch > 'z') && (ch < '0' || ch > '9') && ch != '-' && ch != '_' {
return "", false
}
}
return v, true
}
func isMarkdownExtension(ext string) bool {
ext = strings.ToLower(strings.TrimPrefix(ext, "."))
switch ext {
case "md", "markdown":
return true
default:
return false
}
}
func (s *Server) fileTagsForPaths(uid int64, paths []string) (map[string][]string, error) {
out := make(map[string][]string, len(paths))
if len(paths) == 0 {
return out, nil
}
args := make([]any, 0, len(paths)+1)
args = append(args, uid)
pl := make([]string, len(paths))
for i, p := range paths {
norm := normalizePath(p)
pl[i] = "?"
args = append(args, norm)
}
q := `SELECT rel_path, tag FROM file_tags WHERE user_id = ? AND rel_path IN (` + strings.Join(pl, ",") + `) ORDER BY rel_path, tag`
rows, err := s.db.Query(q, args...)
if err != nil {
return out, err
}
defer rows.Close()
for rows.Next() {
var rel string
var tag string
if rows.Scan(&rel, &tag) == nil {
out[rel] = append(out[rel], tag)
}
}
return out, nil
}
func setCookie(w http.ResponseWriter, name, value string, maxAge int, secure bool) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: value,
Path: "/",
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
MaxAge: maxAge,
})
}
func clearCookie(w http.ResponseWriter, name string, secure bool) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: "",
Path: "/",
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
MaxAge: -1,
})
}
func randomToken() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func hashToken(token string) string {
h := sha256.Sum256([]byte(token))
return base64.RawURLEncoding.EncodeToString(h[:])
}
func subtleConstantTimeEq(a, b string) int {
if len(a) != len(b) {
return 0
}
var out byte
for i := 0; i < len(a); i++ {
out |= a[i] ^ b[i]
}
if out == 0 {
return 1
}
return 0
}
func maybeRunHashCommand() bool {
if len(os.Args) < 2 {
return false
}
if os.Args[1] != "hash-admin" {
return false
}
if len(os.Args) < 3 {
fmt.Println("usage: go run . hash-admin <password>")
os.Exit(1)
}
h, err := hashPasswordArgon2ID(os.Args[2])
if err != nil {
fmt.Println("failed to hash password:", err)
os.Exit(1)
}
fmt.Println(h)
return true
}
func hashPasswordArgon2ID(password string) (string, error) {
salt := make([]byte, 16)
if _, err := rand.Read(salt); err != nil {
return "", err
}
timeCost := uint32(3)
memoryCost := uint32(64 * 1024)
threads := uint8(2)
keyLen := uint32(32)
derived := argon2.IDKey([]byte(password), salt, timeCost, memoryCost, threads, keyLen)
saltB64 := base64.RawStdEncoding.EncodeToString(salt)
hashB64 := base64.RawStdEncoding.EncodeToString(derived)
encoded := fmt.Sprintf("argon2id:v=19,m=%d,t=%d,p=%d:%s:%s", memoryCost, timeCost, threads, saltB64, hashB64)
return encoded, nil
}
func verifyPasswordHash(storedHash, password string) bool {
if strings.HasPrefix(storedHash, "argon2id:") {
parts := strings.Split(storedHash, ":")
if len(parts) != 4 {
return false
}
var mem, timeCost uint32
var threads uint8
if _, err := fmt.Sscanf(parts[1], "v=19,m=%d,t=%d,p=%d", &mem, &timeCost, &threads); err != nil {
return false
}
salt, err := base64.RawStdEncoding.DecodeString(parts[2])
if err != nil {
return false
}
expected, err := base64.RawStdEncoding.DecodeString(parts[3])
if err != nil {
return false
}
actual := argon2.IDKey([]byte(password), salt, timeCost, mem, threads, uint32(len(expected)))
return subtle.ConstantTimeCompare(expected, actual) == 1
}
if strings.HasPrefix(storedHash, "argon2id$") {
parts := strings.Split(storedHash, "$")
if len(parts) == 6 {
var mem, timeCost uint32
var threads uint8
if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &mem, &timeCost, &threads); err == nil {
salt, errSalt := base64.RawStdEncoding.DecodeString(parts[4])
expected, errExpected := base64.RawStdEncoding.DecodeString(parts[5])
if errSalt == nil && errExpected == nil {
actual := argon2.IDKey([]byte(password), salt, timeCost, mem, threads, uint32(len(expected)))
return subtle.ConstantTimeCompare(expected, actual) == 1
}
}
}
}
if strings.HasPrefix(storedHash, "sha256:") {
expected := strings.TrimPrefix(storedHash, "sha256:")
sum := sha256.Sum256([]byte(password))
actual := hex.EncodeToString(sum[:])
return subtle.ConstantTimeCompare([]byte(expected), []byte(actual)) == 1
}
bcryptHash := storedHash
if strings.HasPrefix(storedHash, "bcrypt:") {
bcryptHash = strings.TrimPrefix(storedHash, "bcrypt:")
}
if strings.HasPrefix(bcryptHash, "$2a$") || strings.HasPrefix(bcryptHash, "$2b$") || strings.HasPrefix(bcryptHash, "$2y$") {
return bcrypt.CompareHashAndPassword([]byte(bcryptHash), []byte(password)) == nil
}
return false
}
func verifyAdminPasswordHash(storedHash, password string) bool {
return verifyPasswordHash(storedHash, password)
}
func userIDFromContext(ctx context.Context) int64 {
v, _ := ctx.Value(userIDKey).(int64)
return v
}
func schemeOf(r *http.Request) string {
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
return proto
}
if r.TLS != nil {
return "https"
}
return "http"
}
func hostOnly(hostport string) string {
h := strings.TrimSpace(hostport)
if h == "" {
return ""
}
if strings.Contains(h, ":") {
if host, _, err := net.SplitHostPort(h); err == nil {
return host
}
}
return h
}
func clientIP(r *http.Request) string {
for _, h := range []string{"CF-Connecting-IP", "X-Forwarded-For", "X-Real-IP"} {
v := strings.TrimSpace(r.Header.Get(h))
if v != "" {
if strings.Contains(v, ",") {
return strings.TrimSpace(strings.Split(v, ",")[0])
}
return v
}
}
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}
func writeJSON(w http.ResponseWriter, status int, payload any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(payload)
}
func writeErr(w http.ResponseWriter, status int, msg string) {
writeJSON(w, status, map[string]string{"error": msg})
}
func getEnv(key, fallback string) string {
v := strings.TrimSpace(os.Getenv(key))
if v == "" {
return fallback
}
return v
}
func getEnvInt(key string, fallback int) int {
v := strings.TrimSpace(os.Getenv(key))
if v == "" {
return fallback
}
n, err := strconv.Atoi(v)
if err != nil {
return fallback
}
return n
}
func limitString(s string, max int) string {
if len(s) <= max {
return s
}
return s[:max]
}