Proceed with role management

This commit is contained in:
HugeFrog24
2024-10-22 15:50:51 +02:00
parent acaf5d01ab
commit ce59b5f5f1
6 changed files with 172 additions and 73 deletions

112
bot.go
View File

@@ -13,11 +13,13 @@ import (
"github.com/go-telegram/bot" "github.com/go-telegram/bot"
"github.com/go-telegram/bot/models" "github.com/go-telegram/bot/models"
"github.com/liushuangls/go-anthropic/v2" "github.com/liushuangls/go-anthropic/v2"
"golang.org/x/text/cases"
"golang.org/x/text/language"
"gorm.io/gorm" "gorm.io/gorm"
) )
type Bot struct { type Bot struct {
tgBot *bot.Bot tgBot TelegramClient
db *gorm.DB db *gorm.DB
anthropicClient *anthropic.Client anthropicClient *anthropic.Client
chatMemories map[int64]*ChatMemory chatMemories map[int64]*ChatMemory
@@ -30,7 +32,8 @@ type Bot struct {
botID uint // Reference to BotModel.ID botID uint // Reference to BotModel.ID
} }
func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) { // bot.go
func NewBot(db *gorm.DB, config BotConfig, clock Clock, tgClient TelegramClient) (*Bot, error) {
// Retrieve or create Bot entry in the database // Retrieve or create Bot entry in the database
var botEntry BotModel var botEntry BotModel
err := db.Where("identifier = ?", config.ID).First(&botEntry).Error err := db.Where("identifier = ?", config.ID).First(&botEntry).Error
@@ -57,7 +60,7 @@ func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) {
owner = User{ owner = User{
BotID: botEntry.ID, BotID: botEntry.ID,
TelegramID: config.OwnerTelegramID, TelegramID: config.OwnerTelegramID,
Username: "Owner", // You might want to fetch the actual username Username: "", // Initialize as empty; will be updated upon interaction
RoleID: ownerRole.ID, RoleID: ownerRole.ID,
IsOwner: true, IsOwner: true,
} }
@@ -85,14 +88,9 @@ func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) {
userLimiters: make(map[int64]*userLimiter), userLimiters: make(map[int64]*userLimiter),
clock: clock, clock: clock,
botID: botEntry.ID, // Ensure BotModel has ID field botID: botEntry.ID, // Ensure BotModel has ID field
tgBot: tgClient,
} }
tgBot, err := initTelegramBot(config.TelegramToken, b.handleUpdate)
if err != nil {
return nil, err
}
b.tgBot = tgBot
return b, nil return b, nil
} }
@@ -105,17 +103,28 @@ func (b *Bot) getOrCreateUser(userID int64, username string, isOwner bool) (User
err := b.db.Preload("Role").Where("telegram_id = ? AND bot_id = ?", userID, b.botID).First(&user).Error err := b.db.Preload("Role").Where("telegram_id = ? AND bot_id = ?", userID, b.botID).First(&user).Error
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
var role Role // Check if an owner already exists for this bot
if isOwner { if isOwner {
role, err = b.getRoleByName("owner") var existingOwner User
if err != nil { err := b.db.Where("bot_id = ? AND is_owner = ?", b.botID, true).First(&existingOwner).Error
return User{}, err if err == nil {
return User{}, fmt.Errorf("an owner already exists for this bot")
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
return User{}, fmt.Errorf("failed to check existing owner: %w", err)
} }
}
var role Role
var roleName string
if isOwner {
roleName = "owner"
} else { } else {
role, err = b.getRoleByName("admin") roleName = "user" // Assign "user" role to non-owner users
if err != nil { }
return User{}, err
} err := b.db.Where("name = ?", roleName).First(&role).Error
if err != nil {
return User{}, fmt.Errorf("failed to get role: %w", err)
} }
user = User{ user = User{
@@ -123,39 +132,23 @@ func (b *Bot) getOrCreateUser(userID int64, username string, isOwner bool) (User
TelegramID: userID, TelegramID: userID,
Username: username, Username: username,
RoleID: role.ID, RoleID: role.ID,
Role: role,
IsOwner: isOwner, IsOwner: isOwner,
} }
if err := b.db.Create(&user).Error; err != nil { if err := b.db.Create(&user).Error; err != nil {
// Handle unique constraint for owner // If unique constraint is violated, another owner already exists
if isOwner && strings.Contains(err.Error(), "unique index") { if strings.Contains(err.Error(), "unique index") {
return User{}, fmt.Errorf("an owner already exists for this bot") return User{}, fmt.Errorf("an owner already exists for this bot")
} }
return User{}, err return User{}, fmt.Errorf("failed to create user: %w", err)
} }
} else { } else {
return User{}, err return User{}, err
} }
} else { } else {
if isOwner && !user.IsOwner { if isOwner && !user.IsOwner {
// Check if another owner exists return User{}, fmt.Errorf("cannot change existing user to owner")
var existingOwner User
err := b.db.Where("bot_id = ? AND is_owner = ?", b.botID, true).First(&existingOwner).Error
if err == nil {
return User{}, fmt.Errorf("a bot can have only one owner")
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
return User{}, err
}
// Promote to owner
role, err := b.getRoleByName("owner")
if err != nil {
return User{}, err
}
user.RoleID = role.ID
user.IsOwner = true
if err := b.db.Save(&user).Error; err != nil {
return User{}, err
}
} }
} }
@@ -274,12 +267,17 @@ func (b *Bot) isAdminOrOwner(userID int64) bool {
return user.Role.Name == "admin" || user.Role.Name == "owner" return user.Role.Name == "admin" || user.Role.Name == "owner"
} }
func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot *bot.Bot, update *models.Update)) (*bot.Bot, error) { func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot *bot.Bot, update *models.Update)) (TelegramClient, error) {
opts := []bot.Option{ opts := []bot.Option{
bot.WithDefaultHandler(handleUpdate), bot.WithDefaultHandler(handleUpdate),
} }
return bot.New(token, opts...) tgBot, err := bot.New(token, opts...)
if err != nil {
return nil, err
}
return tgBot, nil
} }
func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, businessConnectionID string) error { func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, businessConnectionID string) error {
@@ -369,3 +367,37 @@ func isEmoji(r rune) bool {
(r >= 0x2600 && r <= 0x26FF) || // Misc symbols (r >= 0x2600 && r <= 0x26FF) || // Misc symbols
(r >= 0x2700 && r <= 0x27BF) // Dingbats (r >= 0x2700 && r <= 0x27BF) // Dingbats
} }
func (b *Bot) sendWhoAmI(ctx context.Context, chatID int64, userID int64, username string, businessConnectionID string) {
user, err := b.getOrCreateUser(userID, username, false)
if err != nil {
log.Printf("Error getting or creating user: %v", err)
b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve your information.", businessConnectionID)
return
}
caser := cases.Title(language.English)
whoAmIMessage := fmt.Sprintf(
"👤 Your Information:\n\n"+
"- Username: %s\n"+
"- Role: %s",
user.Username,
caser.String(user.Role.Name),
)
// Store the user's /whoami command
userMessage := b.createMessage(chatID, userID, username, "user", "/whoami", true)
if err := b.storeMessage(userMessage); err != nil {
log.Printf("Error storing user message: %v", err)
}
// Send and store the bot's response
if err := b.sendResponse(ctx, chatID, whoAmIMessage, businessConnectionID); err != nil {
log.Printf("Error sending /whoami message: %v", err)
}
assistantMessage := b.createMessage(chatID, 0, "", "assistant", whoAmIMessage, false)
if err := b.storeMessage(assistantMessage); err != nil {
log.Printf("Error storing assistant message: %v", err)
}
b.addMessageToChatMemory(b.getOrCreateChatMemory(chatID), assistantMessage)
}

