diff --git a/.cursor/rules/snyk_rules.mdc b/.cursor/rules/snyk_rules.mdc new file mode 100644 index 0000000..eef0634 --- /dev/null +++ b/.cursor/rules/snyk_rules.mdc @@ -0,0 +1,14 @@ +--- +alwaysApply: true +always_on: true +trigger: always_on +applyTo: "**" +description: Snyk Security At Inception +--- + +# Project security best practices + +- Always run snyk_code_scan tool for new first party code that is generated in a Snyk-supported language. +- If any security issues are found based on newly introduced or modified code or dependencies, attempt to fix the issues using the results context from Snyk. +- Rescan the code after fixing the issues to ensure that the issues were fixed and that there are no newly introduced issues. +- Repeat this process until no new issues are found. diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..7deffff --- /dev/null +++ b/.dockerignore @@ -0,0 +1,77 @@ +# Git files +.git +.gitignore +.gitattributes + +# Documentation +README.md +*.md + +# Docker files +Dockerfile +docker-compose.yml +.dockerignore + +# Environment files +.env +.env.* + +# Log files +*.log +logs/ + +# Database files +*.db +*.sqlite +*.sqlite3 +bot.db + +# Config files (except default template) +config/* +!config/default.json + +# Test files +*_test.go +test/ +tests/ + +# Build artifacts +telegram-bot +*.exe +*.dll +*.so +*.dylib + +# IDE files +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# OS files +.DS_Store +Thumbs.db + +# Go specific +vendor/ +*.mod.backup +*.sum.backup + +# Temporary files +tmp/ +temp/ +*.tmp + +# Coverage files +*.out +coverage.html + +# CI/CD files +.github/ +.gitlab-ci.yml +.travis.yml + +# Examples and documentation +examples/ +docs/ \ No newline at end of file diff --git a/.gitattributes b/.gitattributes old mode 100755 new mode 100644 diff --git a/.github/workflows/go-ci.yaml b/.github/workflows/go-ci.yaml old mode 100755 new mode 100644 index 774c1de..fc59cf4 --- a/.github/workflows/go-ci.yaml +++ b/.github/workflows/go-ci.yaml @@ -14,7 +14,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: - go-version: '1.23' + go-version: '1.24.2' - uses: actions/cache@v4 with: path: | @@ -31,9 +31,9 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: golangci/golangci-lint-action@v6 + - uses: golangci/golangci-lint-action@v7 with: - version: v1.60 + version: v2.0 args: --timeout 5m # Test job @@ -44,7 +44,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: - go-version: '1.23' + go-version: '1.24.2' - run: go test ./... -v # Security scan job diff --git a/.gitignore b/.gitignore old mode 100755 new mode 100644 diff --git a/.roo/mcp.json b/.roo/mcp.json new file mode 100644 index 0000000..7001130 --- /dev/null +++ b/.roo/mcp.json @@ -0,0 +1,3 @@ +{ + "mcpServers": {} +} \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..b375df0 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,57 @@ +# Multi-stage build for Go Telegram Bot +# Build stage +FROM golang:1.24-alpine AS builder + +# Install build dependencies including C compiler for CGO +RUN apk add --no-cache git ca-certificates tzdata gcc musl-dev + +# Set working directory +WORKDIR /build + +# Copy go mod files first for better caching +COPY go.mod go.sum ./ + +# Download dependencies +RUN go mod download + +# Copy source code +COPY . . + +# Build the application +RUN CGO_ENABLED=1 GOOS=linux go build -a -installsuffix cgo -o telegram-bot . + +# Runtime stage +FROM alpine:latest + +# Install runtime dependencies +RUN apk --no-cache add ca-certificates tzdata sqlite + +# Create non-root user +RUN addgroup -g 1001 -S appgroup && \ + adduser -u 1001 -S appuser -G appgroup + +# Set working directory +WORKDIR /app + +# Create necessary directories +RUN mkdir -p /app/config /app/data /app/logs && \ + chown -R appuser:appgroup /app + +# Copy binary from builder stage +COPY --from=builder /build/telegram-bot /app/telegram-bot + +# Copy default config as template +COPY --chown=appuser:appgroup config/default.json /app/config/ + +# Switch to non-root user +USER appuser + +# Expose any ports if needed (not required for this bot) +# EXPOSE 8080 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD pgrep telegram-bot || exit 1 + +# Run the application +CMD ["/app/telegram-bot"] \ No newline at end of file diff --git a/README.md b/README.md old mode 100755 new mode 100644 index ed6af73..cc5acb9 --- a/README.md +++ b/README.md @@ -12,39 +12,38 @@ A scalable, multi-bot solution for Telegram using Go, GORM, and the Anthropic AP ## Usage -1. Clone the repository or install using `go get`: - - Option 1: Clone the repository - ```bash - git clone https://github.com/HugeFrog24/go-telegram-bot.git - ``` - - - Option 2: Install using go get - ```bash - go get -u github.com/HugeFrog24/go-telegram-bot - ``` +### Docker Deployment (Recommended) - - Navigate to the project directory: - ```bash - cd go-telegram-bot - ``` +1. Clone the repository: + ```bash + git clone https://github.com/HugeFrog24/go-telegram-bot.git + cd go-telegram-bot + ``` 2. Copy the default config template and edit it: ```bash - cp config/default.json config/config-mybot.json + cp config/default.json config/mybot.json + nano config/mybot.json ``` - Replace `config-mybot.json` with the name of your bot. - - ```bash - nano config/config-mybot.json - ``` - - You can set up as many bots as you want. Just copy the template and edit the parameters. - -> [!IMPORTANT] +> [!IMPORTANT] > Keep your config files secret and do not commit them to version control. -3. Build the application: +3. Create data directory and run: + ```bash + mkdir -p data + docker-compose up -d + ``` + +### Native Deployment + +1. Install using `go get`: + ```bash + go get -u github.com/HugeFrog24/go-telegram-bot + cd go-telegram-bot + ``` + +2. Configure as above, then build: ```bash go build -o telegram-bot ``` @@ -76,11 +75,11 @@ To enable the bot to start automatically on system boot and run in the backgroun ``` ```bash - sudo systemctl enable telegram-bot.service + sudo systemctl enable telegram-bot ``` ```bash - sudo systemctl start telegram-bot.service + sudo systemctl start telegram-bot ``` 4. Check the status: @@ -93,22 +92,16 @@ For more details on the systemd setup, refer to the [demo service file](examples ## Logs -View logs using journalctl: - +### Docker ```bash -journalctl -u telegram-bot +docker-compose logs -f telegram-bot ``` -Follow logs: +### Systemd ```bash journalctl -u telegram-bot -f ``` -View errors: -```bash -journalctl -u telegram-bot -p err -``` - ## Testing The GitHub actions workflow already runs tests on every commit: @@ -119,3 +112,11 @@ However, you can run the tests locally using: ```bash go test -race -v ./... ``` + +## Storage + +At the moment, a SQLite database (`./data/bot.db`) is used for persistent storage. + +Remember to back it up regularly. + +Future versions will support more robust storage backends. diff --git a/anthropic.go b/anthropic.go old mode 100755 new mode 100644 index 82e1021..734a92b --- a/anthropic.go +++ b/anthropic.go @@ -3,11 +3,13 @@ package main import ( "context" "fmt" + "strings" + "time" "github.com/liushuangls/go-anthropic/v2" ) -func (b *Bot) getAnthropicResponse(ctx context.Context, messages []anthropic.Message, isNewChat, isAdminOrOwner, isEmojiOnly bool) (string, error) { +func (b *Bot) getAnthropicResponse(ctx context.Context, messages []anthropic.Message, isNewChat, isAdminOrOwner, isEmojiOnly bool, username string, firstName string, lastName string, isPremium bool, languageCode string, messageTime int) (string, error) { // Use prompts from config var systemMessage string if isNewChat { @@ -19,6 +21,56 @@ func (b *Bot) getAnthropicResponse(ctx context.Context, messages []anthropic.Mes // Combine default prompt with custom instructions systemMessage = b.config.SystemPrompts["default"] + " " + b.config.SystemPrompts["custom_instructions"] + " " + systemMessage + // Handle username placeholder + usernameValue := username + if username == "" { + usernameValue = "unknown" // Use "unknown" when username is not available + } + systemMessage = strings.ReplaceAll(systemMessage, "{username}", usernameValue) + + // Handle firstname placeholder + firstnameValue := firstName + if firstName == "" { + firstnameValue = "unknown" // Use "unknown" when first name is not available + } + systemMessage = strings.ReplaceAll(systemMessage, "{firstname}", firstnameValue) + + // Handle lastname placeholder + lastnameValue := lastName + if lastName == "" { + lastnameValue = "" // Empty string when last name is not available + } + systemMessage = strings.ReplaceAll(systemMessage, "{lastname}", lastnameValue) + + // Handle language code placeholder + langValue := languageCode + if languageCode == "" { + langValue = "en" // Default to English when language code is not available + } + systemMessage = strings.ReplaceAll(systemMessage, "{language}", langValue) + + // Handle premium status + premiumStatus := "regular user" + if isPremium { + premiumStatus = "premium user" + } + systemMessage = strings.ReplaceAll(systemMessage, "{premium_status}", premiumStatus) + + // Handle time awareness + timeObj := time.Unix(int64(messageTime), 0) + hour := timeObj.Hour() + var timeContext string + if hour >= 5 && hour < 12 { + timeContext = "morning" + } else if hour >= 12 && hour < 18 { + timeContext = "afternoon" + } else if hour >= 18 && hour < 22 { + timeContext = "evening" + } else { + timeContext = "night" + } + systemMessage = strings.ReplaceAll(systemMessage, "{time_context}", timeContext) + if !isAdminOrOwner { systemMessage += " " + b.config.SystemPrompts["avoid_sensitive"] } @@ -27,6 +79,16 @@ func (b *Bot) getAnthropicResponse(ctx context.Context, messages []anthropic.Mes systemMessage += " " + b.config.SystemPrompts["respond_with_emojis"] } + // Debug logging + InfoLogger.Printf("Sending %d messages to Anthropic", len(messages)) + for i, msg := range messages { + for _, content := range msg.Content { + if content.Type == anthropic.MessagesContentTypeText { + InfoLogger.Printf("Message %d: Role=%v, Text=%v", i, msg.Role, content.Text) + } + } + } + // Ensure the roles are correct for i := range messages { switch messages[i].Role { @@ -42,12 +104,20 @@ func (b *Bot) getAnthropicResponse(ctx context.Context, messages []anthropic.Mes model := anthropic.Model(b.config.Model) - resp, err := b.anthropicClient.CreateMessages(ctx, anthropic.MessagesRequest{ + // Create the request + request := anthropic.MessagesRequest{ Model: model, // Now `model` is of type anthropic.Model Messages: messages, System: systemMessage, MaxTokens: 1000, - }) + } + + // Apply temperature if set in config + if b.config.Temperature != nil { + request.Temperature = b.config.Temperature + } + + resp, err := b.anthropicClient.CreateMessages(ctx, request) if err != nil { return "", fmt.Errorf("error creating Anthropic message: %w", err) } diff --git a/anthropic_test.go b/anthropic_test.go new file mode 100644 index 0000000..cf82305 --- /dev/null +++ b/anthropic_test.go @@ -0,0 +1,197 @@ +package main + +import ( + "fmt" + "strings" + "testing" + "time" +) + +// TestLanguageCodeReplacement tests that language code is properly handled and replaced +func TestLanguageCodeReplacement(t *testing.T) { + // Test with provided language code + systemMessage := "User's language preference: '{language}'" + + // Test with a specific language code + langValue := "fr" + result := strings.ReplaceAll(systemMessage, "{language}", langValue) + + if !strings.Contains(result, "User's language preference: 'fr'") { + t.Errorf("Expected language code 'fr' to be replaced, got: %s", result) + } + + // Test with empty language code (should default to "en") + langValue = "" + if langValue == "" { + langValue = "en" // Default to English when language code is not available + } + result = strings.ReplaceAll(systemMessage, "{language}", langValue) + + if !strings.Contains(result, "User's language preference: 'en'") { + t.Errorf("Expected default language code 'en' to be used, got: %s", result) + } +} + +// TestPremiumStatusReplacement tests that premium status is properly handled and replaced +func TestPremiumStatusReplacement(t *testing.T) { + systemMessage := "User is a {premium_status}" + + // Test with premium user + isPremium := true + premiumStatus := "regular user" + if isPremium { + premiumStatus = "premium user" + } + result := strings.ReplaceAll(systemMessage, "{premium_status}", premiumStatus) + + if !strings.Contains(result, "User is a premium user") { + t.Errorf("Expected premium status to be replaced with 'premium user', got: %s", result) + } + + // Test with regular user + isPremium = false + premiumStatus = "regular user" + if isPremium { + premiumStatus = "premium user" + } + result = strings.ReplaceAll(systemMessage, "{premium_status}", premiumStatus) + + if !strings.Contains(result, "User is a regular user") { + t.Errorf("Expected premium status to be replaced with 'regular user', got: %s", result) + } +} + +// TestTimeContextCalculation tests that time context is correctly calculated for different hours +func TestTimeContextCalculation(t *testing.T) { + // Test cases for different hours + testCases := []struct { + hour int + expected string + }{ + {3, "night"}, // Night: hours < 5 or hours >= 22 + {5, "morning"}, // Morning: 5 <= hours < 12 + {12, "afternoon"}, // Afternoon: 12 <= hours < 18 + {17, "afternoon"}, // Afternoon: 12 <= hours < 18 + {18, "evening"}, // Evening: 18 <= hours < 22 + {21, "evening"}, // Evening: 18 <= hours < 22 + {22, "night"}, // Night: hours < 5 or hours >= 22 + {23, "night"}, // Night: hours < 5 or hours >= 22 + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("Hour_%d", tc.hour), func(t *testing.T) { + // Create a timestamp for the specified hour + testTime := time.Date(2025, 5, 15, tc.hour, 0, 0, 0, time.UTC) + + // Get the hour directly from the test time to ensure it's what we expect + actualHour := testTime.Hour() + if actualHour != tc.hour { + t.Fatalf("Test setup error: expected hour %d, got %d", tc.hour, actualHour) + } + + // Calculate time context using the same logic as in anthropic.go + var timeContext string + if actualHour >= 5 && actualHour < 12 { + timeContext = "morning" + } else if actualHour >= 12 && actualHour < 18 { + timeContext = "afternoon" + } else if actualHour >= 18 && actualHour < 22 { + timeContext = "evening" + } else { + timeContext = "night" + } + + // Check if the calculated time context matches the expected value + if timeContext != tc.expected { + t.Errorf("For hour %d: expected time context '%s', got '%s'", + actualHour, tc.expected, timeContext) + } + }) + } +} + +// TestSystemMessagePlaceholderReplacement tests that all placeholders are correctly replaced +func TestSystemMessagePlaceholderReplacement(t *testing.T) { + systemMessage := "The user you're talking to has username '{username}' and display name '{firstname} {lastname}'.\n" + + "User's language preference: '{language}'\n" + + "User is a {premium_status}\n" + + "It's currently {time_context} in your timezone" + + // Set up test data + username := "testuser" + firstName := "Test" + lastName := "User" + isPremium := true + languageCode := "de" + + // Create a timestamp for a specific hour (e.g., 14:00 = afternoon) + testTime := time.Date(2025, 5, 15, 14, 0, 0, 0, time.UTC) + messageTime := int(testTime.Unix()) + + // Handle username placeholder + usernameValue := username + if username == "" { + usernameValue = "unknown" + } + systemMessage = strings.ReplaceAll(systemMessage, "{username}", usernameValue) + + // Handle firstname placeholder + firstnameValue := firstName + if firstName == "" { + firstnameValue = "unknown" + } + systemMessage = strings.ReplaceAll(systemMessage, "{firstname}", firstnameValue) + + // Handle lastname placeholder + lastnameValue := lastName + if lastName == "" { + lastnameValue = "" + } + systemMessage = strings.ReplaceAll(systemMessage, "{lastname}", lastnameValue) + + // Handle language code placeholder + langValue := languageCode + if languageCode == "" { + langValue = "en" + } + systemMessage = strings.ReplaceAll(systemMessage, "{language}", langValue) + + // Handle premium status + premiumStatus := "regular user" + if isPremium { + premiumStatus = "premium user" + } + systemMessage = strings.ReplaceAll(systemMessage, "{premium_status}", premiumStatus) + + // Handle time awareness + timeObj := time.Unix(int64(messageTime), 0) + hour := timeObj.Hour() + var timeContext string + if hour >= 5 && hour < 12 { + timeContext = "morning" + } else if hour >= 12 && hour < 18 { + timeContext = "afternoon" + } else if hour >= 18 && hour < 22 { + timeContext = "evening" + } else { + timeContext = "night" + } + systemMessage = strings.ReplaceAll(systemMessage, "{time_context}", timeContext) + + // Check that all placeholders were replaced correctly + if !strings.Contains(systemMessage, "username 'testuser'") { + t.Errorf("Username not replaced correctly, got: %s", systemMessage) + } + if !strings.Contains(systemMessage, "display name 'Test User'") { + t.Errorf("Display name not replaced correctly, got: %s", systemMessage) + } + if !strings.Contains(systemMessage, "language preference: 'de'") { + t.Errorf("Language preference not replaced correctly, got: %s", systemMessage) + } + if !strings.Contains(systemMessage, "User is a premium user") { + t.Errorf("Premium status not replaced correctly, got: %s", systemMessage) + } + if !strings.Contains(systemMessage, "It's currently afternoon in your timezone") { + t.Errorf("Time context not replaced correctly, got: %s", systemMessage) + } +} diff --git a/bot.go b/bot.go old mode 100755 new mode 100644 index ec45c30..4a6a3cc --- a/bot.go +++ b/bot.go @@ -28,6 +28,14 @@ type Bot struct { botID uint // Reference to BotModel.ID } +// Helper function to determine message type +func messageType(msg *models.Message) string { + if msg.Sticker != nil { + return "sticker" + } + return "text" +} + // NewBot initializes and returns a new Bot instance. func NewBot(db *gorm.DB, config BotConfig, clock Clock, tgClient TelegramClient) (*Bot, error) { // Retrieve or create Bot entry in the database @@ -87,6 +95,15 @@ func NewBot(db *gorm.DB, config BotConfig, clock Clock, tgClient TelegramClient) tgBot: tgClient, } + if tgClient == nil { + var err error + tgClient, err = initTelegramBot(config.TelegramToken, b) + if err != nil { + return nil, fmt.Errorf("failed to initialize Telegram bot: %w", err) + } + b.tgBot = tgClient + } + return b, nil } @@ -178,9 +195,10 @@ func (b *Bot) createMessage(chatID, userID int64, username, userRole, text strin return message } -func (b *Bot) storeMessage(message Message) error { - message.BotID = b.botID // Associate the message with the correct bot - return b.db.Create(&message).Error +// storeMessage stores a message in the database and updates its ID +func (b *Bot) storeMessage(message *Message) error { + message.BotID = b.botID // Associate the message with the correct bot + return b.db.Create(message).Error // This will update the message with its new ID } func (b *Bot) getOrCreateChatMemory(chatID int64) *ChatMemory { @@ -190,14 +208,30 @@ func (b *Bot) getOrCreateChatMemory(chatID int64) *ChatMemory { if !exists { b.chatMemoriesMu.Lock() - // Double-check to prevent race condition + defer b.chatMemoriesMu.Unlock() + chatMemory, exists = b.chatMemories[chatID] if !exists { + // Check if this is a new chat by querying the database + var count int64 + b.db.Model(&Message{}).Where("chat_id = ? AND bot_id = ?", chatID, b.botID).Count(&count) + isNewChat := count == 0 // Truly new chat if no messages exist + var messages []Message - b.db.Where("chat_id = ? AND bot_id = ?", chatID, b.botID). - Order("timestamp asc"). - Limit(b.memorySize * 2). - Find(&messages) + if !isNewChat { + // Fetch existing messages only if it's not a new chat + err := b.db.Where("chat_id = ? AND bot_id = ?", chatID, b.botID). + Order("timestamp asc"). + Limit(b.memorySize * 2). + Find(&messages).Error + + if err != nil { + ErrorLogger.Printf("Error fetching messages from database: %v", err) + messages = []Message{} // Initialize an empty slice on error + } + } else { + messages = []Message{} // Ensure messages is initialized for new chats + } chatMemory = &ChatMemory{ Messages: messages, @@ -206,19 +240,22 @@ func (b *Bot) getOrCreateChatMemory(chatID int64) *ChatMemory { b.chatMemories[chatID] = chatMemory } - b.chatMemoriesMu.Unlock() } return chatMemory } +// addMessageToChatMemory adds a new message to the chat memory, ensuring the memory size is maintained. func (b *Bot) addMessageToChatMemory(chatMemory *ChatMemory, message Message) { b.chatMemoriesMu.Lock() defer b.chatMemoriesMu.Unlock() + // Add the new message chatMemory.Messages = append(chatMemory.Messages, message) + + // Maintain the memory size if len(chatMemory.Messages) > chatMemory.Size { - chatMemory.Messages = chatMemory.Messages[2:] + chatMemory.Messages = chatMemory.Messages[len(chatMemory.Messages)-chatMemory.Size:] } } @@ -226,6 +263,12 @@ func (b *Bot) prepareContextMessages(chatMemory *ChatMemory) []anthropic.Message b.chatMemoriesMu.RLock() defer b.chatMemoriesMu.RUnlock() + // Debug logging + InfoLogger.Printf("Chat memory contains %d messages", len(chatMemory.Messages)) + for i, msg := range chatMemory.Messages { + InfoLogger.Printf("Message %d: IsUser=%v, Text=%q", i, msg.IsUser, msg.Text) + } + var contextMessages []anthropic.Message for _, msg := range chatMemory.Messages { role := anthropic.RoleUser @@ -252,7 +295,7 @@ func (b *Bot) prepareContextMessages(chatMemory *ChatMemory) []anthropic.Message func (b *Bot) isNewChat(chatID int64) bool { var count int64 b.db.Model(&Message{}).Where("chat_id = ? AND bot_id = ?", chatID, b.botID).Count(&count) - return count == 1 + return count == 0 // Only consider a chat new if it has 0 messages } func (b *Bot) isAdminOrOwner(userID int64) bool { @@ -264,9 +307,9 @@ func (b *Bot) isAdminOrOwner(userID int64) bool { return user.Role.Name == "admin" || user.Role.Name == "owner" } -func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot *bot.Bot, update *models.Update)) (TelegramClient, error) { +func initTelegramBot(token string, b *Bot) (TelegramClient, error) { opts := []bot.Option{ - bot.WithDefaultHandler(handleUpdate), + bot.WithDefaultHandler(b.handleUpdate), } tgBot, err := bot.New(token, opts...) @@ -274,11 +317,40 @@ func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot return nil, err } + // Define bot commands + commands := []models.BotCommand{ + { + Command: "stats", + Description: "Get bot statistics. Usage: /stats or /stats user [user_id]", + }, + { + Command: "whoami", + Description: "Get your user information", + }, + { + Command: "clear", + Description: "Clear chat history (soft delete). Admins: /clear [user_id]", + }, + { + Command: "clear_hard", + Description: "Clear chat history (permanently delete). Admins: /clear_hard [user_id]", + }, + } + + // Set bot commands + _, err = tgBot.SetMyCommands(context.Background(), &bot.SetMyCommandsParams{ + Commands: commands, + }) + if err != nil { + ErrorLogger.Printf("Error setting bot commands: %v", err) + return nil, err + } + return tgBot, nil } func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, businessConnectionID string) error { - // Pass the outgoing message through the centralized screen for storage + // Pass the outgoing message through the centralized screen for storage and chat memory update _, err := b.screenOutgoingMessage(chatID, text) if err != nil { ErrorLogger.Printf("Error storing assistant message: %v", err) @@ -306,28 +378,75 @@ func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, busin } // sendStats sends the bot statistics to the specified chat. -func (b *Bot) sendStats(ctx context.Context, chatID int64, businessConnectionID string) { - totalUsers, totalMessages, err := b.getStats() +func (b *Bot) sendStats(ctx context.Context, chatID int64, userID int64, targetUserID int64, businessConnectionID string) { + // If targetUserID is 0, show global stats + if targetUserID == 0 { + totalUsers, totalMessages, err := b.getStats() + if err != nil { + ErrorLogger.Printf("Error fetching stats: %v\n", err) + if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve the stats at this time.", businessConnectionID); err != nil { + ErrorLogger.Printf("Error sending response: %v", err) + } + return + } + + // Do NOT manually escape hyphens here + statsMessage := fmt.Sprintf( + "📊 Bot Statistics:\n\n"+ + "- Total Users: %d\n"+ + "- Total Messages: %d", + totalUsers, + totalMessages, + ) + + // Send the response through the centralized screen + if err := b.sendResponse(ctx, chatID, statsMessage, businessConnectionID); err != nil { + ErrorLogger.Printf("Error sending stats message: %v", err) + } + return + } + + // If targetUserID is not 0, show user-specific stats + // Check permissions if the user is trying to view someone else's stats + if targetUserID != userID { + if !b.isAdminOrOwner(userID) { + InfoLogger.Printf("User %d attempted to view stats for user %d without permission", userID, targetUserID) + if err := b.sendResponse(ctx, chatID, "Permission denied. Only admins and owners can view other users' statistics.", businessConnectionID); err != nil { + ErrorLogger.Printf("Error sending response: %v", err) + } + return + } + } + + // Get user stats + username, messagesIn, messagesOut, totalMessages, err := b.getUserStats(targetUserID) if err != nil { - ErrorLogger.Printf("Error fetching stats: %v\n", err) - if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve the stats at this time.", businessConnectionID); err != nil { + ErrorLogger.Printf("Error fetching user stats: %v\n", err) + if err := b.sendResponse(ctx, chatID, fmt.Sprintf("Sorry, I couldn't retrieve statistics for user ID %d.", targetUserID), businessConnectionID); err != nil { ErrorLogger.Printf("Error sending response: %v", err) } return } - // Do NOT manually escape hyphens here + // Build the user stats message + userInfo := fmt.Sprintf("@%s (ID: %d)", username, targetUserID) + if username == "" { + userInfo = fmt.Sprintf("User ID: %d", targetUserID) + } + statsMessage := fmt.Sprintf( - "📊 Bot Statistics:\n\n"+ - "- Total Users: %d\n"+ + "👤 User Statistics for %s:\n\n"+ + "- Messages Sent: %d\n"+ + "- Messages Received: %d\n"+ "- Total Messages: %d", - totalUsers, + userInfo, + messagesIn, + messagesOut, totalMessages, ) - // Send the response through the centralized screen if err := b.sendResponse(ctx, chatID, statsMessage, businessConnectionID); err != nil { - ErrorLogger.Printf("Error sending stats message: %v", err) + ErrorLogger.Printf("Error sending user stats message: %v", err) } } @@ -346,6 +465,35 @@ func (b *Bot) getStats() (int64, int64, error) { return totalUsers, totalMessages, nil } +// getUserStats retrieves statistics for a specific user +func (b *Bot) getUserStats(userID int64) (string, int64, int64, int64, error) { + // Get user information from database + var user User + err := b.db.Where("telegram_id = ? AND bot_id = ?", userID, b.botID).First(&user).Error + if err != nil { + return "", 0, 0, 0, fmt.Errorf("user not found: %w", err) + } + + // Count messages sent by the user (IN) + var messagesIn int64 + if err := b.db.Model(&Message{}).Where("user_id = ? AND bot_id = ? AND is_user = ?", + userID, b.botID, true).Count(&messagesIn).Error; err != nil { + return "", 0, 0, 0, err + } + + // Count responses to the user (OUT) + var messagesOut int64 + if err := b.db.Model(&Message{}).Where("chat_id IN (SELECT DISTINCT chat_id FROM messages WHERE user_id = ? AND bot_id = ?) AND bot_id = ? AND is_user = ?", + userID, b.botID, b.botID, false).Count(&messagesOut).Error; err != nil { + return "", 0, 0, 0, err + } + + // Total messages is the sum + totalMessages := messagesIn + messagesOut + + return user.Username, messagesIn, messagesOut, totalMessages, nil +} + // isOnlyEmojis checks if the string consists solely of emojis. func isOnlyEmojis(s string) bool { for _, r := range s { @@ -399,41 +547,96 @@ func (b *Bot) sendWhoAmI(ctx context.Context, chatID int64, userID int64, userna } } -// screenIncomingMessage handles storing of incoming messages. +// screenIncomingMessage centralizes all incoming message processing: storing messages and updating chat memory. func (b *Bot) screenIncomingMessage(message *models.Message) (Message, error) { - userRole := string(anthropic.RoleUser) // Convert RoleUser to string - userMessage := b.createMessage(message.Chat.ID, message.From.ID, message.From.Username, userRole, message.Text, true) + if b.config.DebugScreening { + start := time.Now() + defer func() { + InfoLogger.Printf( + "[Screen] Incoming: chat=%d user=%d type=%s memory_size=%d duration=%v", + message.Chat.ID, + message.From.ID, + messageType(message), + len(b.getOrCreateChatMemory(message.Chat.ID).Messages), + time.Since(start), + ) + }() + } - // If the message contains a sticker, include its details. + userRole := string(anthropic.RoleUser) + + // Determine message text based on message type + messageText := message.Text + if message.Sticker != nil { + if message.Sticker.Emoji != "" { + messageText = fmt.Sprintf("Sent a sticker: %s", message.Sticker.Emoji) + } else { + messageText = "Sent a sticker." + } + } + + userMessage := b.createMessage(message.Chat.ID, message.From.ID, message.From.Username, userRole, messageText, true) + + // Handle sticker-specific details if present if message.Sticker != nil { userMessage.StickerFileID = message.Sticker.FileID + userMessage.StickerEmoji = message.Sticker.Emoji // Store the sticker emoji if message.Sticker.Thumbnail != nil { userMessage.StickerPNGFile = message.Sticker.Thumbnail.FileID } } - // Store the message. - if err := b.storeMessage(userMessage); err != nil { + // Get the chat memory before storing the message + chatMemory := b.getOrCreateChatMemory(message.Chat.ID) + + // Store the message and get its ID + if err := b.storeMessage(&userMessage); err != nil { return Message{}, err } - // Update chat memory. - chatMemory := b.getOrCreateChatMemory(message.Chat.ID) + // Add the message to the chat memory b.addMessageToChatMemory(chatMemory, userMessage) return userMessage, nil } -// screenOutgoingMessage handles storing of outgoing messages. +// screenOutgoingMessage handles storing of outgoing messages and updating chat memory. +// It also marks the most recent unanswered user message as answered. func (b *Bot) screenOutgoingMessage(chatID int64, response string) (Message, error) { - assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false) + if b.config.DebugScreening { + start := time.Now() + defer func() { + InfoLogger.Printf( + "[Screen] Outgoing: chat=%d len=%d memory_size=%d duration=%v", + chatID, + len(response), + len(b.getOrCreateChatMemory(chatID).Messages), + time.Since(start), + ) + }() + } - // Store the message. - if err := b.storeMessage(assistantMessage); err != nil { + // Create and store the assistant message + assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false) + if err := b.storeMessage(&assistantMessage); err != nil { return Message{}, err } - // Update chat memory. + // Find and mark the most recent unanswered user message as answered + now := time.Now() + err := b.db.Model(&Message{}). + Where("chat_id = ? AND bot_id = ? AND is_user = ? AND answered_on IS NULL", + chatID, b.botID, true). + Order("timestamp DESC"). + Limit(1). + Update("answered_on", now).Error + + if err != nil { + ErrorLogger.Printf("Error marking user message as answered: %v", err) + // Continue even if there's an error updating the user message + } + + // Update chat memory with the message that now has an ID chatMemory := b.getOrCreateChatMemory(chatID) b.addMessageToChatMemory(chatMemory, assistantMessage) diff --git a/clock.go b/clock.go old mode 100755 new mode 100644 diff --git a/config.go b/config.go old mode 100755 new mode 100644 index 904b784..f83b90f --- a/config.go +++ b/config.go @@ -18,10 +18,12 @@ type BotConfig struct { MessagePerDay int `json:"messages_per_day"` TempBanDuration string `json:"temp_ban_duration"` Model anthropic.Model `json:"model"` + Temperature *float32 `json:"temperature,omitempty"` // Controls creativity vs determinism (0.0-1.0) SystemPrompts map[string]string `json:"system_prompts"` Active bool `json:"active"` OwnerTelegramID int64 `json:"owner_telegram_id"` AnthropicAPIKey string `json:"anthropic_api_key"` + DebugScreening bool `json:"debug_screening"` // Enable detailed screening logs } // Custom unmarshalling to handle anthropic.Model @@ -144,12 +146,15 @@ func validateConfig(config *BotConfig, ids, tokens map[string]bool) error { func loadConfig(filename string) (BotConfig, error) { var config BotConfig // Use filepath.Clean before opening the file - cleanPath := filepath.Clean(filename) - file, err := os.OpenFile(cleanPath, os.O_RDONLY, 0) + file, err := os.OpenFile(filepath.Clean(filename), os.O_RDONLY, 0) if err != nil { - return config, fmt.Errorf("failed to open config file %s: %w", cleanPath, err) + return config, fmt.Errorf("failed to open config file %s: %w", filename, err) } - defer file.Close() + defer func() { + if err := file.Close(); err != nil { + InfoLogger.Printf("Failed to close config file: %v", err) + } + }() decoder := json.NewDecoder(file) if err := decoder.Decode(&config); err != nil { @@ -173,7 +178,11 @@ func (c *BotConfig) Reload(configDir, filename string) error { if err != nil { return fmt.Errorf("failed to open config file %s: %w", cleanPath, err) } - defer file.Close() + defer func() { + if err := file.Close(); err != nil { + InfoLogger.Printf("Failed to close config file: %v", err) + } + }() decoder := json.NewDecoder(file) if err := decoder.Decode(c); err != nil { diff --git a/config/default.json b/config/default.json old mode 100755 new mode 100644 index f7ddd8b..52dc735 --- a/config/default.json +++ b/config/default.json @@ -8,12 +8,14 @@ "messages_per_hour": 20, "messages_per_day": 100, "temp_ban_duration": "24h", - "model": "claude-3-5-sonnet-20240620", + "model": "claude-3-5-haiku-latest", + "temperature": 0.7, + "debug_screening": false, "system_prompts": { "default": "You are a helpful assistant.", - "custom_instructions": "Please follow these guidelines:\n- Your name is Atom.\n- If a user asks about buying apples, inform them that we don't sell apples.\n- When asked for a joke, tell a clean, family-friendly joke about programming or technology.\n- If someone inquires about our services, explain that we offer AI-powered chatbot solutions.\n- For any questions about pricing, direct users to contact our sales team at sales@example.com.\n- If asked about your capabilities, be honest about what you can and cannot do.\nAlways maintain a friendly and professional tone.", + "custom_instructions": "You are texting through a limited Telegram interface with 15-word maximum. Write like texting a friend - use shorthand, skip grammar, use slang/abbreviations. System cuts off anything longer than 15 words.\n\n- Your name is Atom.\n- The user you're talking to has username '{username}' and display name '{firstname} {lastname}'.\n- User's language preference: '{language}'\n- User is a {premium_status}\n- It's currently {time_context} in your timezone. Use appropriate time-based greetings and address the user by name.\n- If a user asks about buying apples, inform them that we don't sell apples.\n- When asked for a joke, tell a clean, family-friendly joke about programming or technology.\n- If someone inquires about our services, explain that we offer AI-powered chatbot solutions.\n- For any questions about pricing, direct users to contact our sales team at sales@example.com.\n- If asked about your capabilities, be honest about what you can and cannot do.\nAlways maintain a friendly and professional tone.", "continue_conversation": "Continuing our conversation. Remember previous context if relevant.", "avoid_sensitive": "Avoid discussing sensitive topics or providing harmful information.", "respond_with_emojis": "Since the user sent only emojis, respond using emojis only." } -} \ No newline at end of file +} diff --git a/config_test.go b/config_test.go old mode 100755 new mode 100644 index 0fcd945..f635447 --- a/config_test.go +++ b/config_test.go @@ -10,7 +10,7 @@ import ( "github.com/liushuangls/go-anthropic/v2" ) -// Add this at the beginning of the file, after the imports +// Set up loggers func TestMain(m *testing.M) { initLoggers() os.Exit(m.Run()) @@ -26,6 +26,7 @@ func TestBotConfig_UnmarshalJSON(t *testing.T) { "messages_per_day": 100, "temp_ban_duration": "1h", "model": "claude-v1", + "temperature": 0.7, "system_prompts": {"welcome": "Hello!"}, "active": true, "owner_telegram_id": 123456789, @@ -100,7 +101,11 @@ func TestValidateConfigPath(t *testing.T) { if err := os.MkdirAll(subDir, 0755); err != nil { t.Fatalf("Failed to create subdir: %v", err) } - defer os.RemoveAll(subDir) + defer func() { + if err := os.RemoveAll(subDir); err != nil { + t.Errorf("Failed to remove test subdirectory: %v", err) + } + }() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -124,7 +129,11 @@ func TestLoadConfig(t *testing.T) { if err != nil { t.Fatalf("Failed to create temp dir: %v", err) } - defer os.RemoveAll(tempDir) + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Errorf("Failed to remove temp directory: %v", err) + } + }() // Valid config JSON validConfig := `{ @@ -135,6 +144,7 @@ func TestLoadConfig(t *testing.T) { "messages_per_day": 100, "temp_ban_duration": "1h", "model": "claude-v1", + "temperature": 0.7, "system_prompts": {"welcome": "Hello!"}, "active": true, "owner_telegram_id": 123456789, @@ -318,7 +328,11 @@ func TestLoadAllConfigs(t *testing.T) { if err != nil { t.Fatalf("Failed to create temp dir: %v", err) } - defer os.RemoveAll(tempDir) + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Errorf("Failed to remove temp directory: %v", err) + } + }() tests := []struct { name string @@ -338,6 +352,7 @@ func TestLoadAllConfigs(t *testing.T) { "messages_per_day": 100, "temp_ban_duration": "1h", "model": "claude-v1", + "temperature": 0.7, "system_prompts": {"welcome": "Hello!"}, "active": true, "owner_telegram_id": 123456789, @@ -371,6 +386,7 @@ func TestLoadAllConfigs(t *testing.T) { "messages_per_day": 50, "temp_ban_duration": "30m", "model": "claude-v2", + "temperature": 0.5, "system_prompts": {"welcome": "Hi!"}, "active": false, "owner_telegram_id": 987654321, @@ -404,6 +420,7 @@ func TestLoadAllConfigs(t *testing.T) { "messages_per_day": 20, "temp_ban_duration": "15m", "model": "claude-v3", + "temperature": 0.3, "system_prompts": {"welcome": "Hey!"}, "active": true, "owner_telegram_id": 1122334455, @@ -437,6 +454,7 @@ func TestLoadAllConfigs(t *testing.T) { "messages_per_day": 10, "temp_ban_duration": "5m", "model": "claude-v4", + "temperature": 0.2, "system_prompts": {"welcome": "Greetings!"}, "active": true, "owner_telegram_id": 5566778899, @@ -511,7 +529,11 @@ func TestBotConfig_Reload(t *testing.T) { if err != nil { t.Fatalf("Failed to create temp dir: %v", err) } - defer os.RemoveAll(tempDir) + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Errorf("Failed to remove temp directory: %v", err) + } + }() // Create initial config file config1 := `{ @@ -522,6 +544,7 @@ func TestBotConfig_Reload(t *testing.T) { "messages_per_day": 100, "temp_ban_duration": "1h", "model": "claude-v1", + "temperature": 0.7, "system_prompts": {"welcome": "Hello!"}, "active": true, "owner_telegram_id": 123456789, @@ -555,6 +578,7 @@ func TestBotConfig_Reload(t *testing.T) { "messages_per_day": 200, "temp_ban_duration": "2h", "model": "claude-v2", + "temperature": 0.3, "system_prompts": {"welcome": "Hi there!"}, "active": true, "owner_telegram_id": 987654321, @@ -594,6 +618,7 @@ func TestBotConfig_UnmarshalJSON_Invalid(t *testing.T) { "messages_per_day": 100, "temp_ban_duration": "1h", "model": "", + "temperature": 0.7, "system_prompts": {"welcome": "Hello!"}, "active": true, "owner_telegram_id": 123456789, @@ -616,4 +641,84 @@ func contains(s, substr string) bool { return strings.Contains(s, substr) } +// TestTemperatureConfig tests that the temperature value is correctly loaded +func TestTemperatureConfig(t *testing.T) { + // Create a temporary directory + tempDir, err := os.MkdirTemp("", "temperature_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Errorf("Failed to remove temp directory: %v", err) + } + }() + + // Create config with temperature + configWithTemp := `{ + "id": "bot123", + "telegram_token": "token123", + "memory_size": 1024, + "messages_per_hour": 10, + "messages_per_day": 100, + "temp_ban_duration": "1h", + "model": "claude-v1", + "temperature": 0.42, + "system_prompts": {"welcome": "Hello!"}, + "active": true, + "owner_telegram_id": 123456789, + "anthropic_api_key": "api_key_123" + }` + + // Create config without temperature + configWithoutTemp := `{ + "id": "bot124", + "telegram_token": "token124", + "memory_size": 1024, + "messages_per_hour": 10, + "messages_per_day": 100, + "temp_ban_duration": "1h", + "model": "claude-v1", + "system_prompts": {"welcome": "Hello!"}, + "active": true, + "owner_telegram_id": 123456789, + "anthropic_api_key": "api_key_123" + }` + + // Write config files + withTempPath := filepath.Join(tempDir, "with_temp.json") + if err := os.WriteFile(withTempPath, []byte(configWithTemp), 0644); err != nil { + t.Fatalf("Failed to write config with temperature: %v", err) + } + + withoutTempPath := filepath.Join(tempDir, "without_temp.json") + if err := os.WriteFile(withoutTempPath, []byte(configWithoutTemp), 0644); err != nil { + t.Fatalf("Failed to write config without temperature: %v", err) + } + + // Test loading config with temperature + configWithTempObj, err := loadConfig(withTempPath) + if err != nil { + t.Fatalf("Failed to load config with temperature: %v", err) + } + + // Verify temperature is set correctly + if configWithTempObj.Temperature == nil { + t.Errorf("Expected Temperature to be set, got nil") + } else if *configWithTempObj.Temperature != 0.42 { + t.Errorf("Expected Temperature 0.42, got %f", *configWithTempObj.Temperature) + } + + // Test loading config without temperature + configWithoutTempObj, err := loadConfig(withoutTempPath) + if err != nil { + t.Fatalf("Failed to load config without temperature: %v", err) + } + + // Verify temperature is nil when not specified + if configWithoutTempObj.Temperature != nil { + t.Errorf("Expected Temperature to be nil, got %f", *configWithoutTempObj.Temperature) + } +} + // Additional tests can be added here to cover more scenarios diff --git a/database.go b/database.go old mode 100755 new mode 100644 index 66ef280..fe787ca --- a/database.go +++ b/database.go @@ -20,7 +20,7 @@ func initDB() (*gorm.DB, error) { }, ) - db, err := gorm.Open(sqlite.Open("bot.db"), &gorm.Config{ + db, err := gorm.Open(sqlite.Open("data/bot.db"), &gorm.Config{ Logger: newLogger, }) if err != nil { diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..9d0d689 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,39 @@ +services: + telegram-bot: + image: bogerserge/go-telegram-bot:latest + build: + context: . + dockerfile: Dockerfile + platforms: + - linux/amd64 + - linux/arm64 + container_name: go-telegram-bot + restart: unless-stopped + + # Optional: Environment variables (can be overridden with .env file) + # environment: + # - BOT_LOG_LEVEL=info + + # Volume mounts + volumes: + # Bind mount config directory for live configuration updates + - ./config:/app/config:ro + # Named volume for persistent database storage + - ./data:/app/data + # Optional: Bind mount for log access (uncomment if needed) + # - ./logs:/app/logs + + # Health check + healthcheck: + test: ["CMD", "pgrep", "telegram-bot"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 10s + + # Logging configuration + logging: + driver: "json-file" + options: + max-size: "10m" + max-file: "3" diff --git a/examples/systemd/telegram-bot.service b/examples/systemd/telegram-bot.service old mode 100755 new mode 100644 diff --git a/go.mod b/go.mod old mode 100755 new mode 100644 index db4f838..22a33b0 --- a/go.mod +++ b/go.mod @@ -1,18 +1,23 @@ module github.com/HugeFrog24/go-telegram-bot -go 1.23 +go 1.24.2 require ( - github.com/go-telegram/bot v1.9.1 - github.com/liushuangls/go-anthropic/v2 v2.8.2 - golang.org/x/time v0.7.0 - gorm.io/driver/sqlite v1.5.6 - gorm.io/gorm v1.25.12 + github.com/go-telegram/bot v1.18.0 + github.com/liushuangls/go-anthropic/v2 v2.17.0 + github.com/stretchr/testify v1.11.1 + golang.org/x/time v0.14.0 + gorm.io/driver/sqlite v1.6.0 + gorm.io/gorm v1.31.1 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect - github.com/mattn/go-sqlite3 v1.14.24 // indirect - golang.org/x/text v0.19.0 // indirect + github.com/mattn/go-sqlite3 v1.14.34 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.3 // indirect + golang.org/x/text v0.34.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum old mode 100755 new mode 100644 index da0e112..1d44f91 --- a/go.sum +++ b/go.sum @@ -1,18 +1,44 @@ -github.com/go-telegram/bot v1.9.1 h1:4vkNV6vDmEPZaYP7sZYaagOaJyV4GerfOPkjg/Ki5ic= -github.com/go-telegram/bot v1.9.1/go.mod h1:i2TRs7fXWIeaceF3z7KzsMt/he0TwkVC680mvdTFYeM= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-telegram/bot v1.17.0 h1:Hs0kGxSj97QFqOQP0zxduY/4tSx8QDzvNI9uVRS+zmY= +github.com/go-telegram/bot v1.17.0/go.mod h1:i2TRs7fXWIeaceF3z7KzsMt/he0TwkVC680mvdTFYeM= +github.com/go-telegram/bot v1.18.0 h1:yQzv437DY42SYTPBY48RinAvwbmf1ox5QICskIYWCD8= +github.com/go-telegram/bot v1.18.0/go.mod h1:i2TRs7fXWIeaceF3z7KzsMt/he0TwkVC680mvdTFYeM= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/liushuangls/go-anthropic/v2 v2.8.2 h1:PbR9oQF3JDnU/hmbbQI+3tkCqNtdaw4K6S0YfzByl9I= -github.com/liushuangls/go-anthropic/v2 v2.8.2/go.mod h1:8BKv/fkeTaL5R9R9bGkaknYBueyw2WxY20o7bImbOek= -github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= -github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= -golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= -golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= -golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= -gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE= -gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= -gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= -gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= +github.com/liushuangls/go-anthropic/v2 v2.16.2 h1:eK2tdDTKlMiHEdTKhbSUf11dgY0K//PulXDFAj2EeHQ= +github.com/liushuangls/go-anthropic/v2 v2.16.2/go.mod h1:a550cJXPoTG2FL3DvfKG2zzD5O2vjgvo4tHtoGPzFLU= +github.com/liushuangls/go-anthropic/v2 v2.17.0 h1:iBA6h7aghi1q86owEQ95XE2R2MF/0dQ7bCxtwTxOg4c= +github.com/liushuangls/go-anthropic/v2 v2.17.0/go.mod h1:a550cJXPoTG2FL3DvfKG2zzD5O2vjgvo4tHtoGPzFLU= +github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= +github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0= +github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk= +github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.5.3 h1:jmXUvGomnU1o3W/V5h2VEradbpJDwGrzugQQvL0POH4= +github.com/stretchr/objx v0.5.3/go.mod h1:rDQraq+vQZU7Fde9LOZLr8Tax6zZvy4kuNKF+QYS+U0= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= +golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= +golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= +gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= +gorm.io/gorm v1.31.0 h1:0VlycGreVhK7RF/Bwt51Fk8v0xLiiiFdbGDPIZQ7mJY= +gorm.io/gorm v1.31.0/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= diff --git a/handlers.go b/handlers.go old mode 100755 new mode 100644 index 9fbe698..af6ca76 --- a/handlers.go +++ b/handlers.go @@ -2,6 +2,8 @@ package main import ( "context" + "fmt" + "strconv" "strings" "github.com/go-telegram/bot" @@ -32,90 +34,188 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U chatID := message.Chat.ID userID := message.From.ID username := message.From.Username + firstName := message.From.FirstName + lastName := message.From.LastName + languageCode := message.From.LanguageCode + isPremium := message.From.IsPremium + messageTime := message.Date text := message.Text - // Pass the incoming message through the centralized screen for storage - _, err := b.screenIncomingMessage(message) - if err != nil { + // Check if it's a new chat + isNewChatFlag := b.isNewChat(chatID) + + // Screen incoming message + if _, err := b.screenIncomingMessage(message); err != nil { ErrorLogger.Printf("Error storing user message: %v", err) return } - // Check if the message is a command - if message.Entities != nil { - for _, entity := range message.Entities { - if entity.Type == "bot_command" { - command := strings.TrimSpace(message.Text[entity.Offset : entity.Offset+entity.Length]) - switch command { - case "/stats": - b.sendStats(ctx, chatID, businessConnectionID) - return - case "/whoami": - b.sendWhoAmI(ctx, chatID, userID, username, businessConnectionID) - return - } - } - } - } - - // Check if the message contains a sticker - if message.Sticker != nil { - b.handleStickerMessage(ctx, chatID, userID, message, businessConnectionID) - return - } - - // Rate limit check - if !b.checkRateLimits(userID) { - b.sendRateLimitExceededMessage(ctx, chatID, businessConnectionID) - return - } - - // Proceed only if the message contains text - if text == "" { - InfoLogger.Printf("Received a non-text message from user %d in chat %d", userID, chatID) - return - } - // Determine if the user is the owner var isOwner bool - err = b.db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", userID, b.botID, true).First(&User{}).Error + err := b.db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", userID, b.botID, true).First(&User{}).Error if err == nil { isOwner = true } - user, err := b.getOrCreateUser(userID, username, isOwner) - if err != nil { - ErrorLogger.Printf("Error getting or creating user: %v", err) - return - } - - // Update the username if it's empty or has changed - if user.Username != username { - user.Username = username - if err := b.db.Save(&user).Error; err != nil { - ErrorLogger.Printf("Error updating user username: %v", err) - } - } - - // Determine if the text contains only emojis - isEmojiOnly := isOnlyEmojis(text) - - // Prepare context messages for Anthropic + // Get the chat memory which now contains the user's message chatMemory := b.getOrCreateChatMemory(chatID) - b.addMessageToChatMemory(chatMemory, b.createMessage(chatID, userID, username, user.Role.Name, text, true)) contextMessages := b.prepareContextMessages(chatMemory) - // Get response from Anthropic - response, err := b.getAnthropicResponse(ctx, contextMessages, b.isNewChat(chatID), isOwner, isEmojiOnly) - if err != nil { - ErrorLogger.Printf("Error getting Anthropic response: %v", err) - response = "I'm sorry, I'm having trouble processing your request right now." - } + if isNewChatFlag { - // Send the response through the centralized screen - if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil { - ErrorLogger.Printf("Error sending response: %v", err) - return + // Get response from Anthropic using the context messages + response, err := b.getAnthropicResponse(ctx, contextMessages, true, isOwner, false, username, firstName, lastName, isPremium, languageCode, messageTime) + if err != nil { + ErrorLogger.Printf("Error getting Anthropic response: %v", err) + // Use the same error message as in the non-new chat case + response = "I'm sorry, I'm having trouble processing your request right now." + } + + // Send the AI-generated response or error message + if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil { + ErrorLogger.Printf("Error sending response: %v", err) + return + } + } else { + user, err := b.getOrCreateUser(userID, username, isOwner) + if err != nil { + ErrorLogger.Printf("Error getting or creating user: %v", err) + return + } + + // Update the username if it's empty or has changed + if user.Username != username { + user.Username = username + if err := b.db.Save(&user).Error; err != nil { + ErrorLogger.Printf("Error updating user username: %v", err) + } + } + + // Check if the message is a command + if message.Entities != nil { + for _, entity := range message.Entities { + if entity.Type == "bot_command" { + command := strings.TrimSpace(message.Text[entity.Offset : entity.Offset+entity.Length]) + switch command { + case "/stats": + // Parse command parameters + parts := strings.Fields(message.Text) + + // Default: show global stats + if len(parts) == 1 { + b.sendStats(ctx, chatID, userID, 0, businessConnectionID) + return + } + + // Check for "user" parameter + if len(parts) >= 2 && parts[1] == "user" { + targetUserID := userID // Default to current user + + // If a user ID is provided, parse it + if len(parts) >= 3 { + var parseErr error + targetUserID, parseErr = strconv.ParseInt(parts[2], 10, 64) + if parseErr != nil { + InfoLogger.Printf("User %d provided invalid user ID format: %s", userID, parts[2]) + if err := b.sendResponse(ctx, chatID, "Invalid user ID format. Usage: /stats user [user_id]", businessConnectionID); err != nil { + ErrorLogger.Printf("Error sending response: %v", err) + } + return + } + } + + b.sendStats(ctx, chatID, userID, targetUserID, businessConnectionID) + return + } + + // Invalid parameter + if err := b.sendResponse(ctx, chatID, "Invalid command format. Usage: /stats or /stats user [user_id]", businessConnectionID); err != nil { + ErrorLogger.Printf("Error sending response: %v", err) + } + return + case "/whoami": + b.sendWhoAmI(ctx, chatID, userID, username, businessConnectionID) + return + case "/clear": + // Extract optional user ID parameter + parts := strings.Fields(message.Text) + var targetUserID int64 = 0 + if len(parts) > 1 { + // Parse the user ID + var parseErr error + targetUserID, parseErr = strconv.ParseInt(parts[1], 10, 64) + if parseErr != nil { + // Invalid user ID format + InfoLogger.Printf("User %d provided invalid user ID format: %s", userID, parts[1]) + if err := b.sendResponse(ctx, chatID, "Invalid user ID format. Usage: /clear [user_id]", businessConnectionID); err != nil { + ErrorLogger.Printf("Error sending response: %v", err) + } + return + } + } + b.clearChatHistory(ctx, chatID, userID, targetUserID, businessConnectionID, false) + return + case "/clear_hard": + // Extract optional user ID parameter + parts := strings.Fields(message.Text) + var targetUserID int64 = 0 + if len(parts) > 1 { + // Parse the user ID + var parseErr error + targetUserID, parseErr = strconv.ParseInt(parts[1], 10, 64) + if parseErr != nil { + // Invalid user ID format + InfoLogger.Printf("User %d provided invalid user ID format: %s", userID, parts[1]) + if err := b.sendResponse(ctx, chatID, "Invalid user ID format. Usage: /clear_hard [user_id]", businessConnectionID); err != nil { + ErrorLogger.Printf("Error sending response: %v", err) + } + return + } + } + b.clearChatHistory(ctx, chatID, userID, targetUserID, businessConnectionID, true) + return + } + } + } + } + + // Check if the message contains a sticker + if message.Sticker != nil { + b.handleStickerMessage(ctx, chatID, message, businessConnectionID) + return + } + + // Rate limit check + if !b.checkRateLimits(userID) { + b.sendRateLimitExceededMessage(ctx, chatID, businessConnectionID) + return + } + + // Proceed only if the message contains text + if text == "" { + InfoLogger.Printf("Received a non-text message from user %d in chat %d", userID, chatID) + return + } + + // Determine if the text contains only emojis + isEmojiOnly := isOnlyEmojis(text) + + // Prepare context messages for Anthropic + chatMemory := b.getOrCreateChatMemory(chatID) + contextMessages := b.prepareContextMessages(chatMemory) + + // Get response from Anthropic + response, err := b.getAnthropicResponse(ctx, contextMessages, false, isOwner, isEmojiOnly, username, firstName, lastName, isPremium, languageCode, messageTime) // isNewChat is false here + if err != nil { + ErrorLogger.Printf("Error getting Anthropic response: %v", err) + response = "I'm sorry, I'm having trouble processing your request right now." + } + + // Send the response + if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil { + ErrorLogger.Printf("Error sending response: %v", err) + return + } } } @@ -125,22 +225,14 @@ func (b *Bot) sendRateLimitExceededMessage(ctx context.Context, chatID int64, bu } } -func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, message *models.Message, businessConnectionID string) { - username := message.From.Username - - // Create the user message (without storing it manually) - userMessage := b.createMessage(chatID, userID, username, "user", "Sent a sticker.", true) - userMessage.StickerFileID = message.Sticker.FileID - - // Safely store the Thumbnail's FileID if available - if message.Sticker.Thumbnail != nil { - userMessage.StickerPNGFile = message.Sticker.Thumbnail.FileID +func (b *Bot) handleStickerMessage(ctx context.Context, chatID int64, message *models.Message, businessConnectionID string) { + // Process sticker through centralized screening + userMessage, err := b.screenIncomingMessage(message) + if err != nil { + ErrorLogger.Printf("Error processing sticker message: %v", err) + return } - // Update chat memory with the user message - chatMemory := b.getOrCreateChatMemory(chatID) - b.addMessageToChatMemory(chatMemory, userMessage) - // Generate AI response about the sticker response, err := b.generateStickerResponse(ctx, userMessage) if err != nil { @@ -155,7 +247,7 @@ func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, me } } - // Send the response through the centralized screen + // Send the response if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil { ErrorLogger.Printf("Error sending response: %v", err) return @@ -165,18 +257,28 @@ func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, me func (b *Bot) generateStickerResponse(ctx context.Context, message Message) (string, error) { // Example: Use the sticker type to generate a response if message.StickerFileID != "" { + // Create message content with emoji information if available + var messageContent string + if message.StickerEmoji != "" { + messageContent = fmt.Sprintf("User sent a sticker: %s", message.StickerEmoji) + } else { + messageContent = "User sent a sticker." + } + // Prepare context with information about the sticker contextMessages := []anthropic.Message{ { Role: anthropic.RoleUser, Content: []anthropic.MessageContent{ - anthropic.NewTextMessageContent("User sent a sticker."), + anthropic.NewTextMessageContent(messageContent), }, }, } - // Since this is a sticker message, isEmojiOnly is false - response, err := b.getAnthropicResponse(ctx, contextMessages, false, false, false) + // Treat sticker messages like emoji messages to get emoji responses + // Convert the timestamp to Unix time for the messageTime parameter + messageTime := int(message.Timestamp.Unix()) + response, err := b.getAnthropicResponse(ctx, contextMessages, false, false, true, message.Username, "", "", false, "", messageTime) if err != nil { return "", err } @@ -186,3 +288,93 @@ func (b *Bot) generateStickerResponse(ctx context.Context, message Message) (str return "Hmm, that's interesting!", nil } + +func (b *Bot) clearChatHistory(ctx context.Context, chatID int64, currentUserID int64, targetUserID int64, businessConnectionID string, hardDelete bool) { + // If targetUserID is provided and different from currentUserID, check permissions + if targetUserID != 0 && targetUserID != currentUserID { + // Check if the current user is an admin or owner + if !b.isAdminOrOwner(currentUserID) { + InfoLogger.Printf("User %d attempted to clear history for user %d without permission", currentUserID, targetUserID) + if err := b.sendResponse(ctx, chatID, "Permission denied. Only admins and owners can clear other users' histories.", businessConnectionID); err != nil { + ErrorLogger.Printf("Error sending response: %v", err) + } + return + } + + // Check if the target user exists + var targetUser User + err := b.db.Where("telegram_id = ? AND bot_id = ?", targetUserID, b.botID).First(&targetUser).Error + if err != nil { + ErrorLogger.Printf("Error finding target user %d: %v", targetUserID, err) + if err := b.sendResponse(ctx, chatID, fmt.Sprintf("User with ID %d not found.", targetUserID), businessConnectionID); err != nil { + ErrorLogger.Printf("Error sending response: %v", err) + } + return + } + } else { + // If no targetUserID is provided, set it to currentUserID + targetUserID = currentUserID + } + + // Delete messages from the database + var err error + if hardDelete { + // Permanently delete messages + if targetUserID == currentUserID { + // Deleting own messages + err = b.db.Unscoped().Where("chat_id = ? AND bot_id = ? AND user_id = ?", chatID, b.botID, targetUserID).Delete(&Message{}).Error + InfoLogger.Printf("User %d permanently deleted their own chat history in chat %d", currentUserID, chatID) + } else { + // Deleting another user's messages + err = b.db.Unscoped().Where("chat_id = ? AND bot_id = ? AND user_id = ?", chatID, b.botID, targetUserID).Delete(&Message{}).Error + InfoLogger.Printf("Admin/owner %d permanently deleted chat history for user %d in chat %d", currentUserID, targetUserID, chatID) + } + } else { + // Soft delete messages + if targetUserID == currentUserID { + // Deleting own messages + err = b.db.Where("chat_id = ? AND bot_id = ? AND user_id = ?", chatID, b.botID, targetUserID).Delete(&Message{}).Error + InfoLogger.Printf("User %d soft deleted their own chat history in chat %d", currentUserID, chatID) + } else { + // Deleting another user's messages + err = b.db.Where("chat_id = ? AND bot_id = ? AND user_id = ?", chatID, b.botID, targetUserID).Delete(&Message{}).Error + InfoLogger.Printf("Admin/owner %d soft deleted chat history for user %d in chat %d", currentUserID, targetUserID, chatID) + } + } + + if err != nil { + ErrorLogger.Printf("Error clearing chat history: %v", err) + if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't clear the chat history.", businessConnectionID); err != nil { + ErrorLogger.Printf("Error sending response: %v", err) + } + return + } + + // Reset the chat memory if clearing own history + if targetUserID == currentUserID { + chatMemory := b.getOrCreateChatMemory(chatID) + chatMemory.Messages = []Message{} // Clear the messages + b.chatMemoriesMu.Lock() + b.chatMemories[chatID] = chatMemory + b.chatMemoriesMu.Unlock() + } + + // Send a confirmation message + var confirmationMessage string + if targetUserID == currentUserID { + confirmationMessage = "Your chat history has been cleared." + } else { + // Get the username of the target user if available + var targetUser User + err := b.db.Where("telegram_id = ? AND bot_id = ?", targetUserID, b.botID).First(&targetUser).Error + if err == nil && targetUser.Username != "" { + confirmationMessage = fmt.Sprintf("Chat history for user @%s (ID: %d) has been cleared.", targetUser.Username, targetUserID) + } else { + confirmationMessage = fmt.Sprintf("Chat history for user with ID %d has been cleared.", targetUserID) + } + } + + if err := b.sendResponse(ctx, chatID, confirmationMessage, businessConnectionID); err != nil { + ErrorLogger.Printf("Error sending response: %v", err) + } +} diff --git a/handlers_test.go b/handlers_test.go new file mode 100644 index 0000000..2fb6f52 --- /dev/null +++ b/handlers_test.go @@ -0,0 +1,609 @@ +package main + +import ( + "context" + "testing" + "time" + + "github.com/go-telegram/bot" + "github.com/go-telegram/bot/models" + "github.com/stretchr/testify/assert" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func TestHandleUpdate_NewChat(t *testing.T) { + // Setup + db := setupTestDB(t) + mockClock := &MockClock{ + currentTime: time.Now(), + } + + config := BotConfig{ + ID: "test_bot", + OwnerTelegramID: 123, // owner's ID + TelegramToken: "test_token", + MemorySize: 10, + MessagePerHour: 5, + MessagePerDay: 10, + TempBanDuration: "1h", + SystemPrompts: make(map[string]string), + Active: true, + } + + mockTgClient := &MockTelegramClient{} + + // Create bot model first + botModel := &BotModel{ + Identifier: config.ID, + Name: config.ID, + } + err := db.Create(botModel).Error + assert.NoError(t, err) + + // Create bot config + configModel := &ConfigModel{ + BotID: botModel.ID, + MemorySize: config.MemorySize, + MessagePerHour: config.MessagePerHour, + MessagePerDay: config.MessagePerDay, + TempBanDuration: config.TempBanDuration, + SystemPrompts: "{}", + TelegramToken: config.TelegramToken, + Active: config.Active, + } + err = db.Create(configModel).Error + assert.NoError(t, err) + + // Create bot instance + b, err := NewBot(db, config, mockClock, mockTgClient) + assert.NoError(t, err) + + testCases := []struct { + name string + userID int64 + isOwner bool + wantResp string + }{ + { + name: "Owner First Message", + userID: 123, // owner's ID + isOwner: true, + wantResp: "I'm sorry, I'm having trouble processing your request right now.", + }, + { + name: "Regular User First Message", + userID: 456, + isOwner: false, + wantResp: "I'm sorry, I'm having trouble processing your request right now.", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Setup mock response expectations for error case to test fallback messages + mockTgClient.SendMessageFunc = func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) { + assert.Equal(t, tc.userID, params.ChatID) + assert.Equal(t, tc.wantResp, params.Text) + return &models.Message{}, nil + } + + // Create update with new message + update := &models.Update{ + Message: &models.Message{ + Chat: models.Chat{ID: tc.userID}, + From: &models.User{ + ID: tc.userID, + Username: "testuser", + }, + Text: "Hello", + }, + } + + // Handle the update + b.handleUpdate(context.Background(), nil, update) + + // Verify message was stored + var storedMsg Message + err := db.Where("chat_id = ? AND user_id = ? AND text = ?", tc.userID, tc.userID, "Hello").First(&storedMsg).Error + assert.NoError(t, err) + + // Verify response was stored + var respMsg Message + err = db.Where("chat_id = ? AND is_user = ? AND text = ?", tc.userID, false, tc.wantResp).First(&respMsg).Error + assert.NoError(t, err) + }) + } +} + +func TestClearChatHistory(t *testing.T) { + // Setup + db := setupTestDB(t) + mockClock := &MockClock{ + currentTime: time.Now(), + } + + config := BotConfig{ + ID: "test_bot", + OwnerTelegramID: 123, // owner's ID + TelegramToken: "test_token", + MemorySize: 10, + MessagePerHour: 5, + MessagePerDay: 10, + TempBanDuration: "1h", + SystemPrompts: make(map[string]string), + Active: true, + } + + mockTgClient := &MockTelegramClient{} + + // Create bot model first + botModel := &BotModel{ + Identifier: config.ID, + Name: config.ID, + } + err := db.Create(botModel).Error + assert.NoError(t, err) + + // Create bot config + configModel := &ConfigModel{ + BotID: botModel.ID, + MemorySize: config.MemorySize, + MessagePerHour: config.MessagePerHour, + MessagePerDay: config.MessagePerDay, + TempBanDuration: config.TempBanDuration, + SystemPrompts: "{}", + TelegramToken: config.TelegramToken, + Active: config.Active, + } + err = db.Create(configModel).Error + assert.NoError(t, err) + + // Create bot instance + b, err := NewBot(db, config, mockClock, mockTgClient) + assert.NoError(t, err) + + // Create test users + ownerID := int64(123) + adminID := int64(456) + regularUserID := int64(789) + nonExistentUserID := int64(999) + chatID := int64(1000) + + // Create admin role + adminRole, err := b.getRoleByName("admin") + assert.NoError(t, err) + + // Create admin user + adminUser := User{ + BotID: b.botID, + TelegramID: adminID, + Username: "admin", + RoleID: adminRole.ID, + Role: adminRole, + IsOwner: false, + } + err = db.Create(&adminUser).Error + assert.NoError(t, err) + + // Create regular user + regularRole, err := b.getRoleByName("user") + assert.NoError(t, err) + regularUser := User{ + BotID: b.botID, + TelegramID: regularUserID, + Username: "regular", + RoleID: regularRole.ID, + Role: regularRole, + IsOwner: false, + } + err = db.Create(®ularUser).Error + assert.NoError(t, err) + + // Create test messages for each user + for _, userID := range []int64{ownerID, adminID, regularUserID} { + for i := 0; i < 5; i++ { + message := Message{ + BotID: b.botID, + ChatID: chatID, + UserID: userID, + Username: "test", + UserRole: "user", + Text: "Test message", + Timestamp: time.Now(), + IsUser: true, + } + err = db.Create(&message).Error + assert.NoError(t, err) + } + } + + // Test cases + testCases := []struct { + name string + currentUserID int64 + targetUserID int64 + hardDelete bool + expectedError bool + expectedCount int64 + expectedMsg string + businessConnID string + }{ + { + name: "Owner clears own history", + currentUserID: ownerID, + targetUserID: ownerID, + hardDelete: false, + expectedError: false, + expectedCount: 0, + expectedMsg: "Your chat history has been cleared.", + }, + { + name: "Admin clears own history", + currentUserID: adminID, + targetUserID: adminID, + hardDelete: false, + expectedError: false, + expectedCount: 0, + expectedMsg: "Your chat history has been cleared.", + }, + { + name: "Regular user clears own history", + currentUserID: regularUserID, + targetUserID: regularUserID, + hardDelete: false, + expectedError: false, + expectedCount: 0, + expectedMsg: "Your chat history has been cleared.", + }, + { + name: "Owner clears admin's history", + currentUserID: ownerID, + targetUserID: adminID, + hardDelete: false, + expectedError: false, + expectedCount: 0, + expectedMsg: "Chat history for user @admin (ID: 456) has been cleared.", + }, + { + name: "Admin clears regular user's history", + currentUserID: adminID, + targetUserID: regularUserID, + hardDelete: false, + expectedError: false, + expectedCount: 0, + expectedMsg: "Chat history for user @regular (ID: 789) has been cleared.", + }, + { + name: "Regular user attempts to clear admin's history", + currentUserID: regularUserID, + targetUserID: adminID, + hardDelete: false, + expectedError: true, + expectedCount: 5, // Messages should remain + expectedMsg: "Permission denied. Only admins and owners can clear other users' histories.", + }, + { + name: "Admin attempts to clear non-existent user's history", + currentUserID: adminID, + targetUserID: nonExistentUserID, + hardDelete: false, + expectedError: true, + expectedCount: 5, // Messages should remain for admin + expectedMsg: "User with ID 999 not found.", + }, + { + name: "Owner hard deletes regular user's history", + currentUserID: ownerID, + targetUserID: regularUserID, + hardDelete: true, + expectedError: false, + expectedCount: 0, + expectedMsg: "Chat history for user @regular (ID: 789) has been cleared.", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Reset messages for the test case + if tc.name != "Owner hard deletes regular user's history" { + // Delete all messages for the target user + err = db.Where("user_id = ?", tc.targetUserID).Delete(&Message{}).Error + assert.NoError(t, err) + + // Recreate messages for the target user + for i := 0; i < 5; i++ { + message := Message{ + BotID: b.botID, + ChatID: chatID, + UserID: tc.targetUserID, + Username: "test", + UserRole: "user", + Text: "Test message", + Timestamp: time.Now(), + IsUser: true, + } + err = db.Create(&message).Error + assert.NoError(t, err) + } + } + + // Setup mock response expectations + var sentMessage string + mockTgClient.SendMessageFunc = func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) { + sentMessage = params.Text + return &models.Message{}, nil + } + + // Call the clearChatHistory method + b.clearChatHistory(context.Background(), chatID, tc.currentUserID, tc.targetUserID, tc.businessConnID, tc.hardDelete) + + // Verify the response message + assert.Equal(t, tc.expectedMsg, sentMessage) + + // Count remaining messages for the target user + var count int64 + if tc.hardDelete { + db.Unscoped().Model(&Message{}).Where("user_id = ? AND chat_id = ?", tc.targetUserID, chatID).Count(&count) + } else { + db.Model(&Message{}).Where("user_id = ? AND chat_id = ?", tc.targetUserID, chatID).Count(&count) + } + assert.Equal(t, tc.expectedCount, count) + }) + } +} + +func TestStatsCommand(t *testing.T) { + // Setup + db := setupTestDB(t) + mockClock := &MockClock{ + currentTime: time.Now(), + } + + config := BotConfig{ + ID: "test_bot", + OwnerTelegramID: 123, // owner's ID + TelegramToken: "test_token", + MemorySize: 10, + MessagePerHour: 5, + MessagePerDay: 10, + TempBanDuration: "1h", + SystemPrompts: make(map[string]string), + Active: true, + } + + mockTgClient := &MockTelegramClient{} + + // Create bot model first + botModel := &BotModel{ + Identifier: config.ID, + Name: config.ID, + } + err := db.Create(botModel).Error + assert.NoError(t, err) + + // Create bot config + configModel := &ConfigModel{ + BotID: botModel.ID, + MemorySize: config.MemorySize, + MessagePerHour: config.MessagePerHour, + MessagePerDay: config.MessagePerDay, + TempBanDuration: config.TempBanDuration, + SystemPrompts: "{}", + TelegramToken: config.TelegramToken, + Active: config.Active, + } + err = db.Create(configModel).Error + assert.NoError(t, err) + + // Create bot instance + b, err := NewBot(db, config, mockClock, mockTgClient) + assert.NoError(t, err) + + // Create test users + ownerID := int64(123) + adminID := int64(456) + regularUserID := int64(789) + chatID := int64(1000) + + // Create admin role + adminRole, err := b.getRoleByName("admin") + assert.NoError(t, err) + + // Create admin user + adminUser := User{ + BotID: b.botID, + TelegramID: adminID, + Username: "admin", + RoleID: adminRole.ID, + Role: adminRole, + IsOwner: false, + } + err = db.Create(&adminUser).Error + assert.NoError(t, err) + + // Create regular user + regularRole, err := b.getRoleByName("user") + assert.NoError(t, err) + regularUser := User{ + BotID: b.botID, + TelegramID: regularUserID, + Username: "regular", + RoleID: regularRole.ID, + Role: regularRole, + IsOwner: false, + } + err = db.Create(®ularUser).Error + assert.NoError(t, err) + + // Create test messages for each user + for _, userID := range []int64{ownerID, adminID, regularUserID} { + for i := 0; i < 5; i++ { + // User message + userMessage := Message{ + BotID: b.botID, + ChatID: chatID, + UserID: userID, + Username: "test", + UserRole: "user", + Text: "Test message", + Timestamp: time.Now(), + IsUser: true, + } + err = db.Create(&userMessage).Error + assert.NoError(t, err) + + // Bot response + botMessage := Message{ + BotID: b.botID, + ChatID: chatID, + UserID: 0, + Username: "AI Assistant", + UserRole: "assistant", + Text: "Test response", + Timestamp: time.Now(), + IsUser: false, + } + err = db.Create(&botMessage).Error + assert.NoError(t, err) + } + } + + // Test cases + testCases := []struct { + name string + command string + currentUserID int64 + expectedError bool + expectedMsg string + businessConnID string + }{ + { + name: "Global stats", + command: "/stats", + currentUserID: regularUserID, + expectedError: false, + expectedMsg: "📊 Bot Statistics:", + }, + { + name: "User requests own stats", + command: "/stats user", + currentUserID: regularUserID, + expectedError: false, + expectedMsg: "👤 User Statistics for @regular (ID: 789):", + }, + { + name: "Admin requests another user's stats", + command: "/stats user 789", + currentUserID: adminID, + expectedError: false, + expectedMsg: "👤 User Statistics for @regular (ID: 789):", + }, + { + name: "Owner requests another user's stats", + command: "/stats user 456", + currentUserID: ownerID, + expectedError: false, + expectedMsg: "👤 User Statistics for @admin (ID: 456):", + }, + { + name: "Regular user attempts to request another user's stats", + command: "/stats user 456", + currentUserID: regularUserID, + expectedError: true, + expectedMsg: "Permission denied. Only admins and owners can view other users' statistics.", + }, + { + name: "User provides invalid user ID format", + command: "/stats user abc", + currentUserID: adminID, + expectedError: true, + expectedMsg: "Invalid user ID format. Usage: /stats user [user_id]", + }, + { + name: "User provides invalid command format", + command: "/stats invalid", + currentUserID: adminID, + expectedError: true, + expectedMsg: "Invalid command format. Usage: /stats or /stats user [user_id]", + }, + { + name: "User requests non-existent user's stats", + command: "/stats user 999", + currentUserID: adminID, + expectedError: true, + expectedMsg: "Sorry, I couldn't retrieve statistics for user ID 999.", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Setup mock response expectations + var sentMessage string + mockTgClient.SendMessageFunc = func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) { + sentMessage = params.Text + return &models.Message{}, nil + } + + // Create update with command + update := &models.Update{ + Message: &models.Message{ + Chat: models.Chat{ID: chatID}, + From: &models.User{ + ID: tc.currentUserID, + Username: getUsernameByID(tc.currentUserID), + }, + Text: tc.command, + Entities: []models.MessageEntity{ + { + Type: "bot_command", + Offset: 0, + Length: 6, // Length of "/stats" + }, + }, + }, + } + + // Handle the update + b.handleUpdate(context.Background(), nil, update) + + // Verify the response message contains the expected text + assert.Contains(t, sentMessage, tc.expectedMsg) + }) + } +} + +// Helper function to get username by ID for test +func getUsernameByID(id int64) string { + switch id { + case 123: + return "owner" + case 456: + return "admin" + case 789: + return "regular" + default: + return "unknown" + } +} + +func setupTestDB(t *testing.T) *gorm.DB { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("Failed to open test database: %v", err) + } + + // AutoMigrate the models + err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{}) + if err != nil { + t.Fatalf("Failed to migrate database schema: %v", err) + } + + // Create default roles + err = createDefaultRoles(db) + if err != nil { + t.Fatalf("Failed to create default roles: %v", err) + } + + return db +} diff --git a/logger.go b/logger.go old mode 100755 new mode 100644 diff --git a/main.go b/main.go old mode 100755 new mode 100644 index c70c224..62bdaf9 --- a/main.go +++ b/main.go @@ -47,19 +47,13 @@ func main() { return } - // Initialize TelegramClient with the bot's handleUpdate method - tgClient, err := initTelegramBot(cfg.TelegramToken, bot.handleUpdate) - if err != nil { - ErrorLogger.Printf("Error initializing Telegram client for bot %s: %v", cfg.ID, err) - return - } + // Start the bot in a separate goroutine + go bot.Start(ctx) - // Assign the TelegramClient to the bot - bot.tgBot = tgClient + // Keep the bot running until the context is cancelled + <-ctx.Done() - // Start the bot - InfoLogger.Printf("Starting bot %s...", cfg.ID) - bot.Start(ctx) + InfoLogger.Printf("Bot %s stopped", cfg.ID) }(config) } diff --git a/models.go b/models.go old mode 100755 new mode 100644 index 6ce7b1e..ef69b77 --- a/models.go +++ b/models.go @@ -29,16 +29,19 @@ type ConfigModel struct { type Message struct { gorm.Model - BotID uint - ChatID int64 - UserID int64 - Username string - UserRole string - Text string - StickerFileID string `json:"sticker_file_id,omitempty"` // New field to store Sticker File ID - StickerPNGFile string `json:"sticker_png_file,omitempty"` // Optionally store PNG file ID if needed - Timestamp time.Time + BotID uint `gorm:"index"` + ChatID int64 `gorm:"index"` + UserID int64 `gorm:"index"` + Username string `gorm:"index"` + UserRole string // Store the role as a string + Text string `gorm:"type:text"` + Timestamp time.Time `gorm:"index"` IsUser bool + StickerFileID string + StickerPNGFile string + StickerEmoji string // Store the emoji associated with the sticker + DeletedAt gorm.DeletedAt `gorm:"index"` // Add soft delete field + AnsweredOn *time.Time `gorm:"index"` // Tracks when a user message was answered (NULL for assistant messages and unanswered user messages) } type ChatMemory struct { diff --git a/rate_limiter.go b/rate_limiter.go old mode 100755 new mode 100644 diff --git a/rate_limiter_test.go b/rate_limiter_test.go old mode 100755 new mode 100644 diff --git a/telegram_client.go b/telegram_client.go old mode 100755 new mode 100644 diff --git a/telegram_client_mock.go b/telegram_client_mock.go old mode 100755 new mode 100644 index 61520f7..a11caae --- a/telegram_client_mock.go +++ b/telegram_client_mock.go @@ -6,13 +6,14 @@ import ( "github.com/go-telegram/bot" "github.com/go-telegram/bot/models" + "github.com/stretchr/testify/mock" ) // MockTelegramClient is a mock implementation of TelegramClient for testing. type MockTelegramClient struct { - // You can add fields to keep track of calls if needed. + mock.Mock SendMessageFunc func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) - StartFunc func(ctx context.Context) // Optional: track Start calls + StartFunc func(ctx context.Context) } // SendMessage mocks sending a message. @@ -20,16 +21,18 @@ func (m *MockTelegramClient) SendMessage(ctx context.Context, params *bot.SendMe if m.SendMessageFunc != nil { return m.SendMessageFunc(ctx, params) } - // Default behavior: return an empty message without error. - return &models.Message{}, nil + args := m.Called(ctx, params) + if msg, ok := args.Get(0).(*models.Message); ok { + return msg, args.Error(1) + } + return nil, args.Error(1) } // Start mocks starting the Telegram client. func (m *MockTelegramClient) Start(ctx context.Context) { if m.StartFunc != nil { m.StartFunc(ctx) + return } - // Default behavior: do nothing. + m.Called(ctx) } - -// Add other mocked methods if your Bot uses more TelegramClient methods. diff --git a/user_management_test.go b/user_management_test.go old mode 100755 new mode 100644 index b24588e..52cc5d5 --- a/user_management_test.go +++ b/user_management_test.go @@ -184,6 +184,104 @@ func TestPromoteUserToAdmin(t *testing.T) { } } +// TestGetOrCreateUser tests the getOrCreateUser method of the Bot. +// It verifies that a new user is created when one does not exist, +// and an existing user is returned when one does exist. +func TestGetOrCreateUser(t *testing.T) { + // Initialize loggers + initLoggers() + + // Initialize in-memory database for testing + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("Failed to open in-memory database: %v", err) + } + + // Migrate the schema + err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{}) + if err != nil { + t.Fatalf("Failed to migrate database schema: %v", err) + } + + // Create default roles + err = createDefaultRoles(db) + if err != nil { + t.Fatalf("Failed to create default roles: %v", err) + } + + // Create a mock clock starting at a fixed time + mockClock := &MockClock{ + currentTime: time.Date(2023, 10, 1, 0, 0, 0, 0, time.UTC), + } + + // Create a mock configuration + config := BotConfig{ + ID: "bot1", + MemorySize: 10, + MessagePerHour: 5, + MessagePerDay: 10, + TempBanDuration: "1m", + SystemPrompts: make(map[string]string), + TelegramToken: "YOUR_TELEGRAM_BOT_TOKEN", + OwnerTelegramID: 123456789, + } + + // Initialize MockTelegramClient + mockTGClient := &MockTelegramClient{ + SendMessageFunc: func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) { + chatID, ok := params.ChatID.(int64) + if !ok { + return nil, fmt.Errorf("ChatID is not of type int64") + } + // Simulate successful message sending + return &models.Message{ID: 1, Chat: models.Chat{ID: chatID}}, nil + }, + } + + // Create the bot with the mock Telegram client + bot, err := NewBot(db, config, mockClock, mockTGClient) + if err != nil { + t.Fatalf("Failed to create bot: %v", err) + } + + // Verify that the owner exists + var owner User + err = db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", config.OwnerTelegramID, bot.botID, true).First(&owner).Error + if err != nil { + t.Fatalf("Owner was not created: %v", err) + } + + // Attempt to create another owner for the same bot + _, err = bot.getOrCreateUser(222222222, "AnotherOwner", true) + if err == nil { + t.Fatalf("Expected error when creating a second owner, but got none") + } + + // Create a new user + newUser, err := bot.getOrCreateUser(987654321, "TestUser", false) + if err != nil { + t.Fatalf("Failed to create a new user: %v", err) + } + + // Verify that the new user was created + var userInDB User + err = db.Where("telegram_id = ?", newUser.TelegramID).First(&userInDB).Error + if err != nil { + t.Fatalf("New user was not created in the database: %v", err) + } + + // Get the existing user + existingUser, err := bot.getOrCreateUser(987654321, "TestUser", false) + if err != nil { + t.Fatalf("Failed to get existing user: %v", err) + } + + // Verify that the existing user is the same as the new user + if existingUser.ID != userInDB.ID { + t.Fatalf("Expected to get the existing user, but got a different user") + } +} + // To ensure thread safety and avoid race conditions during testing, // you can run the tests with the `-race` flag: // go test -race -v