sun-pc 2 semanas atrás
commit
780067ddae
13 arquivos alterados com 540 adições e 0 exclusões
  1. 175 0
      gen/http.go
  2. 14 0
      go.mod
  3. 15 0
      go.sum
  4. 179 0
      main.go
  5. 36 0
      model/result.go
  6. BIN
      models/model2.jpg
  7. 66 0
      oss/aliyun.go
  8. 19 0
      readme.md
  9. BIN
      shoes/shoe2.png
  10. BIN
      shoes/shoe3.png
  11. BIN
      tryon-new.zip
  12. 5 0
      tryonpy/try.bat
  13. 31 0
      tryonpy/tryOn.py

+ 175 - 0
gen/http.go

@@ -0,0 +1,175 @@
+package gen
+
+import (
+	"bytes"
+	"encoding/json"
+	"errors"
+	"fmt"
+	"io"
+	"net/http"
+	"os"
+	"time"
+	"tryon/model"
+)
+
+const (
+	API_URL           = "https://dashscope.aliyuncs.com/api/v1/services/aigc/virtualmodel/generation/"
+	DASHSCOPE_API_KEY = "sk-36a7725c51be4ffe8d3374ea534014b6"
+)
+
+type OutPutRes struct {
+	OutPut    OutPut `json:"output"`
+	RequestId string `json:"request_id"`
+}
+type OutPut struct {
+	TaskStatus string `json:"task_status"`
+	TaskId     string `json:"task_id"`
+}
+
+func Generate(model string, shoes []string, scale float64) (string, error) {
+	fmt.Println(scale)
+	// 设置请求体
+	data := map[string]interface{}{
+		"model": "shoemodel-v1",
+		"input": map[string]interface{}{
+			"template_image_url": model,
+			"shoe_image_url":     shoes,
+			"scale":              scale,
+		},
+		"parameters": map[string]interface{}{
+			"n": 1,
+		},
+	}
+	// 将请求体编码为 JSON
+	jsonData, err := json.Marshal(data)
+	if err != nil {
+		return "", nil
+	}
+
+	req, err := http.NewRequest("POST", API_URL, bytes.NewBuffer(jsonData))
+	if err != nil {
+		return "", nil
+	}
+
+	// 设置请求头
+	req.Header.Set("X-DashScope-Async", "enable")
+	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", DASHSCOPE_API_KEY))
+	req.Header.Set("Content-Type", "application/json")
+
+	// 发起 POST 请求
+	client := &http.Client{}
+	resp, err := client.Do(req)
+	if err != nil {
+		return "", err
+	}
+	defer resp.Body.Close()
+
+	// 读取响应
+	body, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return "", err
+	}
+
+	bodyObj := OutPutRes{}
+	err = json.Unmarshal(body, &bodyObj)
+	if err != nil {
+		return "", err
+	}
+	fmt.Println(bodyObj)
+	return bodyObj.OutPut.TaskId, nil
+
+}
+
+func GetReslut(taskId string, outPut string) (string, error) {
+	if len(taskId) == 0 {
+		return "", errors.New("taskId is empty")
+	}
+
+	// 获取任务 ID
+	// taskID := "6488c66d-2a0c-4134-90bb-39ab957d6cdf"
+	returl := fmt.Sprintf("https://dashscope.aliyuncs.com/api/v1/tasks/%s", taskId)
+
+	// 创建 GET 请求
+	getReq, err := http.NewRequest("GET", returl, nil)
+	if err != nil {
+		return "", err
+	}
+
+	// 设置 GET 请求头
+	getReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", DASHSCOPE_API_KEY))
+
+	// 发起 GET 请求
+	client := &http.Client{}
+	getResp, err := client.Do(getReq)
+	if err != nil {
+		return "", err
+	}
+	defer getResp.Body.Close()
+
+	// 读取响应
+	getBody, err := io.ReadAll(getResp.Body)
+	if err != nil {
+		return "", err
+	}
+	bodyObj := model.Response{}
+	err = json.Unmarshal(getBody, &bodyObj)
+	if err != nil {
+		return "", err
+	}
+	fmt.Println(bodyObj)
+	if bodyObj.Output.TaskStatus == "SUCCEEDED" {
+		// 下载到本地
+		return bodyObj.Output.Results[0].URL, nil
+	}
+	return bodyObj.Output.TaskStatus, errors.New("后台处理中。。。")
+
+	// {
+	//     "request_id": "a052b714-29d5-92c0-b981-43dc027523db",
+	//     "output": {
+	//         "task_id": "5f6316e3-ef6c-4143-abb5-81ffefdbfcd0",
+	//         "task_status": "SUCCEEDED",
+	//         "submit_time": "2024-11-04 16:27:38.824",
+	//         "scheduled_time": "2024-11-04 16:27:38.852",
+	//         "end_time": "2024-11-04 16:27:54.092",
+	//         "results": [
+	//             {
+	//                 "url": "https://dashscope-result-bj.oss-cn-beijing.aliyuncs.com/7d/08/20241104/18d9de26/2024-11-04/42b28ff3-8e67-44fa-805e-72cab9e204fc-1/res_img.png?Expires=1730795274&OSSAccessKeyId=LTAI5tQZd8AEcZX6KZV4G8qL&Signature=KlU2NsBfFKDcl9NwklzH1i2X8bk%3D"
+	//             }
+	//         ],
+	//         "task_metrics": {
+	//             "TOTAL": 1,
+	//             "SUCCEEDED": 1,
+	//             "FAILED": 0
+	//         }
+	//     },
+	//     "usage": {
+	//         "image_count": 1
+	//     }
+	// }
+}
+
+func Download(url string, outPut string) (string, error) {
+	// 创建 GET 请求
+	resp, err := http.Get(url)
+	if err != nil {
+		return "", err
+	}
+	defer resp.Body.Close() // 确保在函数返回时关闭响应体
+
+	key := time.Now().Format("20060102_150405")
+	outPutFile := fmt.Sprintf("%s/%s.png", outPut, key)
+
+	out, err := os.Create(outPutFile)
+	if err != nil {
+		return "", err
+	}
+	defer out.Close()
+
+	_, err = io.Copy(out, resp.Body)
+	if err != nil {
+		return "", err
+	}
+
+	return outPutFile, nil
+
+}

