diff --git a/.gitignore b/.gitignore index 21238dc..ae342fe 100644 --- a/.gitignore +++ b/.gitignore @@ -4,8 +4,8 @@ vendor/ # Environment variables .env -# Log file -bot.log +# Any log files +*.log # Database file bot.db diff --git a/bot.go b/bot.go index b90593c..7b851cd 100644 --- a/bot.go +++ b/bot.go @@ -13,11 +13,13 @@ import ( "github.com/go-telegram/bot" "github.com/go-telegram/bot/models" "github.com/liushuangls/go-anthropic/v2" + "golang.org/x/text/cases" + "golang.org/x/text/language" "gorm.io/gorm" ) type Bot struct { - tgBot *bot.Bot + tgBot TelegramClient db *gorm.DB anthropicClient *anthropic.Client chatMemories map[int64]*ChatMemory @@ -30,7 +32,8 @@ type Bot struct { botID uint // Reference to BotModel.ID } -func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) { +// bot.go +func NewBot(db *gorm.DB, config BotConfig, clock Clock, tgClient TelegramClient) (*Bot, error) { // Retrieve or create Bot entry in the database var botEntry BotModel err := db.Where("identifier = ?", config.ID).First(&botEntry).Error @@ -43,6 +46,37 @@ func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) { return nil, err } + // Ensure the owner exists in the Users table + var owner User + err = db.Where("telegram_id = ? AND bot_id = ?", config.OwnerTelegramID, botEntry.ID).First(&owner).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + // Assign the "owner" role + var ownerRole Role + err := db.Where("name = ?", "owner").First(&ownerRole).Error + if err != nil { + return nil, fmt.Errorf("owner role not found: %w", err) + } + + owner = User{ + BotID: botEntry.ID, + TelegramID: config.OwnerTelegramID, + Username: "", // Initialize as empty; will be updated upon interaction + RoleID: ownerRole.ID, + IsOwner: true, + } + + if err := db.Create(&owner).Error; err != nil { + // If unique constraint is violated, another owner already exists + if strings.Contains(err.Error(), "unique index") { + return nil, fmt.Errorf("an owner already exists for this bot") + } + return nil, fmt.Errorf("failed to create owner user: %w", err) + } + } else if err != nil { + return nil, err + } + + // Initialize Anthropic client anthropicClient := anthropic.NewClient(os.Getenv("ANTHROPIC_API_KEY")) b := &Bot{ @@ -54,14 +88,9 @@ func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) { userLimiters: make(map[int64]*userLimiter), clock: clock, botID: botEntry.ID, // Ensure BotModel has ID field + tgBot: tgClient, } - tgBot, err := initTelegramBot(config.TelegramToken, b.handleUpdate) - if err != nil { - return nil, err - } - b.tgBot = tgBot - return b, nil } @@ -69,26 +98,69 @@ func (b *Bot) Start(ctx context.Context) { b.tgBot.Start(ctx) } -func (b *Bot) getOrCreateUser(userID int64, username string) (User, error) { +func (b *Bot) getOrCreateUser(userID int64, username string, isOwner bool) (User, error) { var user User - err := b.db.Preload("Role").Where("telegram_id = ?", userID).First(&user).Error + err := b.db.Preload("Role").Where("telegram_id = ? AND bot_id = ?", userID, b.botID).First(&user).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - var defaultRole Role - if err := b.db.Where("name = ?", "user").First(&defaultRole).Error; err != nil { - return User{}, err + // Check if an owner already exists for this bot + if isOwner { + var existingOwner User + err := b.db.Where("bot_id = ? AND is_owner = ?", b.botID, true).First(&existingOwner).Error + if err == nil { + return User{}, fmt.Errorf("an owner already exists for this bot") + } else if !errors.Is(err, gorm.ErrRecordNotFound) { + return User{}, fmt.Errorf("failed to check existing owner: %w", err) + } } - user = User{TelegramID: userID, Username: username, RoleID: defaultRole.ID} + + var role Role + var roleName string + if isOwner { + roleName = "owner" + } else { + roleName = "user" // Assign "user" role to non-owner users + } + + err := b.db.Where("name = ?", roleName).First(&role).Error + if err != nil { + return User{}, fmt.Errorf("failed to get role: %w", err) + } + + user = User{ + BotID: b.botID, + TelegramID: userID, + Username: username, + RoleID: role.ID, + Role: role, + IsOwner: isOwner, + } + if err := b.db.Create(&user).Error; err != nil { - return User{}, err + // If unique constraint is violated, another owner already exists + if strings.Contains(err.Error(), "unique index") { + return User{}, fmt.Errorf("an owner already exists for this bot") + } + return User{}, fmt.Errorf("failed to create user: %w", err) } } else { return User{}, err } + } else { + if isOwner && !user.IsOwner { + return User{}, fmt.Errorf("cannot change existing user to owner") + } } + return user, nil } +func (b *Bot) getRoleByName(roleName string) (Role, error) { + var role Role + err := b.db.Where("name = ?", roleName).First(&role).Error + return role, err +} + func (b *Bot) createMessage(chatID, userID int64, username, userRole, text string, isUser bool) Message { message := Message{ ChatID: chatID, @@ -195,12 +267,17 @@ func (b *Bot) isAdminOrOwner(userID int64) bool { return user.Role.Name == "admin" || user.Role.Name == "owner" } -func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot *bot.Bot, update *models.Update)) (*bot.Bot, error) { +func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot *bot.Bot, update *models.Update)) (TelegramClient, error) { opts := []bot.Option{ bot.WithDefaultHandler(handleUpdate), } - return bot.New(token, opts...) + tgBot, err := bot.New(token, opts...) + if err != nil { + return nil, err + } + + return tgBot, nil } func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, businessConnectionID string) error { @@ -290,3 +367,37 @@ func isEmoji(r rune) bool { (r >= 0x2600 && r <= 0x26FF) || // Misc symbols (r >= 0x2700 && r <= 0x27BF) // Dingbats } + +func (b *Bot) sendWhoAmI(ctx context.Context, chatID int64, userID int64, username string, businessConnectionID string) { + user, err := b.getOrCreateUser(userID, username, false) + if err != nil { + log.Printf("Error getting or creating user: %v", err) + b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve your information.", businessConnectionID) + return + } + + caser := cases.Title(language.English) + whoAmIMessage := fmt.Sprintf( + "👤 Your Information:\n\n"+ + "- Username: %s\n"+ + "- Role: %s", + user.Username, + caser.String(user.Role.Name), + ) + + // Store the user's /whoami command + userMessage := b.createMessage(chatID, userID, username, "user", "/whoami", true) + if err := b.storeMessage(userMessage); err != nil { + log.Printf("Error storing user message: %v", err) + } + + // Send and store the bot's response + if err := b.sendResponse(ctx, chatID, whoAmIMessage, businessConnectionID); err != nil { + log.Printf("Error sending /whoami message: %v", err) + } + assistantMessage := b.createMessage(chatID, 0, "", "assistant", whoAmIMessage, false) + if err := b.storeMessage(assistantMessage); err != nil { + log.Printf("Error storing assistant message: %v", err) + } + b.addMessageToChatMemory(b.getOrCreateChatMemory(chatID), assistantMessage) +} diff --git a/config.go b/config.go index 0ff9436..2d1e556 100644 --- a/config.go +++ b/config.go @@ -19,6 +19,7 @@ type BotConfig struct { Model anthropic.Model `json:"model"` // Changed from string to anthropic.Model SystemPrompts map[string]string `json:"system_prompts"` Active bool `json:"active"` // New field to control bot activity + OwnerTelegramID int64 `json:"owner_telegram_id"` } // Custom unmarshalling to handle anthropic.Model diff --git a/config/default.json b/config/default.json index 9df35f7..036fd74 100644 --- a/config/default.json +++ b/config/default.json @@ -2,6 +2,7 @@ "id": "default_bot", "active": false, "telegram_token": "YOUR_TELEGRAM_BOT_TOKEN", + "owner_telegram_id": 111111111, "memory_size": 10, "messages_per_hour": 20, "messages_per_day": 100, diff --git a/database.go b/database.go index da7a29a..6c7a8f6 100644 --- a/database.go +++ b/database.go @@ -27,11 +27,22 @@ func initDB() (*gorm.DB, error) { return nil, fmt.Errorf("failed to connect to database: %w", err) } + // AutoMigrate the models err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{}) if err != nil { return nil, fmt.Errorf("failed to migrate database schema: %w", err) } + // Enforce unique owner per bot using raw SQL + // Note: SQLite doesn't support partial indexes, but we can simulate it by making a unique index on (BotID, IsOwner) + // and ensuring that IsOwner can only be true for one user per BotID. + // This approach allows multiple users with IsOwner=false for the same BotID, + // but only one user can have IsOwner=true per BotID. + err = db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_bot_owner ON users (bot_id, is_owner) WHERE is_owner = 1;`).Error + if err != nil { + return nil, fmt.Errorf("failed to create unique index for bot owners: %w", err) + } + err = createDefaultRoles(db) if err != nil { return nil, err diff --git a/handlers.go b/handlers.go index a7d9473..39b668e 100644 --- a/handlers.go +++ b/handlers.go @@ -42,6 +42,9 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U case "/stats": b.sendStats(ctx, chatID, userID, message.From.Username, businessConnectionID) return + case "/whoami": + b.sendWhoAmI(ctx, chatID, userID, message.From.Username, businessConnectionID) + return } } } @@ -69,32 +72,55 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U return } - user, err := b.getOrCreateUser(userID, username) + // Determine if the user is the owner + var isOwner bool + err := b.db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", userID, b.botID, true).First(&User{}).Error + if err == nil { + isOwner = true + } + + user, err := b.getOrCreateUser(userID, username, isOwner) if err != nil { log.Printf("Error getting or creating user: %v", err) return } + // Update the username if it's empty or has changed + if user.Username != username { + user.Username = username + if err := b.db.Save(&user).Error; err != nil { + log.Printf("Error updating user username: %v", err) + } + } + userMessage := b.createMessage(chatID, userID, username, user.Role.Name, text, true) userMessage.UserRole = string(anthropic.RoleUser) // Convert to string - b.storeMessage(userMessage) + if err := b.storeMessage(userMessage); err != nil { + log.Printf("Error storing user message: %v", err) + return + } chatMemory := b.getOrCreateChatMemory(chatID) b.addMessageToChatMemory(chatMemory, userMessage) contextMessages := b.prepareContextMessages(chatMemory) - isEmojiOnly := isOnlyEmojis(text) // Ensure you have this variable defined - response, err := b.getAnthropicResponse(ctx, contextMessages, b.isNewChat(chatID), b.isAdminOrOwner(userID), isEmojiOnly) + isEmojiOnly := isOnlyEmojis(text) + response, err := b.getAnthropicResponse(ctx, contextMessages, b.isNewChat(chatID), isOwner, isEmojiOnly) if err != nil { log.Printf("Error getting Anthropic response: %v", err) response = "I'm sorry, I'm having trouble processing your request right now." } - b.sendResponse(ctx, chatID, response, businessConnectionID) + if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil { + log.Printf("Error sending response: %v", err) + return + } - assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false) - b.storeMessage(assistantMessage) + assistantMessage := b.createMessage(chatID, 0, "", "assistant", response, false) + if err := b.storeMessage(assistantMessage); err != nil { + log.Printf("Error storing assistant message: %v", err) + } b.addMessageToChatMemory(chatMemory, assistantMessage) } diff --git a/main.go b/main.go index c93a754..1e6c2aa 100644 --- a/main.go +++ b/main.go @@ -52,14 +52,24 @@ func main() { go func(cfg BotConfig) { defer wg.Done() - // Create Bot instance with RealClock + // Create Bot instance without TelegramClient initially realClock := RealClock{} - bot, err := NewBot(db, cfg, realClock) + bot, err := NewBot(db, cfg, realClock, nil) if err != nil { log.Printf("Error creating bot %s: %v", cfg.ID, err) return } + // Initialize TelegramClient with the bot's handleUpdate method + tgClient, err := initTelegramBot(cfg.TelegramToken, bot.handleUpdate) + if err != nil { + log.Printf("Error initializing Telegram client for bot %s: %v", cfg.ID, err) + return + } + + // Assign the TelegramClient to the bot + bot.tgBot = tgClient + // Start the bot log.Printf("Starting bot %s...", cfg.ID) bot.Start(ctx) diff --git a/models.go b/models.go index f518754..6ce7b1e 100644 --- a/models.go +++ b/models.go @@ -11,7 +11,7 @@ type BotModel struct { Identifier string `gorm:"uniqueIndex"` // Renamed from ID to Identifier Name string Configs []ConfigModel `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"` - Users []User `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"` // Added foreign key + Users []User `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"` // Associated users Messages []Message `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"` } @@ -54,9 +54,15 @@ type Role struct { type User struct { gorm.Model - BotID uint `gorm:"index"` // Added foreign key to BotModel - TelegramID int64 `gorm:"uniqueIndex"` // Consider composite unique index if TelegramID is unique per Bot + BotID uint `gorm:"index"` // Foreign key to BotModel + TelegramID int64 `gorm:"uniqueIndex;not null"` // Unique per user Username string RoleID uint Role Role `gorm:"foreignKey:RoleID"` + IsOwner bool `gorm:"default:false"` // Indicates if the user is the owner +} + +// Compound unique index to ensure only one owner per bot +func (User) TableName() string { + return "users" } diff --git a/rate_limiter_test.go b/rate_limiter_test.go index 6e4f54a..7b85f38 100644 --- a/rate_limiter_test.go +++ b/rate_limiter_test.go @@ -1,8 +1,15 @@ package main import ( + "context" + "fmt" "testing" "time" + + "github.com/go-telegram/bot" + "github.com/go-telegram/bot/models" + "gorm.io/driver/sqlite" + "gorm.io/gorm" ) // TestCheckRateLimits tests the checkRateLimits method of the Bot. @@ -22,6 +29,7 @@ func TestCheckRateLimits(t *testing.T) { TempBanDuration: "1m", // Temporary ban duration of 1 minute for testing SystemPrompts: make(map[string]string), TelegramToken: "YOUR_TELEGRAM_BOT_TOKEN", + OwnerTelegramID: 123456789, } // Initialize the Bot with mock data and MockClock @@ -79,6 +87,102 @@ func TestCheckRateLimits(t *testing.T) { } } +func TestOwnerAssignment(t *testing.T) { + // Initialize in-memory database for testing + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("Failed to open in-memory database: %v", err) + } + + // Migrate the schema + err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{}) + if err != nil { + t.Fatalf("Failed to migrate database schema: %v", err) + } + + // Create default roles + err = createDefaultRoles(db) + if err != nil { + t.Fatalf("Failed to create default roles: %v", err) + } + + // Create a bot configuration + config := BotConfig{ + ID: "test_bot", + TelegramToken: "TEST_TELEGRAM_TOKEN", + MemorySize: 10, + MessagePerHour: 5, + MessagePerDay: 10, + TempBanDuration: "1m", + SystemPrompts: make(map[string]string), + Active: true, + OwnerTelegramID: 111111111, + } + + // Initialize MockClock + mockClock := &MockClock{ + currentTime: time.Now(), + } + + // Initialize MockTelegramClient + mockTGClient := &MockTelegramClient{ + SendMessageFunc: func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) { + chatID, ok := params.ChatID.(int64) + if !ok { + return nil, fmt.Errorf("ChatID is not of type int64") + } + // Simulate successful message sending + return &models.Message{ID: 1, Chat: models.Chat{ID: chatID}}, nil + }, + } + + // Create the bot with the mock Telegram client + bot, err := NewBot(db, config, mockClock, mockTGClient) + if err != nil { + t.Fatalf("Failed to create bot: %v", err) + } + + // Verify that the owner exists + var owner User + err = db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", config.OwnerTelegramID, bot.botID, true).First(&owner).Error + if err != nil { + t.Fatalf("Owner was not created: %v", err) + } + + // Attempt to create another owner for the same bot + _, err = bot.getOrCreateUser(222222222, "AnotherOwner", true) + if err == nil { + t.Fatalf("Expected error when creating a second owner, but got none") + } + + // Verify that the error message is appropriate + expectedErrorMsg := "an owner already exists for this bot" + if err.Error() != expectedErrorMsg { + t.Fatalf("Unexpected error message: %v", err) + } + + // Assign admin role to a new user + adminUser, err := bot.getOrCreateUser(333333333, "AdminUser", false) + if err != nil { + t.Fatalf("Failed to create admin user: %v", err) + } + + if adminUser.Role.Name != "admin" { + t.Fatalf("Expected role 'admin', got '%s'", adminUser.Role.Name) + } + + // Attempt to change an existing user to owner + _, err = bot.getOrCreateUser(333333333, "AdminUser", true) + if err == nil { + t.Fatalf("Expected error when changing existing user to owner, but got none") + } + + expectedErrorMsg = "cannot change existing user to owner" + if err.Error() != expectedErrorMsg { + t.Fatalf("Unexpected error message: %v", err) + } +} + // To ensure thread safety and avoid race conditions during testing, // you can run the tests with the `-race` flag: // go test -race -v diff --git a/telegram_client.go b/telegram_client.go new file mode 100644 index 0000000..10dc9bf --- /dev/null +++ b/telegram_client.go @@ -0,0 +1,16 @@ +// telegram_client.go +package main + +import ( + "context" + + "github.com/go-telegram/bot" + "github.com/go-telegram/bot/models" +) + +// TelegramClient defines the methods required from the Telegram bot. +type TelegramClient interface { + SendMessage(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) + Start(ctx context.Context) + // Add other methods if needed. +} diff --git a/telegram_client_mock.go b/telegram_client_mock.go new file mode 100644 index 0000000..61520f7 --- /dev/null +++ b/telegram_client_mock.go @@ -0,0 +1,35 @@ +// telegram_client_mock.go +package main + +import ( + "context" + + "github.com/go-telegram/bot" + "github.com/go-telegram/bot/models" +) + +// MockTelegramClient is a mock implementation of TelegramClient for testing. +type MockTelegramClient struct { + // You can add fields to keep track of calls if needed. + SendMessageFunc func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) + StartFunc func(ctx context.Context) // Optional: track Start calls +} + +// SendMessage mocks sending a message. +func (m *MockTelegramClient) SendMessage(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) { + if m.SendMessageFunc != nil { + return m.SendMessageFunc(ctx, params) + } + // Default behavior: return an empty message without error. + return &models.Message{}, nil +} + +// Start mocks starting the Telegram client. +func (m *MockTelegramClient) Start(ctx context.Context) { + if m.StartFunc != nil { + m.StartFunc(ctx) + } + // Default behavior: do nothing. +} + +// Add other mocked methods if your Bot uses more TelegramClient methods.