diff --git a/examples/webdav/main.go b/examples/webdav/main.go index 968d6ba..87bb00e 100644 --- a/examples/webdav/main.go +++ b/examples/webdav/main.go @@ -18,11 +18,9 @@ func main() { } // Serve the "public" directory on the "/webdav/" route. - closer, err := webdav.Serve(r, "/webdav", "public") - if err != nil { + if err := webdav.Serve(r, "/webdav", "public"); err != nil { log.Fatal(err) } - defer closer.Close() log.Println("Touka WebDAV Server starting on :8080...") if err := r.RunShutdown(":8080", 10*time.Second); err != nil { diff --git a/webdav/easy.go b/webdav/easy.go index f61fbfd..56d1eb0 100644 --- a/webdav/easy.go +++ b/webdav/easy.go @@ -5,7 +5,6 @@ package webdav import ( - "io" "log" "os" @@ -21,6 +20,10 @@ type Config struct { // Register registers a WebDAV handler on the given router. func Register(engine *touka.Engine, prefix string, cfg *Config) { + if cfg.LockSystem == nil { + cfg.LockSystem = NewMemLock() + } + handler := NewHandler(prefix, cfg.FileSystem, cfg.LockSystem, cfg.Logger) webdavMethods := []string{ @@ -30,18 +33,16 @@ func Register(engine *touka.Engine, prefix string, cfg *Config) { } // Serve serves a local directory via WebDAV. -func Serve(engine *touka.Engine, prefix string, rootDir string) (io.Closer, error) { +func Serve(engine *touka.Engine, prefix string, rootDir string) error { fs, err := NewOSFS(rootDir) if err != nil { - return nil, err + return err } - ls := NewMemLock() cfg := &Config{ FileSystem: fs, - LockSystem: ls, Logger: log.New(os.Stdout, "", 0), } Register(engine, prefix, cfg) - return ls, nil + return nil } diff --git a/webdav/easy_test.go b/webdav/easy_test.go index 2d00566..bf44441 100644 --- a/webdav/easy_test.go +++ b/webdav/easy_test.go @@ -17,7 +17,6 @@ func TestRegister(t *testing.T) { r := touka.New() cfg := &Config{ FileSystem: NewMemFS(), - LockSystem: NewMemLock(), } Register(r, "/dav", cfg) @@ -36,11 +35,9 @@ func TestServe(t *testing.T) { dir, _ := os.MkdirTemp("", "webdav") defer os.RemoveAll(dir) - closer, err := Serve(r, "/serve", dir) - if err != nil { + if err := Serve(r, "/serve", dir); err != nil { t.Fatalf("Serve failed: %v", err) } - defer closer.Close() // Check if a WebDAV method is registered req, _ := http.NewRequest("OPTIONS", "/serve/", nil) diff --git a/webdav/memfs.go b/webdav/memfs.go index e263d67..6333eb1 100644 --- a/webdav/memfs.go +++ b/webdav/memfs.go @@ -131,35 +131,16 @@ func (fs *MemFS) RemoveAll(ctx context.Context, name string) error { fs.mu.Lock() defer fs.mu.Unlock() - cleanPath := path.Clean(name) - if cleanPath == "/" { - return os.ErrInvalid - } - - dir, base := path.Split(cleanPath) + dir, base := path.Split(name) parent, err := fs.findNode(dir) if err != nil { return err } - node, exists := parent.children[base] - if !exists { + if _, exists := parent.children[base]; !exists { return os.ErrNotExist } - var recursiveDelete func(*memNode) - recursiveDelete = func(n *memNode) { - if n.isDir { - for _, child := range n.children { - recursiveDelete(child) - } - } - n.parent = nil - n.children = nil - n.data = nil - } - recursiveDelete(node) - delete(parent.children, base) return nil } @@ -259,34 +240,17 @@ func (f *memFile) Read(p []byte) (n int, err error) { func (f *memFile) Write(p []byte) (n int, err error) { f.fs.mu.Lock() defer f.fs.mu.Unlock() - - writeEnd := f.offset + int64(len(p)) - - // Grow slice if necessary - if writeEnd > int64(cap(f.node.data)) { - newCap := int64(cap(f.node.data)) * 2 - if newCap < writeEnd { - newCap = writeEnd - } - newData := make([]byte, len(f.node.data), newCap) + newSize := f.offset + int64(len(p)) + if newSize > int64(cap(f.node.data)) { + newData := make([]byte, newSize) copy(newData, f.node.data) f.node.data = newData + } else { + f.node.data = f.node.data[:newSize] } - - // Extend slice length if write goes past the end - if writeEnd > int64(len(f.node.data)) { - f.node.data = f.node.data[:writeEnd] - } - n = copy(f.node.data[f.offset:], p) f.offset += int64(n) - - // Update size only if the file has grown - if f.offset > atomic.LoadInt64(&f.node.size) { - atomic.StoreInt64(&f.node.size, f.offset) - } - f.node.modTime = time.Now() - + atomic.StoreInt64(&f.node.size, newSize) return n, nil } diff --git a/webdav/memlock.go b/webdav/memlock.go index 6b9ebd9..dabdd71 100644 --- a/webdav/memlock.go +++ b/webdav/memlock.go @@ -38,9 +38,8 @@ func NewMemLock() *MemLock { } // Close stops the cleanup goroutine. -func (l *MemLock) Close() error { +func (l *MemLock) Close() { close(l.stop) - return nil } func (l *MemLock) cleanup() { @@ -67,13 +66,6 @@ func (l *MemLock) Create(ctx context.Context, path string, info LockInfo) (strin l.mu.Lock() defer l.mu.Unlock() - // Check for conflicting locks - for _, v := range l.locks { - if v.path == path { - return "", os.ErrExist - } - } - token := make([]byte, 16) if _, err := rand.Read(token); err != nil { return "", err diff --git a/webdav/osfs.go b/webdav/osfs.go index 152927a..3d54a30 100644 --- a/webdav/osfs.go +++ b/webdav/osfs.go @@ -28,7 +28,7 @@ func NewOSFS(rootDir string) (*OSFS, error) { } func (fs *OSFS) resolve(name string) (string, error) { - if strings.Contains(name, "..") { + if filepath.IsAbs(name) || strings.Contains(name, "..") { return "", os.ErrPermission } diff --git a/webdav/webdav.go b/webdav/webdav.go index f255628..b438c24 100644 --- a/webdav/webdav.go +++ b/webdav/webdav.go @@ -588,8 +588,11 @@ func (h *Handler) handlePropfind(c *touka.Context) { } -func (h *Handler) createPropfindResponse(p string, info ObjectInfo, propfind Propfind) *Response { - fullPath := path.Join(h.Prefix, p) +func (h *Handler) createPropfindResponse(path string, info ObjectInfo, propfind Propfind) *Response { + fullPath := path + if h.Prefix != "/" { + fullPath = h.Prefix + path + } resp := &Response{ Href: []string{fullPath}, @@ -638,7 +641,10 @@ func (h *Handler) handleProppatch(c *touka.Context) { } func (h *Handler) stripPrefix(p string) string { - return strings.TrimPrefix(strings.TrimPrefix(p, h.Prefix), "/") + if h.Prefix == "/" { + return p + } + return strings.TrimPrefix(p, h.Prefix) } func (h *Handler) handleLock(c *touka.Context) {