View File

@@ -42,6 +42,9 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
case "/stats": case "/stats":
b.sendStats(ctx, chatID, userID, message.From.Username, businessConnectionID) b.sendStats(ctx, chatID, userID, message.From.Username, businessConnectionID)
return return
case "/whoami":
b.sendWhoAmI(ctx, chatID, userID, message.From.Username, businessConnectionID)
return
} }
} }
} }
@@ -82,6 +85,14 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
return return
} }
// Update the username if it's empty or has changed
if user.Username != username {
user.Username = username
if err := b.db.Save(&user).Error; err != nil {
log.Printf("Error updating user username: %v", err)
}
}
userMessage := b.createMessage(chatID, userID, username, user.Role.Name, text, true) userMessage := b.createMessage(chatID, userID, username, user.Role.Name, text, true)
userMessage.UserRole = string(anthropic.RoleUser) // Convert to string userMessage.UserRole = string(anthropic.RoleUser) // Convert to string
if err := b.storeMessage(userMessage); err != nil { if err := b.storeMessage(userMessage); err != nil {
@@ -109,7 +120,6 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
assistantMessage := b.createMessage(chatID, 0, "", "assistant", response, false) assistantMessage := b.createMessage(chatID, 0, "", "assistant", response, false)
if err := b.storeMessage(assistantMessage); err != nil { if err := b.storeMessage(assistantMessage); err != nil {
log.Printf("Error storing assistant message: %v", err) log.Printf("Error storing assistant message: %v", err)
return
} }
b.addMessageToChatMemory(chatMemory, assistantMessage) b.addMessageToChatMemory(chatMemory, assistantMessage)
} }

14
main.go
View File

