diff --git a/handler/xibao/xibao.go b/handler/xibao/xibao.go index 638ceaf..c995211 100644 --- a/handler/xibao/xibao.go +++ b/handler/xibao/xibao.go @@ -82,7 +82,7 @@ func beiBao(msg model.Message) (reply *model.Reply) { imageMsg := message.ImageMessage{ Type: "image", Data: message.ImageMessageData{ - File: "file:///tmp/qqbot/" + fileName + ".png", + File: "file://" + filePath, }, } return &model.Reply{ @@ -100,7 +100,7 @@ func beiBaoTemp(msg model.Message) (reply *model.Reply, isTrigger bool) { imageMsg := message.ImageMessage{ Type: "image", Data: message.ImageMessageData{ - File: "file:///tmp/qqbot/" + fileName + ".png", + File: "file://" + filePath, }, } return &model.Reply{ diff --git a/service/xibao/image_gen.go b/service/xibao/image_gen.go index 4af84b5..604c2ec 100644 --- a/service/xibao/image_gen.go +++ b/service/xibao/image_gen.go @@ -8,7 +8,6 @@ import ( "git.lxtend.com/qqbot/message" "git.lxtend.com/qqbot/util" "github.com/fogleman/gg" - "github.com/google/uuid" "golang.org/x/image/font/opentype" ) @@ -29,10 +28,8 @@ func GenerateCongratulationImage(text string, inputFile, outputFile string, isGo dc.DrawImage(im, 0, 0) // 判断是否为图片 if imgUrl, ok := isImageCQ(text); ok { - fileName := uuid.New().String() + ".jpg" - filePath := "./tmp/" + fileName - - if err := util.DownloadFile(imgUrl, filePath); err != nil { + filePath, err := util.DownloadFile(imgUrl, "/tmp") + if err != nil { log.Print("无法下载图片:", err) return } diff --git a/util/url.go b/util/url.go index 73220e5..4ddb502 100644 --- a/util/url.go +++ b/util/url.go @@ -1,14 +1,41 @@ package util import ( + "bytes" "fmt" "io" "net/http" "net/url" "os" + "path" "strings" + + "github.com/google/uuid" ) +// GetImageExtension 根据文件头检测图片格式并返回对应的扩展名 +func GetImageExtension(data []byte) string { + if len(data) < 8 { + return ".jpg" // 默认返回jpg + } + + // 检查文件头(Magic Numbers) + switch { + case bytes.HasPrefix(data, []byte{0x89, 0x50, 0x4E, 0x47}): + return ".png" + case bytes.HasPrefix(data, []byte{0xFF, 0xD8, 0xFF}): + return ".jpg" + case bytes.HasPrefix(data, []byte{0x47, 0x49, 0x46}): + return ".gif" + case bytes.HasPrefix(data, []byte{0x42, 0x4D}): + return ".bmp" + case bytes.HasPrefix(data, []byte{0x52, 0x49, 0x46, 0x46}) && bytes.Contains(data[0:12], []byte("WEBP")): + return ".webp" + default: + return ".jpg" // 默认返回jpg + } +} + // isEquivalentURL 判断两个 URL 是否在规范化后相同 func IsEquivalentURL(url1, url2 string) bool { norm1 := normalizeURL(url1) @@ -36,9 +63,9 @@ func normalizeURL(rawURL string) string { return u.String() } -func DownloadFile(url string, filepath string) (err error) { +// DownloadFile 下载文件到指定目录,返回带有正确扩展名的完整文件路径 +func DownloadFile(url string, dirPath string) (filepath string, err error) { // 发送 HTTP GET 请求 - // resp, err := http.Get(url) var resp *http.Response var maxRetry = 100 var retry = 0 @@ -47,27 +74,47 @@ func DownloadFile(url string, filepath string) (err error) { retry++ } if err != nil { - return fmt.Errorf("下载失败: %v", err) + return "", fmt.Errorf("下载失败: %v", err) } defer resp.Body.Close() // 检查 HTTP 响应状态码 if resp.StatusCode != http.StatusOK { - return fmt.Errorf("请求失败,状态码: %d", resp.StatusCode) + return "", fmt.Errorf("请求失败,状态码: %d", resp.StatusCode) } + // 读取响应内容到内存中以检测文件类型 + bodyData, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("读取响应内容失败: %v", err) + } + + // 获取正确的文件扩展名 + ext := GetImageExtension(bodyData) + + // 生成随机文件名 + fileName := uuid.New().String() + ext + + // 确保目录存在 + if err := os.MkdirAll(dirPath, 0755); err != nil { + return "", fmt.Errorf("创建目录失败: %v", err) + } + + // 构建完整的文件路径 + filepath = path.Join(dirPath, fileName) + // 创建文件 out, err := os.Create(filepath) if err != nil { - return fmt.Errorf("创建文件失败: %v", err) + return "", fmt.Errorf("创建文件失败: %v", err) } defer out.Close() - // 将响应的内容复制到文件 - _, err = io.Copy(out, resp.Body) + // 将内容写入文件 + _, err = io.Copy(out, bytes.NewReader(bodyData)) if err != nil { - return fmt.Errorf("保存失败: %v", err) + return "", fmt.Errorf("保存失败: %v", err) } - return nil + return filepath, nil }