+ 14 - 0
go.mod

@@ -0,0 +1,14 @@
+module tryon
+
+go 1.22.5
+
+require (
+	github.com/aliyun/aliyun-oss-go-sdk v3.0.2+incompatible
+	github.com/jessevdk/go-flags v1.6.1
+)
+
+require (
+	golang.org/x/sys v0.21.0 // indirect
+	golang.org/x/time v0.7.0 // indirect
+	gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
+)

+ 15 - 0
go.sum

@@ -0,0 +1,15 @@
+github.com/aliyun/aliyun-oss-go-sdk v3.0.2+incompatible h1:8psS8a+wKfiLt1iVDX79F7Y6wUM49Lcha2FMXt4UM8g=
+github.com/aliyun/aliyun-oss-go-sdk v3.0.2+incompatible/go.mod h1:T/Aws4fEfogEE9v+HPhhw+CntffsBHJ8nXQCwKr0/g8=
+github.com/jessevdk/go-flags v1.6.1 h1:Cvu5U8UGrLay1rZfv/zP7iLpSHGUZ/Ou68T0iX1bBK4=
+github.com/jessevdk/go-flags v1.6.1/go.mod h1:Mk8T1hIAWpOiJiHa9rJASDK2UGWji0EuPGBnNLMooyc=
+github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
+github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
+github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
+github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
+github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
+golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
+golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
+golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=

+ 179 - 0
main.go

