Files
ZFile/backend/main.go
2026-03-06 21:32:58 +03:00

2602 lines
71 KiB
Go

package main
import (
"archive/tar"
"archive/zip"
"compress/gzip"
"context"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"database/sql"
"embed"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log"
"mime"
"mime/multipart"
"net"
"net/http"
"os"
"os/exec"
"path"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/gorilla/mux"
"golang.org/x/crypto/argon2"
"golang.org/x/crypto/bcrypt"
)
//go:embed web/dist
var embeddedWeb embed.FS
type Server struct {
db *sql.DB
orm *ormRepo
config Config
storage Storage
limiter *rateLimiter
}
type rateLimiter struct {
mu sync.Mutex
entries map[string]*rateEntry
}
type rateEntry struct {
Count int
WindowEnds time.Time
}
func newRateLimiter() *rateLimiter {
return &rateLimiter{entries: make(map[string]*rateEntry)}
}
func (rl *rateLimiter) allow(key string, limit int, now time.Time) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
entry, ok := rl.entries[key]
if !ok || now.After(entry.WindowEnds) {
rl.entries[key] = &rateEntry{Count: 1, WindowEnds: now.Add(time.Minute)}
return true
}
if entry.Count >= limit {
return false
}
entry.Count++
return true
}
type statusRecorder struct {
http.ResponseWriter
status int
}
func (r *statusRecorder) WriteHeader(code int) {
r.status = code
r.ResponseWriter.WriteHeader(code)
}
type key int
const (
userIDKey key = 1
)
type AccessClaims struct {
UserID int64 `json:"uid"`
jwt.RegisteredClaims
}
type AdminClaims struct {
Login string `json:"login"`
Role string `json:"role"`
jwt.RegisteredClaims
}
type User struct {
ID int64 `json:"id"`
Username string `json:"username"`
Theme string `json:"theme"`
ColorMode string `json:"colorMode"`
Archive string `json:"archiveFormat"`
}
type FileEntry struct {
Name string `json:"name"`
Path string `json:"path"`
IsDir bool `json:"isDir"`
Size int64 `json:"size"`
ModTime time.Time `json:"modTime"`
Tags []string `json:"tags,omitempty"`
}
type FileMeta struct {
Name string
Size int64
ModTime time.Time
IsDir bool
}
type ReadSeekCloser interface {
io.ReadSeeker
io.Closer
}
type Storage interface {
List(userID int64, rel string) ([]FileEntry, error)
Mkdir(userID int64, rel string) error
Save(userID int64, rel string, src multipart.File) error
SaveBytes(userID int64, rel string, data []byte) error
Delete(userID int64, rel string) error
Stat(userID int64, rel string) (FileMeta, error)
OpenReadSeeker(userID int64, rel string) (ReadSeekCloser, error)
}
func main() {
if maybeRunHashCommand() {
return
}
cfg := loadConfig()
db, err := openDB(cfg.DBPath)
if err != nil {
log.Fatalf("db open failed: %v", err)
}
if err := migrate(db); err != nil {
log.Fatalf("migrate failed: %v", err)
}
storage, err := buildStorage(cfg)
if err != nil {
log.Fatalf("storage init failed: %v", err)
}
orm, err := newORMRepo(cfg.DBPath)
if err != nil {
log.Fatalf("orm init failed: %v", err)
}
if err := startProtocolServers(cfg, db); err != nil {
log.Fatalf("protocol init failed: %v", err)
}
s := &Server{db: db, orm: orm, config: cfg, storage: storage, limiter: newRateLimiter()}
r := mux.NewRouter()
r.Use(s.recoverMiddleware)
r.Use(s.securityHeadersMiddleware)
r.Use(s.hostGuardMiddleware)
r.Use(s.corsMiddleware)
r.Use(s.bodyLimitMiddleware)
r.Use(s.rateLimitMiddleware)
r.Use(s.requestLogMiddleware)
r.HandleFunc("/api/health", s.handleHealth).Methods(http.MethodGet)
r.HandleFunc("/api/auth/register", s.handleRegisterDisabled).Methods(http.MethodPost)
r.HandleFunc("/api/auth/login", s.handleLogin).Methods(http.MethodPost)
r.HandleFunc("/api/auth/google/start", s.handleGoogleAuthStart).Methods(http.MethodGet)
r.HandleFunc("/api/auth/google/callback", s.handleGoogleAuthCallback).Methods(http.MethodGet)
r.HandleFunc("/api/auth/refresh", s.handleRefresh).Methods(http.MethodPost)
r.HandleFunc("/api/auth/logout", s.handleLogout).Methods(http.MethodPost)
r.HandleFunc("/api/admin/login", s.handleAdminLogin).Methods(http.MethodPost)
r.HandleFunc("/api/admin/logout", s.handleAdminLogout).Methods(http.MethodPost)
r.HandleFunc("/api/share/{token}", s.handleSharedDownload).Methods(http.MethodGet, http.MethodHead)
protected := r.PathPrefix("/api").Subrouter()
protected.Use(s.authMiddleware)
protected.HandleFunc("/auth/me", s.handleMe).Methods(http.MethodGet)
protected.HandleFunc("/user/preferences", s.handleSetPreferences).Methods(http.MethodPost)
protected.HandleFunc("/user/google/link/start", s.handleGoogleLinkStart).Methods(http.MethodGet)
protected.HandleFunc("/user/protocols", s.handleUserProtocols).Methods(http.MethodGet)
protected.HandleFunc("/files", s.handleListFiles).Methods(http.MethodGet)
protected.HandleFunc("/files/upload", s.handleUpload).Methods(http.MethodPost)
protected.HandleFunc("/files/download", s.handleDownload).Methods(http.MethodGet, http.MethodHead)
protected.HandleFunc("/files/download-batch", s.handleBatchDownload).Methods(http.MethodPost)
protected.HandleFunc("/files/move-batch", s.handleBatchMove).Methods(http.MethodPost)
protected.HandleFunc("/files/preview", s.handlePreview).Methods(http.MethodGet, http.MethodHead)
protected.HandleFunc("/files/text", s.handleReadTextFile).Methods(http.MethodGet)
protected.HandleFunc("/files/text", s.handleWriteTextFile).Methods(http.MethodPut)
protected.HandleFunc("/files/rename", s.handleRename).Methods(http.MethodPost)
protected.HandleFunc("/files", s.handleDelete).Methods(http.MethodDelete)
protected.HandleFunc("/files/folder", s.handleCreateFolder).Methods(http.MethodPost)
protected.HandleFunc("/files/share", s.handleCreateShareLink).Methods(http.MethodPost)
protected.HandleFunc("/files/tags", s.handleListFileTags).Methods(http.MethodGet)
protected.HandleFunc("/files/tags", s.handleAddFileTag).Methods(http.MethodPost)
protected.HandleFunc("/files/tags", s.handleDeleteFileTag).Methods(http.MethodDelete)
admin := r.PathPrefix("/api/admin").Subrouter()
admin.Use(s.adminMiddleware)
admin.HandleFunc("/me", s.handleAdminMe).Methods(http.MethodGet)
admin.HandleFunc("/users", s.handleAdminUsersList).Methods(http.MethodGet)
admin.HandleFunc("/users", s.handleAdminUserCreate).Methods(http.MethodPost)
admin.HandleFunc("/users/{id}", s.handleAdminUserDelete).Methods(http.MethodDelete)
r.PathPrefix("/").Handler(s.staticHandler())
server := &http.Server{
Addr: cfg.Addr,
Handler: r,
ReadHeaderTimeout: 10 * time.Second,
}
log.Printf("listening on %s", cfg.Addr)
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("server failed: %v", err)
}
}
func openDB(path string) (*sql.DB, error) {
db, err := sql.Open("sqlite", path)
if err != nil {
return nil, err
}
if err := db.Ping(); err != nil {
return nil, err
}
return db, nil
}
func migrate(db *sql.DB) error {
stmts := []string{
`CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
email TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
theme TEXT NOT NULL DEFAULT 'dracula',
color_mode TEXT NOT NULL DEFAULT 'auto',
archive_format TEXT NOT NULL DEFAULT 'zip',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);`,
`CREATE TABLE IF NOT EXISTS refresh_tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
token_hash TEXT NOT NULL,
expires_at TIMESTAMP NOT NULL,
revoked_at TIMESTAMP,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
user_agent TEXT,
ip TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
);`,
`CREATE INDEX IF NOT EXISTS idx_refresh_tokens_hash ON refresh_tokens(token_hash);`,
`CREATE TABLE IF NOT EXISTS share_links (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
rel_path TEXT NOT NULL,
token_hash TEXT NOT NULL UNIQUE,
expires_at TIMESTAMP NOT NULL,
max_downloads INTEGER,
download_count INTEGER NOT NULL DEFAULT 0,
revoked_at TIMESTAMP,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(user_id) REFERENCES users(id)
);`,
`CREATE INDEX IF NOT EXISTS idx_share_links_hash ON share_links(token_hash);`,
`CREATE TABLE IF NOT EXISTS file_tags (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
rel_path TEXT NOT NULL,
tag TEXT NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
UNIQUE(user_id, rel_path, tag),
FOREIGN KEY(user_id) REFERENCES users(id)
);`,
`CREATE INDEX IF NOT EXISTS idx_file_tags_user_path ON file_tags(user_id, rel_path);`,
`CREATE INDEX IF NOT EXISTS idx_file_tags_user_tag ON file_tags(user_id, tag);`,
}
for _, stmt := range stmts {
if _, err := db.Exec(stmt); err != nil {
return err
}
}
if _, err := db.Exec(`ALTER TABLE users ADD COLUMN color_mode TEXT NOT NULL DEFAULT 'auto'`); err != nil {
if !strings.Contains(strings.ToLower(err.Error()), "duplicate column") {
return err
}
}
if _, err := db.Exec(`ALTER TABLE users ADD COLUMN archive_format TEXT NOT NULL DEFAULT 'zip'`); err != nil {
if !strings.Contains(strings.ToLower(err.Error()), "duplicate column") {
return err
}
}
if _, err := db.Exec(`ALTER TABLE users ADD COLUMN google_sub TEXT`); err != nil {
if !strings.Contains(strings.ToLower(err.Error()), "duplicate column") {
return err
}
}
if _, err := db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_users_google_sub ON users(google_sub) WHERE google_sub IS NOT NULL`); err != nil {
return err
}
return nil
}
func buildStorage(cfg Config) (Storage, error) {
root := cfg.StorageRoot
if err := os.MkdirAll(root, 0o755); err != nil {
return nil, err
}
return &LocalStorage{root: root}, nil
}
func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
}
func (s *Server) staticHandler() http.Handler {
webRoot, err := fs.Sub(embeddedWeb, "web/dist")
if err != nil {
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "web assets unavailable", http.StatusServiceUnavailable)
})
}
fileServer := http.FileServer(http.FS(webRoot))
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/api/") {
http.NotFound(w, r)
return
}
requestPath := strings.TrimPrefix(path.Clean(r.URL.Path), "/")
if requestPath == "." || requestPath == "" {
requestPath = "index.html"
}
if _, statErr := fs.Stat(webRoot, requestPath); statErr == nil {
fileServer.ServeHTTP(w, r)
return
}
index, readErr := fs.ReadFile(webRoot, "index.html")
if readErr != nil {
http.Error(w, "web app is not bundled", http.StatusServiceUnavailable)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(index)
})
}
func (s *Server) handleRegisterDisabled(w http.ResponseWriter, _ *http.Request) {
writeErr(w, http.StatusForbidden, "public registration is disabled; ask an administrator")
}
type authInput struct {
Username string `json:"username"`
Password string `json:"password"`
}
func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
var in authInput
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
writeErr(w, http.StatusBadRequest, "invalid payload")
return
}
user, hash, err := s.findUserWithHash(in.Username)
if err != nil {
log.Printf("auth.login.failed username=%q ip=%q reason=%q", in.Username, clientIP(r), "user_not_found")
writeErr(w, http.StatusUnauthorized, "invalid credentials")
return
}
if !verifyPasswordHash(hash, in.Password) {
log.Printf("auth.login.failed username=%q ip=%q reason=%q", in.Username, clientIP(r), "invalid_password")
writeErr(w, http.StatusUnauthorized, "invalid credentials")
return
}
if err := s.issueUserSession(w, r, user.ID); err != nil {
log.Printf("auth.login.failed username=%q user_id=%d ip=%q reason=%q", user.Username, user.ID, clientIP(r), "session_issue_failed")
writeErr(w, http.StatusInternalServerError, "failed to create session")
return
}
log.Printf("auth.login.success user_id=%d username=%q ip=%q", user.ID, user.Username, clientIP(r))
writeJSON(w, http.StatusOK, user)
}
func (s *Server) handleRefresh(w http.ResponseWriter, r *http.Request) {
rt, err := r.Cookie("refresh_token")
if err != nil || rt.Value == "" {
writeErr(w, http.StatusUnauthorized, "missing refresh token")
return
}
uid, err := s.consumeRefreshToken(rt.Value)
if err != nil {
writeErr(w, http.StatusUnauthorized, "invalid refresh token")
return
}
if err := s.issueUserSession(w, r, uid); err != nil {
writeErr(w, http.StatusInternalServerError, "failed to refresh session")
return
}
writeJSON(w, http.StatusOK, map[string]string{"status": "refreshed"})
}
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
if cookie, err := r.Cookie("access_token"); err == nil && cookie.Value != "" {
claims := &AccessClaims{}
if tkn, parseErr := jwt.ParseWithClaims(cookie.Value, claims, func(token *jwt.Token) (any, error) {
return []byte(s.config.JWTSecret), nil
}); parseErr == nil && tkn.Valid {
log.Printf("auth.logout user_id=%d ip=%q", claims.UserID, clientIP(r))
}
}
rt, _ := r.Cookie("refresh_token")
if rt != nil && rt.Value != "" {
_, _ = s.db.Exec(`UPDATE refresh_tokens SET revoked_at = CURRENT_TIMESTAMP WHERE token_hash = ?`, hashToken(rt.Value))
}
clearCookie(w, "access_token", s.config.CookieSecure)
clearCookie(w, "refresh_token", s.config.CookieSecure)
writeJSON(w, http.StatusOK, map[string]string{"status": "logged_out"})
}
type adminLoginInput struct {
Login string `json:"login"`
Password string `json:"password"`
}
func (s *Server) handleAdminLogin(w http.ResponseWriter, r *http.Request) {
var in adminLoginInput
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
writeErr(w, http.StatusBadRequest, "invalid payload")
return
}
if subtleConstantTimeEq(strings.TrimSpace(in.Login), s.config.AdminLogin) == 0 || !verifyAdminPasswordHash(s.config.AdminPasswordHash, in.Password) {
log.Printf("auth.admin_login.failed login=%q ip=%q", strings.TrimSpace(in.Login), clientIP(r))
writeErr(w, http.StatusUnauthorized, "invalid admin credentials")
return
}
now := time.Now()
claims := AdminClaims{
Login: s.config.AdminLogin,
Role: "admin",
RegisteredClaims: jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(s.config.AdminSessionTTL)),
Subject: s.config.AdminLogin,
},
}
token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(s.config.JWTSecret))
if err != nil {
writeErr(w, http.StatusInternalServerError, "failed to issue admin session")
return
}
setCookie(w, "admin_token", token, int(s.config.AdminSessionTTL.Seconds()), s.config.CookieSecure)
log.Printf("auth.admin_login.success login=%q ip=%q", s.config.AdminLogin, clientIP(r))
writeJSON(w, http.StatusOK, map[string]string{"login": s.config.AdminLogin})
}
func (s *Server) handleAdminLogout(w http.ResponseWriter, _ *http.Request) {
log.Printf("auth.admin_logout")
clearCookie(w, "admin_token", s.config.CookieSecure)
writeJSON(w, http.StatusOK, map[string]string{"status": "logged_out"})
}
func (s *Server) handleMe(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
user, err := s.findUser(uid)
if err != nil {
writeErr(w, http.StatusNotFound, "user not found")
return
}
writeJSON(w, http.StatusOK, user)
}
type userPrefInput struct {
Theme string `json:"theme"`
ColorMode string `json:"colorMode"`
ArchiveFmt string `json:"archiveFormat"`
}
func (s *Server) handleSetPreferences(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
var in userPrefInput
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
writeErr(w, http.StatusBadRequest, "invalid payload")
return
}
in.Theme = normalizeTheme(in.Theme)
in.ColorMode = normalizeColorMode(in.ColorMode)
in.ArchiveFmt = normalizeArchiveFormat(in.ArchiveFmt)
if in.ArchiveFmt == "" {
in.ArchiveFmt = "zip"
}
if _, err := s.db.Exec(`UPDATE users SET theme = ?, color_mode = ?, archive_format = ? WHERE id = ?`, in.Theme, in.ColorMode, in.ArchiveFmt, uid); err != nil {
writeErr(w, http.StatusInternalServerError, "failed to update preferences")
return
}
writeJSON(w, http.StatusOK, map[string]string{"theme": in.Theme, "colorMode": in.ColorMode, "archiveFormat": in.ArchiveFmt})
}
func (s *Server) handleListFiles(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
rel := r.URL.Query().Get("path")
log.Printf("file.list user_id=%d path=%q", uid, normalizePath(rel))
entries, err := s.storage.List(uid, rel)
if err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
if len(entries) > 0 {
paths := make([]string, 0, len(entries))
for _, e := range entries {
paths = append(paths, normalizePath(e.Path))
}
tagsByPath, tagErr := s.fileTagsForPaths(uid, paths)
if tagErr == nil {
for i := range entries {
entries[i].Tags = tagsByPath[normalizePath(entries[i].Path)]
}
}
}
writeJSON(w, http.StatusOK, map[string]any{"path": normalizePath(rel), "entries": entries})
}
func (s *Server) handleUpload(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
relDir := r.URL.Query().Get("path")
if err := r.ParseMultipartForm(256 << 20); err != nil {
writeErr(w, http.StatusBadRequest, "failed to parse upload")
return
}
files := r.MultipartForm.File["file"]
if len(files) == 0 {
writeErr(w, http.StatusBadRequest, "no files provided")
return
}
for _, fh := range files {
log.Printf("file.upload user_id=%d dir=%q name=%q size=%d", uid, normalizePath(relDir), fh.Filename, fh.Size)
src, err := fh.Open()
if err != nil {
writeErr(w, http.StatusBadRequest, fmt.Sprintf("cannot open %s", fh.Filename))
return
}
relName, err := normalizeUploadRelativePath(fh.Filename)
if err != nil {
_ = src.Close()
writeErr(w, http.StatusBadRequest, err.Error())
return
}
target := path.Join(normalizePath(relDir), relName)
err = s.storage.Save(uid, target, src)
_ = src.Close()
if err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
}
writeJSON(w, http.StatusCreated, map[string]string{"status": "uploaded"})
}
func (s *Server) handleDownload(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
rel := r.URL.Query().Get("path")
log.Printf("file.download user_id=%d path=%q", uid, normalizePath(rel))
if err := s.serveFile(w, r, uid, rel, false, r.URL.Query().Get("archive")); err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
}
}
func (s *Server) handleBatchDownload(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
var in struct {
Paths []string `json:"paths"`
Archive string `json:"archive"`
}
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
writeErr(w, http.StatusBadRequest, "invalid payload")
return
}
if len(in.Paths) == 0 {
writeErr(w, http.StatusBadRequest, "no paths selected")
return
}
if len(in.Paths) > 200 {
writeErr(w, http.StatusBadRequest, "too many selected paths")
return
}
format, err := s.resolveArchiveFormat(uid, in.Archive)
if err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
archivePath, name, ctype, err := s.createBatchArchiveTemp(uid, in.Paths, format)
if err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
defer os.Remove(archivePath)
f, err := os.Open(archivePath)
if err != nil {
writeErr(w, http.StatusInternalServerError, "failed to open archive")
return
}
defer f.Close()
st, err := f.Stat()
if err != nil {
writeErr(w, http.StatusInternalServerError, "failed to stat archive")
return
}
w.Header().Set("Accept-Ranges", "bytes")
w.Header().Set("Content-Type", ctype)
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name))
http.ServeContent(w, r, name, st.ModTime(), f)
}
func (s *Server) handleBatchMove(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
var in struct {
Paths []string `json:"paths"`
Destination string `json:"destination"`
}
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
writeErr(w, http.StatusBadRequest, "invalid payload")
return
}
if len(in.Paths) == 0 {
writeErr(w, http.StatusBadRequest, "no paths selected")
return
}
destDir := normalizePath(in.Destination)
if err := s.storage.Mkdir(uid, destDir); err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
moved := 0
for _, p := range in.Paths {
src := normalizePath(p)
if src == "/" {
continue
}
meta, err := s.storage.Stat(uid, src)
if err != nil {
writeErr(w, http.StatusBadRequest, "invalid path")
return
}
if meta.IsDir {
srcPrefix := strings.TrimSuffix(src, "/") + "/"
if destDir == src || strings.HasPrefix(destDir, srcPrefix) {
writeErr(w, http.StatusBadRequest, "cannot move a folder into itself")
return
}
}
base := path.Base(src)
dst := path.Join(destDir, base)
dst, err = s.nextMoveTarget(uid, dst)
if err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
if err := s.copyPath(uid, src, dst); err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
if err := s.storage.Delete(uid, src); err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
s.moveTags(uid, src, dst)
moved++
}
writeJSON(w, http.StatusOK, map[string]any{"status": "moved", "count": moved})
}
func (s *Server) handlePreview(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
rel := r.URL.Query().Get("path")
log.Printf("file.preview user_id=%d path=%q", uid, normalizePath(rel))
if err := s.serveFile(w, r, uid, rel, true, ""); err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
}
}
func (s *Server) handleReadTextFile(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
rel := normalizePath(r.URL.Query().Get("path"))
if rel == "/" {
writeErr(w, http.StatusBadRequest, "path is required")
return
}
if !isMarkdownExtension(path.Ext(rel)) {
writeErr(w, http.StatusBadRequest, "only markdown files are editable")
return
}
meta, err := s.storage.Stat(uid, rel)
if err != nil || meta.IsDir {
writeErr(w, http.StatusBadRequest, "file not found")
return
}
rc, err := s.storage.OpenReadSeeker(uid, rel)
if err != nil {
writeErr(w, http.StatusBadRequest, "file not found")
return
}
defer rc.Close()
const maxTextBytes = 2 << 20
data, err := io.ReadAll(io.LimitReader(rc, maxTextBytes+1))
if err != nil {
writeErr(w, http.StatusInternalServerError, "failed to read file")
return
}
if len(data) > maxTextBytes {
writeErr(w, http.StatusBadRequest, "markdown file is too large (max 2MB)")
return
}
writeJSON(w, http.StatusOK, map[string]any{"path": rel, "content": string(data), "size": len(data)})
}
func (s *Server) handleWriteTextFile(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
var in struct {
Path string `json:"path"`
Content string `json:"content"`
}
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
writeErr(w, http.StatusBadRequest, "invalid payload")
return
}
rel := normalizePath(in.Path)
if rel == "/" {
writeErr(w, http.StatusBadRequest, "path is required")
return
}
if !isMarkdownExtension(path.Ext(rel)) {
writeErr(w, http.StatusBadRequest, "only markdown files are editable")
return
}
if len(in.Content) > 2<<20 {
writeErr(w, http.StatusBadRequest, "markdown file is too large (max 2MB)")
return
}
if err := s.storage.SaveBytes(uid, rel, []byte(in.Content)); err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]any{"path": rel, "status": "saved"})
}
func (s *Server) handleListFileTags(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
rel := normalizePath(r.URL.Query().Get("path"))
rows, err := s.db.Query(`SELECT tag FROM file_tags WHERE user_id = ? AND rel_path = ? ORDER BY tag ASC`, uid, rel)
if err != nil {
writeErr(w, http.StatusInternalServerError, "failed to load tags")
return
}
defer rows.Close()
tags := make([]string, 0)
for rows.Next() {
var tag string
if rows.Scan(&tag) == nil {
tags = append(tags, tag)
}
}
writeJSON(w, http.StatusOK, map[string]any{"path": rel, "tags": tags})
}
func (s *Server) handleAddFileTag(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
var in struct {
Path string `json:"path"`
Tag string `json:"tag"`
}
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
writeErr(w, http.StatusBadRequest, "invalid payload")
return
}
rel := normalizePath(in.Path)
tag, ok := normalizeTag(in.Tag)
if !ok {
writeErr(w, http.StatusBadRequest, "tag must be 1-24 chars: a-z 0-9 dash underscore")
return
}
if _, err := s.storage.Stat(uid, rel); err != nil {
writeErr(w, http.StatusBadRequest, "file not found")
return
}
if _, err := s.db.Exec(`INSERT OR IGNORE INTO file_tags(user_id, rel_path, tag) VALUES (?, ?, ?)`, uid, rel, tag); err != nil {
writeErr(w, http.StatusInternalServerError, "failed to save tag")
return
}
writeJSON(w, http.StatusCreated, map[string]any{"path": rel, "tag": tag})
}
func (s *Server) handleDeleteFileTag(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
rel := normalizePath(r.URL.Query().Get("path"))
tag, ok := normalizeTag(r.URL.Query().Get("tag"))
if !ok {
writeErr(w, http.StatusBadRequest, "invalid tag")
return
}
if _, err := s.db.Exec(`DELETE FROM file_tags WHERE user_id = ? AND rel_path = ? AND tag = ?`, uid, rel, tag); err != nil {
writeErr(w, http.StatusInternalServerError, "failed to delete tag")
return
}
writeJSON(w, http.StatusOK, map[string]any{"path": rel, "tag": tag, "status": "deleted"})
}
func (s *Server) serveFile(w http.ResponseWriter, r *http.Request, uid int64, rel string, inline bool, archiveQuery string) error {
meta, err := s.storage.Stat(uid, rel)
if err != nil {
return err
}
if meta.IsDir {
if inline {
return fmt.Errorf("cannot preview a directory")
}
format, err := s.resolveArchiveFormat(uid, archiveQuery)
if err != nil {
return err
}
return s.serveDirectoryArchive(w, r, uid, rel, format)
}
rc, err := s.storage.OpenReadSeeker(uid, rel)
if err != nil {
return err
}
defer rc.Close()
name := path.Base(normalizePath(rel))
if name == "." || name == "/" || name == "" {
name = "download"
}
ctype := mime.TypeByExtension(strings.ToLower(filepath.Ext(name)))
if ctype == "" {
ctype = "application/octet-stream"
}
dispositionType := "attachment"
if inline {
dispositionType = "inline"
}
w.Header().Set("Accept-Ranges", "bytes")
w.Header().Set("Content-Type", ctype)
w.Header().Set("Content-Disposition", fmt.Sprintf("%s; filename=%q", dispositionType, name))
http.ServeContent(w, r, name, meta.ModTime, rc)
return nil
}
func (s *Server) resolveArchiveFormat(uid int64, archiveQuery string) (string, error) {
requested := normalizeArchiveFormat(archiveQuery)
if requested != "" {
return requested, nil
}
var fmtPref string
err := s.db.QueryRow(`SELECT archive_format FROM users WHERE id = ?`, uid).Scan(&fmtPref)
if err != nil {
return "zip", nil
}
fmtPref = normalizeArchiveFormat(fmtPref)
if fmtPref == "" {
return "zip", nil
}
return fmtPref, nil
}
func (s *Server) serveDirectoryArchive(w http.ResponseWriter, r *http.Request, uid int64, rel, format string) error {
base := path.Base(normalizePath(rel))
if base == "/" || base == "." || base == "" {
base = "folder"
}
archivePath, downloadName, ctype, err := s.createArchiveTemp(uid, rel, base, format)
if err != nil {
return err
}
defer os.Remove(archivePath)
f, err := os.Open(archivePath)
if err != nil {
return err
}
defer f.Close()
st, err := f.Stat()
if err != nil {
return err
}
w.Header().Set("Accept-Ranges", "bytes")
w.Header().Set("Content-Type", ctype)
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", downloadName))
http.ServeContent(w, r, downloadName, st.ModTime(), f)
return nil
}
func (s *Server) createArchiveTemp(uid int64, rel, base, format string) (string, string, string, error) {
switch format {
case "rar":
return s.createRarTemp(uid, rel, base)
case "tar.gz":
return s.createTarGzTemp(uid, rel, base)
case "lz4":
return s.createTarLz4Temp(uid, rel, base)
default:
return s.createZipTemp(uid, rel, base)
}
}
func (s *Server) createBatchArchiveTemp(uid int64, paths []string, format string) (string, string, string, error) {
workdir, err := os.MkdirTemp("", "filez-batch-*")
if err != nil {
return "", "", "", err
}
defer os.RemoveAll(workdir)
root := filepath.Join(workdir, "selection")
if err := os.MkdirAll(root, 0o755); err != nil {
return "", "", "", err
}
used := map[string]int{}
for _, raw := range paths {
rel := normalizePath(raw)
if rel == "/" {
continue
}
meta, err := s.storage.Stat(uid, rel)
if err != nil {
return "", "", "", fmt.Errorf("invalid path: %s", rel)
}
base := path.Base(rel)
if base == "." || base == "/" || base == "" {
base = "item"
}
targetName := uniqueArchiveName(base, used)
target := filepath.Join(root, targetName)
if meta.IsDir {
if err := os.MkdirAll(target, 0o755); err != nil {
return "", "", "", err
}
}
if err := s.materializePath(uid, rel, target); err != nil {
return "", "", "", err
}
}
baseName := "files"
switch format {
case "tar.gz":
return createTarGzFromLocalDir(root, baseName)
case "lz4":
return createTarLz4FromLocalDir(root, baseName)
case "rar":
return createRarFromLocalDir(root, baseName)
default:
return createZipFromLocalDir(root, baseName)
}
}
func (s *Server) createZipTemp(uid int64, rel, base string) (string, string, string, error) {
tmp, err := os.CreateTemp("", "filez-*.zip")
if err != nil {
return "", "", "", err
}
defer tmp.Close()
zw := zip.NewWriter(tmp)
if err := s.addPathToZip(zw, uid, rel, base); err != nil {
zw.Close()
return "", "", "", err
}
if err := zw.Close(); err != nil {
return "", "", "", err
}
return tmp.Name(), base + ".zip", "application/zip", nil
}
func (s *Server) addPathToZip(zw *zip.Writer, uid int64, rel string, zipBase string) error {
meta, err := s.storage.Stat(uid, rel)
if err != nil {
return err
}
if meta.IsDir {
if zipBase != "" {
hdr := &zip.FileHeader{Name: strings.TrimPrefix(zipBase, "/") + "/", Method: zip.Store}
hdr.Modified = meta.ModTime
if _, err := zw.CreateHeader(hdr); err != nil {
return err
}
}
entries, err := s.storage.List(uid, rel)
if err != nil {
return err
}
for _, e := range entries {
childRel := path.Join(normalizePath(rel), e.Name)
childBase := path.Join(zipBase, e.Name)
if err := s.addPathToZip(zw, uid, childRel, childBase); err != nil {
return err
}
}
return nil
}
rc, err := s.storage.OpenReadSeeker(uid, rel)
if err != nil {
return err
}
defer rc.Close()
hdr := &zip.FileHeader{Name: strings.TrimPrefix(zipBase, "/"), Method: zip.Deflate}
hdr.Modified = meta.ModTime
writer, err := zw.CreateHeader(hdr)
if err != nil {
return err
}
_, err = io.Copy(writer, rc)
return err
}
func (s *Server) createRarTemp(uid int64, rel, base string) (string, string, string, error) {
if _, err := exec.LookPath("rar"); err != nil {
return "", "", "", fmt.Errorf("rar format requires 'rar' binary on server; choose zip or tar.gz in settings")
}
workdir, err := os.MkdirTemp("", "filez-rar-*")
if err != nil {
return "", "", "", err
}
root := filepath.Join(workdir, base)
if err := os.MkdirAll(root, 0o755); err != nil {
os.RemoveAll(workdir)
return "", "", "", err
}
if err := s.materializePath(uid, rel, root); err != nil {
os.RemoveAll(workdir)
return "", "", "", err
}
outTmp, err := os.CreateTemp("", "filez-*.rar")
if err != nil {
os.RemoveAll(workdir)
return "", "", "", err
}
archivePath := outTmp.Name()
outTmp.Close()
cmd := exec.Command("rar", "a", "-idq", archivePath, root)
if out, err := cmd.CombinedOutput(); err != nil {
os.Remove(archivePath)
os.RemoveAll(workdir)
return "", "", "", fmt.Errorf("rar creation failed: %s", strings.TrimSpace(string(out)))
}
os.RemoveAll(workdir)
return archivePath, base + ".rar", "application/vnd.rar", nil
}
func (s *Server) createTarGzTemp(uid int64, rel, base string) (string, string, string, error) {
tmp, err := os.CreateTemp("", "filez-*.tar.gz")
if err != nil {
return "", "", "", err
}
defer tmp.Close()
gzw := gzip.NewWriter(tmp)
tw := tar.NewWriter(gzw)
if err := s.addPathToTar(tw, uid, rel, base); err != nil {
tw.Close()
gzw.Close()
return "", "", "", err
}
if err := tw.Close(); err != nil {
gzw.Close()
return "", "", "", err
}
if err := gzw.Close(); err != nil {
return "", "", "", err
}
return tmp.Name(), base + ".tar.gz", "application/gzip", nil
}
func (s *Server) createTarLz4Temp(uid int64, rel, base string) (string, string, string, error) {
if _, err := exec.LookPath("lz4"); err != nil {
return "", "", "", fmt.Errorf("lz4 format requires 'lz4' binary on server; choose zip or tar.gz in settings")
}
cleanupStaleTempArchives(1 * time.Hour)
tarPath, err := s.createTarTemp(uid, rel, base)
if err != nil {
return "", "", "", err
}
defer os.Remove(tarPath)
outTmp, err := os.CreateTemp("", "filez-*.tar.lz4")
if err != nil {
return "", "", "", err
}
outPath := outTmp.Name()
outTmp.Close()
_ = os.Remove(outPath)
cmd := exec.Command("lz4", "-z", "-q", tarPath, outPath)
if out, err := cmd.CombinedOutput(); err != nil {
os.Remove(outPath)
return "", "", "", fmt.Errorf("lz4 creation failed: %s", strings.TrimSpace(string(out)))
}
return outPath, base + ".tar.lz4", "application/x-lz4", nil
}
func (s *Server) createTarTemp(uid int64, rel, base string) (string, error) {
tmp, err := os.CreateTemp("", "filez-*.tar")
if err != nil {
return "", err
}
defer tmp.Close()
tw := tar.NewWriter(tmp)
if err := s.addPathToTar(tw, uid, rel, base); err != nil {
tw.Close()
return "", err
}
if err := tw.Close(); err != nil {
return "", err
}
return tmp.Name(), nil
}
func (s *Server) addPathToTar(tw *tar.Writer, uid int64, rel string, tarBase string) error {
meta, err := s.storage.Stat(uid, rel)
if err != nil {
return err
}
if meta.IsDir {
if tarBase != "" {
hdr := &tar.Header{Name: strings.TrimPrefix(tarBase, "/") + "/", Mode: 0o755, ModTime: meta.ModTime, Typeflag: tar.TypeDir}
if err := tw.WriteHeader(hdr); err != nil {
return err
}
}
entries, err := s.storage.List(uid, rel)
if err != nil {
return err
}
for _, e := range entries {
childRel := path.Join(normalizePath(rel), e.Name)
childBase := path.Join(tarBase, e.Name)
if err := s.addPathToTar(tw, uid, childRel, childBase); err != nil {
return err
}
}
return nil
}
rc, err := s.storage.OpenReadSeeker(uid, rel)
if err != nil {
return err
}
defer rc.Close()
hdr := &tar.Header{Name: strings.TrimPrefix(tarBase, "/"), Mode: 0o644, Size: meta.Size, ModTime: meta.ModTime, Typeflag: tar.TypeReg}
if err := tw.WriteHeader(hdr); err != nil {
return err
}
_, err = io.Copy(tw, rc)
return err
}
func (s *Server) materializePath(uid int64, rel string, localPath string) error {
meta, err := s.storage.Stat(uid, rel)
if err != nil {
return err
}
if meta.IsDir {
entries, err := s.storage.List(uid, rel)
if err != nil {
return err
}
for _, e := range entries {
childRel := path.Join(normalizePath(rel), e.Name)
childLocal := filepath.Join(localPath, e.Name)
if e.IsDir {
if err := os.MkdirAll(childLocal, 0o755); err != nil {
return err
}
}
if err := s.materializePath(uid, childRel, childLocal); err != nil {
return err
}
}
return nil
}
rc, err := s.storage.OpenReadSeeker(uid, rel)
if err != nil {
return err
}
defer rc.Close()
out, err := os.Create(localPath)
if err != nil {
return err
}
defer out.Close()
_, err = io.Copy(out, rc)
return err
}
func (s *Server) copyPath(uid int64, src, dst string) error {
meta, err := s.storage.Stat(uid, src)
if err != nil {
return err
}
if meta.IsDir {
if err := s.storage.Mkdir(uid, dst); err != nil {
return err
}
entries, err := s.storage.List(uid, src)
if err != nil {
return err
}
for _, e := range entries {
childSrc := path.Join(src, e.Name)
childDst := path.Join(dst, e.Name)
if err := s.copyPath(uid, childSrc, childDst); err != nil {
return err
}
}
return nil
}
rc, err := s.storage.OpenReadSeeker(uid, src)
if err != nil {
return err
}
defer rc.Close()
data, err := io.ReadAll(rc)
if err != nil {
return err
}
return s.storage.SaveBytes(uid, dst, data)
}
func (s *Server) nextMoveTarget(uid int64, dst string) (string, error) {
norm := normalizePath(dst)
if _, err := s.storage.Stat(uid, norm); err != nil {
return norm, nil
}
ext := path.Ext(norm)
base := strings.TrimSuffix(path.Base(norm), ext)
dir := path.Dir(norm)
for i := 2; i <= 9999; i++ {
candidate := path.Join(dir, fmt.Sprintf("%s-%d%s", base, i, ext))
if _, err := s.storage.Stat(uid, candidate); err != nil {
return candidate, nil
}
}
return "", fmt.Errorf("cannot allocate destination name")
}
func (s *Server) moveTags(uid int64, src, dst string) {
_, _ = s.db.Exec(`UPDATE file_tags SET rel_path = ? WHERE user_id = ? AND rel_path = ?`, dst, uid, src)
prefixFrom := strings.TrimSuffix(src, "/") + "/"
prefixTo := strings.TrimSuffix(dst, "/") + "/"
rows, err := s.db.Query(`SELECT rel_path, tag FROM file_tags WHERE user_id = ? AND rel_path LIKE ?`, uid, prefixFrom+"%")
if err != nil {
return
}
defer rows.Close()
type rowItem struct{ p, t string }
items := make([]rowItem, 0)
for rows.Next() {
var rp, tg string
if rows.Scan(&rp, &tg) == nil {
items = append(items, rowItem{p: rp, t: tg})
}
}
for _, it := range items {
next := prefixTo + strings.TrimPrefix(it.p, prefixFrom)
_, _ = s.db.Exec(`DELETE FROM file_tags WHERE user_id = ? AND rel_path = ? AND tag = ?`, uid, it.p, it.t)
_, _ = s.db.Exec(`INSERT OR IGNORE INTO file_tags(user_id, rel_path, tag) VALUES (?, ?, ?)`, uid, next, it.t)
}
}
func uniqueArchiveName(base string, used map[string]int) string {
base = strings.TrimSpace(base)
if base == "" {
base = "item"
}
count := used[base]
used[base] = count + 1
if count == 0 {
return base
}
ext := path.Ext(base)
name := strings.TrimSuffix(base, ext)
return fmt.Sprintf("%s-%d%s", name, count+1, ext)
}
func createZipFromLocalDir(root, baseName string) (string, string, string, error) {
tmp, err := os.CreateTemp("", "filez-batch-*.zip")
if err != nil {
return "", "", "", err
}
defer tmp.Close()
zw := zip.NewWriter(tmp)
err = filepath.WalkDir(root, func(p string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
rel, err := filepath.Rel(root, p)
if err != nil {
return err
}
if rel == "." {
return nil
}
rel = filepath.ToSlash(rel)
info, err := d.Info()
if err != nil {
return err
}
if d.IsDir() {
hdr := &zip.FileHeader{Name: rel + "/", Method: zip.Store}
hdr.Modified = info.ModTime()
_, err = zw.CreateHeader(hdr)
return err
}
hdr, err := zip.FileInfoHeader(info)
if err != nil {
return err
}
hdr.Name = rel
hdr.Method = zip.Deflate
w, err := zw.CreateHeader(hdr)
if err != nil {
return err
}
f, err := os.Open(p)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(w, f)
return err
})
if err != nil {
zw.Close()
return "", "", "", err
}
if err := zw.Close(); err != nil {
return "", "", "", err
}
return tmp.Name(), baseName + ".zip", "application/zip", nil
}
func createTarGzFromLocalDir(root, baseName string) (string, string, string, error) {
tmp, err := os.CreateTemp("", "filez-batch-*.tar.gz")
if err != nil {
return "", "", "", err
}
defer tmp.Close()
gzw := gzip.NewWriter(tmp)
tw := tar.NewWriter(gzw)
err = filepath.WalkDir(root, func(p string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
rel, err := filepath.Rel(root, p)
if err != nil {
return err
}
if rel == "." {
return nil
}
rel = filepath.ToSlash(rel)
info, err := d.Info()
if err != nil {
return err
}
hdr, err := tar.FileInfoHeader(info, "")
if err != nil {
return err
}
if d.IsDir() {
hdr.Name = rel + "/"
} else {
hdr.Name = rel
}
if err := tw.WriteHeader(hdr); err != nil {
return err
}
if d.IsDir() {
return nil
}
f, err := os.Open(p)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(tw, f)
return err
})
if err != nil {
tw.Close()
gzw.Close()
return "", "", "", err
}
if err := tw.Close(); err != nil {
gzw.Close()
return "", "", "", err
}
if err := gzw.Close(); err != nil {
return "", "", "", err
}
return tmp.Name(), baseName + ".tar.gz", "application/gzip", nil
}
func createTarFromLocalDir(root string) (string, error) {
tmp, err := os.CreateTemp("", "filez-batch-*.tar")
if err != nil {
return "", err
}
defer tmp.Close()
tw := tar.NewWriter(tmp)
err = filepath.WalkDir(root, func(p string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
rel, err := filepath.Rel(root, p)
if err != nil {
return err
}
if rel == "." {
return nil
}
rel = filepath.ToSlash(rel)
info, err := d.Info()
if err != nil {
return err
}
hdr, err := tar.FileInfoHeader(info, "")
if err != nil {
return err
}
if d.IsDir() {
hdr.Name = rel + "/"
} else {
hdr.Name = rel
}
if err := tw.WriteHeader(hdr); err != nil {
return err
}
if d.IsDir() {
return nil
}
f, err := os.Open(p)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(tw, f)
return err
})
if err != nil {
tw.Close()
return "", err
}
if err := tw.Close(); err != nil {
return "", err
}
return tmp.Name(), nil
}
func createTarLz4FromLocalDir(root, baseName string) (string, string, string, error) {
if _, err := exec.LookPath("lz4"); err != nil {
return "", "", "", fmt.Errorf("lz4 format requires 'lz4' binary on server; choose zip or tar.gz in settings")
}
cleanupStaleTempArchives(1 * time.Hour)
tarPath, err := createTarFromLocalDir(root)
if err != nil {
return "", "", "", err
}
defer os.Remove(tarPath)
outTmp, err := os.CreateTemp("", "filez-batch-*.tar.lz4")
if err != nil {
return "", "", "", err
}
outPath := outTmp.Name()
outTmp.Close()
_ = os.Remove(outPath)
cmd := exec.Command("lz4", "-z", "-q", tarPath, outPath)
if out, err := cmd.CombinedOutput(); err != nil {
os.Remove(outPath)
return "", "", "", fmt.Errorf("lz4 creation failed: %s", strings.TrimSpace(string(out)))
}
return outPath, baseName + ".tar.lz4", "application/x-lz4", nil
}
func cleanupStaleTempArchives(maxIdle time.Duration) {
tmpDir := os.TempDir()
entries, err := os.ReadDir(tmpDir)
if err != nil {
return
}
cutoff := time.Now().Add(-maxIdle)
for _, entry := range entries {
name := entry.Name()
if !strings.HasPrefix(name, "filez-") && !strings.HasPrefix(name, "filez-batch-") {
continue
}
if !(strings.HasSuffix(name, ".tar.lz4") || strings.HasSuffix(name, ".tar") || strings.HasSuffix(name, ".tar.gz") || strings.HasSuffix(name, ".zip") || strings.HasSuffix(name, ".rar")) {
continue
}
info, err := entry.Info()
if err != nil || info.IsDir() {
continue
}
if info.ModTime().After(cutoff) {
continue
}
_ = os.Remove(filepath.Join(tmpDir, name))
}
}
func createRarFromLocalDir(root, baseName string) (string, string, string, error) {
if _, err := exec.LookPath("rar"); err != nil {
return "", "", "", fmt.Errorf("rar format requires 'rar' binary on server; choose zip or tar.gz in settings")
}
outTmp, err := os.CreateTemp("", "filez-batch-*.rar")
if err != nil {
return "", "", "", err
}
outPath := outTmp.Name()
outTmp.Close()
cmd := exec.Command("rar", "a", "-idq", outPath, root)
if out, err := cmd.CombinedOutput(); err != nil {
os.Remove(outPath)
return "", "", "", fmt.Errorf("rar creation failed: %s", strings.TrimSpace(string(out)))
}
return outPath, baseName + ".rar", "application/vnd.rar", nil
}
func (s *Server) handleDelete(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
rel := r.URL.Query().Get("path")
norm := normalizePath(rel)
log.Printf("file.delete user_id=%d path=%q", uid, norm)
if err := s.storage.Delete(uid, norm); err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
_, _ = s.db.Exec(`DELETE FROM file_tags WHERE user_id = ? AND (rel_path = ? OR rel_path LIKE ?)`, uid, norm, strings.TrimSuffix(norm, "/")+"/%")
writeJSON(w, http.StatusOK, map[string]string{"status": "deleted"})
}
func (s *Server) handleCreateFolder(w http.ResponseWriter, r *http.Request) {
type payload struct {
Path string `json:"path"`
Name string `json:"name"`
}
var in payload
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
writeErr(w, http.StatusBadRequest, "invalid payload")
return
}
uid := userIDFromContext(r.Context())
rel := path.Join(normalizePath(in.Path), path.Base(strings.TrimSpace(in.Name)))
log.Printf("file.mkdir user_id=%d path=%q", uid, normalizePath(rel))
if rel == "/" || rel == "." {
writeErr(w, http.StatusBadRequest, "invalid folder name")
return
}
if err := s.storage.Mkdir(uid, rel); err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
writeJSON(w, http.StatusCreated, map[string]string{"status": "created"})
}
func (s *Server) handleRename(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
var in struct {
Path string `json:"path"`
Name string `json:"name"`
}
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
writeErr(w, http.StatusBadRequest, "invalid payload")
return
}
src := normalizePath(in.Path)
if src == "/" {
writeErr(w, http.StatusBadRequest, "invalid path")
return
}
meta, err := s.storage.Stat(uid, src)
if err != nil {
writeErr(w, http.StatusBadRequest, "file not found")
return
}
name := path.Base(strings.TrimSpace(in.Name))
if name == "" || name == "." || name == ".." {
writeErr(w, http.StatusBadRequest, "invalid name")
return
}
dir := path.Dir(src)
if dir == "." {
dir = "/"
}
dst := normalizePath(path.Join(dir, name))
if dst == src {
writeJSON(w, http.StatusOK, map[string]any{"status": "renamed", "path": src})
return
}
if _, err := s.storage.Stat(uid, dst); err == nil {
writeErr(w, http.StatusBadRequest, "target already exists")
return
}
if meta.IsDir {
srcPrefix := strings.TrimSuffix(src, "/") + "/"
if dst == src || strings.HasPrefix(dst, srcPrefix) {
writeErr(w, http.StatusBadRequest, "cannot rename folder into itself")
return
}
}
if err := s.copyPath(uid, src, dst); err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
if err := s.storage.Delete(uid, src); err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
return
}
s.moveTags(uid, src, dst)
writeJSON(w, http.StatusOK, map[string]any{"status": "renamed", "path": dst})
}
type shareInput struct {
Path string `json:"path"`
ExpiresMinutes int `json:"expiresMinutes"`
MaxDownloads *int `json:"maxDownloads"`
AllowPreview bool `json:"allowPreview"`
PreferredInline bool `json:"preferredInline"`
}
func (s *Server) handleCreateShareLink(w http.ResponseWriter, r *http.Request) {
uid := userIDFromContext(r.Context())
var in shareInput
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
writeErr(w, http.StatusBadRequest, "invalid payload")
return
}
rel := normalizePath(in.Path)
log.Printf("file.share.create user_id=%d path=%q", uid, rel)
meta, err := s.storage.Stat(uid, rel)
if err != nil {
writeErr(w, http.StatusBadRequest, "file not found")
return
}
if meta.IsDir {
writeErr(w, http.StatusBadRequest, "cannot create share for folder")
return
}
ttl := s.config.ShareDefaultTTL
if in.ExpiresMinutes > 0 {
ttl = time.Duration(in.ExpiresMinutes) * time.Minute
}
if ttl < 5*time.Minute {
ttl = 5 * time.Minute
}
if ttl > 30*24*time.Hour {
ttl = 30 * 24 * time.Hour
}
token, err := randomToken()
if err != nil {
writeErr(w, http.StatusInternalServerError, "failed to create token")
return
}
var maxDownloads any
if in.MaxDownloads != nil && *in.MaxDownloads > 0 {
maxDownloads = *in.MaxDownloads
}
expiresAt := time.Now().Add(ttl)
_, err = s.db.Exec(`INSERT INTO share_links(user_id, rel_path, token_hash, expires_at, max_downloads) VALUES (?, ?, ?, ?, ?)`,
uid,
rel,
hashToken(token),
expiresAt,
maxDownloads,
)
if err != nil {
writeErr(w, http.StatusInternalServerError, "failed to create share")
return
}
shareURL := fmt.Sprintf("%s://%s/api/share/%s", schemeOf(r), r.Host, token)
writeJSON(w, http.StatusCreated, map[string]any{
"url": shareURL,
"token": token,
"path": rel,
"expiresAt": expiresAt,
})
}
func (s *Server) handleSharedDownload(w http.ResponseWriter, r *http.Request) {
token := mux.Vars(r)["token"]
if strings.TrimSpace(token) == "" {
writeErr(w, http.StatusBadRequest, "missing token")
return
}
var uid int64
var rel string
var expiresAt time.Time
var revokedAt sql.NullTime
var maxDownloads sql.NullInt64
var downloadCount int64
err := s.db.QueryRow(`SELECT user_id, rel_path, expires_at, revoked_at, max_downloads, download_count FROM share_links WHERE token_hash = ?`, hashToken(token)).
Scan(&uid, &rel, &expiresAt, &revokedAt, &maxDownloads, &downloadCount)
if err != nil {
writeErr(w, http.StatusNotFound, "share link not found")
return
}
if revokedAt.Valid || expiresAt.Before(time.Now()) {
writeErr(w, http.StatusGone, "share link expired")
return
}
if maxDownloads.Valid && downloadCount >= maxDownloads.Int64 {
writeErr(w, http.StatusGone, "share link download limit reached")
return
}
if _, err := s.db.Exec(`UPDATE share_links SET download_count = download_count + 1 WHERE token_hash = ?`, hashToken(token)); err != nil {
writeErr(w, http.StatusInternalServerError, "failed to track download")
return
}
log.Printf("file.share.download user_id=%d path=%q ip=%q", uid, normalizePath(rel), clientIP(r))
if err := s.serveFile(w, r, uid, rel, true, ""); err != nil {
writeErr(w, http.StatusBadRequest, err.Error())
}
}
func (s *Server) handleAdminMe(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, http.StatusOK, map[string]string{"login": s.config.AdminLogin})
}
func (s *Server) handleAdminUsersList(w http.ResponseWriter, _ *http.Request) {
users, err := s.orm.listUsers()
if err != nil {
writeErr(w, http.StatusInternalServerError, "failed to load users")
return
}
writeJSON(w, http.StatusOK, map[string]any{"users": users})
}
type adminCreateUserInput struct {
Username string `json:"username"`
Password string `json:"password"`
Theme string `json:"theme"`
ColorMode string `json:"colorMode"`
}
func (s *Server) handleAdminUserCreate(w http.ResponseWriter, r *http.Request) {
var in adminCreateUserInput
if err := json.NewDecoder(r.Body).Decode(&in); err != nil {
writeErr(w, http.StatusBadRequest, "invalid payload")
return
}
in.Theme = normalizeTheme(in.Theme)
in.ColorMode = normalizeColorMode(in.ColorMode)
user, err := s.createUser(in.Username, in.Password, in.Theme, in.ColorMode)
if err != nil {
if strings.Contains(strings.ToLower(err.Error()), "exists") {
writeErr(w, http.StatusConflict, err.Error())
return
}
writeErr(w, http.StatusBadRequest, err.Error())
return
}
writeJSON(w, http.StatusCreated, user)
log.Printf("admin.user.create user_id=%d username=%q", user.ID, user.Username)
}
func (s *Server) handleAdminUserDelete(w http.ResponseWriter, r *http.Request) {
idStr := mux.Vars(r)["id"]
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
writeErr(w, http.StatusBadRequest, "invalid user id")
return
}
if _, err := s.db.Exec(`DELETE FROM users WHERE id = ?`, id); err != nil {
writeErr(w, http.StatusInternalServerError, "failed to delete user")
return
}
_, _ = s.db.Exec(`DELETE FROM refresh_tokens WHERE user_id = ?`, id)
_, _ = s.db.Exec(`UPDATE share_links SET revoked_at = CURRENT_TIMESTAMP WHERE user_id = ?`, id)
_ = s.storage.Delete(id, "/")
writeJSON(w, http.StatusOK, map[string]string{"status": "deleted"})
log.Printf("admin.user.delete user_id=%d", id)
}
func (s *Server) recoverMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if v := recover(); v != nil {
log.Printf("panic recovered path=%q err=%v", r.URL.Path, v)
writeErr(w, http.StatusInternalServerError, "internal server error")
}
}()
next.ServeHTTP(w, r)
})
}
func (s *Server) securityHeadersMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()")
w.Header().Set("Content-Security-Policy", "default-src 'self'; connect-src 'self'; img-src 'self' data: blob:; style-src 'self' 'unsafe-inline'; script-src 'self'; frame-ancestors 'none'; base-uri 'self'; form-action 'self'")
if r.TLS != nil {
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
}
next.ServeHTTP(w, r)
})
}
func (s *Server) bodyLimitMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/api/files/upload") {
next.ServeHTTP(w, r)
return
}
if strings.HasPrefix(r.URL.Path, "/api/") {
r.Body = http.MaxBytesReader(w, r.Body, s.config.MaxBodyBytes)
}
next.ServeHTTP(w, r)
})
}
func (s *Server) rateLimitMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.Path, "/api/") {
next.ServeHTTP(w, r)
return
}
ip := clientIP(r)
limit := s.config.RateLimitPerMin
if strings.Contains(r.URL.Path, "/auth/") || strings.Contains(r.URL.Path, "/admin/login") {
limit = s.config.AuthRateLimitPerMin
}
if !s.limiter.allow(ip+":"+r.URL.Path, limit, time.Now()) {
writeErr(w, http.StatusTooManyRequests, "rate limit exceeded")
return
}
next.ServeHTTP(w, r)
})
}
func (s *Server) requestLogMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK}
next.ServeHTTP(rec, r)
if strings.HasPrefix(r.URL.Path, "/api/") {
log.Printf(
"api.request method=%s path=%q status=%d duration_ms=%d ip=%q ua=%q",
r.Method,
r.URL.Path,
rec.status,
time.Since(start).Milliseconds(),
clientIP(r),
limitString(r.UserAgent(), 120),
)
}
})
}
func (s *Server) createUser(username, password, theme, colorMode string) (User, error) {
username = strings.ToLower(strings.TrimSpace(username))
if len(username) < 3 || len(username) > 32 {
return User{}, fmt.Errorf("username must be 3-32 characters")
}
for _, ch := range username {
if (ch < 'a' || ch > 'z') && (ch < '0' || ch > '9') && ch != '_' && ch != '-' && ch != '.' {
return User{}, fmt.Errorf("username can contain only a-z, 0-9, dot, dash, underscore")
}
}
if len(password) < 10 {
return User{}, fmt.Errorf("password must be at least 10 characters")
}
hash, err := hashPasswordArgon2ID(password)
if err != nil {
return User{}, fmt.Errorf("failed to hash password")
}
id, err := s.orm.createUser(username, hash, normalizeTheme(theme), normalizeColorMode(colorMode), "zip", nil)
if err != nil {
if strings.Contains(strings.ToLower(err.Error()), "unique") {
return User{}, fmt.Errorf("account already exists")
}
return User{}, err
}
if err := s.storage.Mkdir(id, "/"); err != nil {
return User{}, fmt.Errorf("failed to provision user storage: %w", err)
}
return User{ID: id, Username: username, Theme: normalizeTheme(theme), ColorMode: normalizeColorMode(colorMode), Archive: "zip"}, nil
}
func (s *Server) findUserWithHash(username string) (User, string, error) {
return s.orm.findUserWithHashByEmail(username)
}
func (s *Server) findUser(id int64) (User, error) {
return s.orm.findUserByID(id)
}
func (s *Server) issueUserSession(w http.ResponseWriter, r *http.Request, userID int64) error {
now := time.Now()
claims := AccessClaims{
UserID: userID,
RegisteredClaims: jwt.RegisteredClaims{
Subject: strconv.FormatInt(userID, 10),
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(s.config.AccessTTL)),
},
}
access, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(s.config.JWTSecret))
if err != nil {
return err
}
refresh, err := randomToken()
if err != nil {
return err
}
if _, err := s.db.Exec(`INSERT INTO refresh_tokens(user_id, token_hash, expires_at, user_agent, ip) VALUES (?, ?, ?, ?, ?)`,
userID,
hashToken(refresh),
now.Add(s.config.RefreshTTL),
limitString(r.UserAgent(), 255),
limitString(clientIP(r), 64),
); err != nil {
return err
}
setCookie(w, "access_token", access, int(s.config.AccessTTL.Seconds()), s.config.CookieSecure)
setCookie(w, "refresh_token", refresh, int(s.config.RefreshTTL.Seconds()), s.config.CookieSecure)
return nil
}
func (s *Server) consumeRefreshToken(token string) (int64, error) {
var id int64
var uid int64
var expiresAt time.Time
var revokedAt sql.NullTime
err := s.db.QueryRow(`SELECT id, user_id, expires_at, revoked_at FROM refresh_tokens WHERE token_hash = ?`, hashToken(token)).
Scan(&id, &uid, &expiresAt, &revokedAt)
if err != nil {
return 0, err
}
if revokedAt.Valid || expiresAt.Before(time.Now()) {
return 0, fmt.Errorf("refresh token expired or revoked")
}
if _, err := s.db.Exec(`UPDATE refresh_tokens SET revoked_at = CURRENT_TIMESTAMP WHERE id = ?`, id); err != nil {
return 0, err
}
return uid, nil
}
func (s *Server) authMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("access_token")
if err != nil || cookie.Value == "" {
writeErr(w, http.StatusUnauthorized, "missing access token")
return
}
claims := &AccessClaims{}
tkn, err := jwt.ParseWithClaims(cookie.Value, claims, func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method")
}
return []byte(s.config.JWTSecret), nil
})
if err != nil || !tkn.Valid {
writeErr(w, http.StatusUnauthorized, "invalid access token")
return
}
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), userIDKey, claims.UserID)))
})
}
func (s *Server) adminMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("admin_token")
if err != nil || cookie.Value == "" {
writeErr(w, http.StatusUnauthorized, "missing admin token")
return
}
claims := &AdminClaims{}
tkn, err := jwt.ParseWithClaims(cookie.Value, claims, func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method")
}
return []byte(s.config.JWTSecret), nil
})
if err != nil || !tkn.Valid || claims.Role != "admin" || claims.Login != s.config.AdminLogin {
writeErr(w, http.StatusUnauthorized, "invalid admin session")
return
}
next.ServeHTTP(w, r)
})
}
func (s *Server) corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := strings.TrimSpace(r.Header.Get("Origin"))
if origin != "" && s.config.CORSOrigin != "" && origin != s.config.CORSOrigin {
writeErr(w, http.StatusForbidden, "cors origin denied")
return
}
if s.config.CORSOrigin != "" {
w.Header().Set("Access-Control-Allow-Origin", s.config.CORSOrigin)
w.Header().Set("Vary", "Origin")
}
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
w.Header().Set("Access-Control-Allow-Methods", "GET,POST,DELETE,OPTIONS,HEAD")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}
func (s *Server) hostGuardMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.TrimSpace(s.config.AllowedHost) == "" {
next.ServeHTTP(w, r)
return
}
host := strings.ToLower(hostOnly(r.Host))
if host == strings.ToLower(s.config.AllowedHost) {
next.ServeHTTP(w, r)
return
}
if strings.HasPrefix(r.URL.Path, "/api/") {
writeErr(w, http.StatusForbidden, fmt.Sprintf("access denied: host must be %s", s.config.AllowedHost))
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(fmt.Sprintf("<!doctype html><html><body><h1>Access denied</h1><p>This service is only available at <b>%s</b>.</p></body></html>", s.config.AllowedHost)))
})
}
type LocalStorage struct {
root string
}
func (l *LocalStorage) userRoot(userID int64) string {
return filepath.Join(l.root, strconv.FormatInt(userID, 10))
}
func (l *LocalStorage) fullPath(userID int64, rel string) (string, error) {
root := l.userRoot(userID)
if err := os.MkdirAll(root, 0o755); err != nil {
return "", err
}
clean := filepath.FromSlash(strings.TrimPrefix(normalizePath(rel), "/"))
full := filepath.Clean(filepath.Join(root, clean))
if full != root && !strings.HasPrefix(full, root+string(os.PathSeparator)) {
return "", fmt.Errorf("invalid path")
}
return full, nil
}
func (l *LocalStorage) List(userID int64, rel string) ([]FileEntry, error) {
full, err := l.fullPath(userID, rel)
if err != nil {
return nil, err
}
entries, err := os.ReadDir(full)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return []FileEntry{}, nil
}
return nil, err
}
out := make([]FileEntry, 0, len(entries))
for _, e := range entries {
info, err := e.Info()
if err != nil {
continue
}
out = append(out, FileEntry{
Name: e.Name(),
Path: path.Join(normalizePath(rel), e.Name()),
IsDir: e.IsDir(),
Size: info.Size(),
ModTime: info.ModTime(),
})
}
return out, nil
}
func (l *LocalStorage) Mkdir(userID int64, rel string) error {
full, err := l.fullPath(userID, rel)
if err != nil {
return err
}
return os.MkdirAll(full, 0o755)
}
func (l *LocalStorage) Save(userID int64, rel string, src multipart.File) error {
full, err := l.fullPath(userID, rel)
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil {
return err
}
dst, err := os.Create(full)
if err != nil {
return err
}
defer dst.Close()
_, err = io.Copy(dst, src)
return err
}
func (l *LocalStorage) SaveBytes(userID int64, rel string, data []byte) error {
full, err := l.fullPath(userID, rel)
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil {
return err
}
return os.WriteFile(full, data, 0o644)
}
func (l *LocalStorage) Delete(userID int64, rel string) error {
if normalizePath(rel) == "/" {
full, err := l.fullPath(userID, rel)
if err != nil {
return err
}
return os.RemoveAll(full)
}
full, err := l.fullPath(userID, rel)
if err != nil {
return err
}
return os.RemoveAll(full)
}
func (l *LocalStorage) Stat(userID int64, rel string) (FileMeta, error) {
full, err := l.fullPath(userID, rel)
if err != nil {
return FileMeta{}, err
}
st, err := os.Stat(full)
if err != nil {
return FileMeta{}, err
}
return FileMeta{Name: st.Name(), Size: st.Size(), ModTime: st.ModTime(), IsDir: st.IsDir()}, nil
}
func (l *LocalStorage) OpenReadSeeker(userID int64, rel string) (ReadSeekCloser, error) {
full, err := l.fullPath(userID, rel)
if err != nil {
return nil, err
}
return os.Open(full)
}
func normalizePath(rel string) string {
clean := path.Clean("/" + strings.TrimSpace(rel))
if clean == "." {
return "/"
}
return clean
}
func normalizeUploadRelativePath(v string) (string, error) {
v = strings.ReplaceAll(strings.TrimSpace(v), "\\", "/")
v = strings.TrimPrefix(v, "/")
v = path.Clean(v)
if v == "." || v == "" {
return "", fmt.Errorf("invalid upload path")
}
if strings.HasPrefix(v, "../") || strings.Contains(v, "/../") || strings.HasPrefix(v, "/") {
return "", fmt.Errorf("invalid upload path")
}
return v, nil
}
func normalizeTheme(v string) string {
v = strings.TrimSpace(strings.ToLower(v))
switch v {
case "dracula", "nord", "monokai", "solarized", "github":
return v
case "material", "glass", "desktop", "auto":
return "dracula"
default:
return "dracula"
}
}
func normalizeColorMode(v string) string {
v = strings.TrimSpace(strings.ToLower(v))
switch v {
case "light", "dark", "auto":
return v
default:
return "auto"
}
}
func normalizeArchiveFormat(v string) string {
v = strings.TrimSpace(strings.ToLower(v))
switch v {
case "zip", "rar", "tar.gz", "lz4":
return v
default:
return ""
}
}
func normalizeTag(v string) (string, bool) {
v = strings.ToLower(strings.TrimSpace(v))
if len(v) < 1 || len(v) > 24 {
return "", false
}
for _, ch := range v {
if (ch < 'a' || ch > 'z') && (ch < '0' || ch > '9') && ch != '-' && ch != '_' {
return "", false
}
}
return v, true
}
func isMarkdownExtension(ext string) bool {
ext = strings.ToLower(strings.TrimPrefix(ext, "."))
switch ext {
case "md", "markdown":
return true
default:
return false
}
}
func (s *Server) fileTagsForPaths(uid int64, paths []string) (map[string][]string, error) {
out := make(map[string][]string, len(paths))
if len(paths) == 0 {
return out, nil
}
args := make([]any, 0, len(paths)+1)
args = append(args, uid)
pl := make([]string, len(paths))
for i, p := range paths {
norm := normalizePath(p)
pl[i] = "?"
args = append(args, norm)
}
q := `SELECT rel_path, tag FROM file_tags WHERE user_id = ? AND rel_path IN (` + strings.Join(pl, ",") + `) ORDER BY rel_path, tag`
rows, err := s.db.Query(q, args...)
if err != nil {
return out, err
}
defer rows.Close()
for rows.Next() {
var rel string
var tag string
if rows.Scan(&rel, &tag) == nil {
out[rel] = append(out[rel], tag)
}
}
return out, nil
}
func setCookie(w http.ResponseWriter, name, value string, maxAge int, secure bool) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: value,
Path: "/",
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
MaxAge: maxAge,
})
}
func clearCookie(w http.ResponseWriter, name string, secure bool) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: "",
Path: "/",
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
MaxAge: -1,
})
}
func randomToken() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func hashToken(token string) string {
h := sha256.Sum256([]byte(token))
return base64.RawURLEncoding.EncodeToString(h[:])
}
func subtleConstantTimeEq(a, b string) int {
if len(a) != len(b) {
return 0
}
var out byte
for i := 0; i < len(a); i++ {
out |= a[i] ^ b[i]
}
if out == 0 {
return 1
}
return 0
}
func maybeRunHashCommand() bool {
if len(os.Args) < 2 {
return false
}
if os.Args[1] != "hash-admin" {
return false
}
if len(os.Args) < 3 {
fmt.Println("usage: go run . hash-admin <password>")
os.Exit(1)
}
h, err := hashPasswordArgon2ID(os.Args[2])
if err != nil {
fmt.Println("failed to hash password:", err)
os.Exit(1)
}
fmt.Println(h)
return true
}
func hashPasswordArgon2ID(password string) (string, error) {
salt := make([]byte, 16)
if _, err := rand.Read(salt); err != nil {
return "", err
}
timeCost := uint32(3)
memoryCost := uint32(64 * 1024)
threads := uint8(2)
keyLen := uint32(32)
derived := argon2.IDKey([]byte(password), salt, timeCost, memoryCost, threads, keyLen)
saltB64 := base64.RawStdEncoding.EncodeToString(salt)
hashB64 := base64.RawStdEncoding.EncodeToString(derived)
encoded := fmt.Sprintf("argon2id:v=19,m=%d,t=%d,p=%d:%s:%s", memoryCost, timeCost, threads, saltB64, hashB64)
return encoded, nil
}
func verifyPasswordHash(storedHash, password string) bool {
if strings.HasPrefix(storedHash, "argon2id:") {
parts := strings.Split(storedHash, ":")
if len(parts) != 4 {
return false
}
var mem, timeCost uint32
var threads uint8
if _, err := fmt.Sscanf(parts[1], "v=19,m=%d,t=%d,p=%d", &mem, &timeCost, &threads); err != nil {
return false
}
salt, err := base64.RawStdEncoding.DecodeString(parts[2])
if err != nil {
return false
}
expected, err := base64.RawStdEncoding.DecodeString(parts[3])
if err != nil {
return false
}
actual := argon2.IDKey([]byte(password), salt, timeCost, mem, threads, uint32(len(expected)))
return subtle.ConstantTimeCompare(expected, actual) == 1
}
if strings.HasPrefix(storedHash, "argon2id$") {
parts := strings.Split(storedHash, "$")
if len(parts) == 6 {
var mem, timeCost uint32
var threads uint8
if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &mem, &timeCost, &threads); err == nil {
salt, errSalt := base64.RawStdEncoding.DecodeString(parts[4])
expected, errExpected := base64.RawStdEncoding.DecodeString(parts[5])
if errSalt == nil && errExpected == nil {
actual := argon2.IDKey([]byte(password), salt, timeCost, mem, threads, uint32(len(expected)))
return subtle.ConstantTimeCompare(expected, actual) == 1
}
}
}
}
if strings.HasPrefix(storedHash, "sha256:") {
expected := strings.TrimPrefix(storedHash, "sha256:")
sum := sha256.Sum256([]byte(password))
actual := hex.EncodeToString(sum[:])
return subtle.ConstantTimeCompare([]byte(expected), []byte(actual)) == 1
}
bcryptHash := storedHash
if strings.HasPrefix(storedHash, "bcrypt:") {
bcryptHash = strings.TrimPrefix(storedHash, "bcrypt:")
}
if strings.HasPrefix(bcryptHash, "$2a$") || strings.HasPrefix(bcryptHash, "$2b$") || strings.HasPrefix(bcryptHash, "$2y$") {
return bcrypt.CompareHashAndPassword([]byte(bcryptHash), []byte(password)) == nil
}
return false
}
func verifyAdminPasswordHash(storedHash, password string) bool {
return verifyPasswordHash(storedHash, password)
}
func userIDFromContext(ctx context.Context) int64 {
v, _ := ctx.Value(userIDKey).(int64)
return v
}
func schemeOf(r *http.Request) string {
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
return proto
}
if r.TLS != nil {
return "https"
}
return "http"
}
func hostOnly(hostport string) string {
h := strings.TrimSpace(hostport)
if h == "" {
return ""
}
if strings.Contains(h, ":") {
if host, _, err := net.SplitHostPort(h); err == nil {
return host
}
}
return h
}
func clientIP(r *http.Request) string {
for _, h := range []string{"CF-Connecting-IP", "X-Forwarded-For", "X-Real-IP"} {
v := strings.TrimSpace(r.Header.Get(h))
if v != "" {
if strings.Contains(v, ",") {
return strings.TrimSpace(strings.Split(v, ",")[0])
}
return v
}
}
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}
func writeJSON(w http.ResponseWriter, status int, payload any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(payload)
}
func writeErr(w http.ResponseWriter, status int, msg string) {
writeJSON(w, status, map[string]string{"error": msg})
}
func limitString(s string, max int) string {
if len(s) <= max {
return s
}
return s[:max]
}