diff --git a/handler/rss/job.go b/handler/rss/job.go index d1355d3..81063b5 100644 --- a/handler/rss/job.go +++ b/handler/rss/job.go @@ -47,7 +47,7 @@ func CheckNewRss() { continue } //获取最新的rss数据 - title, items, err := ParseRssFeed(feed.FeedURL) + title, items, err := ParseFeed(feed.FeedURL) if err != nil { continue } @@ -64,11 +64,12 @@ func CheckNewRss() { &qq_message.TextMessage{ Type: qq_message.TypeText, Data: qq_message.TextMessageData{ - Text: fmt.Sprintf("您订阅的%s发布了新的文章: %s", title, items[0].Title), + Text: fmt.Sprintf("您订阅的%s发布了新的文章: %s\n%s", title, items[0].Title, items[0].Link), }, }, }, }) + db.Model(&group).Update("last_item_hash", items[0].Hash) } } } diff --git a/handler/rss/parse.go b/handler/rss/parse.go index 67a20ad..4ee27f2 100644 --- a/handler/rss/parse.go +++ b/handler/rss/parse.go @@ -3,6 +3,7 @@ package rss import ( "crypto/md5" "encoding/xml" + "errors" "fmt" "io" "net/http" @@ -43,7 +44,7 @@ func CheckRssFeed(feedURL string) error { return nil } -func ParseRssFeed(feedURL string) (string, []RssItem, error) { +func ParseFeed(feedURL string) (string, []RssItem, error) { //确认大小 resp, err := http.Head(feedURL) if err != nil { @@ -81,6 +82,10 @@ func ParseRssFeed(feedURL string) (string, []RssItem, error) { return "", nil, fmt.Errorf("解析RSS/Atom数据失败: %v", err) } + if len(items) == 0 { + return title, nil, errors.New("未解析到rss信息") + } + return title, items, nil } diff --git a/handler/rss/rss.go b/handler/rss/rss.go index 55e4799..ef3bc38 100644 --- a/handler/rss/rss.go +++ b/handler/rss/rss.go @@ -27,7 +27,7 @@ func init() { func TestRss(msg model.Message) (reply *model.Reply) { rssUrl := util.SplitN(msg.StructuredMsg[0].(*qq_message.TextMessage).Data.Text, 2)[1] - title, items, err := ParseRssFeed(rssUrl) + title, items, err := ParseFeed(rssUrl) if err != nil { return &model.Reply{ ReplyMsg: "解析RSS源失败: " + err.Error(), @@ -76,7 +76,8 @@ func Subscribe(msg model.Message) (reply *model.Reply) { // SubscribeToFeed 订阅RSS源 func SubscribeToFeed(feedURL string, userID int64, groupID int64) error { - db := sqlite3.GetGormDB() + db := sqlite3.GetGormDB().Begin() + defer db.Rollback() // 确保URL有协议前缀 if !regexp.MustCompile(`^https?://`).MatchString(feedURL) { @@ -116,16 +117,22 @@ func SubscribeToFeed(feedURL string, userID int64, groupID int64) error { return fmt.Errorf("该群已订阅过此RSS源") } + //获取最新文章hash + _, items, err := ParseFeed(feedURL) + if err != nil { + return err + } + // 创建订阅关系 newSubscribe := RssSubscribe{ - FeedID: feedID, - GroupID: int(groupID), + FeedID: feedID, + GroupID: int(groupID), + LastItemHash: items[0].Hash, } if err := db.Create(&newSubscribe).Error; err != nil { return fmt.Errorf("创建订阅关系失败: %v", err) } - - return nil + return db.Commit().Error } func MySubscribed(msg model.Message) (reply *model.Reply) {