package main import ( "archive/tar" "archive/zip" "compress/gzip" "context" "crypto/rand" "crypto/sha256" "crypto/subtle" "database/sql" "embed" "encoding/base64" "encoding/hex" "encoding/json" "errors" "fmt" "io" "io/fs" "log" "mime" "mime/multipart" "net" "net/http" "os" "os/exec" "path" "path/filepath" "strconv" "strings" "sync" "time" "github.com/golang-jwt/jwt/v5" "github.com/gorilla/mux" "github.com/hirochachacha/go-smb2" "github.com/joho/godotenv" "golang.org/x/crypto/argon2" "golang.org/x/crypto/bcrypt" _ "modernc.org/sqlite" ) type Config struct { Addr string DBPath string StorageRoot string AppDomain string AllowedHost string CORSOrigin string CookieSecure bool MaxBodyBytes int64 RateLimitPerMin int AuthRateLimitPerMin int JWTSecret string AccessTTL time.Duration RefreshTTL time.Duration ShareDefaultTTL time.Duration AdminSessionTTL time.Duration AdminLogin string AdminPasswordHash string StorageBackend string SMBHost string SMBShare string SMBUser string SMBPass string SMBDomain string SMBBasePath string SMBConnectTimout time.Duration } //go:embed web/dist var embeddedWeb embed.FS type Server struct { db *sql.DB config Config storage Storage limiter *rateLimiter } type rateLimiter struct { mu sync.Mutex entries map[string]*rateEntry } type rateEntry struct { Count int WindowEnds time.Time } func newRateLimiter() *rateLimiter { return &rateLimiter{entries: make(map[string]*rateEntry)} } func (rl *rateLimiter) allow(key string, limit int, now time.Time) bool { rl.mu.Lock() defer rl.mu.Unlock() entry, ok := rl.entries[key] if !ok || now.After(entry.WindowEnds) { rl.entries[key] = &rateEntry{Count: 1, WindowEnds: now.Add(time.Minute)} return true } if entry.Count >= limit { return false } entry.Count++ return true } type statusRecorder struct { http.ResponseWriter status int } func (r *statusRecorder) WriteHeader(code int) { r.status = code r.ResponseWriter.WriteHeader(code) } type key int const ( userIDKey key = 1 ) type AccessClaims struct { UserID int64 `json:"uid"` jwt.RegisteredClaims } type AdminClaims struct { Login string `json:"login"` Role string `json:"role"` jwt.RegisteredClaims } type User struct { ID int64 `json:"id"` Username string `json:"username"` Theme string `json:"theme"` ColorMode string `json:"colorMode"` Archive string `json:"archiveFormat"` } type FileEntry struct { Name string `json:"name"` Path string `json:"path"` IsDir bool `json:"isDir"` Size int64 `json:"size"` ModTime time.Time `json:"modTime"` Tags []string `json:"tags,omitempty"` } type FileMeta struct { Name string Size int64 ModTime time.Time IsDir bool } type ReadSeekCloser interface { io.ReadSeeker io.Closer } type Storage interface { List(userID int64, rel string) ([]FileEntry, error) Mkdir(userID int64, rel string) error Save(userID int64, rel string, src multipart.File) error SaveBytes(userID int64, rel string, data []byte) error Delete(userID int64, rel string) error Stat(userID int64, rel string) (FileMeta, error) OpenReadSeeker(userID int64, rel string) (ReadSeekCloser, error) } func main() { if maybeRunHashCommand() { return } cfg := loadConfig() db, err := openDB(cfg.DBPath) if err != nil { log.Fatalf("db open failed: %v", err) } if err := migrate(db); err != nil { log.Fatalf("migrate failed: %v", err) } storage, err := buildStorage(cfg) if err != nil { log.Fatalf("storage init failed: %v", err) } s := &Server{db: db, config: cfg, storage: storage, limiter: newRateLimiter()} r := mux.NewRouter() r.Use(s.recoverMiddleware) r.Use(s.securityHeadersMiddleware) r.Use(s.hostGuardMiddleware) r.Use(s.corsMiddleware) r.Use(s.bodyLimitMiddleware) r.Use(s.rateLimitMiddleware) r.Use(s.requestLogMiddleware) r.HandleFunc("/api/health", s.handleHealth).Methods(http.MethodGet) r.HandleFunc("/api/auth/register", s.handleRegisterDisabled).Methods(http.MethodPost) r.HandleFunc("/api/auth/login", s.handleLogin).Methods(http.MethodPost) r.HandleFunc("/api/auth/refresh", s.handleRefresh).Methods(http.MethodPost) r.HandleFunc("/api/auth/logout", s.handleLogout).Methods(http.MethodPost) r.HandleFunc("/api/admin/login", s.handleAdminLogin).Methods(http.MethodPost) r.HandleFunc("/api/admin/logout", s.handleAdminLogout).Methods(http.MethodPost) r.HandleFunc("/api/share/{token}", s.handleSharedDownload).Methods(http.MethodGet, http.MethodHead) protected := r.PathPrefix("/api").Subrouter() protected.Use(s.authMiddleware) protected.HandleFunc("/auth/me", s.handleMe).Methods(http.MethodGet) protected.HandleFunc("/user/preferences", s.handleSetPreferences).Methods(http.MethodPost) protected.HandleFunc("/files", s.handleListFiles).Methods(http.MethodGet) protected.HandleFunc("/files/upload", s.handleUpload).Methods(http.MethodPost) protected.HandleFunc("/files/download", s.handleDownload).Methods(http.MethodGet, http.MethodHead) protected.HandleFunc("/files/download-batch", s.handleBatchDownload).Methods(http.MethodPost) protected.HandleFunc("/files/move-batch", s.handleBatchMove).Methods(http.MethodPost) protected.HandleFunc("/files/preview", s.handlePreview).Methods(http.MethodGet, http.MethodHead) protected.HandleFunc("/files/text", s.handleReadTextFile).Methods(http.MethodGet) protected.HandleFunc("/files/text", s.handleWriteTextFile).Methods(http.MethodPut) protected.HandleFunc("/files", s.handleDelete).Methods(http.MethodDelete) protected.HandleFunc("/files/folder", s.handleCreateFolder).Methods(http.MethodPost) protected.HandleFunc("/files/share", s.handleCreateShareLink).Methods(http.MethodPost) protected.HandleFunc("/files/tags", s.handleListFileTags).Methods(http.MethodGet) protected.HandleFunc("/files/tags", s.handleAddFileTag).Methods(http.MethodPost) protected.HandleFunc("/files/tags", s.handleDeleteFileTag).Methods(http.MethodDelete) admin := r.PathPrefix("/api/admin").Subrouter() admin.Use(s.adminMiddleware) admin.HandleFunc("/me", s.handleAdminMe).Methods(http.MethodGet) admin.HandleFunc("/users", s.handleAdminUsersList).Methods(http.MethodGet) admin.HandleFunc("/users", s.handleAdminUserCreate).Methods(http.MethodPost) admin.HandleFunc("/users/{id}", s.handleAdminUserDelete).Methods(http.MethodDelete) r.PathPrefix("/").Handler(s.staticHandler()) server := &http.Server{ Addr: cfg.Addr, Handler: r, ReadHeaderTimeout: 10 * time.Second, } log.Printf("listening on %s", cfg.Addr) if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { log.Fatalf("server failed: %v", err) } } func loadConfig() Config { _ = godotenv.Load(".env", "../.env") dbPath := getEnv("DB_PATH", "./app.db") storageRoot := getEnv("STORAGE_ROOT", "./users") if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil { log.Fatalf("failed to create db path parent: %v", err) } if err := os.MkdirAll(storageRoot, 0o755); err != nil { log.Fatalf("failed to create storage root: %v", err) } cfg := Config{ Addr: getEnv("ADDR", ":8080"), DBPath: dbPath, StorageRoot: storageRoot, AppDomain: strings.ToLower(strings.TrimSpace(getEnv("APP_DOMAIN", "file.example.com"))), AllowedHost: strings.ToLower(getEnv("ALLOWED_HOST", "")), CORSOrigin: getEnv("CORS_ALLOWED_ORIGIN", ""), CookieSecure: getEnv("COOKIE_SECURE", "false") == "true", MaxBodyBytes: int64(getEnvInt("MAX_BODY_MB", 8)) * 1024 * 1024, RateLimitPerMin: getEnvInt("RATE_LIMIT_PER_MIN", 240), AuthRateLimitPerMin: getEnvInt("AUTH_RATE_LIMIT_PER_MIN", 30), JWTSecret: getEnv("JWT_SECRET", "dev-change-me-immediately"), AccessTTL: 15 * time.Minute, RefreshTTL: 30 * 24 * time.Hour, ShareDefaultTTL: 24 * time.Hour, AdminSessionTTL: 12 * time.Hour, AdminLogin: getEnv("ADMIN_LOGIN", "admin"), AdminPasswordHash: getEnv("ADMIN_PASSWORD_HASH", ""), StorageBackend: getEnv("STORAGE_BACKEND", "local"), SMBHost: getEnv("SMB_HOST", ""), SMBShare: getEnv("SMB_SHARE", ""), SMBUser: getEnv("SMB_USER", ""), SMBPass: getEnv("SMB_PASS", ""), SMBDomain: getEnv("SMB_DOMAIN", ""), SMBBasePath: getEnv("SMB_BASE_PATH", "driveflow"), SMBConnectTimout: 5 * time.Second, } if cfg.AllowedHost == "" { cfg.AllowedHost = cfg.AppDomain } if cfg.CORSOrigin == "" { cfg.CORSOrigin = "https://" + cfg.AppDomain } if cfg.JWTSecret == "dev-change-me-immediately" { log.Println("warning: JWT_SECRET is using default development value") } if strings.TrimSpace(cfg.AdminPasswordHash) == "" { log.Fatal("ADMIN_PASSWORD_HASH is required. Generate one with: go run . hash-admin ") } return cfg } func openDB(path string) (*sql.DB, error) { db, err := sql.Open("sqlite", path) if err != nil { return nil, err } if err := db.Ping(); err != nil { return nil, err } return db, nil } func migrate(db *sql.DB) error { stmts := []string{ `CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, email TEXT NOT NULL UNIQUE, password_hash TEXT NOT NULL, theme TEXT NOT NULL DEFAULT 'dracula', color_mode TEXT NOT NULL DEFAULT 'auto', archive_format TEXT NOT NULL DEFAULT 'zip', created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP );`, `CREATE TABLE IF NOT EXISTS refresh_tokens ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL, token_hash TEXT NOT NULL, expires_at TIMESTAMP NOT NULL, revoked_at TIMESTAMP, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, user_agent TEXT, ip TEXT, FOREIGN KEY(user_id) REFERENCES users(id) );`, `CREATE INDEX IF NOT EXISTS idx_refresh_tokens_hash ON refresh_tokens(token_hash);`, `CREATE TABLE IF NOT EXISTS share_links ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL, rel_path TEXT NOT NULL, token_hash TEXT NOT NULL UNIQUE, expires_at TIMESTAMP NOT NULL, max_downloads INTEGER, download_count INTEGER NOT NULL DEFAULT 0, revoked_at TIMESTAMP, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY(user_id) REFERENCES users(id) );`, `CREATE INDEX IF NOT EXISTS idx_share_links_hash ON share_links(token_hash);`, `CREATE TABLE IF NOT EXISTS file_tags ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL, rel_path TEXT NOT NULL, tag TEXT NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, UNIQUE(user_id, rel_path, tag), FOREIGN KEY(user_id) REFERENCES users(id) );`, `CREATE INDEX IF NOT EXISTS idx_file_tags_user_path ON file_tags(user_id, rel_path);`, `CREATE INDEX IF NOT EXISTS idx_file_tags_user_tag ON file_tags(user_id, tag);`, } for _, stmt := range stmts { if _, err := db.Exec(stmt); err != nil { return err } } if _, err := db.Exec(`ALTER TABLE users ADD COLUMN color_mode TEXT NOT NULL DEFAULT 'auto'`); err != nil { if !strings.Contains(strings.ToLower(err.Error()), "duplicate column") { return err } } if _, err := db.Exec(`ALTER TABLE users ADD COLUMN archive_format TEXT NOT NULL DEFAULT 'zip'`); err != nil { if !strings.Contains(strings.ToLower(err.Error()), "duplicate column") { return err } } return nil } func buildStorage(cfg Config) (Storage, error) { if strings.EqualFold(cfg.StorageBackend, "smb") { if cfg.SMBHost == "" || cfg.SMBShare == "" || cfg.SMBUser == "" { return nil, fmt.Errorf("SMB_HOST, SMB_SHARE, SMB_USER must be set for smb backend") } return &SMBStorage{cfg: cfg}, nil } root := cfg.StorageRoot if err := os.MkdirAll(root, 0o755); err != nil { return nil, err } return &LocalStorage{root: root}, nil } func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request) { writeJSON(w, http.StatusOK, map[string]string{"status": "ok"}) } func (s *Server) staticHandler() http.Handler { webRoot, err := fs.Sub(embeddedWeb, "web/dist") if err != nil { return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "web assets unavailable", http.StatusServiceUnavailable) }) } fileServer := http.FileServer(http.FS(webRoot)) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if strings.HasPrefix(r.URL.Path, "/api/") { http.NotFound(w, r) return } requestPath := strings.TrimPrefix(path.Clean(r.URL.Path), "/") if requestPath == "." || requestPath == "" { requestPath = "index.html" } if _, statErr := fs.Stat(webRoot, requestPath); statErr == nil { fileServer.ServeHTTP(w, r) return } index, readErr := fs.ReadFile(webRoot, "index.html") if readErr != nil { http.Error(w, "web app is not bundled", http.StatusServiceUnavailable) return } w.Header().Set("Content-Type", "text/html; charset=utf-8") w.WriteHeader(http.StatusOK) _, _ = w.Write(index) }) } func (s *Server) handleRegisterDisabled(w http.ResponseWriter, _ *http.Request) { writeErr(w, http.StatusForbidden, "public registration is disabled; ask an administrator") } type authInput struct { Username string `json:"username"` Password string `json:"password"` } func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) { var in authInput if err := json.NewDecoder(r.Body).Decode(&in); err != nil { writeErr(w, http.StatusBadRequest, "invalid payload") return } user, hash, err := s.findUserWithHash(in.Username) if err != nil { log.Printf("auth.login.failed username=%q ip=%q reason=%q", in.Username, clientIP(r), "user_not_found") writeErr(w, http.StatusUnauthorized, "invalid credentials") return } if !verifyPasswordHash(hash, in.Password) { log.Printf("auth.login.failed username=%q ip=%q reason=%q", in.Username, clientIP(r), "invalid_password") writeErr(w, http.StatusUnauthorized, "invalid credentials") return } if err := s.issueUserSession(w, r, user.ID); err != nil { log.Printf("auth.login.failed username=%q user_id=%d ip=%q reason=%q", user.Username, user.ID, clientIP(r), "session_issue_failed") writeErr(w, http.StatusInternalServerError, "failed to create session") return } log.Printf("auth.login.success user_id=%d username=%q ip=%q", user.ID, user.Username, clientIP(r)) writeJSON(w, http.StatusOK, user) } func (s *Server) handleRefresh(w http.ResponseWriter, r *http.Request) { rt, err := r.Cookie("refresh_token") if err != nil || rt.Value == "" { writeErr(w, http.StatusUnauthorized, "missing refresh token") return } uid, err := s.consumeRefreshToken(rt.Value) if err != nil { writeErr(w, http.StatusUnauthorized, "invalid refresh token") return } if err := s.issueUserSession(w, r, uid); err != nil { writeErr(w, http.StatusInternalServerError, "failed to refresh session") return } writeJSON(w, http.StatusOK, map[string]string{"status": "refreshed"}) } func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) { if cookie, err := r.Cookie("access_token"); err == nil && cookie.Value != "" { claims := &AccessClaims{} if tkn, parseErr := jwt.ParseWithClaims(cookie.Value, claims, func(token *jwt.Token) (any, error) { return []byte(s.config.JWTSecret), nil }); parseErr == nil && tkn.Valid { log.Printf("auth.logout user_id=%d ip=%q", claims.UserID, clientIP(r)) } } rt, _ := r.Cookie("refresh_token") if rt != nil && rt.Value != "" { _, _ = s.db.Exec(`UPDATE refresh_tokens SET revoked_at = CURRENT_TIMESTAMP WHERE token_hash = ?`, hashToken(rt.Value)) } clearCookie(w, "access_token", s.config.CookieSecure) clearCookie(w, "refresh_token", s.config.CookieSecure) writeJSON(w, http.StatusOK, map[string]string{"status": "logged_out"}) } type adminLoginInput struct { Login string `json:"login"` Password string `json:"password"` } func (s *Server) handleAdminLogin(w http.ResponseWriter, r *http.Request) { var in adminLoginInput if err := json.NewDecoder(r.Body).Decode(&in); err != nil { writeErr(w, http.StatusBadRequest, "invalid payload") return } if subtleConstantTimeEq(strings.TrimSpace(in.Login), s.config.AdminLogin) == 0 || !verifyAdminPasswordHash(s.config.AdminPasswordHash, in.Password) { log.Printf("auth.admin_login.failed login=%q ip=%q", strings.TrimSpace(in.Login), clientIP(r)) writeErr(w, http.StatusUnauthorized, "invalid admin credentials") return } now := time.Now() claims := AdminClaims{ Login: s.config.AdminLogin, Role: "admin", RegisteredClaims: jwt.RegisteredClaims{ IssuedAt: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(now.Add(s.config.AdminSessionTTL)), Subject: s.config.AdminLogin, }, } token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(s.config.JWTSecret)) if err != nil { writeErr(w, http.StatusInternalServerError, "failed to issue admin session") return } setCookie(w, "admin_token", token, int(s.config.AdminSessionTTL.Seconds()), s.config.CookieSecure) log.Printf("auth.admin_login.success login=%q ip=%q", s.config.AdminLogin, clientIP(r)) writeJSON(w, http.StatusOK, map[string]string{"login": s.config.AdminLogin}) } func (s *Server) handleAdminLogout(w http.ResponseWriter, _ *http.Request) { log.Printf("auth.admin_logout") clearCookie(w, "admin_token", s.config.CookieSecure) writeJSON(w, http.StatusOK, map[string]string{"status": "logged_out"}) } func (s *Server) handleMe(w http.ResponseWriter, r *http.Request) { uid := userIDFromContext(r.Context()) user, err := s.findUser(uid) if err != nil { writeErr(w, http.StatusNotFound, "user not found") return } writeJSON(w, http.StatusOK, user) } type userPrefInput struct { Theme string `json:"theme"` ColorMode string `json:"colorMode"` ArchiveFmt string `json:"archiveFormat"` } func (s *Server) handleSetPreferences(w http.ResponseWriter, r *http.Request) { uid := userIDFromContext(r.Context()) var in userPrefInput if err := json.NewDecoder(r.Body).Decode(&in); err != nil { writeErr(w, http.StatusBadRequest, "invalid payload") return } in.Theme = normalizeTheme(in.Theme) in.ColorMode = normalizeColorMode(in.ColorMode) in.ArchiveFmt = normalizeArchiveFormat(in.ArchiveFmt) if in.ArchiveFmt == "" { in.ArchiveFmt = "zip" } if _, err := s.db.Exec(`UPDATE users SET theme = ?, color_mode = ?, archive_format = ? WHERE id = ?`, in.Theme, in.ColorMode, in.ArchiveFmt, uid); err != nil { writeErr(w, http.StatusInternalServerError, "failed to update preferences") return } writeJSON(w, http.StatusOK, map[string]string{"theme": in.Theme, "colorMode": in.ColorMode, "archiveFormat": in.ArchiveFmt}) } func (s *Server) handleListFiles(w http.ResponseWriter, r *http.Request) { uid := userIDFromContext(r.Context()) rel := r.URL.Query().Get("path") log.Printf("file.list user_id=%d path=%q", uid, normalizePath(rel)) entries, err := s.storage.List(uid, rel) if err != nil { writeErr(w, http.StatusBadRequest, err.Error()) return } if len(entries) > 0 { paths := make([]string, 0, len(entries)) for _, e := range entries { paths = append(paths, normalizePath(e.Path)) } tagsByPath, tagErr := s.fileTagsForPaths(uid, paths) if tagErr == nil { for i := range entries { entries[i].Tags = tagsByPath[normalizePath(entries[i].Path)] } } } writeJSON(w, http.StatusOK, map[string]any{"path": normalizePath(rel), "entries": entries}) } func (s *Server) handleUpload(w http.ResponseWriter, r *http.Request) { uid := userIDFromContext(r.Context()) relDir := r.URL.Query().Get("path") if err := r.ParseMultipartForm(256 << 20); err != nil { writeErr(w, http.StatusBadRequest, "failed to parse upload") return } files := r.MultipartForm.File["file"] if len(files) == 0 { writeErr(w, http.StatusBadRequest, "no files provided") return } for _, fh := range files { log.Printf("file.upload user_id=%d dir=%q name=%q size=%d", uid, normalizePath(relDir), fh.Filename, fh.Size) src, err := fh.Open() if err != nil { writeErr(w, http.StatusBadRequest, fmt.Sprintf("cannot open %s", fh.Filename)) return } relName, err := normalizeUploadRelativePath(fh.Filename) if err != nil { _ = src.Close() writeErr(w, http.StatusBadRequest, err.Error()) return } target := path.Join(normalizePath(relDir), relName) err = s.storage.Save(uid, target, src) _ = src.Close() if err != nil { writeErr(w, http.StatusBadRequest, err.Error()) return } } writeJSON(w, http.StatusCreated, map[string]string{"status": "uploaded"}) } func (s *Server) handleDownload(w http.ResponseWriter, r *http.Request) { uid := userIDFromContext(r.Context()) rel := r.URL.Query().Get("path") log.Printf("file.download user_id=%d path=%q", uid, normalizePath(rel)) if err := s.serveFile(w, r, uid, rel, false, r.URL.Query().Get("archive")); err != nil { writeErr(w, http.StatusBadRequest, err.Error()) } } func (s *Server) handleBatchDownload(w http.ResponseWriter, r *http.Request) { uid := userIDFromContext(r.Context()) var in struct { Paths []string `json:"paths"` Archive string `json:"archive"` } if err := json.NewDecoder(r.Body).Decode(&in); err != nil { writeErr(w, http.StatusBadRequest, "invalid payload") return } if len(in.Paths) == 0 { writeErr(w, http.StatusBadRequest, "no paths selected") return } if len(in.Paths) > 200 { writeErr(w, http.StatusBadRequest, "too many selected paths") return } format, err := s.resolveArchiveFormat(uid, in.Archive) if err != nil { writeErr(w, http.StatusBadRequest, err.Error()) return } archivePath, name, ctype, err := s.createBatchArchiveTemp(uid, in.Paths, format) if err != nil { writeErr(w, http.StatusBadRequest, err.Error()) return } defer os.Remove(archivePath) f, err := os.Open(archivePath) if err != nil { writeErr(w, http.StatusInternalServerError, "failed to open archive") return } defer f.Close() st, err := f.Stat() if err != nil { writeErr(w, http.StatusInternalServerError, "failed to stat archive") return } w.Header().Set("Accept-Ranges", "bytes") w.Header().Set("Content-Type", ctype) w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name)) http.ServeContent(w, r, name, st.ModTime(), f) } func (s *Server) handleBatchMove(w http.ResponseWriter, r *http.Request) { uid := userIDFromContext(r.Context()) var in struct { Paths []string `json:"paths"` Destination string `json:"destination"` } if err := json.NewDecoder(r.Body).Decode(&in); err != nil { writeErr(w, http.StatusBadRequest, "invalid payload") return } if len(in.Paths) == 0 { writeErr(w, http.StatusBadRequest, "no paths selected") return } destDir := normalizePath(in.Destination) if err := s.storage.Mkdir(uid, destDir); err != nil { writeErr(w, http.StatusBadRequest, err.Error()) return } moved := 0 for _, p := range in.Paths { src := normalizePath(p) if src == "/" { continue } meta, err := s.storage.Stat(uid, src) if err != nil { writeErr(w, http.StatusBadRequest, "invalid path") return } if meta.IsDir { srcPrefix := strings.TrimSuffix(src, "/") + "/" if destDir == src || strings.HasPrefix(destDir, srcPrefix) { writeErr(w, http.StatusBadRequest, "cannot move a folder into itself") return } } base := path.Base(src) dst := path.Join(destDir, base) dst, err = s.nextMoveTarget(uid, dst) if err != nil { writeErr(w, http.StatusBadRequest, err.Error()) return } if err := s.copyPath(uid, src, dst); err != nil { writeErr(w, http.StatusBadRequest, err.Error()) return } if err := s.storage.Delete(uid, src); err != nil { writeErr(w, http.StatusBadRequest, err.Error()) return } s.moveTags(uid, src, dst) moved++ } writeJSON(w, http.StatusOK, map[string]any{"status": "moved", "count": moved}) } func (s *Server) handlePreview(w http.ResponseWriter, r *http.Request) { uid := userIDFromContext(r.Context()) rel := r.URL.Query().Get("path") log.Printf("file.preview user_id=%d path=%q", uid, normalizePath(rel)) if err := s.serveFile(w, r, uid, rel, true, ""); err != nil { writeErr(w, http.StatusBadRequest, err.Error()) } } func (s *Server) handleReadTextFile(w http.ResponseWriter, r *http.Request) { uid := userIDFromContext(r.Context()) rel := normalizePath(r.URL.Query().Get("path")) if rel == "/" { writeErr(w, http.StatusBadRequest, "path is required") return } if !isMarkdownExtension(path.Ext(rel)) { writeErr(w, http.StatusBadRequest, "only markdown files are editable") return } meta, err := s.storage.Stat(uid, rel) if err != nil || meta.IsDir { writeErr(w, http.StatusBadRequest, "file not found") return } rc, err := s.storage.OpenReadSeeker(uid, rel) if err != nil { writeErr(w, http.StatusBadRequest, "file not found") return } defer rc.Close() const maxTextBytes = 2 << 20 data, err := io.ReadAll(io.LimitReader(rc, maxTextBytes+1)) if err != nil { writeErr(w, http.StatusInternalServerError, "failed to read file") return } if len(data) > maxTextBytes { writeErr(w, http.StatusBadRequest, "markdown file is too large (max 2MB)") return } writeJSON(w, http.StatusOK, map[string]any{"path": rel, "content": string(data), "size": len(data)}) } func (s *Server) handleWriteTextFile(w http.ResponseWriter, r *http.Request) { uid := userIDFromContext(r.Context()) var in struct { Path string `json:"path"` Content string `json:"content"` } if err := json.NewDecoder(r.Body).Decode(&in); err != nil { writeErr(w, http.StatusBadRequest, "invalid payload") return } rel := normalizePath(in.Path) if rel == "/" { writeErr(w, http.StatusBadRequest, "path is required") return } if !isMarkdownExtension(path.Ext(rel)) { writeErr(w, http.StatusBadRequest, "only markdown files are editable") return } if len(in.Content) > 2<<20 { writeErr(w, http.StatusBadRequest, "markdown file is too large (max 2MB)") return } if err := s.storage.SaveBytes(uid, rel, []byte(in.Content)); err != nil { writeErr(w, http.StatusBadRequest, err.Error()) return } writeJSON(w, http.StatusOK, map[string]any{"path": rel, "status": "saved"}) } func (s *Server) handleListFileTags(w http.ResponseWriter, r *http.Request) { uid := userIDFromContext(r.Context()) rel := normalizePath(r.URL.Query().Get("path")) rows, err := s.db.Query(`SELECT tag FROM file_tags WHERE user_id = ? AND rel_path = ? ORDER BY tag ASC`, uid, rel) if err != nil { writeErr(w, http.StatusInternalServerError, "failed to load tags") return } defer rows.Close() tags := make([]string, 0) for rows.Next() { var tag string if rows.Scan(&tag) == nil { tags = append(tags, tag) } } writeJSON(w, http.StatusOK, map[string]any{"path": rel, "tags": tags}) } func (s *Server) handleAddFileTag(w http.ResponseWriter, r *http.Request) { uid := userIDFromContext(r.Context()) var in struct { Path string `json:"path"` Tag string `json:"tag"` } if err := json.NewDecoder(r.Body).Decode(&in); err != nil { writeErr(w, http.StatusBadRequest, "invalid payload") return } rel := normalizePath(in.Path) tag, ok := normalizeTag(in.Tag) if !ok { writeErr(w, http.StatusBadRequest, "tag must be 1-24 chars: a-z 0-9 dash underscore") return } if _, err := s.storage.Stat(uid, rel); err != nil { writeErr(w, http.StatusBadRequest, "file not found") return } if _, err := s.db.Exec(`INSERT OR IGNORE INTO file_tags(user_id, rel_path, tag) VALUES (?, ?, ?)`, uid, rel, tag); err != nil { writeErr(w, http.StatusInternalServerError, "failed to save tag") return } writeJSON(w, http.StatusCreated, map[string]any{"path": rel, "tag": tag}) } func (s *Server) handleDeleteFileTag(w http.ResponseWriter, r *http.Request) { uid := userIDFromContext(r.Context()) rel := normalizePath(r.URL.Query().Get("path")) tag, ok := normalizeTag(r.URL.Query().Get("tag")) if !ok { writeErr(w, http.StatusBadRequest, "invalid tag") return } if _, err := s.db.Exec(`DELETE FROM file_tags WHERE user_id = ? AND rel_path = ? AND tag = ?`, uid, rel, tag); err != nil { writeErr(w, http.StatusInternalServerError, "failed to delete tag") return } writeJSON(w, http.StatusOK, map[string]any{"path": rel, "tag": tag, "status": "deleted"}) } func (s *Server) serveFile(w http.ResponseWriter, r *http.Request, uid int64, rel string, inline bool, archiveQuery string) error { meta, err := s.storage.Stat(uid, rel) if err != nil { return err } if meta.IsDir { if inline { return fmt.Errorf("cannot preview a directory") } format, err := s.resolveArchiveFormat(uid, archiveQuery) if err != nil { return err } return s.serveDirectoryArchive(w, r, uid, rel, format) } rc, err := s.storage.OpenReadSeeker(uid, rel) if err != nil { return err } defer rc.Close() name := path.Base(normalizePath(rel)) if name == "." || name == "/" || name == "" { name = "download" } ctype := mime.TypeByExtension(strings.ToLower(filepath.Ext(name))) if ctype == "" { ctype = "application/octet-stream" } dispositionType := "attachment" if inline { dispositionType = "inline" } w.Header().Set("Accept-Ranges", "bytes") w.Header().Set("Content-Type", ctype) w.Header().Set("Content-Disposition", fmt.Sprintf("%s; filename=%q", dispositionType, name)) http.ServeContent(w, r, name, meta.ModTime, rc) return nil } func (s *Server) resolveArchiveFormat(uid int64, archiveQuery string) (string, error) { requested := normalizeArchiveFormat(archiveQuery) if requested != "" { return requested, nil } var fmtPref string err := s.db.QueryRow(`SELECT archive_format FROM users WHERE id = ?`, uid).Scan(&fmtPref) if err != nil { return "zip", nil } fmtPref = normalizeArchiveFormat(fmtPref) if fmtPref == "" { return "zip", nil } return fmtPref, nil } func (s *Server) serveDirectoryArchive(w http.ResponseWriter, r *http.Request, uid int64, rel, format string) error { base := path.Base(normalizePath(rel)) if base == "/" || base == "." || base == "" { base = "folder" } archivePath, downloadName, ctype, err := s.createArchiveTemp(uid, rel, base, format) if err != nil { return err } defer os.Remove(archivePath) f, err := os.Open(archivePath) if err != nil { return err } defer f.Close() st, err := f.Stat() if err != nil { return err } w.Header().Set("Accept-Ranges", "bytes") w.Header().Set("Content-Type", ctype) w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", downloadName)) http.ServeContent(w, r, downloadName, st.ModTime(), f) return nil } func (s *Server) createArchiveTemp(uid int64, rel, base, format string) (string, string, string, error) { switch format { case "rar": return s.createRarTemp(uid, rel, base) case "tar.gz": return s.createTarGzTemp(uid, rel, base) case "lz4": return s.createTarLz4Temp(uid, rel, base) default: return s.createZipTemp(uid, rel, base) } } func (s *Server) createBatchArchiveTemp(uid int64, paths []string, format string) (string, string, string, error) { workdir, err := os.MkdirTemp("", "filez-batch-*") if err != nil { return "", "", "", err } defer os.RemoveAll(workdir) root := filepath.Join(workdir, "selection") if err := os.MkdirAll(root, 0o755); err != nil { return "", "", "", err } used := map[string]int{} for _, raw := range paths { rel := normalizePath(raw) if rel == "/" { continue } meta, err := s.storage.Stat(uid, rel) if err != nil { return "", "", "", fmt.Errorf("invalid path: %s", rel) } base := path.Base(rel) if base == "." || base == "/" || base == "" { base = "item" } targetName := uniqueArchiveName(base, used) target := filepath.Join(root, targetName) if meta.IsDir { if err := os.MkdirAll(target, 0o755); err != nil { return "", "", "", err } } if err := s.materializePath(uid, rel, target); err != nil { return "", "", "", err } } baseName := "files" switch format { case "tar.gz": return createTarGzFromLocalDir(root, baseName) case "lz4": return createTarLz4FromLocalDir(root, baseName) case "rar": return createRarFromLocalDir(root, baseName) default: return createZipFromLocalDir(root, baseName) } } func (s *Server) createZipTemp(uid int64, rel, base string) (string, string, string, error) { tmp, err := os.CreateTemp("", "filez-*.zip") if err != nil { return "", "", "", err } defer tmp.Close() zw := zip.NewWriter(tmp) if err := s.addPathToZip(zw, uid, rel, base); err != nil { zw.Close() return "", "", "", err } if err := zw.Close(); err != nil { return "", "", "", err } return tmp.Name(), base + ".zip", "application/zip", nil } func (s *Server) addPathToZip(zw *zip.Writer, uid int64, rel string, zipBase string) error { meta, err := s.storage.Stat(uid, rel) if err != nil { return err } if meta.IsDir { if zipBase != "" { hdr := &zip.FileHeader{Name: strings.TrimPrefix(zipBase, "/") + "/", Method: zip.Store} hdr.Modified = meta.ModTime if _, err := zw.CreateHeader(hdr); err != nil { return err } } entries, err := s.storage.List(uid, rel) if err != nil { return err } for _, e := range entries { childRel := path.Join(normalizePath(rel), e.Name) childBase := path.Join(zipBase, e.Name) if err := s.addPathToZip(zw, uid, childRel, childBase); err != nil { return err } } return nil } rc, err := s.storage.OpenReadSeeker(uid, rel) if err != nil { return err } defer rc.Close() hdr := &zip.FileHeader{Name: strings.TrimPrefix(zipBase, "/"), Method: zip.Deflate} hdr.Modified = meta.ModTime writer, err := zw.CreateHeader(hdr) if err != nil { return err } _, err = io.Copy(writer, rc) return err } func (s *Server) createRarTemp(uid int64, rel, base string) (string, string, string, error) { if _, err := exec.LookPath("rar"); err != nil { return "", "", "", fmt.Errorf("rar format requires 'rar' binary on server; choose zip or tar.gz in settings") } workdir, err := os.MkdirTemp("", "filez-rar-*") if err != nil { return "", "", "", err } root := filepath.Join(workdir, base) if err := os.MkdirAll(root, 0o755); err != nil { os.RemoveAll(workdir) return "", "", "", err } if err := s.materializePath(uid, rel, root); err != nil { os.RemoveAll(workdir) return "", "", "", err } outTmp, err := os.CreateTemp("", "filez-*.rar") if err != nil { os.RemoveAll(workdir) return "", "", "", err } archivePath := outTmp.Name() outTmp.Close() cmd := exec.Command("rar", "a", "-idq", archivePath, root) if out, err := cmd.CombinedOutput(); err != nil { os.Remove(archivePath) os.RemoveAll(workdir) return "", "", "", fmt.Errorf("rar creation failed: %s", strings.TrimSpace(string(out))) } os.RemoveAll(workdir) return archivePath, base + ".rar", "application/vnd.rar", nil } func (s *Server) createTarGzTemp(uid int64, rel, base string) (string, string, string, error) { tmp, err := os.CreateTemp("", "filez-*.tar.gz") if err != nil { return "", "", "", err } defer tmp.Close() gzw := gzip.NewWriter(tmp) tw := tar.NewWriter(gzw) if err := s.addPathToTar(tw, uid, rel, base); err != nil { tw.Close() gzw.Close() return "", "", "", err } if err := tw.Close(); err != nil { gzw.Close() return "", "", "", err } if err := gzw.Close(); err != nil { return "", "", "", err } return tmp.Name(), base + ".tar.gz", "application/gzip", nil } func (s *Server) createTarLz4Temp(uid int64, rel, base string) (string, string, string, error) { if _, err := exec.LookPath("lz4"); err != nil { return "", "", "", fmt.Errorf("lz4 format requires 'lz4' binary on server; choose zip or tar.gz in settings") } tarPath, err := s.createTarTemp(uid, rel, base) if err != nil { return "", "", "", err } defer os.Remove(tarPath) outTmp, err := os.CreateTemp("", "filez-*.tar.lz4") if err != nil { return "", "", "", err } outPath := outTmp.Name() outTmp.Close() cmd := exec.Command("lz4", "-z", "-q", tarPath, outPath) if out, err := cmd.CombinedOutput(); err != nil { os.Remove(outPath) return "", "", "", fmt.Errorf("lz4 creation failed: %s", strings.TrimSpace(string(out))) } return outPath, base + ".tar.lz4", "application/x-lz4", nil } func (s *Server) createTarTemp(uid int64, rel, base string) (string, error) { tmp, err := os.CreateTemp("", "filez-*.tar") if err != nil { return "", err } defer tmp.Close() tw := tar.NewWriter(tmp) if err := s.addPathToTar(tw, uid, rel, base); err != nil { tw.Close() return "", err } if err := tw.Close(); err != nil { return "", err } return tmp.Name(), nil } func (s *Server) addPathToTar(tw *tar.Writer, uid int64, rel string, tarBase string) error { meta, err := s.storage.Stat(uid, rel) if err != nil { return err } if meta.IsDir { if tarBase != "" { hdr := &tar.Header{Name: strings.TrimPrefix(tarBase, "/") + "/", Mode: 0o755, ModTime: meta.ModTime, Typeflag: tar.TypeDir} if err := tw.WriteHeader(hdr); err != nil { return err } } entries, err := s.storage.List(uid, rel) if err != nil { return err } for _, e := range entries { childRel := path.Join(normalizePath(rel), e.Name) childBase := path.Join(tarBase, e.Name) if err := s.addPathToTar(tw, uid, childRel, childBase); err != nil { return err } } return nil } rc, err := s.storage.OpenReadSeeker(uid, rel) if err != nil { return err } defer rc.Close() hdr := &tar.Header{Name: strings.TrimPrefix(tarBase, "/"), Mode: 0o644, Size: meta.Size, ModTime: meta.ModTime, Typeflag: tar.TypeReg} if err := tw.WriteHeader(hdr); err != nil { return err } _, err = io.Copy(tw, rc) return err } func (s *Server) materializePath(uid int64, rel string, localPath string) error { meta, err := s.storage.Stat(uid, rel) if err != nil { return err } if meta.IsDir { entries, err := s.storage.List(uid, rel) if err != nil { return err } for _, e := range entries { childRel := path.Join(normalizePath(rel), e.Name) childLocal := filepath.Join(localPath, e.Name) if e.IsDir { if err := os.MkdirAll(childLocal, 0o755); err != nil { return err } } if err := s.materializePath(uid, childRel, childLocal); err != nil { return err } } return nil } rc, err := s.storage.OpenReadSeeker(uid, rel) if err != nil { return err } defer rc.Close() out, err := os.Create(localPath) if err != nil { return err } defer out.Close() _, err = io.Copy(out, rc) return err } func (s *Server) copyPath(uid int64, src, dst string) error { meta, err := s.storage.Stat(uid, src) if err != nil { return err } if meta.IsDir { if err := s.storage.Mkdir(uid, dst); err != nil { return err } entries, err := s.storage.List(uid, src) if err != nil { return err } for _, e := range entries { childSrc := path.Join(src, e.Name) childDst := path.Join(dst, e.Name) if err := s.copyPath(uid, childSrc, childDst); err != nil { return err } } return nil } rc, err := s.storage.OpenReadSeeker(uid, src) if err != nil { return err } defer rc.Close() data, err := io.ReadAll(rc) if err != nil { return err } return s.storage.SaveBytes(uid, dst, data) } func (s *Server) nextMoveTarget(uid int64, dst string) (string, error) { norm := normalizePath(dst) if _, err := s.storage.Stat(uid, norm); err != nil { return norm, nil } ext := path.Ext(norm) base := strings.TrimSuffix(path.Base(norm), ext) dir := path.Dir(norm) for i := 2; i <= 9999; i++ { candidate := path.Join(dir, fmt.Sprintf("%s-%d%s", base, i, ext)) if _, err := s.storage.Stat(uid, candidate); err != nil { return candidate, nil } } return "", fmt.Errorf("cannot allocate destination name") } func (s *Server) moveTags(uid int64, src, dst string) { _, _ = s.db.Exec(`UPDATE file_tags SET rel_path = ? WHERE user_id = ? AND rel_path = ?`, dst, uid, src) prefixFrom := strings.TrimSuffix(src, "/") + "/" prefixTo := strings.TrimSuffix(dst, "/") + "/" rows, err := s.db.Query(`SELECT rel_path, tag FROM file_tags WHERE user_id = ? AND rel_path LIKE ?`, uid, prefixFrom+"%") if err != nil { return } defer rows.Close() type rowItem struct{ p, t string } items := make([]rowItem, 0) for rows.Next() { var rp, tg string if rows.Scan(&rp, &tg) == nil { items = append(items, rowItem{p: rp, t: tg}) } } for _, it := range items { next := prefixTo + strings.TrimPrefix(it.p, prefixFrom) _, _ = s.db.Exec(`DELETE FROM file_tags WHERE user_id = ? AND rel_path = ? AND tag = ?`, uid, it.p, it.t) _, _ = s.db.Exec(`INSERT OR IGNORE INTO file_tags(user_id, rel_path, tag) VALUES (?, ?, ?)`, uid, next, it.t) } } func uniqueArchiveName(base string, used map[string]int) string { base = strings.TrimSpace(base) if base == "" { base = "item" } count := used[base] used[base] = count + 1 if count == 0 { return base } ext := path.Ext(base) name := strings.TrimSuffix(base, ext) return fmt.Sprintf("%s-%d%s", name, count+1, ext) } func createZipFromLocalDir(root, baseName string) (string, string, string, error) { tmp, err := os.CreateTemp("", "filez-batch-*.zip") if err != nil { return "", "", "", err } defer tmp.Close() zw := zip.NewWriter(tmp) err = filepath.WalkDir(root, func(p string, d fs.DirEntry, walkErr error) error { if walkErr != nil { return walkErr } rel, err := filepath.Rel(root, p) if err != nil { return err } if rel == "." { return nil } rel = filepath.ToSlash(rel) info, err := d.Info() if err != nil { return err } if d.IsDir() { hdr := &zip.FileHeader{Name: rel + "/", Method: zip.Store} hdr.Modified = info.ModTime() _, err = zw.CreateHeader(hdr) return err } hdr, err := zip.FileInfoHeader(info) if err != nil { return err } hdr.Name = rel hdr.Method = zip.Deflate w, err := zw.CreateHeader(hdr) if err != nil { return err } f, err := os.Open(p) if err != nil { return err } defer f.Close() _, err = io.Copy(w, f) return err }) if err != nil { zw.Close() return "", "", "", err } if err := zw.Close(); err != nil { return "", "", "", err } return tmp.Name(), baseName + ".zip", "application/zip", nil } func createTarGzFromLocalDir(root, baseName string) (string, string, string, error) { tmp, err := os.CreateTemp("", "filez-batch-*.tar.gz") if err != nil { return "", "", "", err } defer tmp.Close() gzw := gzip.NewWriter(tmp) tw := tar.NewWriter(gzw) err = filepath.WalkDir(root, func(p string, d fs.DirEntry, walkErr error) error { if walkErr != nil { return walkErr } rel, err := filepath.Rel(root, p) if err != nil { return err } if rel == "." { return nil } rel = filepath.ToSlash(rel) info, err := d.Info() if err != nil { return err } hdr, err := tar.FileInfoHeader(info, "") if err != nil { return err } if d.IsDir() { hdr.Name = rel + "/" } else { hdr.Name = rel } if err := tw.WriteHeader(hdr); err != nil { return err } if d.IsDir() { return nil } f, err := os.Open(p) if err != nil { return err } defer f.Close() _, err = io.Copy(tw, f) return err }) if err != nil { tw.Close() gzw.Close() return "", "", "", err } if err := tw.Close(); err != nil { gzw.Close() return "", "", "", err } if err := gzw.Close(); err != nil { return "", "", "", err } return tmp.Name(), baseName + ".tar.gz", "application/gzip", nil } func createTarFromLocalDir(root string) (string, error) { tmp, err := os.CreateTemp("", "filez-batch-*.tar") if err != nil { return "", err } defer tmp.Close() tw := tar.NewWriter(tmp) err = filepath.WalkDir(root, func(p string, d fs.DirEntry, walkErr error) error { if walkErr != nil { return walkErr } rel, err := filepath.Rel(root, p) if err != nil { return err } if rel == "." { return nil } rel = filepath.ToSlash(rel) info, err := d.Info() if err != nil { return err } hdr, err := tar.FileInfoHeader(info, "") if err != nil { return err } if d.IsDir() { hdr.Name = rel + "/" } else { hdr.Name = rel } if err := tw.WriteHeader(hdr); err != nil { return err } if d.IsDir() { return nil } f, err := os.Open(p) if err != nil { return err } defer f.Close() _, err = io.Copy(tw, f) return err }) if err != nil { tw.Close() return "", err } if err := tw.Close(); err != nil { return "", err } return tmp.Name(), nil } func createTarLz4FromLocalDir(root, baseName string) (string, string, string, error) { if _, err := exec.LookPath("lz4"); err != nil { return "", "", "", fmt.Errorf("lz4 format requires 'lz4' binary on server; choose zip or tar.gz in settings") } tarPath, err := createTarFromLocalDir(root) if err != nil { return "", "", "", err } defer os.Remove(tarPath) outTmp, err := os.CreateTemp("", "filez-batch-*.tar.lz4") if err != nil { return "", "", "", err } outPath := outTmp.Name() outTmp.Close() cmd := exec.Command("lz4", "-z", "-q", tarPath, outPath) if out, err := cmd.CombinedOutput(); err != nil { os.Remove(outPath) return "", "", "", fmt.Errorf("lz4 creation failed: %s", strings.TrimSpace(string(out))) } return outPath, baseName + ".tar.lz4", "application/x-lz4", nil } func createRarFromLocalDir(root, baseName string) (string, string, string, error) { if _, err := exec.LookPath("rar"); err != nil { return "", "", "", fmt.Errorf("rar format requires 'rar' binary on server; choose zip or tar.gz in settings") } outTmp, err := os.CreateTemp("", "filez-batch-*.rar") if err != nil { return "", "", "", err } outPath := outTmp.Name() outTmp.Close() cmd := exec.Command("rar", "a", "-idq", outPath, root) if out, err := cmd.CombinedOutput(); err != nil { os.Remove(outPath) return "", "", "", fmt.Errorf("rar creation failed: %s", strings.TrimSpace(string(out))) } return outPath, baseName + ".rar", "application/vnd.rar", nil } func (s *Server) handleDelete(w http.ResponseWriter, r *http.Request) { uid := userIDFromContext(r.Context()) rel := r.URL.Query().Get("path") norm := normalizePath(rel) log.Printf("file.delete user_id=%d path=%q", uid, norm) if err := s.storage.Delete(uid, norm); err != nil { writeErr(w, http.StatusBadRequest, err.Error()) return } _, _ = s.db.Exec(`DELETE FROM file_tags WHERE user_id = ? AND (rel_path = ? OR rel_path LIKE ?)`, uid, norm, strings.TrimSuffix(norm, "/")+"/%") writeJSON(w, http.StatusOK, map[string]string{"status": "deleted"}) } func (s *Server) handleCreateFolder(w http.ResponseWriter, r *http.Request) { type payload struct { Path string `json:"path"` Name string `json:"name"` } var in payload if err := json.NewDecoder(r.Body).Decode(&in); err != nil { writeErr(w, http.StatusBadRequest, "invalid payload") return } uid := userIDFromContext(r.Context()) rel := path.Join(normalizePath(in.Path), path.Base(strings.TrimSpace(in.Name))) log.Printf("file.mkdir user_id=%d path=%q", uid, normalizePath(rel)) if rel == "/" || rel == "." { writeErr(w, http.StatusBadRequest, "invalid folder name") return } if err := s.storage.Mkdir(uid, rel); err != nil { writeErr(w, http.StatusBadRequest, err.Error()) return } writeJSON(w, http.StatusCreated, map[string]string{"status": "created"}) } type shareInput struct { Path string `json:"path"` ExpiresMinutes int `json:"expiresMinutes"` MaxDownloads *int `json:"maxDownloads"` AllowPreview bool `json:"allowPreview"` PreferredInline bool `json:"preferredInline"` } func (s *Server) handleCreateShareLink(w http.ResponseWriter, r *http.Request) { uid := userIDFromContext(r.Context()) var in shareInput if err := json.NewDecoder(r.Body).Decode(&in); err != nil { writeErr(w, http.StatusBadRequest, "invalid payload") return } rel := normalizePath(in.Path) log.Printf("file.share.create user_id=%d path=%q", uid, rel) meta, err := s.storage.Stat(uid, rel) if err != nil { writeErr(w, http.StatusBadRequest, "file not found") return } if meta.IsDir { writeErr(w, http.StatusBadRequest, "cannot create share for folder") return } ttl := s.config.ShareDefaultTTL if in.ExpiresMinutes > 0 { ttl = time.Duration(in.ExpiresMinutes) * time.Minute } if ttl < 5*time.Minute { ttl = 5 * time.Minute } if ttl > 30*24*time.Hour { ttl = 30 * 24 * time.Hour } token, err := randomToken() if err != nil { writeErr(w, http.StatusInternalServerError, "failed to create token") return } var maxDownloads any if in.MaxDownloads != nil && *in.MaxDownloads > 0 { maxDownloads = *in.MaxDownloads } expiresAt := time.Now().Add(ttl) _, err = s.db.Exec(`INSERT INTO share_links(user_id, rel_path, token_hash, expires_at, max_downloads) VALUES (?, ?, ?, ?, ?)`, uid, rel, hashToken(token), expiresAt, maxDownloads, ) if err != nil { writeErr(w, http.StatusInternalServerError, "failed to create share") return } shareURL := fmt.Sprintf("%s://%s/api/share/%s", schemeOf(r), r.Host, token) writeJSON(w, http.StatusCreated, map[string]any{ "url": shareURL, "token": token, "path": rel, "expiresAt": expiresAt, }) } func (s *Server) handleSharedDownload(w http.ResponseWriter, r *http.Request) { token := mux.Vars(r)["token"] if strings.TrimSpace(token) == "" { writeErr(w, http.StatusBadRequest, "missing token") return } var uid int64 var rel string var expiresAt time.Time var revokedAt sql.NullTime var maxDownloads sql.NullInt64 var downloadCount int64 err := s.db.QueryRow(`SELECT user_id, rel_path, expires_at, revoked_at, max_downloads, download_count FROM share_links WHERE token_hash = ?`, hashToken(token)). Scan(&uid, &rel, &expiresAt, &revokedAt, &maxDownloads, &downloadCount) if err != nil { writeErr(w, http.StatusNotFound, "share link not found") return } if revokedAt.Valid || expiresAt.Before(time.Now()) { writeErr(w, http.StatusGone, "share link expired") return } if maxDownloads.Valid && downloadCount >= maxDownloads.Int64 { writeErr(w, http.StatusGone, "share link download limit reached") return } if _, err := s.db.Exec(`UPDATE share_links SET download_count = download_count + 1 WHERE token_hash = ?`, hashToken(token)); err != nil { writeErr(w, http.StatusInternalServerError, "failed to track download") return } log.Printf("file.share.download user_id=%d path=%q ip=%q", uid, normalizePath(rel), clientIP(r)) if err := s.serveFile(w, r, uid, rel, true, ""); err != nil { writeErr(w, http.StatusBadRequest, err.Error()) } } func (s *Server) handleAdminMe(w http.ResponseWriter, _ *http.Request) { writeJSON(w, http.StatusOK, map[string]string{"login": s.config.AdminLogin}) } func (s *Server) handleAdminUsersList(w http.ResponseWriter, _ *http.Request) { rows, err := s.db.Query(`SELECT id, email, theme, color_mode, archive_format FROM users ORDER BY id ASC`) if err != nil { writeErr(w, http.StatusInternalServerError, "failed to load users") return } defer rows.Close() users := make([]User, 0) for rows.Next() { var u User if err := rows.Scan(&u.ID, &u.Username, &u.Theme, &u.ColorMode, &u.Archive); err != nil { continue } u.Theme = normalizeTheme(u.Theme) u.ColorMode = normalizeColorMode(u.ColorMode) users = append(users, u) } writeJSON(w, http.StatusOK, map[string]any{"users": users}) } type adminCreateUserInput struct { Username string `json:"username"` Password string `json:"password"` Theme string `json:"theme"` ColorMode string `json:"colorMode"` } func (s *Server) handleAdminUserCreate(w http.ResponseWriter, r *http.Request) { var in adminCreateUserInput if err := json.NewDecoder(r.Body).Decode(&in); err != nil { writeErr(w, http.StatusBadRequest, "invalid payload") return } in.Theme = normalizeTheme(in.Theme) in.ColorMode = normalizeColorMode(in.ColorMode) user, err := s.createUser(in.Username, in.Password, in.Theme, in.ColorMode) if err != nil { if strings.Contains(strings.ToLower(err.Error()), "exists") { writeErr(w, http.StatusConflict, err.Error()) return } writeErr(w, http.StatusBadRequest, err.Error()) return } writeJSON(w, http.StatusCreated, user) log.Printf("admin.user.create user_id=%d username=%q", user.ID, user.Username) } func (s *Server) handleAdminUserDelete(w http.ResponseWriter, r *http.Request) { idStr := mux.Vars(r)["id"] id, err := strconv.ParseInt(idStr, 10, 64) if err != nil || id <= 0 { writeErr(w, http.StatusBadRequest, "invalid user id") return } if _, err := s.db.Exec(`DELETE FROM users WHERE id = ?`, id); err != nil { writeErr(w, http.StatusInternalServerError, "failed to delete user") return } _, _ = s.db.Exec(`DELETE FROM refresh_tokens WHERE user_id = ?`, id) _, _ = s.db.Exec(`UPDATE share_links SET revoked_at = CURRENT_TIMESTAMP WHERE user_id = ?`, id) _ = s.storage.Delete(id, "/") writeJSON(w, http.StatusOK, map[string]string{"status": "deleted"}) log.Printf("admin.user.delete user_id=%d", id) } func (s *Server) recoverMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() { if v := recover(); v != nil { log.Printf("panic recovered path=%q err=%v", r.URL.Path, v) writeErr(w, http.StatusInternalServerError, "internal server error") } }() next.ServeHTTP(w, r) }) } func (s *Server) securityHeadersMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()") w.Header().Set("Content-Security-Policy", "default-src 'self'; connect-src 'self'; img-src 'self' data: blob:; style-src 'self' 'unsafe-inline'; script-src 'self'; frame-ancestors 'none'; base-uri 'self'; form-action 'self'") if r.TLS != nil { w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") } next.ServeHTTP(w, r) }) } func (s *Server) bodyLimitMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if strings.HasPrefix(r.URL.Path, "/api/files/upload") { next.ServeHTTP(w, r) return } if strings.HasPrefix(r.URL.Path, "/api/") { r.Body = http.MaxBytesReader(w, r.Body, s.config.MaxBodyBytes) } next.ServeHTTP(w, r) }) } func (s *Server) rateLimitMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !strings.HasPrefix(r.URL.Path, "/api/") { next.ServeHTTP(w, r) return } ip := clientIP(r) limit := s.config.RateLimitPerMin if strings.Contains(r.URL.Path, "/auth/") || strings.Contains(r.URL.Path, "/admin/login") { limit = s.config.AuthRateLimitPerMin } if !s.limiter.allow(ip+":"+r.URL.Path, limit, time.Now()) { writeErr(w, http.StatusTooManyRequests, "rate limit exceeded") return } next.ServeHTTP(w, r) }) } func (s *Server) requestLogMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK} next.ServeHTTP(rec, r) if strings.HasPrefix(r.URL.Path, "/api/") { log.Printf( "api.request method=%s path=%q status=%d duration_ms=%d ip=%q ua=%q", r.Method, r.URL.Path, rec.status, time.Since(start).Milliseconds(), clientIP(r), limitString(r.UserAgent(), 120), ) } }) } func (s *Server) createUser(username, password, theme, colorMode string) (User, error) { username = strings.ToLower(strings.TrimSpace(username)) if len(username) < 3 || len(username) > 32 { return User{}, fmt.Errorf("username must be 3-32 characters") } for _, ch := range username { if (ch < 'a' || ch > 'z') && (ch < '0' || ch > '9') && ch != '_' && ch != '-' && ch != '.' { return User{}, fmt.Errorf("username can contain only a-z, 0-9, dot, dash, underscore") } } if len(password) < 10 { return User{}, fmt.Errorf("password must be at least 10 characters") } hash, err := hashPasswordArgon2ID(password) if err != nil { return User{}, fmt.Errorf("failed to hash password") } res, err := s.db.Exec(`INSERT INTO users(email, password_hash, theme, color_mode, archive_format) VALUES (?, ?, ?, ?, ?)`, username, hash, normalizeTheme(theme), normalizeColorMode(colorMode), "zip") if err != nil { if strings.Contains(strings.ToLower(err.Error()), "unique") { return User{}, fmt.Errorf("account already exists") } return User{}, err } id, _ := res.LastInsertId() if err := s.storage.Mkdir(id, "/"); err != nil { return User{}, fmt.Errorf("failed to provision user storage: %w", err) } return User{ID: id, Username: username, Theme: normalizeTheme(theme), ColorMode: normalizeColorMode(colorMode), Archive: "zip"}, nil } func (s *Server) findUserWithHash(username string) (User, string, error) { username = strings.ToLower(strings.TrimSpace(username)) var u User var hash string err := s.db.QueryRow(`SELECT id, email, password_hash, theme, color_mode, archive_format FROM users WHERE email = ?`, username). Scan(&u.ID, &u.Username, &hash, &u.Theme, &u.ColorMode, &u.Archive) if err != nil { return User{}, "", err } u.Theme = normalizeTheme(u.Theme) u.ColorMode = normalizeColorMode(u.ColorMode) return u, hash, nil } func (s *Server) findUser(id int64) (User, error) { var u User err := s.db.QueryRow(`SELECT id, email, theme, color_mode, archive_format FROM users WHERE id = ?`, id). Scan(&u.ID, &u.Username, &u.Theme, &u.ColorMode, &u.Archive) u.Theme = normalizeTheme(u.Theme) u.ColorMode = normalizeColorMode(u.ColorMode) return u, err } func (s *Server) issueUserSession(w http.ResponseWriter, r *http.Request, userID int64) error { now := time.Now() claims := AccessClaims{ UserID: userID, RegisteredClaims: jwt.RegisteredClaims{ Subject: strconv.FormatInt(userID, 10), IssuedAt: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(now.Add(s.config.AccessTTL)), }, } access, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(s.config.JWTSecret)) if err != nil { return err } refresh, err := randomToken() if err != nil { return err } if _, err := s.db.Exec(`INSERT INTO refresh_tokens(user_id, token_hash, expires_at, user_agent, ip) VALUES (?, ?, ?, ?, ?)`, userID, hashToken(refresh), now.Add(s.config.RefreshTTL), limitString(r.UserAgent(), 255), limitString(clientIP(r), 64), ); err != nil { return err } setCookie(w, "access_token", access, int(s.config.AccessTTL.Seconds()), s.config.CookieSecure) setCookie(w, "refresh_token", refresh, int(s.config.RefreshTTL.Seconds()), s.config.CookieSecure) return nil } func (s *Server) consumeRefreshToken(token string) (int64, error) { var id int64 var uid int64 var expiresAt time.Time var revokedAt sql.NullTime err := s.db.QueryRow(`SELECT id, user_id, expires_at, revoked_at FROM refresh_tokens WHERE token_hash = ?`, hashToken(token)). Scan(&id, &uid, &expiresAt, &revokedAt) if err != nil { return 0, err } if revokedAt.Valid || expiresAt.Before(time.Now()) { return 0, fmt.Errorf("refresh token expired or revoked") } if _, err := s.db.Exec(`UPDATE refresh_tokens SET revoked_at = CURRENT_TIMESTAMP WHERE id = ?`, id); err != nil { return 0, err } return uid, nil } func (s *Server) authMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("access_token") if err != nil || cookie.Value == "" { writeErr(w, http.StatusUnauthorized, "missing access token") return } claims := &AccessClaims{} tkn, err := jwt.ParseWithClaims(cookie.Value, claims, func(token *jwt.Token) (any, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method") } return []byte(s.config.JWTSecret), nil }) if err != nil || !tkn.Valid { writeErr(w, http.StatusUnauthorized, "invalid access token") return } next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), userIDKey, claims.UserID))) }) } func (s *Server) adminMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("admin_token") if err != nil || cookie.Value == "" { writeErr(w, http.StatusUnauthorized, "missing admin token") return } claims := &AdminClaims{} tkn, err := jwt.ParseWithClaims(cookie.Value, claims, func(token *jwt.Token) (any, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method") } return []byte(s.config.JWTSecret), nil }) if err != nil || !tkn.Valid || claims.Role != "admin" || claims.Login != s.config.AdminLogin { writeErr(w, http.StatusUnauthorized, "invalid admin session") return } next.ServeHTTP(w, r) }) } func (s *Server) corsMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := strings.TrimSpace(r.Header.Get("Origin")) if origin != "" && s.config.CORSOrigin != "" && origin != s.config.CORSOrigin { writeErr(w, http.StatusForbidden, "cors origin denied") return } if s.config.CORSOrigin != "" { w.Header().Set("Access-Control-Allow-Origin", s.config.CORSOrigin) w.Header().Set("Vary", "Origin") } w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Access-Control-Allow-Headers", "Content-Type") w.Header().Set("Access-Control-Allow-Methods", "GET,POST,DELETE,OPTIONS,HEAD") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) return } next.ServeHTTP(w, r) }) } func (s *Server) hostGuardMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if strings.TrimSpace(s.config.AllowedHost) == "" { next.ServeHTTP(w, r) return } host := strings.ToLower(hostOnly(r.Host)) if host == strings.ToLower(s.config.AllowedHost) { next.ServeHTTP(w, r) return } if strings.HasPrefix(r.URL.Path, "/api/") { writeErr(w, http.StatusForbidden, fmt.Sprintf("access denied: host must be %s", s.config.AllowedHost)) return } w.Header().Set("Content-Type", "text/html; charset=utf-8") w.WriteHeader(http.StatusForbidden) _, _ = w.Write([]byte(fmt.Sprintf("