@@ -0,0 +1,179 @@
+package main
+
+import (
+	"fmt"
+	"os"
+	"path/filepath"
+	"strings"
+	"time"
+	"tryon/gen"
+	"tryon/oss"
+
+	"github.com/jessevdk/go-flags"
+)
+
+var GAppOption = &AppOption{}
+
+// short只能是一个字符
+type AppOption struct {
+	Model string   `short:"m" long:"model" description:"model image"`
+	Scale float64  `short:"c" long:"scale" description:"scale [2.0-8.0] default 5.0"`
+	Shoes []string `short:"s" long:"shoes" description:"shoes images"`
+}
+
+func (o *AppOption) Parse() error {
+	_, err := flags.NewParser(o, flags.Default|flags.IgnoreUnknown).Parse()
+	return err
+
+}
+
+const (
+	LOCAL_MODEL_DIR = "models"
+	LOCAL_SHOES_DIR = "shoes"
+	OUTPUT_DIR      = "output"
+)
+
+func main() {
+	os.Mkdir(OUTPUT_DIR, os.ModePerm)
+	// 读取命令行参数 获取图片
+	// go run . -m xxx.png -s xx.png -s xx.png
+	err := GAppOption.Parse()
+	if err != nil {
+		fmt.Println(err)
+		return
+	}
+
+	// 适配/设置默认scale
+	if GAppOption.Scale == 0 {
+		GAppOption.Scale = 5.01
+	}
+	if GAppOption.Scale < 2 {
+		GAppOption.Scale = 2.01
+	}
+	if GAppOption.Scale > 8 {
+		GAppOption.Scale = 7.99
+	}
+
+	// 获取shoes目录中图片
+	if len(GAppOption.Shoes) == 0 {
+		files, err := getAllFilesInDir(LOCAL_SHOES_DIR)
+		if err != nil {
+			fmt.Println("获取shoes文件失败: ", err)
+			return
+		}
+		GAppOption.Shoes = append(GAppOption.Shoes, files...)
+	}
+
+	if len(GAppOption.Model) == 0 || len(GAppOption.Shoes) == 0 {
+		fmt.Println("缺失图片参数!")
+		return
+	}
+	// 根据参数拼接图片所在当前完整路径
+	modelLocalImage := fmt.Sprintf("%s/%s", LOCAL_MODEL_DIR, GAppOption.Model)
+	shoesLocalImages := []string{}
+	fmt.Println(len(shoesLocalImages))
+	for _, shoe := range GAppOption.Shoes {
+		shoesLocalImages = append(shoesLocalImages, fmt.Sprintf("%s/%s", LOCAL_SHOES_DIR, shoe))
+	}
+	ossReslut, err := oss.UpladImages(modelLocalImage, shoesLocalImages)
+	if err != nil {
+		fmt.Println(err)
+		return
+	}
+	fmt.Println("上传结果: ", ossReslut)
+	if len(ossReslut.Model) == 0 || len(ossReslut.Shoes) == 0 {
+		fmt.Println("上传图片出错")
+		return
+	}
+	taskId, err := gen.Generate(ossReslut.Model, ossReslut.Shoes, GAppOption.Scale)
+	fmt.Println("taskId: ", taskId)
+	if err != nil {
+		fmt.Println(err)
+		return
+	}
+
+	// 延迟
+	fmt.Println("等待处理。。。")
+	time.Sleep(3 * time.Second)
+
+	// 这里成功返回url,失败返回""/status
+	url, err := gen.GetReslut(taskId, OUTPUT_DIR)
+	fmt.Println("url: ", url)
+	if err != nil {
+		fmt.Println(err)
+		if len(url) == 0 {
+			return
+		}
+		if url == "FAILED" {
+			fmt.Println("后台处理失败status[FAILED]")
+			return
+
+		}
+		// 重新查询
+		replay := 10
+		for {
+			if replay == 0 {
+				fmt.Printf("status: %s,任务处理中,请多等待...", url)
+				return
+			}
+			time.Sleep(5 * time.Second)
+			url, err = gen.GetReslut(taskId, OUTPUT_DIR)
+			if err != nil {
+				// 重试逻辑
+				if len(url) > 0 {
+					replay--
+				} else {
+					// 异常逻辑
+					fmt.Println(err)
+					return
+				}
+			} else {
+				// 成功跳出
+				break
+			}
+
+		}
+	}
+	fmt.Println("url---> : ", url)
+
+	dist, err := gen.Download(url, OUTPUT_DIR)
+	if err != nil {
+		fmt.Println(err)
+		return
+	}
+	fmt.Println("获取结果: ", dist)
+
+}
+
+// 获取某个目录下的所有文件路径,递归遍历,并生成相对路径
+func getAllFilesInDir(baseDir string) ([]string, error) {
+	var files []string
+
+	// filepath.Walk 会递归遍历目录及子目录
+	err := filepath.Walk(baseDir, func(path string, info os.FileInfo, err error) error {
+		// 遇到错误直接返回
+		if err != nil {
+			return err
+		}
+
+		// 跳过目录,只处理文件
+		if !info.IsDir() {
+			// 计算相对路径
+			relPath, err := filepath.Rel(baseDir, path)
+			if err != nil {
+				return err
+			}
+
+			// 将路径中的反斜杠替换为正斜杠
+			relPath = strings.ReplaceAll(relPath, "\\", "/")
+
+			// 将相对路径加入文件列表
+			files = append(files, relPath)
+		}
+		return nil
+	})
+	if err != nil {
+		return nil, err
+	}
+	return files, nil
+}

