|
@@ -1,183 +1,183 @@
|
|
|
-package main
|
|
|
-
|
|
|
-import (
|
|
|
- "fmt"
|
|
|
- "os"
|
|
|
- "path/filepath"
|
|
|
- "strings"
|
|
|
- "time"
|
|
|
- "tryon/gen"
|
|
|
- "tryon/oss"
|
|
|
-
|
|
|
- "github.com/jessevdk/go-flags"
|
|
|
-)
|
|
|
-
|
|
|
-var GAppOption = &AppOption{}
|
|
|
-
|
|
|
-
|
|
|
-type AppOption struct {
|
|
|
- Model string `short:"m" long:"model" description:"model image"`
|
|
|
- Scale float64 `short:"c" long:"scale" description:"scale [2.0-8.0] default 5.0"`
|
|
|
- Shoes []string `short:"s" long:"shoes" description:"shoes images"`
|
|
|
-}
|
|
|
-
|
|
|
-func (o *AppOption) Parse() error {
|
|
|
- _, err := flags.NewParser(o, flags.Default|flags.IgnoreUnknown).Parse()
|
|
|
- return err
|
|
|
-
|
|
|
-}
|
|
|
-
|
|
|
-const (
|
|
|
- LOCAL_MODEL_DIR = "models"
|
|
|
- LOCAL_SHOES_DIR = "shoes"
|
|
|
- OUTPUT_DIR = "output"
|
|
|
-)
|
|
|
-
|
|
|
-func main() {
|
|
|
- os.Mkdir(OUTPUT_DIR, os.ModePerm)
|
|
|
-
|
|
|
-
|
|
|
- err := GAppOption.Parse()
|
|
|
- if err != nil {
|
|
|
- fmt.Println(err)
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
-
|
|
|
- if GAppOption.Scale == 0 {
|
|
|
- GAppOption.Scale = 5.01
|
|
|
- }
|
|
|
- if GAppOption.Scale < 2 {
|
|
|
- GAppOption.Scale = 2.01
|
|
|
- }
|
|
|
- if GAppOption.Scale > 8 {
|
|
|
- GAppOption.Scale = 7.99
|
|
|
- }
|
|
|
-
|
|
|
-
|
|
|
- if len(GAppOption.Shoes) == 0 {
|
|
|
- files, err := getAllFilesInDir(LOCAL_SHOES_DIR)
|
|
|
- if err != nil {
|
|
|
- fmt.Println("获取shoes文件失败: ", err)
|
|
|
- return
|
|
|
- }
|
|
|
- GAppOption.Shoes = append(GAppOption.Shoes, files...)
|
|
|
- }
|
|
|
-
|
|
|
- if len(GAppOption.Model) == 0 || len(GAppOption.Shoes) == 0 {
|
|
|
- fmt.Println("缺失图片参数!")
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- modelLocalImage := fmt.Sprintf("%s/%s", LOCAL_MODEL_DIR, GAppOption.Model)
|
|
|
- shoesLocalImages := []string{}
|
|
|
- fmt.Println(len(shoesLocalImages))
|
|
|
- for _, shoe := range GAppOption.Shoes {
|
|
|
- shoesLocalImages = append(shoesLocalImages, fmt.Sprintf("%s/%s", LOCAL_SHOES_DIR, shoe))
|
|
|
- }
|
|
|
- ossReslut, err := oss.UpladImages(modelLocalImage, shoesLocalImages)
|
|
|
- if err != nil {
|
|
|
- fmt.Println(err)
|
|
|
- return
|
|
|
- }
|
|
|
- fmt.Println("上传结果: ", ossReslut)
|
|
|
- if len(ossReslut.Model) == 0 || len(ossReslut.Shoes) == 0 {
|
|
|
- fmt.Println("上传图片出错")
|
|
|
- return
|
|
|
- }
|
|
|
- taskId, err := gen.Generate(ossReslut.Model, ossReslut.Shoes, GAppOption.Scale)
|
|
|
- fmt.Println("taskId: ", taskId)
|
|
|
- if err != nil {
|
|
|
- fmt.Println(err)
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
-
|
|
|
- fmt.Println("等待处理。。。")
|
|
|
- time.Sleep(3 * time.Second)
|
|
|
-
|
|
|
-
|
|
|
- url, err := gen.GetReslut(taskId, OUTPUT_DIR)
|
|
|
- fmt.Println("url: ", url)
|
|
|
- if err != nil {
|
|
|
- fmt.Println(err)
|
|
|
- if len(url) == 0 {
|
|
|
- return
|
|
|
- }
|
|
|
- if url == "FAILED" {
|
|
|
- fmt.Println("后台处理失败status[FAILED]")
|
|
|
- return
|
|
|
-
|
|
|
- }
|
|
|
-
|
|
|
- replay := 10
|
|
|
- for {
|
|
|
- if replay == 0 {
|
|
|
- fmt.Printf("status: %s,任务处理中,请多等待...", url)
|
|
|
- return
|
|
|
- }
|
|
|
- time.Sleep(5 * time.Second)
|
|
|
- url, err = gen.GetReslut(taskId, OUTPUT_DIR)
|
|
|
- if err != nil {
|
|
|
-
|
|
|
- if len(url) > 0 {
|
|
|
- if url == "FAILED" {
|
|
|
- fmt.Println("后台处理失败status[FAILED]")
|
|
|
- return
|
|
|
- }
|
|
|
- replay--
|
|
|
- } else {
|
|
|
-
|
|
|
- fmt.Println(err)
|
|
|
- return
|
|
|
- }
|
|
|
- } else {
|
|
|
-
|
|
|
- break
|
|
|
- }
|
|
|
-
|
|
|
- }
|
|
|
- }
|
|
|
- fmt.Println("url---> : ", url)
|
|
|
-
|
|
|
- dist, err := gen.Download(url, OUTPUT_DIR)
|
|
|
- if err != nil {
|
|
|
- fmt.Println(err)
|
|
|
- return
|
|
|
- }
|
|
|
- fmt.Println("获取结果: ", dist)
|
|
|
-
|
|
|
-}
|
|
|
-
|
|
|
-
|
|
|
-func getAllFilesInDir(baseDir string) ([]string, error) {
|
|
|
- var files []string
|
|
|
-
|
|
|
-
|
|
|
- err := filepath.Walk(baseDir, func(path string, info os.FileInfo, err error) error {
|
|
|
-
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
-
|
|
|
-
|
|
|
- if !info.IsDir() {
|
|
|
-
|
|
|
- relPath, err := filepath.Rel(baseDir, path)
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
-
|
|
|
-
|
|
|
- relPath = strings.ReplaceAll(relPath, "\\", "/")
|
|
|
-
|
|
|
-
|
|
|
- files = append(files, relPath)
|
|
|
- }
|
|
|
- return nil
|
|
|
- })
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- return files, nil
|
|
|
-}
|
|
|
+package main
|
|
|
+
|
|
|
+import (
|
|
|
+ "fmt"
|
|
|
+ "os"
|
|
|
+ "path/filepath"
|
|
|
+ "strings"
|
|
|
+ "time"
|
|
|
+ "tryon/gen"
|
|
|
+ "tryon/oss"
|
|
|
+
|
|
|
+ "github.com/jessevdk/go-flags"
|
|
|
+)
|
|
|
+
|
|
|
+var GAppOption = &AppOption{}
|
|
|
+
|
|
|
+
|
|
|
+type AppOption struct {
|
|
|
+ Model string `short:"m" long:"model" description:"model image"`
|
|
|
+ Scale float64 `short:"c" long:"scale" description:"scale [2.0-8.0] default 5.0"`
|
|
|
+ Shoes []string `short:"s" long:"shoes" description:"shoes images"`
|
|
|
+}
|
|
|
+
|
|
|
+func (o *AppOption) Parse() error {
|
|
|
+ _, err := flags.NewParser(o, flags.Default|flags.IgnoreUnknown).Parse()
|
|
|
+ return err
|
|
|
+
|
|
|
+}
|
|
|
+
|
|
|
+const (
|
|
|
+ LOCAL_MODEL_DIR = "models"
|
|
|
+ LOCAL_SHOES_DIR = "shoes"
|
|
|
+ OUTPUT_DIR = "output"
|
|
|
+)
|
|
|
+
|
|
|
+func main() {
|
|
|
+ os.Mkdir(OUTPUT_DIR, os.ModePerm)
|
|
|
+
|
|
|
+
|
|
|
+ err := GAppOption.Parse()
|
|
|
+ if err != nil {
|
|
|
+ fmt.Println(err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ if GAppOption.Scale == 0 {
|
|
|
+ GAppOption.Scale = 5.01
|
|
|
+ }
|
|
|
+ if GAppOption.Scale < 2 {
|
|
|
+ GAppOption.Scale = 2.01
|
|
|
+ }
|
|
|
+ if GAppOption.Scale > 8 {
|
|
|
+ GAppOption.Scale = 7.99
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ if len(GAppOption.Shoes) == 0 {
|
|
|
+ files, err := getAllFilesInDir(LOCAL_SHOES_DIR)
|
|
|
+ if err != nil {
|
|
|
+ fmt.Println("获取shoes文件失败: ", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ GAppOption.Shoes = append(GAppOption.Shoes, files...)
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(GAppOption.Model) == 0 || len(GAppOption.Shoes) == 0 {
|
|
|
+ fmt.Println("缺失图片参数!")
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ modelLocalImage := fmt.Sprintf("%s/%s", LOCAL_MODEL_DIR, GAppOption.Model)
|
|
|
+ shoesLocalImages := []string{}
|
|
|
+ fmt.Println(len(shoesLocalImages))
|
|
|
+ for _, shoe := range GAppOption.Shoes {
|
|
|
+ shoesLocalImages = append(shoesLocalImages, fmt.Sprintf("%s/%s", LOCAL_SHOES_DIR, shoe))
|
|
|
+ }
|
|
|
+ ossReslut, err := oss.UpladImages(modelLocalImage, shoesLocalImages)
|
|
|
+ if err != nil {
|
|
|
+ fmt.Println(err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ fmt.Println("上传结果: ", ossReslut)
|
|
|
+ if len(ossReslut.Model) == 0 || len(ossReslut.Shoes) == 0 {
|
|
|
+ fmt.Println("上传图片出错")
|
|
|
+ return
|
|
|
+ }
|
|
|
+ taskId, err := gen.Generate(ossReslut.Model, ossReslut.Shoes, GAppOption.Scale)
|
|
|
+ fmt.Println("taskId: ", taskId)
|
|
|
+ if err != nil {
|
|
|
+ fmt.Println(err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ fmt.Println("等待处理。。。")
|
|
|
+ time.Sleep(3 * time.Second)
|
|
|
+
|
|
|
+
|
|
|
+ url, err := gen.GetReslut(taskId, OUTPUT_DIR)
|
|
|
+ fmt.Println("url: ", url)
|
|
|
+ if err != nil {
|
|
|
+ fmt.Println(err)
|
|
|
+ if len(url) == 0 {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ if url == "FAILED" {
|
|
|
+ fmt.Println("后台处理失败status[FAILED]")
|
|
|
+ return
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ replay := 10
|
|
|
+ for {
|
|
|
+ if replay == 0 {
|
|
|
+ fmt.Printf("status: %s,任务处理中,请多等待...", url)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ time.Sleep(5 * time.Second)
|
|
|
+ url, err = gen.GetReslut(taskId, OUTPUT_DIR)
|
|
|
+ if err != nil {
|
|
|
+
|
|
|
+ if len(url) > 0 {
|
|
|
+ if url == "FAILED" {
|
|
|
+ fmt.Println("后台处理失败status[FAILED]")
|
|
|
+ return
|
|
|
+ }
|
|
|
+ replay--
|
|
|
+ } else {
|
|
|
+
|
|
|
+ fmt.Println(err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+
|
|
|
+ break
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+ }
|
|
|
+ fmt.Println("url---> : ", url)
|
|
|
+
|
|
|
+ dist, err := gen.Download(url, OUTPUT_DIR)
|
|
|
+ if err != nil {
|
|
|
+ fmt.Println(err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ fmt.Println("获取结果: ", dist)
|
|
|
+
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+func getAllFilesInDir(baseDir string) ([]string, error) {
|
|
|
+ var files []string
|
|
|
+
|
|
|
+
|
|
|
+ err := filepath.Walk(baseDir, func(path string, info os.FileInfo, err error) error {
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ if !info.IsDir() {
|
|
|
+
|
|
|
+ relPath, err := filepath.Rel(baseDir, path)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ relPath = strings.ReplaceAll(relPath, "\\", "/")
|
|
|
+
|
|
|
+
|
|
|
+ files = append(files, relPath)
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ })
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ return files, nil
|
|
|
+}
|