diff --git a/bot.go b/bot.go index 9b3b80c..7b851cd 100644 --- a/bot.go +++ b/bot.go @@ -13,11 +13,13 @@ import ( "github.com/go-telegram/bot" "github.com/go-telegram/bot/models" "github.com/liushuangls/go-anthropic/v2" + "golang.org/x/text/cases" + "golang.org/x/text/language" "gorm.io/gorm" ) type Bot struct { - tgBot *bot.Bot + tgBot TelegramClient db *gorm.DB anthropicClient *anthropic.Client chatMemories map[int64]*ChatMemory @@ -30,7 +32,8 @@ type Bot struct { botID uint // Reference to BotModel.ID } -func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) { +// bot.go +func NewBot(db *gorm.DB, config BotConfig, clock Clock, tgClient TelegramClient) (*Bot, error) { // Retrieve or create Bot entry in the database var botEntry BotModel err := db.Where("identifier = ?", config.ID).First(&botEntry).Error @@ -57,7 +60,7 @@ func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) { owner = User{ BotID: botEntry.ID, TelegramID: config.OwnerTelegramID, - Username: "Owner", // You might want to fetch the actual username + Username: "", // Initialize as empty; will be updated upon interaction RoleID: ownerRole.ID, IsOwner: true, } @@ -85,14 +88,9 @@ func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) { userLimiters: make(map[int64]*userLimiter), clock: clock, botID: botEntry.ID, // Ensure BotModel has ID field + tgBot: tgClient, } - tgBot, err := initTelegramBot(config.TelegramToken, b.handleUpdate) - if err != nil { - return nil, err - } - b.tgBot = tgBot - return b, nil } @@ -105,17 +103,28 @@ func (b *Bot) getOrCreateUser(userID int64, username string, isOwner bool) (User err := b.db.Preload("Role").Where("telegram_id = ? AND bot_id = ?", userID, b.botID).First(&user).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - var role Role + // Check if an owner already exists for this bot if isOwner { - role, err = b.getRoleByName("owner") - if err != nil { - return User{}, err + var existingOwner User + err := b.db.Where("bot_id = ? AND is_owner = ?", b.botID, true).First(&existingOwner).Error + if err == nil { + return User{}, fmt.Errorf("an owner already exists for this bot") + } else if !errors.Is(err, gorm.ErrRecordNotFound) { + return User{}, fmt.Errorf("failed to check existing owner: %w", err) } + } + + var role Role + var roleName string + if isOwner { + roleName = "owner" } else { - role, err = b.getRoleByName("admin") - if err != nil { - return User{}, err - } + roleName = "user" // Assign "user" role to non-owner users + } + + err := b.db.Where("name = ?", roleName).First(&role).Error + if err != nil { + return User{}, fmt.Errorf("failed to get role: %w", err) } user = User{ @@ -123,39 +132,23 @@ func (b *Bot) getOrCreateUser(userID int64, username string, isOwner bool) (User TelegramID: userID, Username: username, RoleID: role.ID, + Role: role, IsOwner: isOwner, } if err := b.db.Create(&user).Error; err != nil { - // Handle unique constraint for owner - if isOwner && strings.Contains(err.Error(), "unique index") { + // If unique constraint is violated, another owner already exists + if strings.Contains(err.Error(), "unique index") { return User{}, fmt.Errorf("an owner already exists for this bot") } - return User{}, err + return User{}, fmt.Errorf("failed to create user: %w", err) } } else { return User{}, err } } else { if isOwner && !user.IsOwner { - // Check if another owner exists - var existingOwner User - err := b.db.Where("bot_id = ? AND is_owner = ?", b.botID, true).First(&existingOwner).Error - if err == nil { - return User{}, fmt.Errorf("a bot can have only one owner") - } else if !errors.Is(err, gorm.ErrRecordNotFound) { - return User{}, err - } - // Promote to owner - role, err := b.getRoleByName("owner") - if err != nil { - return User{}, err - } - user.RoleID = role.ID - user.IsOwner = true - if err := b.db.Save(&user).Error; err != nil { - return User{}, err - } + return User{}, fmt.Errorf("cannot change existing user to owner") } } @@ -274,12 +267,17 @@ func (b *Bot) isAdminOrOwner(userID int64) bool { return user.Role.Name == "admin" || user.Role.Name == "owner" } -func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot *bot.Bot, update *models.Update)) (*bot.Bot, error) { +func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot *bot.Bot, update *models.Update)) (TelegramClient, error) { opts := []bot.Option{ bot.WithDefaultHandler(handleUpdate), } - return bot.New(token, opts...) + tgBot, err := bot.New(token, opts...) + if err != nil { + return nil, err + } + + return tgBot, nil } func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, businessConnectionID string) error { @@ -369,3 +367,37 @@ func isEmoji(r rune) bool { (r >= 0x2600 && r <= 0x26FF) || // Misc symbols (r >= 0x2700 && r <= 0x27BF) // Dingbats } + +func (b *Bot) sendWhoAmI(ctx context.Context, chatID int64, userID int64, username string, businessConnectionID string) { + user, err := b.getOrCreateUser(userID, username, false) + if err != nil { + log.Printf("Error getting or creating user: %v", err) + b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve your information.", businessConnectionID) + return + } + + caser := cases.Title(language.English) + whoAmIMessage := fmt.Sprintf( + "👤 Your Information:\n\n"+ + "- Username: %s\n"+ + "- Role: %s", + user.Username, + caser.String(user.Role.Name), + ) + + // Store the user's /whoami command + userMessage := b.createMessage(chatID, userID, username, "user", "/whoami", true) + if err := b.storeMessage(userMessage); err != nil { + log.Printf("Error storing user message: %v", err) + } + + // Send and store the bot's response + if err := b.sendResponse(ctx, chatID, whoAmIMessage, businessConnectionID); err != nil { + log.Printf("Error sending /whoami message: %v", err) + } + assistantMessage := b.createMessage(chatID, 0, "", "assistant", whoAmIMessage, false) + if err := b.storeMessage(assistantMessage); err != nil { + log.Printf("Error storing assistant message: %v", err) + } + b.addMessageToChatMemory(b.getOrCreateChatMemory(chatID), assistantMessage) +} diff --git a/handlers.go b/handlers.go index 149b772..39b668e 100644 --- a/handlers.go +++ b/handlers.go @@ -42,6 +42,9 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U case "/stats": b.sendStats(ctx, chatID, userID, message.From.Username, businessConnectionID) return + case "/whoami": + b.sendWhoAmI(ctx, chatID, userID, message.From.Username, businessConnectionID) + return } } } @@ -82,6 +85,14 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U return } + // Update the username if it's empty or has changed + if user.Username != username { + user.Username = username + if err := b.db.Save(&user).Error; err != nil { + log.Printf("Error updating user username: %v", err) + } + } + userMessage := b.createMessage(chatID, userID, username, user.Role.Name, text, true) userMessage.UserRole = string(anthropic.RoleUser) // Convert to string if err := b.storeMessage(userMessage); err != nil { @@ -109,7 +120,6 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U assistantMessage := b.createMessage(chatID, 0, "", "assistant", response, false) if err := b.storeMessage(assistantMessage); err != nil { log.Printf("Error storing assistant message: %v", err) - return } b.addMessageToChatMemory(chatMemory, assistantMessage) } diff --git a/main.go b/main.go index c93a754..1e6c2aa 100644 --- a/main.go +++ b/main.go @@ -52,14 +52,24 @@ func main() { go func(cfg BotConfig) { defer wg.Done() - // Create Bot instance with RealClock + // Create Bot instance without TelegramClient initially realClock := RealClock{} - bot, err := NewBot(db, cfg, realClock) + bot, err := NewBot(db, cfg, realClock, nil) if err != nil { log.Printf("Error creating bot %s: %v", cfg.ID, err) return } + // Initialize TelegramClient with the bot's handleUpdate method + tgClient, err := initTelegramBot(cfg.TelegramToken, bot.handleUpdate) + if err != nil { + log.Printf("Error initializing Telegram client for bot %s: %v", cfg.ID, err) + return + } + + // Assign the TelegramClient to the bot + bot.tgBot = tgClient + // Start the bot log.Printf("Starting bot %s...", cfg.ID) bot.Start(ctx) diff --git a/rate_limiter_test.go b/rate_limiter_test.go index 5965a3e..7b85f38 100644 --- a/rate_limiter_test.go +++ b/rate_limiter_test.go @@ -1,10 +1,13 @@ package main import ( - "strings" + "context" + "fmt" "testing" "time" + "github.com/go-telegram/bot" + "github.com/go-telegram/bot/models" "gorm.io/driver/sqlite" "gorm.io/gorm" ) @@ -121,8 +124,20 @@ func TestOwnerAssignment(t *testing.T) { currentTime: time.Now(), } - // Create the bot - bot, err := NewBot(db, config, mockClock) + // Initialize MockTelegramClient + mockTGClient := &MockTelegramClient{ + SendMessageFunc: func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) { + chatID, ok := params.ChatID.(int64) + if !ok { + return nil, fmt.Errorf("ChatID is not of type int64") + } + // Simulate successful message sending + return &models.Message{ID: 1, Chat: models.Chat{ID: chatID}}, nil + }, + } + + // Create the bot with the mock Telegram client + bot, err := NewBot(db, config, mockClock, mockTGClient) if err != nil { t.Fatalf("Failed to create bot: %v", err) } @@ -142,7 +157,7 @@ func TestOwnerAssignment(t *testing.T) { // Verify that the error message is appropriate expectedErrorMsg := "an owner already exists for this bot" - if err.Error() != expectedErrorMsg && !strings.Contains(err.Error(), "unique index") { + if err.Error() != expectedErrorMsg { t.Fatalf("Unexpected error message: %v", err) } @@ -156,34 +171,15 @@ func TestOwnerAssignment(t *testing.T) { t.Fatalf("Expected role 'admin', got '%s'", adminUser.Role.Name) } - // Assign owner role to a user from a different bot - otherBotConfig := BotConfig{ - ID: "other_bot", - TelegramToken: "OTHER_TELEGRAM_TOKEN", - MemorySize: 10, - MessagePerHour: 5, - MessagePerDay: 10, - TempBanDuration: "1m", - SystemPrompts: make(map[string]string), - Active: true, - OwnerTelegramID: 444444444, + // Attempt to change an existing user to owner + _, err = bot.getOrCreateUser(333333333, "AdminUser", true) + if err == nil { + t.Fatalf("Expected error when changing existing user to owner, but got none") } - otherBot, err := NewBot(db, otherBotConfig, mockClock) - if err != nil { - t.Fatalf("Failed to create other bot: %v", err) - } - - _, err = otherBot.getOrCreateUser(config.OwnerTelegramID, "OwnerOfOtherBot", true) - if err != nil { - t.Fatalf("Failed to assign existing owner to another bot: %v", err) - } - - // Verify multiple bots can have the same owner telegram ID - var ownerOfOtherBot User - err = db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", config.OwnerTelegramID, otherBot.botID, true).First(&ownerOfOtherBot).Error - if err != nil { - t.Fatalf("Owner of other bot was not created: %v", err) + expectedErrorMsg = "cannot change existing user to owner" + if err.Error() != expectedErrorMsg { + t.Fatalf("Unexpected error message: %v", err) } } diff --git a/telegram_client.go b/telegram_client.go new file mode 100644 index 0000000..10dc9bf --- /dev/null +++ b/telegram_client.go @@ -0,0 +1,16 @@ +// telegram_client.go +package main + +import ( + "context" + + "github.com/go-telegram/bot" + "github.com/go-telegram/bot/models" +) + +// TelegramClient defines the methods required from the Telegram bot. +type TelegramClient interface { + SendMessage(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) + Start(ctx context.Context) + // Add other methods if needed. +} diff --git a/telegram_client_mock.go b/telegram_client_mock.go new file mode 100644 index 0000000..61520f7 --- /dev/null +++ b/telegram_client_mock.go @@ -0,0 +1,35 @@ +// telegram_client_mock.go +package main + +import ( + "context" + + "github.com/go-telegram/bot" + "github.com/go-telegram/bot/models" +) + +// MockTelegramClient is a mock implementation of TelegramClient for testing. +type MockTelegramClient struct { + // You can add fields to keep track of calls if needed. + SendMessageFunc func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) + StartFunc func(ctx context.Context) // Optional: track Start calls +} + +// SendMessage mocks sending a message. +func (m *MockTelegramClient) SendMessage(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) { + if m.SendMessageFunc != nil { + return m.SendMessageFunc(ctx, params) + } + // Default behavior: return an empty message without error. + return &models.Message{}, nil +} + +// Start mocks starting the Telegram client. +func (m *MockTelegramClient) Start(ctx context.Context) { + if m.StartFunc != nil { + m.StartFunc(ctx) + } + // Default behavior: do nothing. +} + +// Add other mocked methods if your Bot uses more TelegramClient methods.