diff --git a/webdav/memfs.go b/webdav/memfs.go index c1751ea..837fd9d 100644 --- a/webdav/memfs.go +++ b/webdav/memfs.go @@ -36,7 +36,13 @@ func (fs *MemFS) findNode(path string) (*memNode, error) { current := fs.root parts := strings.Split(path, "/") for _, part := range parts { - if part == "" { + if part == "" || part == "." { + continue + } + if part == ".." { + if current.parent != nil { + current = current.parent + } continue } if current.children == nil { @@ -105,6 +111,7 @@ func (fs *MemFS) OpenFile(ctx context.Context, name string, flag int, perm os.Fi if flag&os.O_TRUNC != 0 { node.data = nil + node.size = 0 } return &memFile{ @@ -234,14 +241,21 @@ func (f *memFile) Write(p []byte) (n int, err error) { func (f *memFile) Seek(offset int64, whence int) (int64, error) { f.fs.mu.Lock() defer f.fs.mu.Unlock() + var newOffset int64 switch whence { - case 0: - f.offset = offset - case 1: - f.offset += offset - case 2: - f.offset = int64(len(f.node.data)) + offset + case io.SeekStart: + newOffset = offset + case io.SeekCurrent: + newOffset = f.offset + offset + case io.SeekEnd: + newOffset = f.node.size + offset + default: + return 0, os.ErrInvalid } + if newOffset < 0 { + return 0, os.ErrInvalid + } + f.offset = newOffset return f.offset, nil } diff --git a/webdav/memlock.go b/webdav/memlock.go index 7c1074f..276b798 100644 --- a/webdav/memlock.go +++ b/webdav/memlock.go @@ -28,9 +28,24 @@ type lock struct { // NewMemLock creates a new in-memory lock system. func NewMemLock() *MemLock { - return &MemLock{ + l := &MemLock{ locks: make(map[string]*lock), } + go l.cleanup() + return l +} + +func (l *MemLock) cleanup() { + for { + time.Sleep(1 * time.Minute) + l.mu.Lock() + for token, lock := range l.locks { + if time.Now().After(lock.expires) { + delete(l.locks, token) + } + } + l.mu.Unlock() + } } // Create creates a new lock. @@ -39,7 +54,9 @@ func (l *MemLock) Create(ctx context.Context, path string, info LockInfo) (strin defer l.mu.Unlock() token := make([]byte, 16) - rand.Read(token) + if _, err := rand.Read(token); err != nil { + return "", err + } tokenStr := hex.EncodeToString(token) l.locks[tokenStr] = &lock{ diff --git a/webdav/osfs.go b/webdav/osfs.go index 6a68108..a4dfb4f 100644 --- a/webdav/osfs.go +++ b/webdav/osfs.go @@ -26,8 +26,16 @@ func NewOSFS(rootDir string) (*OSFS, error) { } func (fs *OSFS) resolve(name string) (string, error) { + if filepath.IsAbs(name) { + return "", os.ErrPermission + } path := filepath.Join(fs.RootDir, name) - if !strings.HasPrefix(path, fs.RootDir) { + + rel, err := filepath.Rel(fs.RootDir, path) + if err != nil { + return "", err + } + if strings.HasPrefix(rel, "..") { return "", os.ErrPermission } return path, nil diff --git a/webdav/webdav.go b/webdav/webdav.go index 07accf7..e6f4d5f 100644 --- a/webdav/webdav.go +++ b/webdav/webdav.go @@ -585,11 +585,11 @@ func (h *Handler) handleProppatch(c *touka.Context) { c.Status(http.StatusNotImplemented) } -func (h *Handler) stripPrefix(path string) string { +func (h *Handler) stripPrefix(p string) string { if h.Prefix == "/" { - return path + return p } - return "/" + strings.TrimPrefix(path, h.Prefix) + return strings.TrimPrefix(p, h.Prefix) } func (h *Handler) handleLock(c *touka.Context) { @@ -599,7 +599,15 @@ func (h *Handler) handleLock(c *touka.Context) { } path, _ := c.Get("webdav_path") - token := c.GetReqHeader("If") + tokenHeader := c.GetReqHeader("If") + var token string + if tokenHeader != "" { + // Basic parsing for + if strings.HasPrefix(tokenHeader, "(<") && strings.HasSuffix(tokenHeader, ">)") { + token = strings.TrimPrefix(tokenHeader, "(<") + token = strings.TrimSuffix(token, ">)") + } + } // Refresh lock if token != "" { @@ -666,7 +674,7 @@ func parseTimeout(timeoutStr string) (time.Duration, error) { return seconds, nil } } - return 0, nil + return 0, os.ErrInvalid } func (h *Handler) handleUnlock(c *touka.Context) { @@ -675,12 +683,16 @@ func (h *Handler) handleUnlock(c *touka.Context) { return } - token := c.GetReqHeader("Lock-Token") - if token == "" { + tokenHeader := c.GetReqHeader("Lock-Token") + if tokenHeader == "" { c.Status(http.StatusBadRequest) return } + // Basic parsing for + token := strings.TrimPrefix(tokenHeader, "<") + token = strings.TrimSuffix(token, ">") + if err := h.LockSystem.Unlock(c.Context(), token); err != nil { c.Status(http.StatusConflict) return