From c8af457af17466489b2c847fcbe75f3a727eb2ef Mon Sep 17 00:00:00 2001 From: HugeFrog24 <62775760+HugeFrog24@users.noreply.github.com> Date: Sun, 20 Oct 2024 17:17:21 +0200 Subject: [PATCH] MVP md formatting doesnt work yet Started implementing owner feature Add .gitattributes to enforce LF line endings Temporary commit before merge Updated owner management Updated json and gitignore Proceed with role management Again, CI Fix some lint errors Implemented screening Per-bot API keys implemented Use getRoleByName func Fix unused imports Upgrade actions rm unused function Upgrade action Fix unaddressed errors --- .gitattributes | 13 +++ .github/workflows/go-ci.yaml | 54 +++++++++ .gitignore | 4 +- anthropic.go | 0 bot.go | 221 ++++++++++++++++++++++++++++++----- clock.go | 0 config.go | 9 ++ config/default.json | 3 + database.go | 11 ++ go.mod | 4 +- go.sum | 0 handlers.go | 86 +++++++++----- main.go | 26 ++--- models.go | 13 ++- rate_limiter.go | 0 rate_limiter_test.go | 104 +++++++++++++++++ telegram_client.go | 16 +++ telegram_client_mock.go | 35 ++++++ 18 files changed, 520 insertions(+), 79 deletions(-) create mode 100644 .gitattributes create mode 100644 .github/workflows/go-ci.yaml mode change 100755 => 100644 .gitignore mode change 100755 => 100644 anthropic.go mode change 100755 => 100644 bot.go mode change 100755 => 100644 clock.go mode change 100755 => 100644 config.go mode change 100755 => 100644 config/default.json mode change 100755 => 100644 database.go mode change 100755 => 100644 go.mod mode change 100755 => 100644 go.sum mode change 100755 => 100644 handlers.go mode change 100755 => 100644 main.go mode change 100755 => 100644 models.go mode change 100755 => 100644 rate_limiter.go mode change 100755 => 100644 rate_limiter_test.go create mode 100644 telegram_client.go create mode 100644 telegram_client_mock.go diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..aafbd82 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,13 @@ +# Enforce LF line endings for all files +* text eol=lf + +# Specific file types that should always have LF line endings +*.go text eol=lf +*.json text eol=lf +*.sh text eol=lf +*.md text eol=lf + +# Example: Binary files should not be modified +*.jpg binary +*.png binary +*.gif binary diff --git a/.github/workflows/go-ci.yaml b/.github/workflows/go-ci.yaml new file mode 100644 index 0000000..64413b1 --- /dev/null +++ b/.github/workflows/go-ci.yaml @@ -0,0 +1,54 @@ +name: CI + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + build: + runs-on: ubuntu-latest + + steps: + # Checkout the repository + - name: Checkout code + uses: actions/checkout@v4 + + # Set up Go environment + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.23' # Specify the Go version you are using + + # Cache Go modules + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + # Install Dependencies + - name: Install Dependencies + run: go mod tidy + + # Run Linters using golangci-lint + - name: Lint Code + uses: golangci/golangci-lint-action@v6 + with: + version: v1.60 # Specify the version of golangci-lint + args: --timeout 5m + + # Run Tests + - name: Run Tests + run: go test ./... -v + + # Security Analysis using gosec + - name: Security Scan + uses: securego/gosec@master + with: + args: ./... diff --git a/.gitignore b/.gitignore old mode 100755 new mode 100644 index 21238dc..ae342fe --- a/.gitignore +++ b/.gitignore @@ -4,8 +4,8 @@ vendor/ # Environment variables .env -# Log file -bot.log +# Any log files +*.log # Database file bot.db diff --git a/anthropic.go b/anthropic.go old mode 100755 new mode 100644 diff --git a/bot.go b/bot.go old mode 100755 new mode 100644 index 64bb25a..5fbe500 --- a/bot.go +++ b/bot.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "log" - "os" "strings" "sync" "time" @@ -17,7 +16,7 @@ import ( ) type Bot struct { - tgBot *bot.Bot + tgBot TelegramClient db *gorm.DB anthropicClient *anthropic.Client chatMemories map[int64]*ChatMemory @@ -30,7 +29,8 @@ type Bot struct { botID uint // Reference to BotModel.ID } -func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) { +// NewBot initializes and returns a new Bot instance. +func NewBot(db *gorm.DB, config BotConfig, clock Clock, tgClient TelegramClient) (*Bot, error) { // Retrieve or create Bot entry in the database var botEntry BotModel err := db.Where("identifier = ?", config.ID).First(&botEntry).Error @@ -43,7 +43,38 @@ func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) { return nil, err } - anthropicClient := anthropic.NewClient(os.Getenv("ANTHROPIC_API_KEY")) + // Ensure the owner exists in the Users table + var owner User + err = db.Where("telegram_id = ? AND bot_id = ?", config.OwnerTelegramID, botEntry.ID).First(&owner).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + // Assign the "owner" role + var ownerRole Role + err := db.Where("name = ?", "owner").First(&ownerRole).Error + if err != nil { + return nil, fmt.Errorf("owner role not found: %w", err) + } + + owner = User{ + BotID: botEntry.ID, + TelegramID: config.OwnerTelegramID, + Username: "", // Initialize as empty; will be updated upon interaction + RoleID: ownerRole.ID, + IsOwner: true, + } + + if err := db.Create(&owner).Error; err != nil { + // If unique constraint is violated, another owner already exists + if strings.Contains(err.Error(), "unique index") { + return nil, fmt.Errorf("an owner already exists for this bot") + } + return nil, fmt.Errorf("failed to create owner user: %w", err) + } + } else if err != nil { + return nil, err + } + + // Use the per-bot Anthropic API key + anthropicClient := anthropic.NewClient(config.AnthropicAPIKey) b := &Bot{ db: db, @@ -54,41 +85,80 @@ func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) { userLimiters: make(map[int64]*userLimiter), clock: clock, botID: botEntry.ID, // Ensure BotModel has ID field + tgBot: tgClient, } - tgBot, err := initTelegramBot(config.TelegramToken, b.handleUpdate) - if err != nil { - return nil, err - } - b.tgBot = tgBot - return b, nil } +// Start begins the bot's operation. func (b *Bot) Start(ctx context.Context) { b.tgBot.Start(ctx) } -func (b *Bot) getOrCreateUser(userID int64, username string) (User, error) { +func (b *Bot) getOrCreateUser(userID int64, username string, isOwner bool) (User, error) { var user User - err := b.db.Preload("Role").Where("telegram_id = ?", userID).First(&user).Error + err := b.db.Preload("Role").Where("telegram_id = ? AND bot_id = ?", userID, b.botID).First(&user).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - var defaultRole Role - if err := b.db.Where("name = ?", "user").First(&defaultRole).Error; err != nil { - return User{}, err + // Check if an owner already exists for this bot + if isOwner { + var existingOwner User + err := b.db.Where("bot_id = ? AND is_owner = ?", b.botID, true).First(&existingOwner).Error + if err == nil { + return User{}, fmt.Errorf("an owner already exists for this bot") + } else if !errors.Is(err, gorm.ErrRecordNotFound) { + return User{}, fmt.Errorf("failed to check existing owner: %w", err) + } } - user = User{TelegramID: userID, Username: username, RoleID: defaultRole.ID} + + var role Role + var roleName string + if isOwner { + roleName = "owner" + } else { + roleName = "user" // Assign "user" role to non-owner users + } + + err := b.db.Where("name = ?", roleName).First(&role).Error + if err != nil { + return User{}, fmt.Errorf("failed to get role: %w", err) + } + + user = User{ + BotID: b.botID, + TelegramID: userID, + Username: username, + RoleID: role.ID, + Role: role, + IsOwner: isOwner, + } + if err := b.db.Create(&user).Error; err != nil { - return User{}, err + // If unique constraint is violated, another owner already exists + if strings.Contains(err.Error(), "unique index") { + return User{}, fmt.Errorf("an owner already exists for this bot") + } + return User{}, fmt.Errorf("failed to create user: %w", err) } } else { return User{}, err } + } else { + if isOwner && !user.IsOwner { + return User{}, fmt.Errorf("cannot change existing user to owner") + } } + return user, nil } +func (b *Bot) getRoleByName(roleName string) (Role, error) { + var role Role + err := b.db.Where("name = ?", roleName).First(&role).Error + return role, err +} + func (b *Bot) createMessage(chatID, userID int64, username, userRole, text string, isUser bool) Message { message := Message{ ChatID: chatID, @@ -195,17 +265,28 @@ func (b *Bot) isAdminOrOwner(userID int64) bool { return user.Role.Name == "admin" || user.Role.Name == "owner" } -func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot *bot.Bot, update *models.Update)) (*bot.Bot, error) { +func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot *bot.Bot, update *models.Update)) (TelegramClient, error) { opts := []bot.Option{ bot.WithDefaultHandler(handleUpdate), } - return bot.New(token, opts...) + tgBot, err := bot.New(token, opts...) + if err != nil { + return nil, err + } + + return tgBot, nil } -// sendResponse sends a message to the specified chat. -// Returns an error if sending the message fails. func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, businessConnectionID string) error { + // Pass the outgoing message through the centralized screen for storage + _, err := b.screenOutgoingMessage(chatID, text, businessConnectionID) + if err != nil { + log.Printf("Error storing assistant message: %v", err) + return err + } + + // Prepare message parameters params := &bot.SendMessageParams{ ChatID: chatID, Text: text, @@ -215,7 +296,8 @@ func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, busin params.BusinessConnectionID = businessConnectionID } - _, err := b.tgBot.SendMessage(ctx, params) + // Send the message via Telegram client + _, err = b.tgBot.SendMessage(ctx, params) if err != nil { log.Printf("[%s] [ERROR] Error sending message to chat %d with BusinessConnectionID %s: %v", b.config.ID, chatID, businessConnectionID, err) @@ -225,16 +307,29 @@ func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, busin } // sendStats sends the bot statistics to the specified chat. -func (b *Bot) sendStats(ctx context.Context, chatID int64, businessConnectionID string) { +func (b *Bot) sendStats(ctx context.Context, chatID int64, userID int64, username string, businessConnectionID string) { totalUsers, totalMessages, err := b.getStats() if err != nil { fmt.Printf("Error fetching stats: %v\n", err) - b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve the stats at this time.", businessConnectionID) + if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve the stats at this time.", businessConnectionID); err != nil { + log.Printf("Error sending response: %v", err) + } return } - statsMessage := fmt.Sprintf("📊 **Bot Statistics:**\n\n- Total Users: %d\n- Total Messages: %d", totalUsers, totalMessages) - b.sendResponse(ctx, chatID, statsMessage, businessConnectionID) + // Do NOT manually escape hyphens here + statsMessage := fmt.Sprintf( + "📊 Bot Statistics:\n\n"+ + "- Total Users: %d\n"+ + "- Total Messages: %d", + totalUsers, + totalMessages, + ) + + // Send the response through the centralized screen + if err := b.sendResponse(ctx, chatID, statsMessage, businessConnectionID); err != nil { + log.Printf("Error sending stats message: %v", err) + } } // getStats retrieves the total number of users and messages from the database. @@ -271,3 +366,77 @@ func isEmoji(r rune) bool { (r >= 0x2600 && r <= 0x26FF) || // Misc symbols (r >= 0x2700 && r <= 0x27BF) // Dingbats } + +func (b *Bot) sendWhoAmI(ctx context.Context, chatID int64, userID int64, username string, businessConnectionID string) { + user, err := b.getOrCreateUser(userID, username, false) + if err != nil { + log.Printf("Error getting or creating user: %v", err) + if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve your information.", businessConnectionID); err != nil { + log.Printf("Error sending response: %v", err) + } + return + } + + role, err := b.getRoleByName(user.Role.Name) + if err != nil { + log.Printf("Error getting role by name: %v", err) + if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve your role information.", businessConnectionID); err != nil { + log.Printf("Error sending response: %v", err) + } + return + } + + whoAmIMessage := fmt.Sprintf( + "👤 Your Information:\n\n"+ + "- Username: %s\n"+ + "- Role: %s", + user.Username, + role.Name, + ) + + // Send the response through the centralized screen + if err := b.sendResponse(ctx, chatID, whoAmIMessage, businessConnectionID); err != nil { + log.Printf("Error sending /whoami message: %v", err) + } +} + +// screenIncomingMessage handles storing of incoming messages. +func (b *Bot) screenIncomingMessage(message *models.Message) (Message, error) { + userRole := string(anthropic.RoleUser) // Convert RoleUser to string + userMessage := b.createMessage(message.Chat.ID, message.From.ID, message.From.Username, userRole, message.Text, true) + + // If the message contains a sticker, include its details. + if message.Sticker != nil { + userMessage.StickerFileID = message.Sticker.FileID + if message.Sticker.Thumbnail != nil { + userMessage.StickerPNGFile = message.Sticker.Thumbnail.FileID + } + } + + // Store the message. + if err := b.storeMessage(userMessage); err != nil { + return Message{}, err + } + + // Update chat memory. + chatMemory := b.getOrCreateChatMemory(message.Chat.ID) + b.addMessageToChatMemory(chatMemory, userMessage) + + return userMessage, nil +} + +// screenOutgoingMessage handles storing of outgoing messages. +func (b *Bot) screenOutgoingMessage(chatID int64, response string, businessConnectionID string) (Message, error) { + assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false) + + // Store the message. + if err := b.storeMessage(assistantMessage); err != nil { + return Message{}, err + } + + // Update chat memory. + chatMemory := b.getOrCreateChatMemory(chatID) + b.addMessageToChatMemory(chatMemory, assistantMessage) + + return assistantMessage, nil +} diff --git a/clock.go b/clock.go old mode 100755 new mode 100644 diff --git a/config.go b/config.go old mode 100755 new mode 100644 index c02496d..650cde4 --- a/config.go +++ b/config.go @@ -18,6 +18,9 @@ type BotConfig struct { TempBanDuration string `json:"temp_ban_duration"` Model anthropic.Model `json:"model"` // Changed from string to anthropic.Model SystemPrompts map[string]string `json:"system_prompts"` + Active bool `json:"active"` // New field to control bot activity + OwnerTelegramID int64 `json:"owner_telegram_id"` + AnthropicAPIKey string `json:"anthropic_api_key"` // Add this line } // Custom unmarshalling to handle anthropic.Model @@ -56,6 +59,12 @@ func loadAllConfigs(dir string) ([]BotConfig, error) { return nil, fmt.Errorf("failed to load config %s: %w", configPath, err) } + // Skip inactive bots + if !config.Active { + fmt.Printf("Skipping inactive bot: %s\n", config.ID) + continue + } + // Validate that ID is present if config.ID == "" { return nil, fmt.Errorf("config %s is missing 'id' field", configPath) diff --git a/config/default.json b/config/default.json old mode 100755 new mode 100644 index b9d0708..f7ddd8b --- a/config/default.json +++ b/config/default.json @@ -1,6 +1,9 @@ { "id": "default_bot", + "active": false, "telegram_token": "YOUR_TELEGRAM_BOT_TOKEN", + "owner_telegram_id": 111111111, + "anthropic_api_key": "YOUR_SPECIFIC_ANTHROPIC_API_KEY", "memory_size": 10, "messages_per_hour": 20, "messages_per_day": 100, diff --git a/database.go b/database.go old mode 100755 new mode 100644 index da7a29a..6c7a8f6 --- a/database.go +++ b/database.go @@ -27,11 +27,22 @@ func initDB() (*gorm.DB, error) { return nil, fmt.Errorf("failed to connect to database: %w", err) } + // AutoMigrate the models err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{}) if err != nil { return nil, fmt.Errorf("failed to migrate database schema: %w", err) } + // Enforce unique owner per bot using raw SQL + // Note: SQLite doesn't support partial indexes, but we can simulate it by making a unique index on (BotID, IsOwner) + // and ensuring that IsOwner can only be true for one user per BotID. + // This approach allows multiple users with IsOwner=false for the same BotID, + // but only one user can have IsOwner=true per BotID. + err = db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_bot_owner ON users (bot_id, is_owner) WHERE is_owner = 1;`).Error + if err != nil { + return nil, fmt.Errorf("failed to create unique index for bot owners: %w", err) + } + err = createDefaultRoles(db) if err != nil { return nil, err diff --git a/go.mod b/go.mod old mode 100755 new mode 100644 index 8ee2d93..5ff897e --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ -module github.com/HugeFrog24/thatsky-telegram-bot +module github.com/HugeFrog24/go-telegram-bot -go 1.23.2 +go 1.23 require ( github.com/go-telegram/bot v1.9.0 diff --git a/go.sum b/go.sum old mode 100755 new mode 100644 diff --git a/handlers.go b/handlers.go old mode 100755 new mode 100644 index 52b1ad9..5611ae5 --- a/handlers.go +++ b/handlers.go @@ -22,9 +22,6 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U return } - chatID := message.Chat.ID - userID := message.From.ID - // Extract businessConnectionID if available var businessConnectionID string if update.BusinessConnection != nil { @@ -33,6 +30,18 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U businessConnectionID = message.BusinessConnectionID } + chatID := message.Chat.ID + userID := message.From.ID + username := message.From.Username + text := message.Text + + // Pass the incoming message through the centralized screen for storage + _, err := b.screenIncomingMessage(message) + if err != nil { + log.Printf("Error storing user message: %v", err) + return + } + // Check if the message is a command if message.Entities != nil { for _, entity := range message.Entities { @@ -40,7 +49,10 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U command := strings.TrimSpace(message.Text[entity.Offset : entity.Offset+entity.Length]) switch command { case "/stats": - b.sendStats(ctx, chatID, businessConnectionID) + b.sendStats(ctx, chatID, userID, username, businessConnectionID) + return + case "/whoami": + b.sendWhoAmI(ctx, chatID, userID, username, businessConnectionID) return } } @@ -53,59 +65,71 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U return } - // Existing rate limit and message handling + // Rate limit check if !b.checkRateLimits(userID) { b.sendRateLimitExceededMessage(ctx, chatID, businessConnectionID) return } - username := message.From.Username - text := message.Text - // Proceed only if the message contains text if text == "" { - // Optionally, handle other message types or ignore log.Printf("Received a non-text message from user %d in chat %d", userID, chatID) return } - user, err := b.getOrCreateUser(userID, username) + // Determine if the user is the owner + var isOwner bool + err = b.db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", userID, b.botID, true).First(&User{}).Error + if err == nil { + isOwner = true + } + + user, err := b.getOrCreateUser(userID, username, isOwner) if err != nil { log.Printf("Error getting or creating user: %v", err) return } - userMessage := b.createMessage(chatID, userID, username, user.Role.Name, text, true) - userMessage.UserRole = string(anthropic.RoleUser) // Convert to string - b.storeMessage(userMessage) + // Update the username if it's empty or has changed + if user.Username != username { + user.Username = username + if err := b.db.Save(&user).Error; err != nil { + log.Printf("Error updating user username: %v", err) + } + } + // Determine if the text contains only emojis + isEmojiOnly := isOnlyEmojis(text) + + // Prepare context messages for Anthropic chatMemory := b.getOrCreateChatMemory(chatID) - b.addMessageToChatMemory(chatMemory, userMessage) - + b.addMessageToChatMemory(chatMemory, b.createMessage(chatID, userID, username, user.Role.Name, text, true)) contextMessages := b.prepareContextMessages(chatMemory) - isEmojiOnly := isOnlyEmojis(text) // Ensure you have this variable defined - response, err := b.getAnthropicResponse(ctx, contextMessages, b.isNewChat(chatID), b.isAdminOrOwner(userID), isEmojiOnly) + // Get response from Anthropic + response, err := b.getAnthropicResponse(ctx, contextMessages, b.isNewChat(chatID), isOwner, isEmojiOnly) if err != nil { log.Printf("Error getting Anthropic response: %v", err) response = "I'm sorry, I'm having trouble processing your request right now." } - b.sendResponse(ctx, chatID, response, businessConnectionID) - - assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false) - b.storeMessage(assistantMessage) - b.addMessageToChatMemory(chatMemory, assistantMessage) + // Send the response through the centralized screen + if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil { + log.Printf("Error sending response: %v", err) + return + } } func (b *Bot) sendRateLimitExceededMessage(ctx context.Context, chatID int64, businessConnectionID string) { - b.sendResponse(ctx, chatID, "Rate limit exceeded. Please try again later.", businessConnectionID) + if err := b.sendResponse(ctx, chatID, "Rate limit exceeded. Please try again later.", businessConnectionID); err != nil { + log.Printf("Error sending rate limit exceeded message: %v", err) + } } func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, message *models.Message, businessConnectionID string) { username := message.From.Username - // Create and store the sticker message + // Create the user message (without storing it manually) userMessage := b.createMessage(chatID, userID, username, "user", "Sent a sticker.", true) userMessage.StickerFileID = message.Sticker.FileID @@ -114,9 +138,7 @@ func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, me userMessage.StickerPNGFile = message.Sticker.Thumbnail.FileID } - b.storeMessage(userMessage) - - // Update chat memory + // Update chat memory with the user message chatMemory := b.getOrCreateChatMemory(chatID) b.addMessageToChatMemory(chatMemory, userMessage) @@ -134,11 +156,11 @@ func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, me } } - b.sendResponse(ctx, chatID, response, businessConnectionID) - - assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false) - b.storeMessage(assistantMessage) - b.addMessageToChatMemory(chatMemory, assistantMessage) + // Send the response through the centralized screen + if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil { + log.Printf("Error sending response: %v", err) + return + } } func (b *Bot) generateStickerResponse(ctx context.Context, message Message) (string, error) { diff --git a/main.go b/main.go old mode 100755 new mode 100644 index c93a754..f3eff8b --- a/main.go +++ b/main.go @@ -24,9 +24,6 @@ func main() { log.Printf("Error loading .env file: %v", err) } - // Check for required environment variables - checkRequiredEnvVars() - // Initialize database db, err := initDB() if err != nil { @@ -52,14 +49,24 @@ func main() { go func(cfg BotConfig) { defer wg.Done() - // Create Bot instance with RealClock + // Create Bot instance without TelegramClient initially realClock := RealClock{} - bot, err := NewBot(db, cfg, realClock) + bot, err := NewBot(db, cfg, realClock, nil) if err != nil { log.Printf("Error creating bot %s: %v", cfg.ID, err) return } + // Initialize TelegramClient with the bot's handleUpdate method + tgClient, err := initTelegramBot(cfg.TelegramToken, bot.handleUpdate) + if err != nil { + log.Printf("Error initializing Telegram client for bot %s: %v", cfg.ID, err) + return + } + + // Assign the TelegramClient to the bot + bot.tgBot = tgClient + // Start the bot log.Printf("Starting bot %s...", cfg.ID) bot.Start(ctx) @@ -79,12 +86,3 @@ func initLogger() (*os.File, error) { log.SetOutput(mw) return logFile, nil } - -func checkRequiredEnvVars() { - requiredEnvVars := []string{"ANTHROPIC_API_KEY"} - for _, envVar := range requiredEnvVars { - if os.Getenv(envVar) == "" { - log.Fatalf("%s environment variable is not set", envVar) - } - } -} diff --git a/models.go b/models.go old mode 100755 new mode 100644 index cec57fe..6ce7b1e --- a/models.go +++ b/models.go @@ -11,7 +11,7 @@ type BotModel struct { Identifier string `gorm:"uniqueIndex"` // Renamed from ID to Identifier Name string Configs []ConfigModel `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"` - Users []User `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"` // Added foreign key + Users []User `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"` // Associated users Messages []Message `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"` } @@ -24,6 +24,7 @@ type ConfigModel struct { TempBanDuration string `json:"temp_ban_duration"` SystemPrompts string `json:"system_prompts"` // Consider JSON string or separate table TelegramToken string `json:"telegram_token"` + Active bool `json:"active"` } type Message struct { @@ -53,9 +54,15 @@ type Role struct { type User struct { gorm.Model - BotID uint `gorm:"index"` // Added foreign key to BotModel - TelegramID int64 `gorm:"uniqueIndex"` // Consider composite unique index if TelegramID is unique per Bot + BotID uint `gorm:"index"` // Foreign key to BotModel + TelegramID int64 `gorm:"uniqueIndex;not null"` // Unique per user Username string RoleID uint Role Role `gorm:"foreignKey:RoleID"` + IsOwner bool `gorm:"default:false"` // Indicates if the user is the owner +} + +// Compound unique index to ensure only one owner per bot +func (User) TableName() string { + return "users" } diff --git a/rate_limiter.go b/rate_limiter.go old mode 100755 new mode 100644 diff --git a/rate_limiter_test.go b/rate_limiter_test.go old mode 100755 new mode 100644 index 6e4f54a..7b85f38 --- a/rate_limiter_test.go +++ b/rate_limiter_test.go @@ -1,8 +1,15 @@ package main import ( + "context" + "fmt" "testing" "time" + + "github.com/go-telegram/bot" + "github.com/go-telegram/bot/models" + "gorm.io/driver/sqlite" + "gorm.io/gorm" ) // TestCheckRateLimits tests the checkRateLimits method of the Bot. @@ -22,6 +29,7 @@ func TestCheckRateLimits(t *testing.T) { TempBanDuration: "1m", // Temporary ban duration of 1 minute for testing SystemPrompts: make(map[string]string), TelegramToken: "YOUR_TELEGRAM_BOT_TOKEN", + OwnerTelegramID: 123456789, } // Initialize the Bot with mock data and MockClock @@ -79,6 +87,102 @@ func TestCheckRateLimits(t *testing.T) { } } +func TestOwnerAssignment(t *testing.T) { + // Initialize in-memory database for testing + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("Failed to open in-memory database: %v", err) + } + + // Migrate the schema + err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{}) + if err != nil { + t.Fatalf("Failed to migrate database schema: %v", err) + } + + // Create default roles + err = createDefaultRoles(db) + if err != nil { + t.Fatalf("Failed to create default roles: %v", err) + } + + // Create a bot configuration + config := BotConfig{ + ID: "test_bot", + TelegramToken: "TEST_TELEGRAM_TOKEN", + MemorySize: 10, + MessagePerHour: 5, + MessagePerDay: 10, + TempBanDuration: "1m", + SystemPrompts: make(map[string]string), + Active: true, + OwnerTelegramID: 111111111, + } + + // Initialize MockClock + mockClock := &MockClock{ + currentTime: time.Now(), + } + + // Initialize MockTelegramClient + mockTGClient := &MockTelegramClient{ + SendMessageFunc: func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) { + chatID, ok := params.ChatID.(int64) + if !ok { + return nil, fmt.Errorf("ChatID is not of type int64") + } + // Simulate successful message sending + return &models.Message{ID: 1, Chat: models.Chat{ID: chatID}}, nil + }, + } + + // Create the bot with the mock Telegram client + bot, err := NewBot(db, config, mockClock, mockTGClient) + if err != nil { + t.Fatalf("Failed to create bot: %v", err) + } + + // Verify that the owner exists + var owner User + err = db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", config.OwnerTelegramID, bot.botID, true).First(&owner).Error + if err != nil { + t.Fatalf("Owner was not created: %v", err) + } + + // Attempt to create another owner for the same bot + _, err = bot.getOrCreateUser(222222222, "AnotherOwner", true) + if err == nil { + t.Fatalf("Expected error when creating a second owner, but got none") + } + + // Verify that the error message is appropriate + expectedErrorMsg := "an owner already exists for this bot" + if err.Error() != expectedErrorMsg { + t.Fatalf("Unexpected error message: %v", err) + } + + // Assign admin role to a new user + adminUser, err := bot.getOrCreateUser(333333333, "AdminUser", false) + if err != nil { + t.Fatalf("Failed to create admin user: %v", err) + } + + if adminUser.Role.Name != "admin" { + t.Fatalf("Expected role 'admin', got '%s'", adminUser.Role.Name) + } + + // Attempt to change an existing user to owner + _, err = bot.getOrCreateUser(333333333, "AdminUser", true) + if err == nil { + t.Fatalf("Expected error when changing existing user to owner, but got none") + } + + expectedErrorMsg = "cannot change existing user to owner" + if err.Error() != expectedErrorMsg { + t.Fatalf("Unexpected error message: %v", err) + } +} + // To ensure thread safety and avoid race conditions during testing, // you can run the tests with the `-race` flag: // go test -race -v diff --git a/telegram_client.go b/telegram_client.go new file mode 100644 index 0000000..10dc9bf --- /dev/null +++ b/telegram_client.go @@ -0,0 +1,16 @@ +// telegram_client.go +package main + +import ( + "context" + + "github.com/go-telegram/bot" + "github.com/go-telegram/bot/models" +) + +// TelegramClient defines the methods required from the Telegram bot. +type TelegramClient interface { + SendMessage(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) + Start(ctx context.Context) + // Add other methods if needed. +} diff --git a/telegram_client_mock.go b/telegram_client_mock.go new file mode 100644 index 0000000..61520f7 --- /dev/null +++ b/telegram_client_mock.go @@ -0,0 +1,35 @@ +// telegram_client_mock.go +package main + +import ( + "context" + + "github.com/go-telegram/bot" + "github.com/go-telegram/bot/models" +) + +// MockTelegramClient is a mock implementation of TelegramClient for testing. +type MockTelegramClient struct { + // You can add fields to keep track of calls if needed. + SendMessageFunc func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) + StartFunc func(ctx context.Context) // Optional: track Start calls +} + +// SendMessage mocks sending a message. +func (m *MockTelegramClient) SendMessage(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) { + if m.SendMessageFunc != nil { + return m.SendMessageFunc(ctx, params) + } + // Default behavior: return an empty message without error. + return &models.Message{}, nil +} + +// Start mocks starting the Telegram client. +func (m *MockTelegramClient) Start(ctx context.Context) { + if m.StartFunc != nil { + m.StartFunc(ctx) + } + // Default behavior: do nothing. +} + +// Add other mocked methods if your Bot uses more TelegramClient methods.