qq_bot/handler/rss/rss.go

194 lines
5.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package rss
import (
"fmt"
"regexp"
"strings"
"time"
"git.lxtend.com/lixiangwuxian/qqbot/constants"
"git.lxtend.com/lixiangwuxian/qqbot/handler"
"git.lxtend.com/lixiangwuxian/qqbot/model"
"git.lxtend.com/lixiangwuxian/qqbot/qq_message"
"git.lxtend.com/lixiangwuxian/qqbot/sqlite3"
"git.lxtend.com/lixiangwuxian/qqbot/util"
"gorm.io/gorm"
)
func init() {
db := sqlite3.GetGormDB()
db.AutoMigrate(&RssFeed{}, &RssSubscribe{})
handler.RegisterAtHandler("订阅", Subscribe, constants.LEVEL_USER)
handler.RegisterAtHandler("我的订阅", MySubscribed, constants.LEVEL_USER)
handler.RegisterAtHandler("退订", Unsubscribe, constants.LEVEL_USER)
//test
handler.RegisterHandler("test_rss", TestRss, constants.LEVEL_USER)
}
func TestRss(msg model.Message) (reply *model.Reply) {
rssUrl := util.SplitN(msg.StructuredMsg[0].(*qq_message.TextMessage).Data.Text, 2)[1]
items, err := ParseRssFeed(rssUrl)
if err != nil {
return &model.Reply{
ReplyMsg: "解析RSS源失败: " + err.Error(),
ReferOriginMsg: true,
FromMsg: msg,
}
}
return &model.Reply{
ReplyMsg: fmt.Sprintf("解析RSS源成功: %d 个条目\n%v", len(items), items[0]),
ReferOriginMsg: true,
FromMsg: msg,
}
}
func Subscribe(msg model.Message) (reply *model.Reply) {
//提取url
var subscribedFeeds []string
for _, data := range msg.StructuredMsg {
if data.GetMessageType() == "text" {
// 匹配RSS链接可选协议域名包含所有顶级域路径必须以.xml结尾
urls := regexp.MustCompile(`(?i)(?:https?://)?(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}(?:/[^\s]*)?\.xml\b`).FindAllString(data.(*qq_message.TextMessage).Data.Text, -1)
if len(urls) > 0 {
for _, url := range urls {
if err := SubscribeToFeed(url, msg.UserId, msg.GroupInfo.GroupId); err == nil {
subscribedFeeds = append(subscribedFeeds, url)
}
}
}
}
}
if len(subscribedFeeds) > 0 {
return &model.Reply{
ReplyMsg: fmt.Sprintf("成功订阅 %d 个RSS源", len(subscribedFeeds)),
ReferOriginMsg: true,
FromMsg: msg,
}
}
return &model.Reply{
ReplyMsg: "未找到有效的RSS链接需要以.xml结尾",
ReferOriginMsg: true,
FromMsg: msg,
}
}
// SubscribeToFeed 订阅RSS源
func SubscribeToFeed(feedURL string, userID int64, groupID int64) error {
db := sqlite3.GetGormDB()
// 确保URL有协议前缀
if !regexp.MustCompile(`^https?://`).MatchString(feedURL) {
feedURL = "https://" + feedURL
}
//检测rss源是否有效
if err := CheckRssFeed(feedURL); err != nil {
return fmt.Errorf("RSS源无效: %v", err)
}
// 检查RSS源是否已存在
var existingFeed RssFeed
result := db.Where("feed_url = ?", feedURL).First(&existingFeed)
var feedID int
if result.Error != nil {
// RSS源不存在创建新的
newFeed := RssFeed{
FeedURL: feedURL,
Creator: fmt.Sprintf("%d", userID),
LastUpdate: time.Now(),
}
if err := db.Create(&newFeed).Error; err != nil {
return fmt.Errorf("创建RSS源失败: %v", err)
}
feedID = newFeed.ID
} else {
feedID = existingFeed.ID
}
// 检查是否已经订阅
var existingSubscribe RssSubscribe
result = db.Where("feed_id = ? AND group_id = ?", feedID, groupID).First(&existingSubscribe)
if result.Error == nil {
return fmt.Errorf("该群已订阅过此RSS源")
}
// 创建订阅关系
newSubscribe := RssSubscribe{
FeedID: feedID,
GroupID: int(groupID),
}
if err := db.Create(&newSubscribe).Error; err != nil {
return fmt.Errorf("创建订阅关系失败: %v", err)
}
return nil
}
func MySubscribed(msg model.Message) (reply *model.Reply) {
db := sqlite3.GetGormDB()
var feeds []RssSubscribe
db.Where("creator = ?", fmt.Sprintf("%d", msg.UserId)).Find(&feeds)
feedIdList := make([]int, 0)
for _, feed := range feeds {
feedIdList = append(feedIdList, feed.FeedID)
}
db.Where("feed_id IN (?)", feedIdList).Find(&feeds)
table := strings.Builder{}
table.WriteString("| 订阅源 | 创建时间 |\n")
table.WriteString("| --- | --- |\n")
for _, feed := range feeds {
table.WriteString(fmt.Sprintf("| %d | %s |\n", feed.FeedID, feed.CreateAt.Format("2006-01-02 15:04:05")))
}
return &model.Reply{
ReplyMsg: "你的订阅列表:\n" + table.String(),
ReferOriginMsg: true,
FromMsg: msg,
}
}
func Unsubscribe(msg model.Message) (reply *model.Reply) {
db := sqlite3.GetGormDB()
if len(msg.StructuredMsg) < 2 ||
(msg.StructuredMsg[1].GetMessageType() != qq_message.TypeText &&
len(util.SplitN(msg.StructuredMsg[1].(*qq_message.TextMessage).Data.Text, 2)) != 2) {
return &model.Reply{
ReplyMsg: "请输入要取消订阅的RSS源ID",
ReferOriginMsg: true,
FromMsg: msg,
}
}
if msg.StructuredMsg[1].GetMessageType() == qq_message.TypeText {
feedId := util.SplitN(msg.StructuredMsg[1].(*qq_message.TextMessage).Data.Text, 2)[1]
defer func() {
if db.Where("feed_id = ?", feedId).First(&RssSubscribe{}).Error == gorm.ErrRecordNotFound {
db.Where("id = ?", feedId).Delete(&RssFeed{})
}
}()
if err := db.Where("feed_id = ?", feedId).Where("group_id = ?", msg.GroupInfo.GroupId).Delete(&RssSubscribe{}).Error; err != nil {
return &model.Reply{
ReplyMsg: "取消订阅失败,报错:" + err.Error() + "\n请检查是否存在此订阅",
ReferOriginMsg: true,
FromMsg: msg,
}
}
return &model.Reply{
ReplyMsg: "取消订阅成功",
ReferOriginMsg: true,
FromMsg: msg,
}
}
return &model.Reply{
ReplyMsg: "请输入要取消订阅的RSS源ID",
ReferOriginMsg: true,
FromMsg: msg,
}
}