Golang实现一个批量自动化执行树莓派指令的软件(4)上传

发布于:2024-04-25 ⋅ 阅读:(14) ⋅ 点赞:(0)

简介

话接上篇 Golang实现一个批量自动化执行树莓派指令的软件(3)下载
, 继续实现上传

环境描述

运行环境: Windows, 基于Golang, 暂时没有使用什么不可跨平台接口, 理论上支持Linux/MacOS
目标终端:树莓派DebianOS(主要做用它测试)

实现

接口定义

type IUploader interface {
	/*
		Upload 下载的同步接口, 会堵塞执行
			from : 上传的路径本地路径
			to   : 保存的远程路径
	*/
	Upload(from, to string) error
	/*
		UploadWithCallback 下载的同步/异步接口
			from : 上传的路径本地路径
			to   : 保存的远程路径
			processCallback : 进度回调函数,每次上传文件的时候被调用, 返回当前上传进度信息
					from : 当前上传的路径本地路径
					to   : 当前保存的远程路径
					num : 上传的文件总数
					uploaded     : 已上传的文件数
			finishedCallback : 完成上传时调用
			background : 表示是同步执行还是异步执行
	*/
	UploadWithCallback(from, to string,
		process func(from, to string, num, uploaded uint),
		finishedCallback func(err error), background bool) error
}

接口实现

package sshutil

import (
	"fmt"
	"github.com/pkg/sftp"
	"io"
	"os"
	"path"
	"time"
)

var oneTimeMaxSizeToWrite = 8192 // 单次最大写文件大小

func IsDir(path string) bool {
	info, err := os.Stat(path)
	if err != nil {
		return false
	}
	return info.IsDir()
}

func IsFile(path string) bool {
	info, err := os.Stat(path)
	if err != nil {
		return false
	}
	return !info.IsDir()
}

type uploader struct {
	client       *sftp.Client
	uploadSize   uint
	uploadNumber uint
	uploaded     uint

	started  bool
	canceled chan struct{}
}

func newUploader(client *sftp.Client) (*uploader, error) {
	return &uploader{client: client, canceled: make(chan struct{})}, nil
}

func (u *uploader) Upload(from, to string) error {
	return u.upload(from, to, nil, nil)
}

func (u *uploader) UploadWithCallback(from, to string,
	process func(from, to string, num, uploaded uint),
	finishedCallback func(err error), background bool) error {
	if !background {
		return u.upload(from, to, process, finishedCallback)
	} else {
		go u.upload(from, to, process, finishedCallback)
	}
	return nil
}

func (u *uploader) Cancel() error {
	if u.started {
		select {
		case u.canceled <- struct{}{}:
		case <-time.After(time.Second * 2): // 取消时间过长,取消失败
			return fmt.Errorf("time out waiting for cancel")
		}
	}
	return nil
}

func (u *uploader) Destroy() error {
	err := u.Cancel()
	close(u.canceled)
	return err
}

func (u *uploader) uploadFolderCount(localPath string) (needUpload, size uint, err error) {
	var (
		infos    []os.DirEntry
		fileInfo os.FileInfo
		c, s     uint
	)
	infos, err = os.ReadDir(localPath)

	for _, info := range infos {
		if info.IsDir() {
			c, s, err = u.uploadFolderCount(path.Join(localPath, info.Name()))
			if nil != err {
				return
			}
			needUpload += c
			size += s
			continue
		}
		needUpload += 1
		fileInfo, _ = info.Info()
		size += uint(fileInfo.Size())
	}
	err = nil
	return
}

func (u *uploader) uploadFileCount(localpath string) (uint, uint, error) {
	var (
		isExist bool
		isDir   bool
	)
	info, err := os.Stat(localpath)
	if err != nil {
		isExist = !os.IsNotExist(err)
		isDir = false
	} else {
		isExist = true
		isDir = info.IsDir()
	}

	if !isExist {
		return 0, 0, nil
	}
	if !isDir {
		return 1, uint(info.Size()), nil
	}

	return u.uploadFolderCount(localpath)
}

func (u *uploader) upload(localPath, remotePath string,
	process func(from, to string, num, uploaded uint),
	finishedCallback func(err error)) (err error) {

	whenErrorCall := func(e error) error {
		if nil != finishedCallback {
			go finishedCallback(e)
		}
		return e
	}

	u.started = true
	defer func() {
		u.started = false
	}()

	u.uploadNumber, u.uploadSize, err = u.uploadFileCount(localPath)
	if nil != err {
		return whenErrorCall(err)
	}

	var isDir = IsDir(localPath)
	if isDir {
		return u.uploadFolder(localPath, remotePath, process, finishedCallback)
	}

	return u.uploadFile(localPath, remotePath, process, finishedCallback)
}

func (u *uploader) writeFile(reader io.Reader, writer io.Writer) (err error) {
	var buffer = make([]byte, oneTimeMaxSizeToWrite)
	var n int
	for {
		n, err = reader.Read(buffer)
		if n < oneTimeMaxSizeToWrite {
			if io.EOF == err {
				err = nil
				if n > 0 {
					_, err = writer.Write(buffer[0:n])
					if err != nil {
						return err
					}
				}
				break
			}
		}
		_, err = writer.Write(buffer)
		if err != nil {
			return err
		}
	}
	return nil
}

