189 lines
5.4 KiB
Go
189 lines
5.4 KiB
Go
package rss
|
||
|
||
import (
|
||
"fmt"
|
||
"regexp"
|
||
"strings"
|
||
"time"
|
||
|
||
"git.lxtend.com/lixiangwuxian/qqbot/constants"
|
||
"git.lxtend.com/lixiangwuxian/qqbot/handler"
|
||
"git.lxtend.com/lixiangwuxian/qqbot/model"
|
||
"git.lxtend.com/lixiangwuxian/qqbot/qq_message"
|
||
"git.lxtend.com/lixiangwuxian/qqbot/sqlite3"
|
||
"git.lxtend.com/lixiangwuxian/qqbot/util"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
func init() {
|
||
db := sqlite3.GetGormDB()
|
||
db.AutoMigrate(&RssFeed{}, &RssSubscribe{})
|
||
handler.RegisterAtHandler("订阅", Subscribe, constants.LEVEL_USER)
|
||
handler.RegisterAtHandler("我的订阅", MySubscribed, constants.LEVEL_USER)
|
||
handler.RegisterAtHandler("退订", Unsubscribe, constants.LEVEL_USER)
|
||
//test
|
||
handler.RegisterHandler("test_rss", TestRss, constants.LEVEL_ADMIN)
|
||
}
|
||
|
||
func TestRss(msg model.Message) (reply *model.Reply) {
|
||
go CheckNewRss()
|
||
return nil
|
||
}
|
||
|
||
func Subscribe(msg model.Message) (reply *model.Reply) {
|
||
//提取url
|
||
var subscribedFeeds []string
|
||
for _, data := range msg.StructuredMsg {
|
||
if data.GetMessageType() == "text" {
|
||
// 匹配RSS链接:可选协议,域名(包含所有顶级域),路径,必须以.xml结尾
|
||
urls := regexp.MustCompile(`(?i)(?:https?://)?(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}(?:/[^\s]*)?\.xml\b`).FindAllString(data.(*qq_message.TextMessage).Data.Text, -1)
|
||
if len(urls) > 0 {
|
||
for _, url := range urls {
|
||
if title, err := SubscribeToFeed(url, msg.UserId, msg.GroupInfo.GroupId); err == nil {
|
||
subscribedFeeds = append(subscribedFeeds, title)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if len(subscribedFeeds) > 0 {
|
||
return &model.Reply{
|
||
ReplyMsg: fmt.Sprintf("成功订阅 %d 个RSS源:%s", len(subscribedFeeds), strings.Join(subscribedFeeds, "\n")),
|
||
ReferOriginMsg: true,
|
||
FromMsg: msg,
|
||
}
|
||
}
|
||
|
||
return &model.Reply{
|
||
ReplyMsg: "未找到有效的RSS链接(需要以.xml结尾)",
|
||
ReferOriginMsg: true,
|
||
FromMsg: msg,
|
||
}
|
||
}
|
||
|
||
// SubscribeToFeed 订阅RSS源
|
||
func SubscribeToFeed(feedURL string, userID int64, groupID int64) (string, error) {
|
||
db := sqlite3.GetGormDB().Begin()
|
||
defer db.Rollback()
|
||
|
||
// 确保URL有协议前缀
|
||
if !regexp.MustCompile(`^https?://`).MatchString(feedURL) {
|
||
feedURL = "https://" + feedURL
|
||
}
|
||
|
||
//检测rss源是否有效
|
||
|
||
if err := CheckRssFeed(feedURL); err != nil {
|
||
return "", fmt.Errorf("RSS源无效: %v", err)
|
||
}
|
||
|
||
// 检查RSS源是否已存在
|
||
var existingFeed RssFeed
|
||
result := db.Where("feed_url = ?", feedURL).First(&existingFeed)
|
||
|
||
var feedID int
|
||
if result.Error != nil {
|
||
// RSS源不存在,创建新的
|
||
newFeed := RssFeed{
|
||
FeedURL: feedURL,
|
||
Creator: fmt.Sprintf("%d", userID),
|
||
LastUpdate: time.Now(),
|
||
}
|
||
if err := db.Create(&newFeed).Error; err != nil {
|
||
return "", fmt.Errorf("创建RSS源失败: %v", err)
|
||
}
|
||
feedID = newFeed.ID
|
||
} else {
|
||
feedID = existingFeed.ID
|
||
}
|
||
|
||
// 检查是否已经订阅
|
||
var existingSubscribe RssSubscribe
|
||
result = db.Where("feed_id = ? AND group_id = ?", feedID, groupID).First(&existingSubscribe)
|
||
if result.Error == nil {
|
||
return "", fmt.Errorf("该群已订阅过此RSS源")
|
||
}
|
||
|
||
//获取最新文章hash
|
||
title, items, err := ParseFeed(feedURL)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
// 创建订阅关系
|
||
newSubscribe := RssSubscribe{
|
||
FeedID: feedID,
|
||
GroupID: int(groupID),
|
||
LastItemHash: items[0].Hash,
|
||
}
|
||
if err := db.Create(&newSubscribe).Error; err != nil {
|
||
return "", fmt.Errorf("创建订阅关系失败: %v", err)
|
||
}
|
||
return title, db.Commit().Error
|
||
}
|
||
|
||
func MySubscribed(msg model.Message) (reply *model.Reply) {
|
||
db := sqlite3.GetGormDB()
|
||
|
||
var feeds []RssSubscribe
|
||
db.Where("creator = ?", fmt.Sprintf("%d", msg.UserId)).Find(&feeds)
|
||
feedIdList := make([]int, 0)
|
||
for _, feed := range feeds {
|
||
feedIdList = append(feedIdList, feed.FeedID)
|
||
}
|
||
db.Where("feed_id IN (?)", feedIdList).Find(&feeds)
|
||
|
||
table := strings.Builder{}
|
||
table.WriteString("| 订阅源 | 创建时间 |\n")
|
||
table.WriteString("| --- | --- |\n")
|
||
for _, feed := range feeds {
|
||
table.WriteString(fmt.Sprintf("| %d | %s |\n", feed.FeedID, feed.CreateAt.Format("2006-01-02 15:04:05")))
|
||
}
|
||
|
||
return &model.Reply{
|
||
ReplyMsg: "你的订阅列表:\n" + table.String(),
|
||
ReferOriginMsg: true,
|
||
FromMsg: msg,
|
||
}
|
||
}
|
||
|
||
func Unsubscribe(msg model.Message) (reply *model.Reply) {
|
||
db := sqlite3.GetGormDB()
|
||
if len(msg.StructuredMsg) < 2 ||
|
||
(msg.StructuredMsg[1].GetMessageType() != qq_message.TypeText &&
|
||
len(util.SplitN(msg.StructuredMsg[1].(*qq_message.TextMessage).Data.Text, 2)) != 2) {
|
||
return &model.Reply{
|
||
ReplyMsg: "请输入要取消订阅的RSS源ID",
|
||
ReferOriginMsg: true,
|
||
FromMsg: msg,
|
||
}
|
||
}
|
||
if msg.StructuredMsg[1].GetMessageType() == qq_message.TypeText {
|
||
feedId := util.SplitN(msg.StructuredMsg[1].(*qq_message.TextMessage).Data.Text, 2)[1]
|
||
defer func() {
|
||
if db.Where("feed_id = ?", feedId).First(&RssSubscribe{}).Error == gorm.ErrRecordNotFound {
|
||
db.Where("id = ?", feedId).Delete(&RssFeed{})
|
||
}
|
||
}()
|
||
if err := db.Where("feed_id = ?", feedId).Where("group_id = ?", msg.GroupInfo.GroupId).Delete(&RssSubscribe{}).Error; err != nil {
|
||
return &model.Reply{
|
||
ReplyMsg: "取消订阅失败,报错:" + err.Error() + "\n请检查是否存在此订阅",
|
||
ReferOriginMsg: true,
|
||
FromMsg: msg,
|
||
}
|
||
}
|
||
|
||
return &model.Reply{
|
||
ReplyMsg: "取消订阅成功",
|
||
ReferOriginMsg: true,
|
||
FromMsg: msg,
|
||
}
|
||
}
|
||
return &model.Reply{
|
||
ReplyMsg: "请输入要取消订阅的RSS源ID",
|
||
ReferOriginMsg: true,
|
||
FromMsg: msg,
|
||
}
|
||
}
|