@@ -52,14 +52,24 @@ func main() {
go func(cfg BotConfig) { go func(cfg BotConfig) {
defer wg.Done() defer wg.Done()
// Create Bot instance with RealClock // Create Bot instance without TelegramClient initially
realClock := RealClock{} realClock := RealClock{}
bot, err := NewBot(db, cfg, realClock) bot, err := NewBot(db, cfg, realClock, nil)
if err != nil { if err != nil {
log.Printf("Error creating bot %s: %v", cfg.ID, err) log.Printf("Error creating bot %s: %v", cfg.ID, err)
return return
} }
// Initialize TelegramClient with the bot's handleUpdate method
tgClient, err := initTelegramBot(cfg.TelegramToken, bot.handleUpdate)
if err != nil {
log.Printf("Error initializing Telegram client for bot %s: %v", cfg.ID, err)
return
}
// Assign the TelegramClient to the bot
bot.tgBot = tgClient
// Start the bot // Start the bot
log.Printf("Starting bot %s...", cfg.ID) log.Printf("Starting bot %s...", cfg.ID)
bot.Start(ctx) bot.Start(ctx)

View File

@@ -1,10 +1,13 @@
package main package main
import ( import (
"strings" "context"
"fmt"
"testing" "testing"
"time" "time"
"github.com/go-telegram/bot"
"github.com/go-telegram/bot/models"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -121,8 +124,20 @@ func TestOwnerAssignment(t *testing.T) {
currentTime: time.Now(), currentTime: time.Now(),
} }
// Create the bot // Initialize MockTelegramClient
bot, err := NewBot(db, config, mockClock) mockTGClient := &MockTelegramClient{
SendMessageFunc: func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) {
chatID, ok := params.ChatID.(int64)
if !ok {
return nil, fmt.Errorf("ChatID is not of type int64")
}
// Simulate successful message sending
return &models.Message{ID: 1, Chat: models.Chat{ID: chatID}}, nil
},
}
// Create the bot with the mock Telegram client
bot, err := NewBot(db, config, mockClock, mockTGClient)
if err != nil { if err != nil {
t.Fatalf("Failed to create bot: %v", err) t.Fatalf("Failed to create bot: %v", err)
} }
@@ -142,7 +157,7 @@ func TestOwnerAssignment(t *testing.T) {
// Verify that the error message is appropriate // Verify that the error message is appropriate
expectedErrorMsg := "an owner already exists for this bot" expectedErrorMsg := "an owner already exists for this bot"
if err.Error() != expectedErrorMsg && !strings.Contains(err.Error(), "unique index") { if err.Error() != expectedErrorMsg {
t.Fatalf("Unexpected error message: %v", err) t.Fatalf("Unexpected error message: %v", err)
} }
@@ -156,34 +171,15 @@ func TestOwnerAssignment(t *testing.T) {
t.Fatalf("Expected role 'admin', got '%s'", adminUser.Role.Name) t.Fatalf("Expected role 'admin', got '%s'", adminUser.Role.Name)
} }
// Assign owner role to a user from a different bot // Attempt to change an existing user to owner
otherBotConfig := BotConfig{ _, err = bot.getOrCreateUser(333333333, "AdminUser", true)
ID: "other_bot", if err == nil {
TelegramToken: "OTHER_TELEGRAM_TOKEN", t.Fatalf("Expected error when changing existing user to owner, but got none")
MemorySize: 10,
MessagePerHour: 5,
MessagePerDay: 10,
TempBanDuration: "1m",
SystemPrompts: make(map[string]string),
Active: true,
OwnerTelegramID: 444444444,
} }
otherBot, err := NewBot(db, otherBotConfig, mockClock) expectedErrorMsg = "cannot change existing user to owner"
if err != nil { if err.Error() != expectedErrorMsg {
t.Fatalf("Failed to create other bot: %v", err) t.Fatalf("Unexpected error message: %v", err)
}
_, err = otherBot.getOrCreateUser(config.OwnerTelegramID, "OwnerOfOtherBot", true)
if err != nil {
t.Fatalf("Failed to assign existing owner to another bot: %v", err)
}
// Verify multiple bots can have the same owner telegram ID
var ownerOfOtherBot User
err = db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", config.OwnerTelegramID, otherBot.botID, true).First(&ownerOfOtherBot).Error
if err != nil {
t.Fatalf("Owner of other bot was not created: %v", err)
} }
} }

16
telegram_client.go Normal file
View File

@@ -0,0 +1,16 @@
// telegram_client.go
package main
import (
"context"
"github.com/go-telegram/bot"
"github.com/go-telegram/bot/models"
)
// TelegramClient defines the methods required from the Telegram bot.
type TelegramClient interface {
SendMessage(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error)
Start(ctx context.Context)
// Add other methods if needed.
}

35
telegram_client_mock.go Normal file
View File

@@ -0,0 +1,35 @@
// telegram_client_mock.go
package main
import (
"context"
"github.com/go-telegram/bot"
"github.com/go-telegram/bot/models"
)
// MockTelegramClient is a mock implementation of TelegramClient for testing.
type MockTelegramClient struct {
// You can add fields to keep track of calls if needed.
SendMessageFunc func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error)
StartFunc func(ctx context.Context) // Optional: track Start calls
}
// SendMessage mocks sending a message.
func (m *MockTelegramClient) SendMessage(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) {
if m.SendMessageFunc != nil {
return m.SendMessageFunc(ctx, params)
}
// Default behavior: return an empty message without error.
return &models.Message{}, nil
}
// Start mocks starting the Telegram client.
func (m *MockTelegramClient) Start(ctx context.Context) {
if m.StartFunc != nil {
m.StartFunc(ctx)
}
// Default behavior: do nothing.
}
// Add other mocked methods if your Bot uses more TelegramClient methods.