func (u *uploader) uploadFile(localPath, remotePath string,
	process func(from, to string, num, uploaded uint),
	finishedCallback func(err error)) error {
	whenErrorCall := func(e error) error {
		if nil != finishedCallback {
			go finishedCallback(e)
		}
		return e
	}

	var (
		srcFile        *os.File
		dstFile        *sftp.File
		remoteFileName string
		err            error
	)
	srcFile, err = os.Open(localPath)
	if err != nil {
		return whenErrorCall(err)
	}
	defer srcFile.Close()

	remoteFileName = path.Join(remotePath, path.Base(localPath))
	dstFile, err = u.client.Create(remoteFileName)
	if err != nil {
		return whenErrorCall(err)
	}
	defer dstFile.Close()

	err = u.writeFile(srcFile, dstFile)
	if nil != err {
		return whenErrorCall(err)
	}

	u.uploaded += 1
	if nil != process {
		go process(localPath, remoteFileName, u.uploadNumber, u.uploaded)
	}
	return whenErrorCall(err)
}

func (u *uploader) uploadFolder(localPath, remotePath string,
	process func(from, to string, num, uploaded uint),
	finishedCallback func(err error)) (err error) {

	whenErrorCall := func(e error) error {
		if nil != finishedCallback {
			go finishedCallback(e)
		}
		return e
	}

	err = u.client.MkdirAll(remotePath)
	if nil != err {
		return whenErrorCall(err)
	}

	localFileInfos, err := os.ReadDir(localPath)
	if err != nil {
		return whenErrorCall(err)
	}

	for _, fileInfo := range localFileInfos {
		localFilePath := path.Join(localPath, fileInfo.Name())

		select {
		case <-u.canceled:
			return whenErrorCall(fmt.Errorf("user canceled"))
		default:
		}

		if fileInfo.IsDir() {
			remoteFilePath := path.Join(remotePath, fileInfo.Name())

			err = u.uploadFolder(localFilePath, remoteFilePath, process, nil)
			if nil != err {
				return whenErrorCall(err)
			}
		} else {
			err = u.uploadFile(localFilePath, remotePath, process, nil)
			if nil != err {
				return whenErrorCall(err)
			}
		}
	}

	return whenErrorCall(err)
}

测试用例

package sshutil

import (
	"fmt"
	"github.com/pkg/sftp"
	"golang.org/x/crypto/ssh"
	"sync"
	"testing"
	"time"
)

type uploaderTest struct {
	sshClient  *ssh.Client
	sftpClient *sftp.Client

	uploader *uploader
}

func newUploaderTest() (*uploaderTest, error) {
	var (
		err   error
		dTest = &uploaderTest{}
	)
	config := ssh.ClientConfig{
		User:            "pi",                                      // 用户名
		Auth:            []ssh.AuthMethod{ssh.Password("a123456")}, // 密码
		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
		Timeout:         10 * time.Second,
	}
	dTest.sshClient, err = ssh.Dial("tcp", "192.168.3.2:22", &config) //IP + 端口
	if err != nil {
		fmt.Print(err)
		return nil, err
	}
	if dTest.sftpClient, err = sftp.NewClient(dTest.sshClient); err != nil {
		dTest.destroy()
		return nil, err
	}

	dTest.uploader, err = newUploader(dTest.sftpClient)

	return dTest, err
}

func (d *uploaderTest) destroy() {
	if nil != d.sftpClient {
		d.sftpClient.Close()
		d.sftpClient = nil
	}

	if nil != d.sshClient {
		d.sshClient.Close()
		d.sshClient = nil
	}
}

func TestUploader_Upload(t *testing.T) {
	var dTest, err = newUploaderTest()
	if nil != err {
		fmt.Println("fail to new uploader test!")
		return
	}
	defer dTest.destroy()

	err = dTest.uploader.Upload("./download", "/home/pi/upload/")
	if nil != err {
		fmt.Println(err)
	}
}

func TestUploader_UploadWithCallback(t *testing.T) {
	var dTest, err = newUploaderTest()
	if nil != err {
		fmt.Println("fail to new uploader test!")
		return
	}
	defer dTest.destroy()

	err = dTest.uploader.UploadWithCallback("./download", "/home/pi/upload1/", func(from, to string, num, uploaded uint) {
		fmt.Println(from, to, num, uploaded)
	}, func(err error) {
		fmt.Println("finished!!!")
	}, false)
	if nil != err {
		fmt.Println(err)
	}
	time.Sleep(time.Second)
}

func TestUploader_UploadWithCallbackAsync(t *testing.T) {
	var waiter sync.WaitGroup
	var dTest, err = newUploaderTest()
	if nil != err {
		fmt.Println("fail to new uploader test!")
		return
	}
	defer dTest.destroy()
	waiter.Add(1)
	err = dTest.uploader.UploadWithCallback("./download", "/home/pi/upload2/", func(from, to string, num, uploaded uint) {
		fmt.Println(from, to, num, uploaded)
	}, func(err error) {
		waiter.Done()
		fmt.Println("finished!!!")
	}, true)
	if nil != err {
		fmt.Println(err)
	}
	fmt.Println("waiting finish...")
	waiter.Wait()
	fmt.Println("had finished!")
	time.Sleep(time.Second)
}

代码源

https://gitee.com/grayhsu/ssh_remote_access

其他

参考