Access denied

This service is only available at %s.

", s.config.AllowedHost))) }) } type LocalStorage struct { root string } func (l *LocalStorage) userRoot(userID int64) string { return filepath.Join(l.root, strconv.FormatInt(userID, 10)) } func (l *LocalStorage) fullPath(userID int64, rel string) (string, error) { root := l.userRoot(userID) if err := os.MkdirAll(root, 0o755); err != nil { return "", err } clean := filepath.FromSlash(strings.TrimPrefix(normalizePath(rel), "/")) full := filepath.Clean(filepath.Join(root, clean)) if full != root && !strings.HasPrefix(full, root+string(os.PathSeparator)) { return "", fmt.Errorf("invalid path") } return full, nil } func (l *LocalStorage) List(userID int64, rel string) ([]FileEntry, error) { full, err := l.fullPath(userID, rel) if err != nil { return nil, err } entries, err := os.ReadDir(full) if err != nil { if errors.Is(err, os.ErrNotExist) { return []FileEntry{}, nil } return nil, err } out := make([]FileEntry, 0, len(entries)) for _, e := range entries { info, err := e.Info() if err != nil { continue } out = append(out, FileEntry{ Name: e.Name(), Path: path.Join(normalizePath(rel), e.Name()), IsDir: e.IsDir(), Size: info.Size(), ModTime: info.ModTime(), }) } return out, nil } func (l *LocalStorage) Mkdir(userID int64, rel string) error { full, err := l.fullPath(userID, rel) if err != nil { return err } return os.MkdirAll(full, 0o755) } func (l *LocalStorage) Save(userID int64, rel string, src multipart.File) error { full, err := l.fullPath(userID, rel) if err != nil { return err } if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil { return err } dst, err := os.Create(full) if err != nil { return err } defer dst.Close() _, err = io.Copy(dst, src) return err } func (l *LocalStorage) SaveBytes(userID int64, rel string, data []byte) error { full, err := l.fullPath(userID, rel) if err != nil { return err } if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil { return err } return os.WriteFile(full, data, 0o644) } func (l *LocalStorage) Delete(userID int64, rel string) error { if normalizePath(rel) == "/" { full, err := l.fullPath(userID, rel) if err != nil { return err } return os.RemoveAll(full) } full, err := l.fullPath(userID, rel) if err != nil { return err } return os.RemoveAll(full) } func (l *LocalStorage) Stat(userID int64, rel string) (FileMeta, error) { full, err := l.fullPath(userID, rel) if err != nil { return FileMeta{}, err } st, err := os.Stat(full) if err != nil { return FileMeta{}, err } return FileMeta{Name: st.Name(), Size: st.Size(), ModTime: st.ModTime(), IsDir: st.IsDir()}, nil } func (l *LocalStorage) OpenReadSeeker(userID int64, rel string) (ReadSeekCloser, error) { full, err := l.fullPath(userID, rel) if err != nil { return nil, err } return os.Open(full) } type SMBStorage struct { cfg Config } type smbConn struct { conn net.Conn session *smb2.Session share *smb2.Share } func (c *smbConn) Close() { if c.share != nil { _ = c.share.Umount() } if c.session != nil { _ = c.session.Logoff() } if c.conn != nil { _ = c.conn.Close() } } type smbReadSeekCloser struct { file *smb2.File conn *smbConn } func (s *smbReadSeekCloser) Read(p []byte) (int, error) { return s.file.Read(p) } func (s *smbReadSeekCloser) Seek(offset int64, whence int) (int64, error) { return s.file.Seek(offset, whence) } func (s *smbReadSeekCloser) Close() error { _ = s.file.Close() s.conn.Close() return nil } func (smbs *SMBStorage) openConnection() (*smbConn, error) { conn, err := net.DialTimeout("tcp", smbs.cfg.SMBHost, smbs.cfg.SMBConnectTimout) if err != nil { return nil, err } dialer := &smb2.Dialer{Initiator: &smb2.NTLMInitiator{ User: smbs.cfg.SMBUser, Password: smbs.cfg.SMBPass, Domain: smbs.cfg.SMBDomain, }} session, err := dialer.Dial(conn) if err != nil { _ = conn.Close() return nil, err } share, err := session.Mount(smbs.cfg.SMBShare) if err != nil { _ = session.Logoff() _ = conn.Close() return nil, err } return &smbConn{conn: conn, session: session, share: share}, nil } func (smbs *SMBStorage) withShare(fn func(share *smb2.Share) error) error { conn, err := smbs.openConnection() if err != nil { return err } defer conn.Close() return fn(conn.share) } func (smbs *SMBStorage) userPath(userID int64, rel string) string { base := path.Join(smbs.cfg.SMBBasePath, strconv.FormatInt(userID, 10)) clean := strings.TrimPrefix(normalizePath(rel), "/") if clean == "" { return base } return path.Join(base, clean) } func (smbs *SMBStorage) List(userID int64, rel string) ([]FileEntry, error) { out := make([]FileEntry, 0) err := smbs.withShare(func(share *smb2.Share) error { target := smbs.userPath(userID, rel) if err := share.MkdirAll(target, 0o755); err != nil { return err } entries, err := share.ReadDir(target) if err != nil { return err } for _, e := range entries { out = append(out, FileEntry{ Name: e.Name(), Path: path.Join(normalizePath(rel), e.Name()), IsDir: e.IsDir(), Size: e.Size(), ModTime: e.ModTime(), }) } return nil }) return out, err } func (smbs *SMBStorage) Mkdir(userID int64, rel string) error { return smbs.withShare(func(share *smb2.Share) error { return share.MkdirAll(smbs.userPath(userID, rel), 0o755) }) } func (smbs *SMBStorage) Save(userID int64, rel string, src multipart.File) error { return smbs.withShare(func(share *smb2.Share) error { target := smbs.userPath(userID, rel) if err := share.MkdirAll(path.Dir(target), 0o755); err != nil { return err } f, err := share.Create(target) if err != nil { return err } defer f.Close() _, err = io.Copy(f, src) return err }) } func (smbs *SMBStorage) SaveBytes(userID int64, rel string, data []byte) error { return smbs.withShare(func(share *smb2.Share) error { target := smbs.userPath(userID, rel) if err := share.MkdirAll(path.Dir(target), 0o755); err != nil { return err } f, err := share.Create(target) if err != nil { return err } defer f.Close() _, err = f.Write(data) return err }) } func (smbs *SMBStorage) Delete(userID int64, rel string) error { return smbs.withShare(func(share *smb2.Share) error { target := smbs.userPath(userID, rel) if normalizePath(rel) == "/" { return share.RemoveAll(target) } if err := share.Remove(target); err == nil { return nil } return share.RemoveAll(target) }) } func (smbs *SMBStorage) Stat(userID int64, rel string) (FileMeta, error) { var out FileMeta err := smbs.withShare(func(share *smb2.Share) error { st, err := share.Stat(smbs.userPath(userID, rel)) if err != nil { return err } out = FileMeta{Name: st.Name(), Size: st.Size(), ModTime: st.ModTime(), IsDir: st.IsDir()} return nil }) return out, err } func (smbs *SMBStorage) OpenReadSeeker(userID int64, rel string) (ReadSeekCloser, error) { conn, err := smbs.openConnection() if err != nil { return nil, err } f, err := conn.share.Open(smbs.userPath(userID, rel)) if err != nil { conn.Close() return nil, err } return &smbReadSeekCloser{file: f, conn: conn}, nil } func normalizePath(rel string) string { clean := path.Clean("/" + strings.TrimSpace(rel)) if clean == "." { return "/" } return clean } func normalizeUploadRelativePath(v string) (string, error) { v = strings.ReplaceAll(strings.TrimSpace(v), "\\", "/") v = strings.TrimPrefix(v, "/") v = path.Clean(v) if v == "." || v == "" { return "", fmt.Errorf("invalid upload path") } if strings.HasPrefix(v, "../") || strings.Contains(v, "/../") || strings.HasPrefix(v, "/") { return "", fmt.Errorf("invalid upload path") } return v, nil } func normalizeTheme(v string) string { v = strings.TrimSpace(strings.ToLower(v)) switch v { case "dracula", "nord", "monokai", "solarized", "github": return v case "material", "glass", "desktop", "auto": return "dracula" default: return "dracula" } } func normalizeColorMode(v string) string { v = strings.TrimSpace(strings.ToLower(v)) switch v { case "light", "dark", "auto": return v default: return "auto" } } func normalizeArchiveFormat(v string) string { v = strings.TrimSpace(strings.ToLower(v)) switch v { case "zip", "rar", "tar.gz", "lz4": return v default: return "" } } func normalizeTag(v string) (string, bool) { v = strings.ToLower(strings.TrimSpace(v)) if len(v) < 1 || len(v) > 24 { return "", false } for _, ch := range v { if (ch < 'a' || ch > 'z') && (ch < '0' || ch > '9') && ch != '-' && ch != '_' { return "", false } } return v, true } func isMarkdownExtension(ext string) bool { ext = strings.ToLower(strings.TrimPrefix(ext, ".")) switch ext { case "md", "markdown": return true default: return false } } func (s *Server) fileTagsForPaths(uid int64, paths []string) (map[string][]string, error) { out := make(map[string][]string, len(paths)) if len(paths) == 0 { return out, nil } args := make([]any, 0, len(paths)+1) args = append(args, uid) pl := make([]string, len(paths)) for i, p := range paths { norm := normalizePath(p) pl[i] = "?" args = append(args, norm) } q := `SELECT rel_path, tag FROM file_tags WHERE user_id = ? AND rel_path IN (` + strings.Join(pl, ",") + `) ORDER BY rel_path, tag` rows, err := s.db.Query(q, args...) if err != nil { return out, err } defer rows.Close() for rows.Next() { var rel string var tag string if rows.Scan(&rel, &tag) == nil { out[rel] = append(out[rel], tag) } } return out, nil } func setCookie(w http.ResponseWriter, name, value string, maxAge int, secure bool) { http.SetCookie(w, &http.Cookie{ Name: name, Value: value, Path: "/", HttpOnly: true, Secure: secure, SameSite: http.SameSiteLaxMode, MaxAge: maxAge, }) } func clearCookie(w http.ResponseWriter, name string, secure bool) { http.SetCookie(w, &http.Cookie{ Name: name, Value: "", Path: "/", HttpOnly: true, Secure: secure, SameSite: http.SameSiteLaxMode, MaxAge: -1, }) } func randomToken() (string, error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { return "", err } return base64.RawURLEncoding.EncodeToString(b), nil } func hashToken(token string) string { h := sha256.Sum256([]byte(token)) return base64.RawURLEncoding.EncodeToString(h[:]) } func subtleConstantTimeEq(a, b string) int { if len(a) != len(b) { return 0 } var out byte for i := 0; i < len(a); i++ { out |= a[i] ^ b[i] } if out == 0 { return 1 } return 0 } func maybeRunHashCommand() bool { if len(os.Args) < 2 { return false } if os.Args[1] != "hash-admin" { return false } if len(os.Args) < 3 { fmt.Println("usage: go run . hash-admin ") os.Exit(1) } h, err := hashPasswordArgon2ID(os.Args[2]) if err != nil { fmt.Println("failed to hash password:", err) os.Exit(1) } fmt.Println(h) return true } func hashPasswordArgon2ID(password string) (string, error) { salt := make([]byte, 16) if _, err := rand.Read(salt); err != nil { return "", err } timeCost := uint32(3) memoryCost := uint32(64 * 1024) threads := uint8(2) keyLen := uint32(32) derived := argon2.IDKey([]byte(password), salt, timeCost, memoryCost, threads, keyLen) saltB64 := base64.RawStdEncoding.EncodeToString(salt) hashB64 := base64.RawStdEncoding.EncodeToString(derived) encoded := fmt.Sprintf("argon2id:v=19,m=%d,t=%d,p=%d:%s:%s", memoryCost, timeCost, threads, saltB64, hashB64) return encoded, nil } func verifyPasswordHash(storedHash, password string) bool { if strings.HasPrefix(storedHash, "argon2id:") { parts := strings.Split(storedHash, ":") if len(parts) != 4 { return false } var mem, timeCost uint32 var threads uint8 if _, err := fmt.Sscanf(parts[1], "v=19,m=%d,t=%d,p=%d", &mem, &timeCost, &threads); err != nil { return false } salt, err := base64.RawStdEncoding.DecodeString(parts[2]) if err != nil { return false } expected, err := base64.RawStdEncoding.DecodeString(parts[3]) if err != nil { return false } actual := argon2.IDKey([]byte(password), salt, timeCost, mem, threads, uint32(len(expected))) return subtle.ConstantTimeCompare(expected, actual) == 1 } if strings.HasPrefix(storedHash, "argon2id$") { parts := strings.Split(storedHash, "$") if len(parts) == 6 { var mem, timeCost uint32 var threads uint8 if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &mem, &timeCost, &threads); err == nil { salt, errSalt := base64.RawStdEncoding.DecodeString(parts[4]) expected, errExpected := base64.RawStdEncoding.DecodeString(parts[5]) if errSalt == nil && errExpected == nil { actual := argon2.IDKey([]byte(password), salt, timeCost, mem, threads, uint32(len(expected))) return subtle.ConstantTimeCompare(expected, actual) == 1 } } } } if strings.HasPrefix(storedHash, "sha256:") { expected := strings.TrimPrefix(storedHash, "sha256:") sum := sha256.Sum256([]byte(password)) actual := hex.EncodeToString(sum[:]) return subtle.ConstantTimeCompare([]byte(expected), []byte(actual)) == 1 } bcryptHash := storedHash if strings.HasPrefix(storedHash, "bcrypt:") { bcryptHash = strings.TrimPrefix(storedHash, "bcrypt:") } if strings.HasPrefix(bcryptHash, "$2a$") || strings.HasPrefix(bcryptHash, "$2b$") || strings.HasPrefix(bcryptHash, "$2y$") { return bcrypt.CompareHashAndPassword([]byte(bcryptHash), []byte(password)) == nil } return false } func verifyAdminPasswordHash(storedHash, password string) bool { return verifyPasswordHash(storedHash, password) } func userIDFromContext(ctx context.Context) int64 { v, _ := ctx.Value(userIDKey).(int64) return v } func schemeOf(r *http.Request) string { if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" { return proto } if r.TLS != nil { return "https" } return "http" } func hostOnly(hostport string) string { h := strings.TrimSpace(hostport) if h == "" { return "" } if strings.Contains(h, ":") { if host, _, err := net.SplitHostPort(h); err == nil { return host } } return h } func clientIP(r *http.Request) string { for _, h := range []string{"CF-Connecting-IP", "X-Forwarded-For", "X-Real-IP"} { v := strings.TrimSpace(r.Header.Get(h)) if v != "" { if strings.Contains(v, ",") { return strings.TrimSpace(strings.Split(v, ",")[0]) } return v } } host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return r.RemoteAddr } return host } func writeJSON(w http.ResponseWriter, status int, payload any) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) _ = json.NewEncoder(w).Encode(payload) } func writeErr(w http.ResponseWriter, status int, msg string) { writeJSON(w, status, map[string]string{"error": msg}) } func getEnv(key, fallback string) string { v := strings.TrimSpace(os.Getenv(key)) if v == "" { return fallback } return v } func getEnvInt(key string, fallback int) int { v := strings.TrimSpace(os.Getenv(key)) if v == "" { return fallback } n, err := strconv.Atoi(v) if err != nil { return fallback } return n } func limitString(s string, max int) string { if len(s) <= max { return s } return s[:max] }