package main import ( "fmt" "os" "path/filepath" "strings" "time" "tryon/gen" "tryon/oss" "github.com/jessevdk/go-flags" ) var GAppOption = &AppOption{} // short只能是一个字符 type AppOption struct { Model string `short:"m" long:"model" description:"model image"` Scale float64 `short:"c" long:"scale" description:"scale [2.0-8.0] default 5.0"` Shoes []string `short:"s" long:"shoes" description:"shoes images"` } func (o *AppOption) Parse() error { _, err := flags.NewParser(o, flags.Default|flags.IgnoreUnknown).Parse() return err } const ( LOCAL_MODEL_DIR = "models" LOCAL_SHOES_DIR = "shoes" OUTPUT_DIR = "output" ) func main() { os.Mkdir(OUTPUT_DIR, os.ModePerm) // 读取命令行参数 获取图片 // go run . -m xxx.png -s xx.png -s xx.png err := GAppOption.Parse() if err != nil { fmt.Println(err) return } // 适配/设置默认scale if GAppOption.Scale == 0 { GAppOption.Scale = 5.01 } if GAppOption.Scale < 2 { GAppOption.Scale = 2.01 } if GAppOption.Scale > 8 { GAppOption.Scale = 7.99 } // 获取shoes目录中图片 if len(GAppOption.Shoes) == 0 { files, err := getAllFilesInDir(LOCAL_SHOES_DIR) if err != nil { fmt.Println("获取shoes文件失败: ", err) return } GAppOption.Shoes = append(GAppOption.Shoes, files...) } if len(GAppOption.Model) == 0 || len(GAppOption.Shoes) == 0 { fmt.Println("缺失图片参数!") return } // 根据参数拼接图片所在当前完整路径 modelLocalImage := fmt.Sprintf("%s/%s", LOCAL_MODEL_DIR, GAppOption.Model) shoesLocalImages := []string{} fmt.Println(len(shoesLocalImages)) for _, shoe := range GAppOption.Shoes { shoesLocalImages = append(shoesLocalImages, fmt.Sprintf("%s/%s", LOCAL_SHOES_DIR, shoe)) } ossReslut, err := oss.UpladImages(modelLocalImage, shoesLocalImages) if err != nil { fmt.Println(err) return } fmt.Println("上传结果: ", ossReslut) if len(ossReslut.Model) == 0 || len(ossReslut.Shoes) == 0 { fmt.Println("上传图片出错") return } taskId, err := gen.Generate(ossReslut.Model, ossReslut.Shoes, GAppOption.Scale) fmt.Println("taskId: ", taskId) if err != nil { fmt.Println(err) return } // 延迟 fmt.Println("等待处理。。。") time.Sleep(3 * time.Second) // 这里成功返回url,失败返回""/status url, err := gen.GetReslut(taskId, OUTPUT_DIR) fmt.Println("url: ", url) if err != nil { fmt.Println(err) if len(url) == 0 { return } if url == "FAILED" { fmt.Println("后台处理失败status[FAILED]") return } // 重新查询 replay := 10 for { if replay == 0 { fmt.Printf("status: %s,任务处理中,请多等待...", url) return } time.Sleep(5 * time.Second) url, err = gen.GetReslut(taskId, OUTPUT_DIR) if err != nil { // 重试逻辑 if len(url) > 0 { replay-- } else { // 异常逻辑 fmt.Println(err) return } } else { // 成功跳出 break } } } fmt.Println("url---> : ", url) dist, err := gen.Download(url, OUTPUT_DIR) if err != nil { fmt.Println(err) return } fmt.Println("获取结果: ", dist) } // 获取某个目录下的所有文件路径,递归遍历,并生成相对路径 func getAllFilesInDir(baseDir string) ([]string, error) { var files []string // filepath.Walk 会递归遍历目录及子目录 err := filepath.Walk(baseDir, func(path string, info os.FileInfo, err error) error { // 遇到错误直接返回 if err != nil { return err } // 跳过目录,只处理文件 if !info.IsDir() { // 计算相对路径 relPath, err := filepath.Rel(baseDir, path) if err != nil { return err } // 将路径中的反斜杠替换为正斜杠 relPath = strings.ReplaceAll(relPath, "\\", "/") // 将相对路径加入文件列表 files = append(files, relPath) } return nil }) if err != nil { return nil, err } return files, nil }