package sftp import ( "context" "io" "os" "path" "path/filepath" "strings" "sync" "syscall" "github.com/pkg/errors" ) // MaxFilelist is the max number of files to return in a readdir batch. var MaxFilelist int64 = 100 // Request contains the data and state for the incoming service request. type Request struct { // Get, Put, Setstat, Stat, Rename, Remove // Rmdir, Mkdir, List, Readlink, Link, Symlink Method string Filepath string Flags uint32 Attrs []byte // convert to sub-struct Target string // for renames and sym-links handle string // reader/writer/readdir from handlers state state // context lasts duration of request ctx context.Context cancelCtx context.CancelFunc } type state struct { *sync.RWMutex writerAt io.WriterAt readerAt io.ReaderAt writerReaderAt WriterAtReaderAt listerAt ListerAt lsoffset int64 } // New Request initialized based on packet data func requestFromPacket(ctx context.Context, pkt hasPath) *Request { method := requestMethod(pkt) request := NewRequest(method, pkt.getPath()) request.ctx, request.cancelCtx = context.WithCancel(ctx) switch p := pkt.(type) { case *sshFxpOpenPacket: request.Flags = p.Pflags case *sshFxpSetstatPacket: request.Flags = p.Flags request.Attrs = p.Attrs.([]byte) case *sshFxpRenamePacket: request.Target = cleanPath(p.Newpath) case *sshFxpSymlinkPacket: // NOTE: given a POSIX compliant signature: symlink(target, linkpath string) // this makes Request.Target the linkpath, and Request.Filepath the target. request.Target = cleanPath(p.Linkpath) case *sshFxpExtendedPacketHardlink: request.Target = cleanPath(p.Newpath) } return request } // NewRequest creates a new Request object. func NewRequest(method, path string) *Request { return &Request{Method: method, Filepath: cleanPath(path), state: state{RWMutex: new(sync.RWMutex)}} } // shallow copy of existing request func (r *Request) copy() *Request { r.state.Lock() defer r.state.Unlock() r2 := new(Request) *r2 = *r return r2 } // Context returns the request's context. To change the context, // use WithContext. // // The returned context is always non-nil; it defaults to the // background context. // // For incoming server requests, the context is canceled when the // request is complete or the client's connection closes. func (r *Request) Context() context.Context { if r.ctx != nil { return r.ctx } return context.Background() } // WithContext returns a copy of r with its context changed to ctx. // The provided ctx must be non-nil. func (r *Request) WithContext(ctx context.Context) *Request { if ctx == nil { panic("nil context") } r2 := r.copy() r2.ctx = ctx r2.cancelCtx = nil return r2 } // Returns current offset for file list func (r *Request) lsNext() int64 { r.state.RLock() defer r.state.RUnlock() return r.state.lsoffset } // Increases next offset func (r *Request) lsInc(offset int64) { r.state.Lock() defer r.state.Unlock() r.state.lsoffset = r.state.lsoffset + offset } // manage file read/write state func (r *Request) setListerState(la ListerAt) { r.state.Lock() defer r.state.Unlock() r.state.listerAt = la } func (r *Request) getLister() ListerAt { r.state.RLock() defer r.state.RUnlock() return r.state.listerAt } // Close reader/writer if possible func (r *Request) close() error { defer func() { if r.cancelCtx != nil { r.cancelCtx() } }() r.state.RLock() wr := r.state.writerAt rd := r.state.readerAt rw := r.state.writerReaderAt r.state.RUnlock() var err error // Close errors on a Writer are far more likely to be the important one. // As they can be information that there was a loss of data. if c, ok := wr.(io.Closer); ok { if err2 := c.Close(); err == nil { // update error if it is still nil err = err2 } } if c, ok := rw.(io.Closer); ok { if err2 := c.Close(); err == nil { // update error if it is still nil err = err2 r.state.writerReaderAt = nil } } if c, ok := rd.(io.Closer); ok { if err2 := c.Close(); err == nil { // update error if it is still nil err = err2 } } return err } // Notify transfer error if any func (r *Request) transferError(err error) { if err == nil { return } r.state.RLock() wr := r.state.writerAt rd := r.state.readerAt rw := r.state.writerReaderAt r.state.RUnlock() if t, ok := wr.(TransferError); ok { t.TransferError(err) } if t, ok := rw.(TransferError); ok { t.TransferError(err) } if t, ok := rd.(TransferError); ok { t.TransferError(err) } } // called from worker to handle packet/request func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { switch r.Method { case "Get": return fileget(handlers.FileGet, r, pkt, alloc, orderID) case "Put": return fileput(handlers.FilePut, r, pkt, alloc, orderID) case "Open": return fileputget(handlers.FilePut, r, pkt, alloc, orderID) case "Setstat", "Rename", "Rmdir", "Mkdir", "Link", "Symlink", "Remove", "PosixRename", "StatVFS": return filecmd(handlers.FileCmd, r, pkt) case "List": return filelist(handlers.FileList, r, pkt) case "Stat", "Lstat", "Readlink": return filestat(handlers.FileList, r, pkt) default: return statusFromError(pkt.id(), errors.Errorf("unexpected method: %s", r.Method)) } } // Additional initialization for Open packets func (r *Request) open(h Handlers, pkt requestPacket) responsePacket { flags := r.Pflags() id := pkt.id() switch { case flags.Write, flags.Append, flags.Creat, flags.Trunc: if flags.Read { if openFileWriter, ok := h.FilePut.(OpenFileWriter); ok { r.Method = "Open" rw, err := openFileWriter.OpenFile(r) if err != nil { return statusFromError(id, err) } r.state.writerReaderAt = rw return &sshFxpHandlePacket{ID: id, Handle: r.handle} } } r.Method = "Put" wr, err := h.FilePut.Filewrite(r) if err != nil { return statusFromError(id, err) } r.state.writerAt = wr case flags.Read: r.Method = "Get" rd, err := h.FileGet.Fileread(r) if err != nil { return statusFromError(id, err) } r.state.readerAt = rd default: return statusFromError(id, errors.New("bad file flags")) } return &sshFxpHandlePacket{ID: id, Handle: r.handle} } func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket { r.Method = "List" la, err := h.FileList.Filelist(r) if err != nil { return statusFromError(pkt.id(), wrapPathError(r.Filepath, err)) } r.state.listerAt = la return &sshFxpHandlePacket{ID: pkt.id(), Handle: r.handle} } // wrap FileReader handler func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { r.state.RLock() reader := r.state.readerAt r.state.RUnlock() if reader == nil { return statusFromError(pkt.id(), errors.New("unexpected read packet")) } data, offset, _ := packetData(pkt, alloc, orderID) n, err := reader.ReadAt(data, offset) // only return EOF error if no data left to read if err != nil && (err != io.EOF || n == 0) { return statusFromError(pkt.id(), err) } return &sshFxpDataPacket{ ID: pkt.id(), Length: uint32(n), Data: data[:n], } } // wrap FileWriter handler func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { r.state.RLock() writer := r.state.writerAt r.state.RUnlock() if writer == nil { return statusFromError(pkt.id(), errors.New("unexpected write packet")) } data, offset, _ := packetData(pkt, alloc, orderID) _, err := writer.WriteAt(data, offset) return statusFromError(pkt.id(), err) } // wrap OpenFileWriter handler func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { r.state.RLock() writerReader := r.state.writerReaderAt r.state.RUnlock() if writerReader == nil { return statusFromError(pkt.id(), errors.New("unexpected write and read packet")) } switch p := pkt.(type) { case *sshFxpReadPacket: data, offset := p.getDataSlice(alloc, orderID), int64(p.Offset) n, err := writerReader.ReadAt(data, offset) // only return EOF error if no data left to read if err != nil && (err != io.EOF || n == 0) { return statusFromError(pkt.id(), err) } return &sshFxpDataPacket{ ID: pkt.id(), Length: uint32(n), Data: data[:n], } case *sshFxpWritePacket: data, offset := p.Data, int64(p.Offset) _, err := writerReader.WriteAt(data, offset) return statusFromError(pkt.id(), err) default: return statusFromError(pkt.id(), errors.New("unexpected packet type for read or write")) } } // file data for additional read/write packets func packetData(p requestPacket, alloc *allocator, orderID uint32) (data []byte, offset int64, length uint32) { switch p := p.(type) { case *sshFxpReadPacket: return p.getDataSlice(alloc, orderID), int64(p.Offset), p.Len case *sshFxpWritePacket: return p.Data, int64(p.Offset), p.Length } return } // wrap FileCmder handler func filecmd(h FileCmder, r *Request, pkt requestPacket) responsePacket { switch p := pkt.(type) { case *sshFxpFsetstatPacket: r.Flags = p.Flags r.Attrs = p.Attrs.([]byte) } if r.Method == "PosixRename" { if posixRenamer, ok := h.(PosixRenameFileCmder); ok { err := posixRenamer.PosixRename(r) return statusFromError(pkt.id(), err) } // PosixRenameFileCmder not implemented handle this request as a Rename r.Method = "Rename" err := h.Filecmd(r) return statusFromError(pkt.id(), err) } if r.Method == "StatVFS" { if statVFSCmdr, ok := h.(StatVFSFileCmder); ok { stat, err := statVFSCmdr.StatVFS(r) if err != nil { return statusFromError(pkt.id(), err) } stat.ID = pkt.id() return stat } return statusFromError(pkt.id(), ErrSSHFxOpUnsupported) } err := h.Filecmd(r) return statusFromError(pkt.id(), err) } // wrap FileLister handler func filelist(h FileLister, r *Request, pkt requestPacket) responsePacket { var err error lister := r.getLister() if lister == nil { return statusFromError(pkt.id(), errors.New("unexpected dir packet")) } offset := r.lsNext() finfo := make([]os.FileInfo, MaxFilelist) n, err := lister.ListAt(finfo, offset) r.lsInc(int64(n)) // ignore EOF as we only return it when there are no results finfo = finfo[:n] // avoid need for nil tests below switch r.Method { case "List": if err != nil && err != io.EOF { return statusFromError(pkt.id(), err) } if err == io.EOF && n == 0 { return statusFromError(pkt.id(), io.EOF) } dirname := filepath.ToSlash(path.Base(r.Filepath)) ret := &sshFxpNamePacket{ID: pkt.id()} for _, fi := range finfo { ret.NameAttrs = append(ret.NameAttrs, &sshFxpNameAttr{ Name: fi.Name(), LongName: runLs(dirname, fi), Attrs: []interface{}{fi}, }) } return ret default: err = errors.Errorf("unexpected method: %s", r.Method) return statusFromError(pkt.id(), err) } } func filestat(h FileLister, r *Request, pkt requestPacket) responsePacket { var lister ListerAt var err error if r.Method == "Lstat" { if lstatFileLister, ok := h.(LstatFileLister); ok { lister, err = lstatFileLister.Lstat(r) } else { // LstatFileLister not implemented handle this request as a Stat r.Method = "Stat" lister, err = h.Filelist(r) } } else { lister, err = h.Filelist(r) } if err != nil { return statusFromError(pkt.id(), err) } finfo := make([]os.FileInfo, 1) n, err := lister.ListAt(finfo, 0) finfo = finfo[:n] // avoid need for nil tests below switch r.Method { case "Stat", "Lstat": if err != nil && err != io.EOF { return statusFromError(pkt.id(), err) } if n == 0 { err = &os.PathError{Op: strings.ToLower(r.Method), Path: r.Filepath, Err: syscall.ENOENT} return statusFromError(pkt.id(), err) } return &sshFxpStatResponse{ ID: pkt.id(), info: finfo[0], } case "Readlink": if err != nil && err != io.EOF { return statusFromError(pkt.id(), err) } if n == 0 { err = &os.PathError{Op: "readlink", Path: r.Filepath, Err: syscall.ENOENT} return statusFromError(pkt.id(), err) } filename := finfo[0].Name() return &sshFxpNamePacket{ ID: pkt.id(), NameAttrs: []*sshFxpNameAttr{ { Name: filename, LongName: filename, Attrs: emptyFileStat, }, }, } default: err = errors.Errorf("unexpected method: %s", r.Method) return statusFromError(pkt.id(), err) } } // init attributes of request object from packet data func requestMethod(p requestPacket) (method string) { switch p.(type) { case *sshFxpReadPacket, *sshFxpWritePacket, *sshFxpOpenPacket: // set in open() above case *sshFxpOpendirPacket, *sshFxpReaddirPacket: // set in opendir() above case *sshFxpSetstatPacket, *sshFxpFsetstatPacket: method = "Setstat" case *sshFxpRenamePacket: method = "Rename" case *sshFxpSymlinkPacket: method = "Symlink" case *sshFxpRemovePacket: method = "Remove" case *sshFxpStatPacket, *sshFxpFstatPacket: method = "Stat" case *sshFxpLstatPacket: method = "Lstat" case *sshFxpRmdirPacket: method = "Rmdir" case *sshFxpReadlinkPacket: method = "Readlink" case *sshFxpMkdirPacket: method = "Mkdir" case *sshFxpExtendedPacketHardlink: method = "Link" } return method }