This commit is contained in:
HugeFrog24
2026-02-11 18:39:02 +01:00
parent e9fd36b22d
commit 547dc8ca1a
30 changed files with 1936 additions and 229 deletions

277
bot.go Executable file → Normal file
View File

@@ -28,6 +28,14 @@ type Bot struct {
botID uint // Reference to BotModel.ID
}
// Helper function to determine message type
func messageType(msg *models.Message) string {
if msg.Sticker != nil {
return "sticker"
}
return "text"
}
// NewBot initializes and returns a new Bot instance.
func NewBot(db *gorm.DB, config BotConfig, clock Clock, tgClient TelegramClient) (*Bot, error) {
// Retrieve or create Bot entry in the database
@@ -87,6 +95,15 @@ func NewBot(db *gorm.DB, config BotConfig, clock Clock, tgClient TelegramClient)
tgBot: tgClient,
}
if tgClient == nil {
var err error
tgClient, err = initTelegramBot(config.TelegramToken, b)
if err != nil {
return nil, fmt.Errorf("failed to initialize Telegram bot: %w", err)
}
b.tgBot = tgClient
}
return b, nil
}
@@ -178,9 +195,10 @@ func (b *Bot) createMessage(chatID, userID int64, username, userRole, text strin
return message
}
func (b *Bot) storeMessage(message Message) error {
message.BotID = b.botID // Associate the message with the correct bot
return b.db.Create(&message).Error
// storeMessage stores a message in the database and updates its ID
func (b *Bot) storeMessage(message *Message) error {
message.BotID = b.botID // Associate the message with the correct bot
return b.db.Create(message).Error // This will update the message with its new ID
}
func (b *Bot) getOrCreateChatMemory(chatID int64) *ChatMemory {
@@ -190,14 +208,30 @@ func (b *Bot) getOrCreateChatMemory(chatID int64) *ChatMemory {
if !exists {
b.chatMemoriesMu.Lock()
// Double-check to prevent race condition
defer b.chatMemoriesMu.Unlock()
chatMemory, exists = b.chatMemories[chatID]
if !exists {
// Check if this is a new chat by querying the database
var count int64
b.db.Model(&Message{}).Where("chat_id = ? AND bot_id = ?", chatID, b.botID).Count(&count)
isNewChat := count == 0 // Truly new chat if no messages exist
var messages []Message
b.db.Where("chat_id = ? AND bot_id = ?", chatID, b.botID).
Order("timestamp asc").
Limit(b.memorySize * 2).
Find(&messages)
if !isNewChat {
// Fetch existing messages only if it's not a new chat
err := b.db.Where("chat_id = ? AND bot_id = ?", chatID, b.botID).
Order("timestamp asc").
Limit(b.memorySize * 2).
Find(&messages).Error
if err != nil {
ErrorLogger.Printf("Error fetching messages from database: %v", err)
messages = []Message{} // Initialize an empty slice on error
}
} else {
messages = []Message{} // Ensure messages is initialized for new chats
}
chatMemory = &ChatMemory{
Messages: messages,
@@ -206,19 +240,22 @@ func (b *Bot) getOrCreateChatMemory(chatID int64) *ChatMemory {
b.chatMemories[chatID] = chatMemory
}
b.chatMemoriesMu.Unlock()
}
return chatMemory
}
// addMessageToChatMemory adds a new message to the chat memory, ensuring the memory size is maintained.
func (b *Bot) addMessageToChatMemory(chatMemory *ChatMemory, message Message) {
b.chatMemoriesMu.Lock()
defer b.chatMemoriesMu.Unlock()
// Add the new message
chatMemory.Messages = append(chatMemory.Messages, message)
// Maintain the memory size
if len(chatMemory.Messages) > chatMemory.Size {
chatMemory.Messages = chatMemory.Messages[2:]
chatMemory.Messages = chatMemory.Messages[len(chatMemory.Messages)-chatMemory.Size:]
}
}
@@ -226,6 +263,12 @@ func (b *Bot) prepareContextMessages(chatMemory *ChatMemory) []anthropic.Message
b.chatMemoriesMu.RLock()
defer b.chatMemoriesMu.RUnlock()
// Debug logging
InfoLogger.Printf("Chat memory contains %d messages", len(chatMemory.Messages))
for i, msg := range chatMemory.Messages {
InfoLogger.Printf("Message %d: IsUser=%v, Text=%q", i, msg.IsUser, msg.Text)
}
var contextMessages []anthropic.Message
for _, msg := range chatMemory.Messages {
role := anthropic.RoleUser
@@ -252,7 +295,7 @@ func (b *Bot) prepareContextMessages(chatMemory *ChatMemory) []anthropic.Message
func (b *Bot) isNewChat(chatID int64) bool {
var count int64
b.db.Model(&Message{}).Where("chat_id = ? AND bot_id = ?", chatID, b.botID).Count(&count)
return count == 1
return count == 0 // Only consider a chat new if it has 0 messages
}
func (b *Bot) isAdminOrOwner(userID int64) bool {
@@ -264,9 +307,9 @@ func (b *Bot) isAdminOrOwner(userID int64) bool {
return user.Role.Name == "admin" || user.Role.Name == "owner"
}
func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot *bot.Bot, update *models.Update)) (TelegramClient, error) {
func initTelegramBot(token string, b *Bot) (TelegramClient, error) {
opts := []bot.Option{
bot.WithDefaultHandler(handleUpdate),
bot.WithDefaultHandler(b.handleUpdate),
}
tgBot, err := bot.New(token, opts...)
@@ -274,11 +317,40 @@ func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot
return nil, err
}
// Define bot commands
commands := []models.BotCommand{
{
Command: "stats",
Description: "Get bot statistics. Usage: /stats or /stats user [user_id]",
},
{
Command: "whoami",
Description: "Get your user information",
},
{
Command: "clear",
Description: "Clear chat history (soft delete). Admins: /clear [user_id]",
},
{
Command: "clear_hard",
Description: "Clear chat history (permanently delete). Admins: /clear_hard [user_id]",
},
}
// Set bot commands
_, err = tgBot.SetMyCommands(context.Background(), &bot.SetMyCommandsParams{
Commands: commands,
})
if err != nil {
ErrorLogger.Printf("Error setting bot commands: %v", err)
return nil, err
}
return tgBot, nil
}
func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, businessConnectionID string) error {
// Pass the outgoing message through the centralized screen for storage
// Pass the outgoing message through the centralized screen for storage and chat memory update
_, err := b.screenOutgoingMessage(chatID, text)
if err != nil {
ErrorLogger.Printf("Error storing assistant message: %v", err)
@@ -306,28 +378,75 @@ func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, busin
}
// sendStats sends the bot statistics to the specified chat.
func (b *Bot) sendStats(ctx context.Context, chatID int64, businessConnectionID string) {
totalUsers, totalMessages, err := b.getStats()
func (b *Bot) sendStats(ctx context.Context, chatID int64, userID int64, targetUserID int64, businessConnectionID string) {
// If targetUserID is 0, show global stats
if targetUserID == 0 {
totalUsers, totalMessages, err := b.getStats()
if err != nil {
ErrorLogger.Printf("Error fetching stats: %v\n", err)
if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve the stats at this time.", businessConnectionID); err != nil {
ErrorLogger.Printf("Error sending response: %v", err)
}
return
}
// Do NOT manually escape hyphens here
statsMessage := fmt.Sprintf(
"📊 Bot Statistics:\n\n"+
"- Total Users: %d\n"+
"- Total Messages: %d",
totalUsers,
totalMessages,
)
// Send the response through the centralized screen
if err := b.sendResponse(ctx, chatID, statsMessage, businessConnectionID); err != nil {
ErrorLogger.Printf("Error sending stats message: %v", err)
}
return
}
// If targetUserID is not 0, show user-specific stats
// Check permissions if the user is trying to view someone else's stats
if targetUserID != userID {
if !b.isAdminOrOwner(userID) {
InfoLogger.Printf("User %d attempted to view stats for user %d without permission", userID, targetUserID)
if err := b.sendResponse(ctx, chatID, "Permission denied. Only admins and owners can view other users' statistics.", businessConnectionID); err != nil {
ErrorLogger.Printf("Error sending response: %v", err)
}
return
}
}
// Get user stats
username, messagesIn, messagesOut, totalMessages, err := b.getUserStats(targetUserID)
if err != nil {
ErrorLogger.Printf("Error fetching stats: %v\n", err)
if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve the stats at this time.", businessConnectionID); err != nil {
ErrorLogger.Printf("Error fetching user stats: %v\n", err)
if err := b.sendResponse(ctx, chatID, fmt.Sprintf("Sorry, I couldn't retrieve statistics for user ID %d.", targetUserID), businessConnectionID); err != nil {
ErrorLogger.Printf("Error sending response: %v", err)
}
return
}
// Do NOT manually escape hyphens here
// Build the user stats message
userInfo := fmt.Sprintf("@%s (ID: %d)", username, targetUserID)
if username == "" {
userInfo = fmt.Sprintf("User ID: %d", targetUserID)
}
statsMessage := fmt.Sprintf(
"📊 Bot Statistics:\n\n"+
"- Total Users: %d\n"+
"👤 User Statistics for %s:\n\n"+
"- Messages Sent: %d\n"+
"- Messages Received: %d\n"+
"- Total Messages: %d",
totalUsers,
userInfo,
messagesIn,
messagesOut,
totalMessages,
)
// Send the response through the centralized screen
if err := b.sendResponse(ctx, chatID, statsMessage, businessConnectionID); err != nil {
ErrorLogger.Printf("Error sending stats message: %v", err)
ErrorLogger.Printf("Error sending user stats message: %v", err)
}
}
@@ -346,6 +465,35 @@ func (b *Bot) getStats() (int64, int64, error) {
return totalUsers, totalMessages, nil
}
// getUserStats retrieves statistics for a specific user
func (b *Bot) getUserStats(userID int64) (string, int64, int64, int64, error) {
// Get user information from database
var user User
err := b.db.Where("telegram_id = ? AND bot_id = ?", userID, b.botID).First(&user).Error
if err != nil {
return "", 0, 0, 0, fmt.Errorf("user not found: %w", err)
}
// Count messages sent by the user (IN)
var messagesIn int64
if err := b.db.Model(&Message{}).Where("user_id = ? AND bot_id = ? AND is_user = ?",
userID, b.botID, true).Count(&messagesIn).Error; err != nil {
return "", 0, 0, 0, err
}
// Count responses to the user (OUT)
var messagesOut int64
if err := b.db.Model(&Message{}).Where("chat_id IN (SELECT DISTINCT chat_id FROM messages WHERE user_id = ? AND bot_id = ?) AND bot_id = ? AND is_user = ?",
userID, b.botID, b.botID, false).Count(&messagesOut).Error; err != nil {
return "", 0, 0, 0, err
}
// Total messages is the sum
totalMessages := messagesIn + messagesOut
return user.Username, messagesIn, messagesOut, totalMessages, nil
}
// isOnlyEmojis checks if the string consists solely of emojis.
func isOnlyEmojis(s string) bool {
for _, r := range s {
@@ -399,41 +547,96 @@ func (b *Bot) sendWhoAmI(ctx context.Context, chatID int64, userID int64, userna
}
}
// screenIncomingMessage handles storing of incoming messages.
// screenIncomingMessage centralizes all incoming message processing: storing messages and updating chat memory.
func (b *Bot) screenIncomingMessage(message *models.Message) (Message, error) {
userRole := string(anthropic.RoleUser) // Convert RoleUser to string
userMessage := b.createMessage(message.Chat.ID, message.From.ID, message.From.Username, userRole, message.Text, true)
if b.config.DebugScreening {
start := time.Now()
defer func() {
InfoLogger.Printf(
"[Screen] Incoming: chat=%d user=%d type=%s memory_size=%d duration=%v",
message.Chat.ID,
message.From.ID,
messageType(message),
len(b.getOrCreateChatMemory(message.Chat.ID).Messages),
time.Since(start),
)
}()
}
// If the message contains a sticker, include its details.
userRole := string(anthropic.RoleUser)
// Determine message text based on message type
messageText := message.Text
if message.Sticker != nil {
if message.Sticker.Emoji != "" {
messageText = fmt.Sprintf("Sent a sticker: %s", message.Sticker.Emoji)
} else {
messageText = "Sent a sticker."
}
}
userMessage := b.createMessage(message.Chat.ID, message.From.ID, message.From.Username, userRole, messageText, true)
// Handle sticker-specific details if present
if message.Sticker != nil {
userMessage.StickerFileID = message.Sticker.FileID
userMessage.StickerEmoji = message.Sticker.Emoji // Store the sticker emoji
if message.Sticker.Thumbnail != nil {
userMessage.StickerPNGFile = message.Sticker.Thumbnail.FileID
}
}
// Store the message.
if err := b.storeMessage(userMessage); err != nil {
// Get the chat memory before storing the message
chatMemory := b.getOrCreateChatMemory(message.Chat.ID)
// Store the message and get its ID
if err := b.storeMessage(&userMessage); err != nil {
return Message{}, err
}
// Update chat memory.
chatMemory := b.getOrCreateChatMemory(message.Chat.ID)
// Add the message to the chat memory
b.addMessageToChatMemory(chatMemory, userMessage)
return userMessage, nil
}
// screenOutgoingMessage handles storing of outgoing messages.
// screenOutgoingMessage handles storing of outgoing messages and updating chat memory.
// It also marks the most recent unanswered user message as answered.
func (b *Bot) screenOutgoingMessage(chatID int64, response string) (Message, error) {
assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false)
if b.config.DebugScreening {
start := time.Now()
defer func() {
InfoLogger.Printf(
"[Screen] Outgoing: chat=%d len=%d memory_size=%d duration=%v",
chatID,
len(response),
len(b.getOrCreateChatMemory(chatID).Messages),
time.Since(start),
)
}()
}
// Store the message.
if err := b.storeMessage(assistantMessage); err != nil {
// Create and store the assistant message
assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false)
if err := b.storeMessage(&assistantMessage); err != nil {
return Message{}, err
}
// Update chat memory.
// Find and mark the most recent unanswered user message as answered
now := time.Now()
err := b.db.Model(&Message{}).
Where("chat_id = ? AND bot_id = ? AND is_user = ? AND answered_on IS NULL",
chatID, b.botID, true).
Order("timestamp DESC").
Limit(1).
Update("answered_on", now).Error
if err != nil {
ErrorLogger.Printf("Error marking user message as answered: %v", err)
// Continue even if there's an error updating the user message
}
// Update chat memory with the message that now has an ID
chatMemory := b.getOrCreateChatMemory(chatID)
b.addMessageToChatMemory(chatMemory, assistantMessage)