123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- package main
- import (
- "fmt"
- "os"
- "path/filepath"
- "strings"
- "time"
- "tryon/gen"
- "tryon/oss"
- "github.com/jessevdk/go-flags"
- )
- var GAppOption = &AppOption{}
- // short只能是一个字符
- type AppOption struct {
- Model string `short:"m" long:"model" description:"model image"`
- Scale float64 `short:"c" long:"scale" description:"scale [2.0-8.0] default 5.0"`
- Shoes []string `short:"s" long:"shoes" description:"shoes images"`
- }
- func (o *AppOption) Parse() error {
- _, err := flags.NewParser(o, flags.Default|flags.IgnoreUnknown).Parse()
- return err
- }
- const (
- LOCAL_MODEL_DIR = "models"
- LOCAL_SHOES_DIR = "shoes"
- OUTPUT_DIR = "output"
- )
- func main() {
- os.Mkdir(OUTPUT_DIR, os.ModePerm)
- // 读取命令行参数 获取图片
- // go run . -m xxx.png -s xx.png -s xx.png
- err := GAppOption.Parse()
- if err != nil {
- fmt.Println(err)
- return
- }
- // 适配/设置默认scale
- if GAppOption.Scale == 0 {
- GAppOption.Scale = 5.01
- }
- if GAppOption.Scale < 2 {
- GAppOption.Scale = 2.01
- }
- if GAppOption.Scale > 8 {
- GAppOption.Scale = 7.99
- }
- // 获取shoes目录中图片
- if len(GAppOption.Shoes) == 0 {
- files, err := getAllFilesInDir(LOCAL_SHOES_DIR)
- if err != nil {
- fmt.Println("获取shoes文件失败: ", err)
- return
- }
- GAppOption.Shoes = append(GAppOption.Shoes, files...)
- }
- if len(GAppOption.Model) == 0 || len(GAppOption.Shoes) == 0 {
- fmt.Println("缺失图片参数!")
- return
- }
- // 根据参数拼接图片所在当前完整路径
- modelLocalImage := fmt.Sprintf("%s/%s", LOCAL_MODEL_DIR, GAppOption.Model)
- shoesLocalImages := []string{}
- fmt.Println(len(shoesLocalImages))
- for _, shoe := range GAppOption.Shoes {
- shoesLocalImages = append(shoesLocalImages, fmt.Sprintf("%s/%s", LOCAL_SHOES_DIR, shoe))
- }
- ossReslut, err := oss.UpladImages(modelLocalImage, shoesLocalImages)
- if err != nil {
- fmt.Println(err)
- return
- }
- fmt.Println("上传结果: ", ossReslut)
- if len(ossReslut.Model) == 0 || len(ossReslut.Shoes) == 0 {
- fmt.Println("上传图片出错")
- return
- }
- taskId, err := gen.Generate(ossReslut.Model, ossReslut.Shoes, GAppOption.Scale)
- fmt.Println("taskId: ", taskId)
- if err != nil {
- fmt.Println(err)
- return
- }
- // 延迟
- fmt.Println("等待处理。。。")
- time.Sleep(3 * time.Second)
- // 这里成功返回url,失败返回""/status
- url, err := gen.GetReslut(taskId, OUTPUT_DIR)
- fmt.Println("url: ", url)
- if err != nil {
- fmt.Println(err)
- if len(url) == 0 {
- return
- }
- if url == "FAILED" {
- fmt.Println("后台处理失败status[FAILED]")
- return
- }
- // 重新查询
- replay := 10
- for {
- if replay == 0 {
- fmt.Printf("status: %s,任务处理中,请多等待...", url)
- return
- }
- time.Sleep(5 * time.Second)
- url, err = gen.GetReslut(taskId, OUTPUT_DIR)
- if err != nil {
- // 重试逻辑
- if len(url) > 0 {
- replay--
- } else {
- // 异常逻辑
- fmt.Println(err)
- return
- }
- } else {
- // 成功跳出
- break
- }
- }
- }
- fmt.Println("url---> : ", url)
- dist, err := gen.Download(url, OUTPUT_DIR)
- if err != nil {
- fmt.Println(err)
- return
- }
- fmt.Println("获取结果: ", dist)
- }
- // 获取某个目录下的所有文件路径,递归遍历,并生成相对路径
- func getAllFilesInDir(baseDir string) ([]string, error) {
- var files []string
- // filepath.Walk 会递归遍历目录及子目录
- err := filepath.Walk(baseDir, func(path string, info os.FileInfo, err error) error {
- // 遇到错误直接返回
- if err != nil {
- return err
- }
- // 跳过目录,只处理文件
- if !info.IsDir() {
- // 计算相对路径
- relPath, err := filepath.Rel(baseDir, path)
- if err != nil {
- return err
- }
- // 将路径中的反斜杠替换为正斜杠
- relPath = strings.ReplaceAll(relPath, "\\", "/")
- // 将相对路径加入文件列表
- files = append(files, relPath)
- }
- return nil
- })
- if err != nil {
- return nil, err
- }
- return files, nil
- }
|