diff --git a/webdav/memfs.go b/webdav/memfs.go index c1751ea..837fd9d 100644 --- a/webdav/memfs.go +++ b/webdav/memfs.go @@ -36,7 +36,13 @@ func (fs *MemFS) findNode(path string) (*memNode, error) { current := fs.root parts := strings.Split(path, "/") for _, part := range parts { - if part == "" { + if part == "" || part == "." { + continue + } + if part == ".." { + if current.parent != nil { + current = current.parent + } continue } if current.children == nil { @@ -105,6 +111,7 @@ func (fs *MemFS) OpenFile(ctx context.Context, name string, flag int, perm os.Fi if flag&os.O_TRUNC != 0 { node.data = nil + node.size = 0 } return &memFile{ @@ -234,14 +241,21 @@ func (f *memFile) Write(p []byte) (n int, err error) { func (f *memFile) Seek(offset int64, whence int) (int64, error) { f.fs.mu.Lock() defer f.fs.mu.Unlock() + var newOffset int64 switch whence { - case 0: - f.offset = offset - case 1: - f.offset += offset - case 2: - f.offset = int64(len(f.node.data)) + offset + case io.SeekStart: + newOffset = offset + case io.SeekCurrent: + newOffset = f.offset + offset + case io.SeekEnd: + newOffset = f.node.size + offset + default: + return 0, os.ErrInvalid } + if newOffset < 0 { + return 0, os.ErrInvalid + } + f.offset = newOffset return f.offset, nil } diff --git a/webdav/memlock.go b/webdav/memlock.go index 7c1074f..dabdd71 100644 --- a/webdav/memlock.go +++ b/webdav/memlock.go @@ -17,6 +17,7 @@ import ( type MemLock struct { mu sync.RWMutex locks map[string]*lock + stop chan struct{} } type lock struct { @@ -28,8 +29,35 @@ type lock struct { // NewMemLock creates a new in-memory lock system. func NewMemLock() *MemLock { - return &MemLock{ + l := &MemLock{ locks: make(map[string]*lock), + stop: make(chan struct{}), + } + go l.cleanup() + return l +} + +// Close stops the cleanup goroutine. +func (l *MemLock) Close() { + close(l.stop) +} + +func (l *MemLock) cleanup() { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + for { + select { + case <-ticker.C: + l.mu.Lock() + for token, lock := range l.locks { + if time.Now().After(lock.expires) { + delete(l.locks, token) + } + } + l.mu.Unlock() + case <-l.stop: + return + } } } @@ -39,7 +67,9 @@ func (l *MemLock) Create(ctx context.Context, path string, info LockInfo) (strin defer l.mu.Unlock() token := make([]byte, 16) - rand.Read(token) + if _, err := rand.Read(token); err != nil { + return "", err + } tokenStr := hex.EncodeToString(token) l.locks[tokenStr] = &lock{ diff --git a/webdav/osfs.go b/webdav/osfs.go index 6a68108..cf4c62e 100644 --- a/webdav/osfs.go +++ b/webdav/osfs.go @@ -26,10 +26,37 @@ func NewOSFS(rootDir string) (*OSFS, error) { } func (fs *OSFS) resolve(name string) (string, error) { + if filepath.IsAbs(name) || strings.Contains(name, "..") { + return "", os.ErrPermission + } + path := filepath.Join(fs.RootDir, name) + + // Evaluate symlinks, but only if the path exists. + if _, err := os.Lstat(path); err == nil { + path, err = filepath.EvalSymlinks(path) + if err != nil { + return "", err + } + } else if !os.IsNotExist(err) { + return "", err + // For non-existent paths (like for PUT or MKCOL), we can't EvalSymlinks the full path. + // Instead, we resolve the parent and ensure it's within the root. + } else { + parentDir := filepath.Dir(path) + if _, err := os.Stat(parentDir); err == nil { + parentDir, err = filepath.EvalSymlinks(parentDir) + if err != nil { + return "", err + } + path = filepath.Join(parentDir, filepath.Base(path)) + } + } + if !strings.HasPrefix(path, fs.RootDir) { return "", os.ErrPermission } + return path, nil } diff --git a/webdav/webdav.go b/webdav/webdav.go index 07accf7..2bad373 100644 --- a/webdav/webdav.go +++ b/webdav/webdav.go @@ -284,7 +284,39 @@ func (h *Handler) handleGetHead(c *touka.Context) { func (h *Handler) handleDelete(c *touka.Context) { path, _ := c.Get("webdav_path") - if err := h.FileSystem.RemoveAll(c.Context(), path.(string)); err != nil { + pathStr := path.(string) + + info, err := h.FileSystem.Stat(c.Context(), pathStr) + if err != nil { + if os.IsNotExist(err) { + c.Status(http.StatusNotFound) + } else { + c.Status(http.StatusInternalServerError) + } + return + } + + if info.IsDir() { + file, err := h.FileSystem.OpenFile(c.Context(), pathStr, os.O_RDONLY, 0) + if err != nil { + c.Status(http.StatusInternalServerError) + return + } + defer file.Close() + + // Check if the directory has any children. Readdir(1) is enough. + children, err := file.Readdir(1) + if err != nil && err != io.EOF { + c.Status(http.StatusInternalServerError) + return + } + if len(children) > 0 { + c.Status(http.StatusConflict) // 409 Conflict for non-empty collection + return + } + } + + if err := h.FileSystem.RemoveAll(c.Context(), pathStr); err != nil { if os.IsNotExist(err) { c.Status(http.StatusNotFound) } else { @@ -347,11 +379,13 @@ func (h *Handler) handleCopy(c *touka.Context) { overwrite = "T" // Default is to overwrite } - if overwrite == "F" { - if _, err := h.FileSystem.Stat(c.Context(), destPath); err == nil { - c.Status(http.StatusPreconditionFailed) - return - } + // Check for existence before the operation to determine status code later. + _, err = h.FileSystem.Stat(c.Context(), destPath) + existed := err == nil + + if overwrite == "F" && existed { + c.Status(http.StatusPreconditionFailed) + return } if err := h.copy(c.Context(), srcPath.(string), destPath); err != nil { @@ -359,7 +393,11 @@ func (h *Handler) handleCopy(c *touka.Context) { return } - c.Status(http.StatusCreated) + if existed { + c.Status(http.StatusNoContent) + } else { + c.Status(http.StatusCreated) + } } func (h *Handler) handleMove(c *touka.Context) { @@ -382,11 +420,13 @@ func (h *Handler) handleMove(c *touka.Context) { overwrite = "T" // Default is to overwrite } - if overwrite == "F" { - if _, err := h.FileSystem.Stat(c.Context(), destPath); err == nil { - c.Status(http.StatusPreconditionFailed) - return - } + // Check for existence before the operation to determine status code later. + _, err = h.FileSystem.Stat(c.Context(), destPath) + existed := err == nil + + if overwrite == "F" && existed { + c.Status(http.StatusPreconditionFailed) + return } if err := h.FileSystem.Rename(c.Context(), srcPath.(string), destPath); err != nil { @@ -394,7 +434,11 @@ func (h *Handler) handleMove(c *touka.Context) { return } - c.Status(http.StatusCreated) + if existed { + c.Status(http.StatusNoContent) + } else { + c.Status(http.StatusCreated) + } } func (h *Handler) copy(ctx context.Context, src, dest string) error { @@ -585,11 +629,11 @@ func (h *Handler) handleProppatch(c *touka.Context) { c.Status(http.StatusNotImplemented) } -func (h *Handler) stripPrefix(path string) string { +func (h *Handler) stripPrefix(p string) string { if h.Prefix == "/" { - return path + return p } - return "/" + strings.TrimPrefix(path, h.Prefix) + return strings.TrimPrefix(p, h.Prefix) } func (h *Handler) handleLock(c *touka.Context) { @@ -599,7 +643,15 @@ func (h *Handler) handleLock(c *touka.Context) { } path, _ := c.Get("webdav_path") - token := c.GetReqHeader("If") + tokenHeader := c.GetReqHeader("If") + var token string + if tokenHeader != "" { + // Basic parsing for + if strings.HasPrefix(tokenHeader, "(<") && strings.HasSuffix(tokenHeader, ">)") { + token = strings.TrimPrefix(tokenHeader, "(<") + token = strings.TrimSuffix(token, ">)") + } + } // Refresh lock if token != "" { @@ -666,7 +718,7 @@ func parseTimeout(timeoutStr string) (time.Duration, error) { return seconds, nil } } - return 0, nil + return 0, os.ErrInvalid } func (h *Handler) handleUnlock(c *touka.Context) { @@ -675,12 +727,16 @@ func (h *Handler) handleUnlock(c *touka.Context) { return } - token := c.GetReqHeader("Lock-Token") - if token == "" { + tokenHeader := c.GetReqHeader("Lock-Token") + if tokenHeader == "" { c.Status(http.StatusBadRequest) return } + // Basic parsing for + token := strings.TrimPrefix(tokenHeader, "<") + token = strings.TrimSuffix(token, ">") + if err := h.LockSystem.Unlock(c.Context(), token); err != nil { c.Status(http.StatusConflict) return