diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..aafbd82 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,13 @@ +# Enforce LF line endings for all files +* text eol=lf + +# Specific file types that should always have LF line endings +*.go text eol=lf +*.json text eol=lf +*.sh text eol=lf +*.md text eol=lf + +# Example: Binary files should not be modified +*.jpg binary +*.png binary +*.gif binary diff --git a/.github/workflows/go-ci.yaml b/.github/workflows/go-ci.yaml new file mode 100644 index 0000000..64413b1 --- /dev/null +++ b/.github/workflows/go-ci.yaml @@ -0,0 +1,54 @@ +name: CI + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + build: + runs-on: ubuntu-latest + + steps: + # Checkout the repository + - name: Checkout code + uses: actions/checkout@v4 + + # Set up Go environment + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.23' # Specify the Go version you are using + + # Cache Go modules + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + # Install Dependencies + - name: Install Dependencies + run: go mod tidy + + # Run Linters using golangci-lint + - name: Lint Code + uses: golangci/golangci-lint-action@v6 + with: + version: v1.60 # Specify the version of golangci-lint + args: --timeout 5m + + # Run Tests + - name: Run Tests + run: go test ./... -v + + # Security Analysis using gosec + - name: Security Scan + uses: securego/gosec@master + with: + args: ./... diff --git a/.gitignore b/.gitignore old mode 100755 new mode 100644 index 21238dc..ae342fe --- a/.gitignore +++ b/.gitignore @@ -4,8 +4,8 @@ vendor/ # Environment variables .env -# Log file -bot.log +# Any log files +*.log # Database file bot.db diff --git a/anthropic.go b/anthropic.go old mode 100755 new mode 100644 diff --git a/bot.go b/bot.go old mode 100755 new mode 100644 index 64bb25a..5fbe500 --- a/bot.go +++ b/bot.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "log" - "os" "strings" "sync" "time" @@ -17,7 +16,7 @@ import ( ) type Bot struct { - tgBot *bot.Bot + tgBot TelegramClient db *gorm.DB anthropicClient *anthropic.Client chatMemories map[int64]*ChatMemory @@ -30,7 +29,8 @@ type Bot struct { botID uint // Reference to BotModel.ID } -func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) { +// NewBot initializes and returns a new Bot instance. +func NewBot(db *gorm.DB, config BotConfig, clock Clock, tgClient TelegramClient) (*Bot, error) { // Retrieve or create Bot entry in the database var botEntry BotModel err := db.Where("identifier = ?", config.ID).First(&botEntry).Error @@ -43,7 +43,38 @@ func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) { return nil, err } - anthropicClient := anthropic.NewClient(os.Getenv("ANTHROPIC_API_KEY")) + // Ensure the owner exists in the Users table + var owner User + err = db.Where("telegram_id = ? AND bot_id = ?", config.OwnerTelegramID, botEntry.ID).First(&owner).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + // Assign the "owner" role + var ownerRole Role + err := db.Where("name = ?", "owner").First(&ownerRole).Error + if err != nil { + return nil, fmt.Errorf("owner role not found: %w", err) + } + + owner = User{ + BotID: botEntry.ID, + TelegramID: config.OwnerTelegramID, + Username: "", // Initialize as empty; will be updated upon interaction + RoleID: ownerRole.ID, + IsOwner: true, + } + + if err := db.Create(&owner).Error; err != nil { + // If unique constraint is violated, another owner already exists + if strings.Contains(err.Error(), "unique index") { + return nil, fmt.Errorf("an owner already exists for this bot") + } + return nil, fmt.Errorf("failed to create owner user: %w", err) + } + } else if err != nil { + return nil, err + } + + // Use the per-bot Anthropic API key + anthropicClient := anthropic.NewClient(config.AnthropicAPIKey) b := &Bot{ db: db, @@ -54,41 +85,80 @@ func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) { userLimiters: make(map[int64]*userLimiter), clock: clock, botID: botEntry.ID, // Ensure BotModel has ID field + tgBot: tgClient, } - tgBot, err := initTelegramBot(config.TelegramToken, b.handleUpdate) - if err != nil { - return nil, err - } - b.tgBot = tgBot - return b, nil } +// Start begins the bot's operation. func (b *Bot) Start(ctx context.Context) { b.tgBot.Start(ctx) } -func (b *Bot) getOrCreateUser(userID int64, username string) (User, error) { +func (b *Bot) getOrCreateUser(userID int64, username string, isOwner bool) (User, error) { var user User - err := b.db.Preload("Role").Where("telegram_id = ?", userID).First(&user).Error + err := b.db.Preload("Role").Where("telegram_id = ? AND bot_id = ?", userID, b.botID).First(&user).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - var defaultRole Role - if err := b.db.Where("name = ?", "user").First(&defaultRole).Error; err != nil { - return User{}, err + // Check if an owner already exists for this bot + if isOwner { + var existingOwner User + err := b.db.Where("bot_id = ? AND is_owner = ?", b.botID, true).First(&existingOwner).Error + if err == nil { + return User{}, fmt.Errorf("an owner already exists for this bot") + } else if !errors.Is(err, gorm.ErrRecordNotFound) { + return User{}, fmt.Errorf("failed to check existing owner: %w", err) + } } - user = User{TelegramID: userID, Username: username, RoleID: defaultRole.ID} + + var role Role + var roleName string + if isOwner { + roleName = "owner" + } else { + roleName = "user" // Assign "user" role to non-owner users + } + + err := b.db.Where("name = ?", roleName).First(&role).Error + if err != nil { + return User{}, fmt.Errorf("failed to get role: %w", err) + } + + user = User{ + BotID: b.botID, + TelegramID: userID, + Username: username, + RoleID: role.ID, + Role: role, + IsOwner: isOwner, + } + if err := b.db.Create(&user).Error; err != nil { - return User{}, err + // If unique constraint is violated, another owner already exists + if strings.Contains(err.Error(), "unique index") { + return User{}, fmt.Errorf("an owner already exists for this bot") + } + return User{}, fmt.Errorf("failed to create user: %w", err) } } else { return User{}, err } + } else { + if isOwner && !user.IsOwner { + return User{}, fmt.Errorf("cannot change existing user to owner") + } } + return user, nil } +func (b *Bot) getRoleByName(roleName string) (Role, error) { + var role Role + err := b.db.Where("name = ?", roleName).First(&role).Error + return role, err +} + func (b *Bot) createMessage(chatID, userID int64, username, userRole, text string, isUser bool) Message { message := Message{ ChatID: chatID, @@ -195,17 +265,28 @@ func (b *Bot) isAdminOrOwner(userID int64) bool { return user.Role.Name == "admin" || user.Role.Name == "owner" } -func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot *bot.Bot, update *models.Update)) (*bot.Bot, error) { +func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot *bot.Bot, update *models.Update)) (TelegramClient, error) { opts := []bot.Option{ bot.WithDefaultHandler(handleUpdate), } - return bot.New(token, opts...) + tgBot, err := bot.New(token, opts...) + if err != nil { + return nil, err + } + + return tgBot, nil } -// sendResponse sends a message to the specified chat. -// Returns an error if sending the message fails. func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, businessConnectionID string) error { + // Pass the outgoing message through the centralized screen for storage + _, err := b.screenOutgoingMessage(chatID, text, businessConnectionID) + if err != nil { + log.Printf("Error storing assistant message: %v", err) + return err + } + + // Prepare message parameters params := &bot.SendMessageParams{ ChatID: chatID, Text: text, @@ -215,7 +296,8 @@ func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, busin params.BusinessConnectionID = businessConnectionID } - _, err := b.tgBot.SendMessage(ctx, params) + // Send the message via Telegram client + _, err = b.tgBot.SendMessage(ctx, params) if err != nil { log.Printf("[%s] [ERROR] Error sending message to chat %d with BusinessConnectionID %s: %v", b.config.ID, chatID, businessConnectionID, err) @@ -225,16 +307,29 @@ func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, busin } // sendStats sends the bot statistics to the specified chat. -func (b *Bot) sendStats(ctx context.Context, chatID int64, businessConnectionID string) { +func (b *Bot) sendStats(ctx context.Context, chatID int64, userID int64, username string, businessConnectionID string) { totalUsers, totalMessages, err := b.getStats() if err != nil { fmt.Printf("Error fetching stats: %v\n", err) - b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve the stats at this time.", businessConnectionID) + if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve the stats at this time.", businessConnectionID); err != nil { + log.Printf("Error sending response: %v", err) + } return } - statsMessage := fmt.Sprintf("📊 **Bot Statistics:**\n\n- Total Users: %d\n- Total Messages: %d", totalUsers, totalMessages) - b.sendResponse(ctx, chatID, statsMessage, businessConnectionID) + // Do NOT manually escape hyphens here + statsMessage := fmt.Sprintf( + "📊 Bot Statistics:\n\n"+ + "- Total Users: %d\n"+ + "- Total Messages: %d", + totalUsers, + totalMessages, + ) + + // Send the response through the centralized screen + if err := b.sendResponse(ctx, chatID, statsMessage, businessConnectionID); err != nil { + log.Printf("Error sending stats message: %v", err) + } } // getStats retrieves the total number of users and messages from the database. @@ -271,3 +366,77 @@ func isEmoji(r rune) bool { (r >= 0x2600 && r <= 0x26FF) || // Misc symbols (r >= 0x2700 && r <= 0x27BF) // Dingbats } + +func (b *Bot) sendWhoAmI(ctx context.Context, chatID int64, userID int64, username string, businessConnectionID string) { + user, err := b.getOrCreateUser(userID, username, false) + if err != nil { + log.Printf("Error getting or creating user: %v", err) + if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve your information.", businessConnectionID); err != nil { + log.Printf("Error sending response: %v", err) + } + return + } + + role, err := b.getRoleByName(user.Role.Name) + if err != nil { + log.Printf("Error getting role by name: %v", err) + if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve your role information.", businessConnectionID); err != nil { + log.Printf("Error sending response: %v", err) + } + return + } + + whoAmIMessage := fmt.Sprintf( + "👤 Your Information:\n\n"+ + "- Username: %s\n"+ + "- Role: %s", + user.Username, + role.Name, + ) + + // Send the response through the centralized screen + if err := b.sendResponse(ctx, chatID, whoAmIMessage, businessConnectionID); err != nil { + log.Printf("Error sending /whoami message: %v", err) + } +} + +// screenIncomingMessage handles storing of incoming messages. +func (b *Bot) screenIncomingMessage(message *models.Message) (Message, error) { + userRole := string(anthropic.RoleUser) // Convert RoleUser to string + userMessage := b.createMessage(message.Chat.ID, message.From.ID, message.From.Username, userRole, message.Text, true) + + // If the message contains a sticker, include its details. + if message.Sticker != nil { + userMessage.StickerFileID = message.Sticker.FileID + if message.Sticker.Thumbnail != nil { + userMessage.StickerPNGFile = message.Sticker.Thumbnail.FileID + } + } + + // Store the message. + if err := b.storeMessage(userMessage); err != nil { + return Message{}, err + } + + // Update chat memory. + chatMemory := b.getOrCreateChatMemory(message.Chat.ID) + b.addMessageToChatMemory(chatMemory, userMessage) + + return userMessage, nil +} + +// screenOutgoingMessage handles storing of outgoing messages. +func (b *Bot) screenOutgoingMessage(chatID int64, response string, businessConnectionID string) (Message, error) { + assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false) + + // Store the message. + if err := b.storeMessage(assistantMessage); err != nil { + return Message{}, err + } + + // Update chat memory. + chatMemory := b.getOrCreateChatMemory(chatID) + b.addMessageToChatMemory(chatMemory, assistantMessage) + + return assistantMessage, nil +} diff --git a/clock.go b/clock.go old mode 100755 new mode 100644 diff --git a/config.go b/config.go old mode 100755 new mode 100644 index c02496d..650cde4 --- a/config.go +++ b/config.go @@ -18,6 +18,9 @@ type BotConfig struct { TempBanDuration string `json:"temp_ban_duration"` Model anthropic.Model `json:"model"` // Changed from string to anthropic.Model SystemPrompts map[string]string `json:"system_prompts"` + Active bool `json:"active"` // New field to control bot activity + OwnerTelegramID int64 `json:"owner_telegram_id"` + AnthropicAPIKey string `json:"anthropic_api_key"` // Add this line } // Custom unmarshalling to handle anthropic.Model @@ -56,6 +59,12 @@ func loadAllConfigs(dir string) ([]BotConfig, error) { return nil, fmt.Errorf("failed to load config %s: %w", configPath, err) } + // Skip inactive bots + if !config.Active { + fmt.Printf("Skipping inactive bot: %s\n", config.ID) + continue + } + // Validate that ID is present if config.ID == "" { return nil, fmt.Errorf("config %s is missing 'id' field", configPath) diff --git a/config/default.json b/config/default.json old mode 100755 new mode 100644 index b9d0708..f7ddd8b --- a/config/default.json +++ b/config/default.json @@ -1,6 +1,9 @@ { "id": "default_bot", + "active": false, "telegram_token": "YOUR_TELEGRAM_BOT_TOKEN", + "owner_telegram_id": 111111111, + "anthropic_api_key": "YOUR_SPECIFIC_ANTHROPIC_API_KEY", "memory_size": 10, "messages_per_hour": 20, "messages_per_day": 100, diff --git a/database.go b/database.go old mode 100755 new mode 100644 index da7a29a..6c7a8f6 --- a/database.go +++ b/database.go @@ -27,11 +27,22 @@ func initDB() (*gorm.DB, error) { return nil, fmt.Errorf("failed to connect to database: %w", err) } + // AutoMigrate the models err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{}) if err != nil { return nil, fmt.Errorf("failed to migrate database schema: %w", err) } + // Enforce unique owner per bot using raw SQL + // Note: SQLite doesn't support partial indexes, but we can simulate it by making a unique index on (BotID, IsOwner) + // and ensuring that IsOwner can only be true for one user per BotID. + // This approach allows multiple users with IsOwner=false for the same BotID, + // but only one user can have IsOwner=true per BotID. + err = db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_bot_owner ON users (bot_id, is_owner) WHERE is_owner = 1;`).Error + if err != nil { + return nil, fmt.Errorf("failed to create unique index for bot owners: %w", err) + } + err = createDefaultRoles(db) if err != nil { return nil, err diff --git a/go.mod b/go.mod old mode 100755 new mode 100644 index 8ee2d93..5ff897e --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ -module github.com/HugeFrog24/thatsky-telegram-bot +module github.com/HugeFrog24/go-telegram-bot -go 1.23.2 +go 1.23 require ( github.com/go-telegram/bot v1.9.0 diff --git a/go.sum b/go.sum old mode 100755 new mode 100644 diff --git a/handlers.go b/handlers.go old mode 100755 new mode 100644 index 52b1ad9..5611ae5 --- a/handlers.go +++ b/handlers.go @@ -22,9 +22,6 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U return } - chatID := message.Chat.ID - userID := message.From.ID - // Extract businessConnectionID if available var businessConnectionID string if update.BusinessConnection != nil { @@ -33,6 +30,18 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U businessConnectionID = message.BusinessConnectionID } + chatID := message.Chat.ID + userID := message.From.ID + username := message.From.Username + text := message.Text + + // Pass the incoming message through the centralized screen for storage + _, err := b.screenIncomingMessage(message) + if err != nil { + log.Printf("Error storing user message: %v", err) + return + } + // Check if the message is a command if message.Entities != nil { for _, entity := range message.Entities { @@ -40,7 +49,10 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U command := strings.TrimSpace(message.Text[entity.Offset : entity.Offset+entity.Length]) switch command { case "/stats": - b.sendStats(ctx, chatID, businessConnectionID) + b.sendStats(ctx, chatID, userID, username, businessConnectionID) + return + case "/whoami": + b.sendWhoAmI(ctx, chatID, userID, username, businessConnectionID) return } } @@ -53,59 +65,71 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U return } - // Existing rate limit and message handling + // Rate limit check if !b.checkRateLimits(userID) { b.sendRateLimitExceededMessage(ctx, chatID, businessConnectionID) return } - username := message.From.Username - text := message.Text - // Proceed only if the message contains text if text == "" { - // Optionally, handle other message types or ignore log.Printf("Received a non-text message from user %d in chat %d", userID, chatID) return } - user, err := b.getOrCreateUser(userID, username) + // Determine if the user is the owner + var isOwner bool + err = b.db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", userID, b.botID, true).First(&User{}).Error + if err == nil { + isOwner = true + } + + user, err := b.getOrCreateUser(userID, username, isOwner) if err != nil { log.Printf("Error getting or creating user: %v", err) return } - userMessage := b.createMessage(chatID, userID, username, user.Role.Name, text, true) - userMessage.UserRole = string(anthropic.RoleUser) // Convert to string - b.storeMessage(userMessage) + // Update the username if it's empty or has changed + if user.Username != username { + user.Username = username + if err := b.db.Save(&user).Error; err != nil { + log.Printf("Error updating user username: %v", err) + } + } + // Determine if the text contains only emojis + isEmojiOnly := isOnlyEmojis(text) + + // Prepare context messages for Anthropic chatMemory := b.getOrCreateChatMemory(chatID) - b.addMessageToChatMemory(chatMemory, userMessage) - + b.addMessageToChatMemory(chatMemory, b.createMessage(chatID, userID, username, user.Role.Name, text, true)) contextMessages := b.prepareContextMessages(chatMemory) - isEmojiOnly := isOnlyEmojis(text) // Ensure you have this variable defined - response, err := b.getAnthropicResponse(ctx, contextMessages, b.isNewChat(chatID), b.isAdminOrOwner(userID), isEmojiOnly) + // Get response from Anthropic + response, err := b.getAnthropicResponse(ctx, contextMessages, b.isNewChat(chatID), isOwner, isEmojiOnly) if err != nil { log.Printf("Error getting Anthropic response: %v", err) response = "I'm sorry, I'm having trouble processing your request right now." } - b.sendResponse(ctx, chatID, response, businessConnectionID) - - assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false) - b.storeMessage(assistantMessage) - b.addMessageToChatMemory(chatMemory, assistantMessage) + // Send the response through the centralized screen + if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil { + log.Printf("Error sending response: %v", err) + return + } } func (b *Bot) sendRateLimitExceededMessage(ctx context.Context, chatID int64, businessConnectionID string) { - b.sendResponse(ctx, chatID, "Rate limit exceeded. Please try again later.", businessConnectionID) + if err := b.sendResponse(ctx, chatID, "Rate limit exceeded. Please try again later.", businessConnectionID); err != nil { + log.Printf("Error sending rate limit exceeded message: %v", err) + } } func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, message *models.Message, businessConnectionID string) { username := message.From.Username - // Create and store the sticker message + // Create the user message (without storing it manually) userMessage := b.createMessage(chatID, userID, username, "user", "Sent a sticker.", true) userMessage.StickerFileID = message.Sticker.FileID @@ -114,9 +138,7 @@ func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, me userMessage.StickerPNGFile = message.Sticker.Thumbnail.FileID } - b.storeMessage(userMessage) - - // Update chat memory + // Update chat memory with the user message chatMemory := b.getOrCreateChatMemory(chatID) b.addMessageToChatMemory(chatMemory, userMessage) @@ -134,11 +156,11 @@ func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, me } } - b.sendResponse(ctx, chatID, response, businessConnectionID) - - assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false) - b.storeMessage(assistantMessage) - b.addMessageToChatMemory(chatMemory, assistantMessage) + // Send the response through the centralized screen + if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil { + log.Printf("Error sending response: %v", err) + return + } } func (b *Bot) generateStickerResponse(ctx context.Context, message Message) (string, error) { diff --git a/main.go b/main.go old mode 100755 new mode 100644 index c93a754..f3eff8b --- a/main.go +++ b/main.go @@ -24,9 +24,6 @@ func main() { log.Printf("Error loading .env file: %v", err) } - // Check for required environment variables - checkRequiredEnvVars() - // Initialize database db, err := initDB() if err != nil { @@ -52,14 +49,24 @@ func main() { go func(cfg BotConfig) { defer wg.Done() - // Create Bot instance with RealClock + // Create Bot instance without TelegramClient initially realClock := RealClock{} - bot, err := NewBot(db, cfg, realClock) + bot, err := NewBot(db, cfg, realClock, nil) if err != nil { log.Printf("Error creating bot %s: %v", cfg.ID, err) return } + // Initialize TelegramClient with the bot's handleUpdate method + tgClient, err := initTelegramBot(cfg.TelegramToken, bot.handleUpdate) + if err != nil { + log.Printf("Error initializing Telegram client for bot %s: %v", cfg.ID, err) + return + } + + // Assign the TelegramClient to the bot + bot.tgBot = tgClient + // Start the bot log.Printf("Starting bot %s...", cfg.ID) bot.Start(ctx) @@ -79,12 +86,3 @@ func initLogger() (*os.File, error) { log.SetOutput(mw) return logFile, nil } - -func checkRequiredEnvVars() { - requiredEnvVars := []string{"ANTHROPIC_API_KEY"} - for _, envVar := range requiredEnvVars { - if os.Getenv(envVar) == "" { - log.Fatalf("%s environment variable is not set", envVar) - } - } -} diff --git a/models.go b/models.go old mode 100755 new mode 100644 index cec57fe..6ce7b1e --- a/models.go +++ b/models.go @@ -11,7 +11,7 @@ type BotModel struct { Identifier string `gorm:"uniqueIndex"` // Renamed from ID to Identifier Name string Configs []ConfigModel `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"` - Users []User `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"` // Added foreign key + Users []User `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"` // Associated users Messages []Message `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"` } @@ -24,6 +24,7 @@ type ConfigModel struct { TempBanDuration string `json:"temp_ban_duration"` SystemPrompts string `json:"system_prompts"` // Consider JSON string or separate table TelegramToken string `json:"telegram_token"` + Active bool `json:"active"` } type Message struct { @@ -53,9 +54,15 @@ type Role struct { type User struct { gorm.Model - BotID uint `gorm:"index"` // Added foreign key to BotModel - TelegramID int64 `gorm:"uniqueIndex"` // Consider composite unique index if TelegramID is unique per Bot + BotID uint `gorm:"index"` // Foreign key to BotModel + TelegramID int64 `gorm:"uniqueIndex;not null"` // Unique per user Username string RoleID uint Role Role `gorm:"foreignKey:RoleID"` + IsOwner bool `gorm:"default:false"` // Indicates if the user is the owner +} + +// Compound unique index to ensure only one owner per bot +func (User) TableName() string { + return "users" } diff --git a/rate_limiter.go b/rate_limiter.go old mode 100755 new mode 100644 diff --git a/rate_limiter_test.go b/rate_limiter_test.go old mode 100755 new mode 100644 index 6e4f54a..7b85f38 --- a/rate_limiter_test.go +++ b/rate_limiter_test.go @@ -1,8 +1,15 @@ package main import ( + "context" + "fmt" "testing" "time" + + "github.com/go-telegram/bot" + "github.com/go-telegram/bot/models" + "gorm.io/driver/sqlite" + "gorm.io/gorm" ) // TestCheckRateLimits tests the checkRateLimits method of the Bot. @@ -22,6 +29,7 @@ func TestCheckRateLimits(t *testing.T) { TempBanDuration: "1m", // Temporary ban duration of 1 minute for testing SystemPrompts: make(map[string]string), TelegramToken: "YOUR_TELEGRAM_BOT_TOKEN", + OwnerTelegramID: 123456789, } // Initialize the Bot with mock data and MockClock @@ -79,6 +87,102 @@ func TestCheckRateLimits(t *testing.T) { } } +func TestOwnerAssignment(t *testing.T) { + // Initialize in-memory database for testing + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("Failed to open in-memory database: %v", err) + } + + // Migrate the schema + err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{}) + if err != nil { + t.Fatalf("Failed to migrate database schema: %v", err) + } + + // Create default roles + err = createDefaultRoles(db) + if err != nil { + t.Fatalf("Failed to create default roles: %v", err) + } + + // Create a bot configuration + config := BotConfig{ + ID: "test_bot", + TelegramToken: "TEST_TELEGRAM_TOKEN", + MemorySize: 10, + MessagePerHour: 5, + MessagePerDay: 10, + TempBanDuration: "1m", + SystemPrompts: make(map[string]string), + Active: true, + OwnerTelegramID: 111111111, + } + + // Initialize MockClock + mockClock := &MockClock{ + currentTime: time.Now(), + } + + // Initialize MockTelegramClient + mockTGClient := &MockTelegramClient{ + SendMessageFunc: func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) { + chatID, ok := params.ChatID.(int64) + if !ok { + return nil, fmt.Errorf("ChatID is not of type int64") + } + // Simulate successful message sending + return &models.Message{ID: 1, Chat: models.Chat{ID: chatID}}, nil + }, + } + + // Create the bot with the mock Telegram client + bot, err := NewBot(db, config, mockClock, mockTGClient) + if err != nil { + t.Fatalf("Failed to create bot: %v", err) + } + + // Verify that the owner exists + var owner User + err = db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", config.OwnerTelegramID, bot.botID, true).First(&owner).Error + if err != nil { + t.Fatalf("Owner was not created: %v", err) + } + + // Attempt to create another owner for the same bot + _, err = bot.getOrCreateUser(222222222, "AnotherOwner", true) + if err == nil { + t.Fatalf("Expected error when creating a second owner, but got none") + } + + // Verify that the error message is appropriate + expectedErrorMsg := "an owner already exists for this bot" + if err.Error() != expectedErrorMsg { + t.Fatalf("Unexpected error message: %v", err) + } + + // Assign admin role to a new user + adminUser, err := bot.getOrCreateUser(333333333, "AdminUser", false) + if err != nil { + t.Fatalf("Failed to create admin user: %v", err) + } + + if adminUser.Role.Name != "admin" { + t.Fatalf("Expected role 'admin', got '%s'", adminUser.Role.Name) + } + + // Attempt to change an existing user to owner + _, err = bot.getOrCreateUser(333333333, "AdminUser", true) + if err == nil { + t.Fatalf("Expected error when changing existing user to owner, but got none") + } + + expectedErrorMsg = "cannot change existing user to owner" + if err.Error() != expectedErrorMsg { + t.Fatalf("Unexpected error message: %v", err) + } +} + // To ensure thread safety and avoid race conditions during testing, // you can run the tests with the `-race` flag: // go test -race -v diff --git a/telegram_client.go b/telegram_client.go new file mode 100644 index 0000000..10dc9bf --- /dev/null +++ b/telegram_client.go @@ -0,0 +1,16 @@ +// telegram_client.go +package main + +import ( + "context" + + "github.com/go-telegram/bot" + "github.com/go-telegram/bot/models" +) + +// TelegramClient defines the methods required from the Telegram bot. +type TelegramClient interface { + SendMessage(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) + Start(ctx context.Context) + // Add other methods if needed. +} diff --git a/telegram_client_mock.go b/telegram_client_mock.go new file mode 100644 index 0000000..61520f7 --- /dev/null +++ b/telegram_client_mock.go @@ -0,0 +1,35 @@ +// telegram_client_mock.go +package main + +import ( + "context" + + "github.com/go-telegram/bot" + "github.com/go-telegram/bot/models" +) + +// MockTelegramClient is a mock implementation of TelegramClient for testing. +type MockTelegramClient struct { + // You can add fields to keep track of calls if needed. + SendMessageFunc func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) + StartFunc func(ctx context.Context) // Optional: track Start calls +} + +// SendMessage mocks sending a message. +func (m *MockTelegramClient) SendMessage(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) { + if m.SendMessageFunc != nil { + return m.SendMessageFunc(ctx, params) + } + // Default behavior: return an empty message without error. + return &models.Message{}, nil +} + +// Start mocks starting the Telegram client. +func (m *MockTelegramClient) Start(ctx context.Context) { + if m.StartFunc != nil { + m.StartFunc(ctx) + } + // Default behavior: do nothing. +} + +// Add other mocked methods if your Bot uses more TelegramClient methods.