+ 36 - 0
model/result.go

@@ -0,0 +1,36 @@
+package model
+
+// Response 结构体对应整个JSON响应
+type Response struct {
+	RequestID string `json:"request_id"`
+	Output    Output `json:"output"`
+	Usage     Usage  `json:"usage"`
+}
+
+// Output 结构体对应输出部分的JSON数据
+type Output struct {
+	TaskID     string `json:"task_id"`
+	TaskStatus string `json:"task_status"`
+	// SubmitTime    time.Time   `json:"submit_time"`
+	// ScheduledTime time.Time   `json:"scheduled_time"`
+	// EndTime       time.Time   `json:"end_time"`
+	Results     []Result    `json:"results"`
+	TaskMetrics TaskMetrics `json:"task_metrics"`
+}
+
+// Result 结构体对应结果数组中的每个元素
+type Result struct {
+	URL string `json:"url"`
+}
+
+// TaskMetrics 结构体对应任务指标部分的JSON数据
+type TaskMetrics struct {
+	TOTAL     int `json:"TOTAL"`
+	SUCCEEDED int `json:"SUCCEEDED"`
+	FAILED    int `json:"FAILED"`
+}
+
+// Usage 结构体对应使用情况部分的JSON数据
+type Usage struct {
+	ImageCount int `json:"image_count"`
+}

BIN
models/model2.jpg


+ 66 - 0
oss/aliyun.go

@@ -0,0 +1,66 @@
+package oss
+
+import (
+	"fmt"
+	"log"
+
+	"github.com/aliyun/aliyun-oss-go-sdk/oss"
+)
+
+const (
+	ENDPOINT      = "oss-cn-beijing.aliyuncs.com"
+	ACCESS_KEY    = "LTAI5tBUvFtfWU4H3AikcmwF"
+	ACCESS_SECRET = "W7Yceh0A0RODc6bELpcML1xHZOQ32q"
+	BUCKET        = "infish-oss"
+)
+
+func CreateClient() {
+	_, err := oss.New(ENDPOINT, ACCESS_KEY, ACCESS_SECRET)
+	if err != nil {
+		panic(err)
+	}
+
+	return
+
+}
+
+type UpladImagesRes struct {
+	Model string
+	Shoes []string
+}
+
+func UpladImages(model string, shoes []string) (UpladImagesRes, error) {
+
+	cli, err := oss.New(ENDPOINT, ACCESS_KEY, ACCESS_SECRET)
+	if err != nil {
+		log.Fatalf("Error: %v", err)
+		return UpladImagesRes{}, err
+	}
+
+	// 获取存储空间
+	bucket, err := cli.Bucket(BUCKET)
+	if err != nil {
+		return UpladImagesRes{}, err
+	}
+	res := UpladImagesRes{}
+	localFileName := model
+	objectName := fmt.Sprintf("tryon/%s", model)
+	err = bucket.PutObjectFromFile(objectName, localFileName)
+	if err != nil {
+		return UpladImagesRes{}, err
+	}
+	res.Model = fmt.Sprintf("https://infish-oss.oss-cn-beijing.aliyuncs.com/%s", objectName)
+
+	for _, shoe := range shoes {
+		localFileName = shoe
+		objectName = fmt.Sprintf("tryon/%s", shoe)
+		err = bucket.PutObjectFromFile(objectName, localFileName)
+		if err != nil {
+			fmt.Println(err)
+			continue
+		}
+		res.Shoes = append(res.Shoes, fmt.Sprintf("https://infish-oss.oss-cn-beijing.aliyuncs.com/%s", objectName))
+	}
+
+	return res, nil
+}

+ 19 - 0
readme.md

@@ -0,0 +1,19 @@
+# 功能描述
+
+> 根据模型和多张鞋子图片合成图片
+
+1. 读取命令行参数 获取图片
+2. 上传图片到oss
+3. 生成结果图片下载到本地目录
+
+## 命令说明
+
+```sh
+# 默认目录 -m ./models -s ./shoes
+# 默认 scale 5.0 [2.0-8.0] 数字越大越鲜亮
+# 完整参数示例
+/tryon.exe -m model2.jpg -c 5.6 -s shoe2.png -s shoe3.png
+
+# 简化参数示例
+/tryon.exe -m model2.jpg
+```sh

BIN
shoes/shoe2.png


BIN
shoes/shoe3.png


BIN
tryon-new.zip


+ 5 - 0
tryonpy/try.bat

@@ -0,0 +1,5 @@
+
+@REM /tryon.exe -m model2.jpg -c 5.6 -s shoe2.png -s shoe3.png
+@REM 默认 scale 5.0 [2.0-8.0] 数字越大越鲜亮
+
+/tryon.exe -m model2.jpg

+ 31 - 0
tryonpy/tryOn.py

@@ -0,0 +1,31 @@
+import requests
+
+url = 'https://dashscope.aliyuncs.com/api/v1/services/aigc/virtualmodel/generation/'
+DASHSCOPE_API_KEY = "sk-36a7725c51be4ffe8d3374ea534014b6"
+headers = {
+    'X-DashScope-Async': 'enable',
+    'Authorization': f'Bearer {DASHSCOPE_API_KEY}',  # 确保你已经设置了 DASHSCOPE_API_KEY
+    'Content-Type': 'application/json'
+}
+
+data = {
+    "model": "shoemodel-v1",
+    "input": {
+        "template_image_url": "https://infish-oss.oss-cn-beijing.aliyuncs.com/test/model2.jpg",
+        "shoe_image_url": ["https://infish-oss.oss-cn-beijing.aliyuncs.com/test/shoe3.png"]
+    },
+    "parameters": {
+        "n": 1
+    }
+}
+
+task_id = "6488c66d-2a0c-4134-90bb-39ab957d6cdf"
+returl = f'https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}'
+headersret = {
+    'Authorization': f'Bearer {DASHSCOPE_API_KEY}',  # 确保你已经设置了 DASHSCOPE_API_KEY
+}
+# response = requests.post(url, headers=headers, json=data)
+# print(response.json())
+
+response = requests.get(returl, headers=headersret)
+print(response.json())