md formatting doesnt work yet

Started implementing owner feature

Add .gitattributes to enforce LF line endings

Temporary commit before merge

Updated owner management

Updated json and gitignore

Proceed with role management

Again, CI

Fix some lint errors

Implemented screening

Per-bot API keys implemented

Use getRoleByName func

Fix unused imports

Upgrade actions

rm unused function

Upgrade action

Fix unaddressed errors
This commit is contained in:
HugeFrog24
2024-10-20 17:17:21 +02:00
parent e5532df7f9
commit c8af457af1
18 changed files with 520 additions and 79 deletions

13
.gitattributes vendored Normal file
View File

@@ -0,0 +1,13 @@
# Enforce LF line endings for all files
* text eol=lf
# Specific file types that should always have LF line endings
*.go text eol=lf
*.json text eol=lf
*.sh text eol=lf
*.md text eol=lf
# Example: Binary files should not be modified
*.jpg binary
*.png binary
*.gif binary

54
.github/workflows/go-ci.yaml vendored Normal file
View File

@@ -0,0 +1,54 @@
name: CI
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build:
runs-on: ubuntu-latest
steps:
# Checkout the repository
- name: Checkout code
uses: actions/checkout@v4
# Set up Go environment
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: '1.23' # Specify the Go version you are using
# Cache Go modules
- name: Cache Go modules
uses: actions/cache@v4
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
# Install Dependencies
- name: Install Dependencies
run: go mod tidy
# Run Linters using golangci-lint
- name: Lint Code
uses: golangci/golangci-lint-action@v6
with:
version: v1.60 # Specify the version of golangci-lint
args: --timeout 5m
# Run Tests
- name: Run Tests
run: go test ./... -v
# Security Analysis using gosec
- name: Security Scan
uses: securego/gosec@master
with:
args: ./...

4
.gitignore vendored Executable file → Normal file
View File

@@ -4,8 +4,8 @@ vendor/
# Environment variables # Environment variables
.env .env
# Log file # Any log files
bot.log *.log
# Database file # Database file
bot.db bot.db

0
anthropic.go Executable file → Normal file
View File

221
bot.go Executable file → Normal file
View File

@@ -5,7 +5,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"log" "log"
"os"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -17,7 +16,7 @@ import (
) )
type Bot struct { type Bot struct {
tgBot *bot.Bot tgBot TelegramClient
db *gorm.DB db *gorm.DB
anthropicClient *anthropic.Client anthropicClient *anthropic.Client
chatMemories map[int64]*ChatMemory chatMemories map[int64]*ChatMemory
@@ -30,7 +29,8 @@ type Bot struct {
botID uint // Reference to BotModel.ID botID uint // Reference to BotModel.ID
} }
func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) { // NewBot initializes and returns a new Bot instance.
func NewBot(db *gorm.DB, config BotConfig, clock Clock, tgClient TelegramClient) (*Bot, error) {
// Retrieve or create Bot entry in the database // Retrieve or create Bot entry in the database
var botEntry BotModel var botEntry BotModel
err := db.Where("identifier = ?", config.ID).First(&botEntry).Error err := db.Where("identifier = ?", config.ID).First(&botEntry).Error
@@ -43,7 +43,38 @@ func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) {
return nil, err return nil, err
} }
anthropicClient := anthropic.NewClient(os.Getenv("ANTHROPIC_API_KEY")) // Ensure the owner exists in the Users table
var owner User
err = db.Where("telegram_id = ? AND bot_id = ?", config.OwnerTelegramID, botEntry.ID).First(&owner).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
// Assign the "owner" role
var ownerRole Role
err := db.Where("name = ?", "owner").First(&ownerRole).Error
if err != nil {
return nil, fmt.Errorf("owner role not found: %w", err)
}
owner = User{
BotID: botEntry.ID,
TelegramID: config.OwnerTelegramID,
Username: "", // Initialize as empty; will be updated upon interaction
RoleID: ownerRole.ID,
IsOwner: true,
}
if err := db.Create(&owner).Error; err != nil {
// If unique constraint is violated, another owner already exists
if strings.Contains(err.Error(), "unique index") {
return nil, fmt.Errorf("an owner already exists for this bot")
}
return nil, fmt.Errorf("failed to create owner user: %w", err)
}
} else if err != nil {
return nil, err
}
// Use the per-bot Anthropic API key
anthropicClient := anthropic.NewClient(config.AnthropicAPIKey)
b := &Bot{ b := &Bot{
db: db, db: db,
@@ -54,41 +85,80 @@ func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) {
userLimiters: make(map[int64]*userLimiter), userLimiters: make(map[int64]*userLimiter),
clock: clock, clock: clock,
botID: botEntry.ID, // Ensure BotModel has ID field botID: botEntry.ID, // Ensure BotModel has ID field
tgBot: tgClient,
} }
tgBot, err := initTelegramBot(config.TelegramToken, b.handleUpdate)
if err != nil {
return nil, err
}
b.tgBot = tgBot
return b, nil return b, nil
} }
// Start begins the bot's operation.
func (b *Bot) Start(ctx context.Context) { func (b *Bot) Start(ctx context.Context) {
b.tgBot.Start(ctx) b.tgBot.Start(ctx)
} }
func (b *Bot) getOrCreateUser(userID int64, username string) (User, error) { func (b *Bot) getOrCreateUser(userID int64, username string, isOwner bool) (User, error) {
var user User var user User
err := b.db.Preload("Role").Where("telegram_id = ?", userID).First(&user).Error err := b.db.Preload("Role").Where("telegram_id = ? AND bot_id = ?", userID, b.botID).First(&user).Error
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
var defaultRole Role // Check if an owner already exists for this bot
if err := b.db.Where("name = ?", "user").First(&defaultRole).Error; err != nil { if isOwner {
return User{}, err var existingOwner User
err := b.db.Where("bot_id = ? AND is_owner = ?", b.botID, true).First(&existingOwner).Error
if err == nil {
return User{}, fmt.Errorf("an owner already exists for this bot")
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
return User{}, fmt.Errorf("failed to check existing owner: %w", err)
} }
user = User{TelegramID: userID, Username: username, RoleID: defaultRole.ID} }
var role Role
var roleName string
if isOwner {
roleName = "owner"
} else {
roleName = "user" // Assign "user" role to non-owner users
}
err := b.db.Where("name = ?", roleName).First(&role).Error
if err != nil {
return User{}, fmt.Errorf("failed to get role: %w", err)
}
user = User{
BotID: b.botID,
TelegramID: userID,
Username: username,
RoleID: role.ID,
Role: role,
IsOwner: isOwner,
}
if err := b.db.Create(&user).Error; err != nil { if err := b.db.Create(&user).Error; err != nil {
return User{}, err // If unique constraint is violated, another owner already exists
if strings.Contains(err.Error(), "unique index") {
return User{}, fmt.Errorf("an owner already exists for this bot")
}
return User{}, fmt.Errorf("failed to create user: %w", err)
} }
} else { } else {
return User{}, err return User{}, err
} }
} else {
if isOwner && !user.IsOwner {
return User{}, fmt.Errorf("cannot change existing user to owner")
} }
}
return user, nil return user, nil
} }
func (b *Bot) getRoleByName(roleName string) (Role, error) {
var role Role
err := b.db.Where("name = ?", roleName).First(&role).Error
return role, err
}
func (b *Bot) createMessage(chatID, userID int64, username, userRole, text string, isUser bool) Message { func (b *Bot) createMessage(chatID, userID int64, username, userRole, text string, isUser bool) Message {
message := Message{ message := Message{
ChatID: chatID, ChatID: chatID,
@@ -195,17 +265,28 @@ func (b *Bot) isAdminOrOwner(userID int64) bool {
return user.Role.Name == "admin" || user.Role.Name == "owner" return user.Role.Name == "admin" || user.Role.Name == "owner"
} }
func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot *bot.Bot, update *models.Update)) (*bot.Bot, error) { func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot *bot.Bot, update *models.Update)) (TelegramClient, error) {
opts := []bot.Option{ opts := []bot.Option{
bot.WithDefaultHandler(handleUpdate), bot.WithDefaultHandler(handleUpdate),
} }
return bot.New(token, opts...) tgBot, err := bot.New(token, opts...)
if err != nil {
return nil, err
}
return tgBot, nil
} }
// sendResponse sends a message to the specified chat.
// Returns an error if sending the message fails.
func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, businessConnectionID string) error { func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, businessConnectionID string) error {
// Pass the outgoing message through the centralized screen for storage
_, err := b.screenOutgoingMessage(chatID, text, businessConnectionID)
if err != nil {
log.Printf("Error storing assistant message: %v", err)
return err
}
// Prepare message parameters
params := &bot.SendMessageParams{ params := &bot.SendMessageParams{
ChatID: chatID, ChatID: chatID,
Text: text, Text: text,
@@ -215,7 +296,8 @@ func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, busin
params.BusinessConnectionID = businessConnectionID params.BusinessConnectionID = businessConnectionID
} }
_, err := b.tgBot.SendMessage(ctx, params) // Send the message via Telegram client
_, err = b.tgBot.SendMessage(ctx, params)
if err != nil { if err != nil {
log.Printf("[%s] [ERROR] Error sending message to chat %d with BusinessConnectionID %s: %v", log.Printf("[%s] [ERROR] Error sending message to chat %d with BusinessConnectionID %s: %v",
b.config.ID, chatID, businessConnectionID, err) b.config.ID, chatID, businessConnectionID, err)
@@ -225,16 +307,29 @@ func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, busin
} }
// sendStats sends the bot statistics to the specified chat. // sendStats sends the bot statistics to the specified chat.
func (b *Bot) sendStats(ctx context.Context, chatID int64, businessConnectionID string) { func (b *Bot) sendStats(ctx context.Context, chatID int64, userID int64, username string, businessConnectionID string) {
totalUsers, totalMessages, err := b.getStats() totalUsers, totalMessages, err := b.getStats()
if err != nil { if err != nil {
fmt.Printf("Error fetching stats: %v\n", err) fmt.Printf("Error fetching stats: %v\n", err)
b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve the stats at this time.", businessConnectionID) if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve the stats at this time.", businessConnectionID); err != nil {
log.Printf("Error sending response: %v", err)
}
return return
} }
statsMessage := fmt.Sprintf("📊 **Bot Statistics:**\n\n- Total Users: %d\n- Total Messages: %d", totalUsers, totalMessages) // Do NOT manually escape hyphens here
b.sendResponse(ctx, chatID, statsMessage, businessConnectionID) statsMessage := fmt.Sprintf(
"📊 Bot Statistics:\n\n"+
"- Total Users: %d\n"+
"- Total Messages: %d",
totalUsers,
totalMessages,
)
// Send the response through the centralized screen
if err := b.sendResponse(ctx, chatID, statsMessage, businessConnectionID); err != nil {
log.Printf("Error sending stats message: %v", err)
}
} }
// getStats retrieves the total number of users and messages from the database. // getStats retrieves the total number of users and messages from the database.
@@ -271,3 +366,77 @@ func isEmoji(r rune) bool {
(r >= 0x2600 && r <= 0x26FF) || // Misc symbols (r >= 0x2600 && r <= 0x26FF) || // Misc symbols
(r >= 0x2700 && r <= 0x27BF) // Dingbats (r >= 0x2700 && r <= 0x27BF) // Dingbats
} }
func (b *Bot) sendWhoAmI(ctx context.Context, chatID int64, userID int64, username string, businessConnectionID string) {
user, err := b.getOrCreateUser(userID, username, false)
if err != nil {
log.Printf("Error getting or creating user: %v", err)
if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve your information.", businessConnectionID); err != nil {
log.Printf("Error sending response: %v", err)
}
return
}
role, err := b.getRoleByName(user.Role.Name)
if err != nil {
log.Printf("Error getting role by name: %v", err)
if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve your role information.", businessConnectionID); err != nil {
log.Printf("Error sending response: %v", err)
}
return
}
whoAmIMessage := fmt.Sprintf(
"👤 Your Information:\n\n"+
"- Username: %s\n"+
"- Role: %s",
user.Username,
role.Name,
)
// Send the response through the centralized screen
if err := b.sendResponse(ctx, chatID, whoAmIMessage, businessConnectionID); err != nil {
log.Printf("Error sending /whoami message: %v", err)
}
}
// screenIncomingMessage handles storing of incoming messages.
func (b *Bot) screenIncomingMessage(message *models.Message) (Message, error) {
userRole := string(anthropic.RoleUser) // Convert RoleUser to string
userMessage := b.createMessage(message.Chat.ID, message.From.ID, message.From.Username, userRole, message.Text, true)
// If the message contains a sticker, include its details.
if message.Sticker != nil {
userMessage.StickerFileID = message.Sticker.FileID
if message.Sticker.Thumbnail != nil {
userMessage.StickerPNGFile = message.Sticker.Thumbnail.FileID
}
}
// Store the message.
if err := b.storeMessage(userMessage); err != nil {
return Message{}, err
}
// Update chat memory.
chatMemory := b.getOrCreateChatMemory(message.Chat.ID)
b.addMessageToChatMemory(chatMemory, userMessage)
return userMessage, nil
}
// screenOutgoingMessage handles storing of outgoing messages.
func (b *Bot) screenOutgoingMessage(chatID int64, response string, businessConnectionID string) (Message, error) {
assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false)
// Store the message.
if err := b.storeMessage(assistantMessage); err != nil {
return Message{}, err
}
// Update chat memory.
chatMemory := b.getOrCreateChatMemory(chatID)
b.addMessageToChatMemory(chatMemory, assistantMessage)
return assistantMessage, nil
}

0
clock.go Executable file → Normal file
View File

9
config.go Executable file → Normal file
View File

@@ -18,6 +18,9 @@ type BotConfig struct {
TempBanDuration string `json:"temp_ban_duration"` TempBanDuration string `json:"temp_ban_duration"`
Model anthropic.Model `json:"model"` // Changed from string to anthropic.Model Model anthropic.Model `json:"model"` // Changed from string to anthropic.Model
SystemPrompts map[string]string `json:"system_prompts"` SystemPrompts map[string]string `json:"system_prompts"`
Active bool `json:"active"` // New field to control bot activity
OwnerTelegramID int64 `json:"owner_telegram_id"`
AnthropicAPIKey string `json:"anthropic_api_key"` // Add this line
} }
// Custom unmarshalling to handle anthropic.Model // Custom unmarshalling to handle anthropic.Model
@@ -56,6 +59,12 @@ func loadAllConfigs(dir string) ([]BotConfig, error) {
return nil, fmt.Errorf("failed to load config %s: %w", configPath, err) return nil, fmt.Errorf("failed to load config %s: %w", configPath, err)
} }
// Skip inactive bots
if !config.Active {
fmt.Printf("Skipping inactive bot: %s\n", config.ID)
continue
}
// Validate that ID is present // Validate that ID is present
if config.ID == "" { if config.ID == "" {
return nil, fmt.Errorf("config %s is missing 'id' field", configPath) return nil, fmt.Errorf("config %s is missing 'id' field", configPath)

3
config/default.json Executable file → Normal file
View File

@@ -1,6 +1,9 @@
{ {
"id": "default_bot", "id": "default_bot",
"active": false,
"telegram_token": "YOUR_TELEGRAM_BOT_TOKEN", "telegram_token": "YOUR_TELEGRAM_BOT_TOKEN",
"owner_telegram_id": 111111111,
"anthropic_api_key": "YOUR_SPECIFIC_ANTHROPIC_API_KEY",
"memory_size": 10, "memory_size": 10,
"messages_per_hour": 20, "messages_per_hour": 20,
"messages_per_day": 100, "messages_per_day": 100,

11
database.go Executable file → Normal file
View File

@@ -27,11 +27,22 @@ func initDB() (*gorm.DB, error) {
return nil, fmt.Errorf("failed to connect to database: %w", err) return nil, fmt.Errorf("failed to connect to database: %w", err)
} }
// AutoMigrate the models
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{}) err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{})
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to migrate database schema: %w", err) return nil, fmt.Errorf("failed to migrate database schema: %w", err)
} }
// Enforce unique owner per bot using raw SQL
// Note: SQLite doesn't support partial indexes, but we can simulate it by making a unique index on (BotID, IsOwner)
// and ensuring that IsOwner can only be true for one user per BotID.
// This approach allows multiple users with IsOwner=false for the same BotID,
// but only one user can have IsOwner=true per BotID.
err = db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_bot_owner ON users (bot_id, is_owner) WHERE is_owner = 1;`).Error
if err != nil {
return nil, fmt.Errorf("failed to create unique index for bot owners: %w", err)
}
err = createDefaultRoles(db) err = createDefaultRoles(db)
if err != nil { if err != nil {
return nil, err return nil, err

4
go.mod Executable file → Normal file
View File

@@ -1,6 +1,6 @@
module github.com/HugeFrog24/thatsky-telegram-bot module github.com/HugeFrog24/go-telegram-bot
go 1.23.2 go 1.23
require ( require (
github.com/go-telegram/bot v1.9.0 github.com/go-telegram/bot v1.9.0

0
go.sum Executable file → Normal file
View File

86
handlers.go Executable file → Normal file
View File

@@ -22,9 +22,6 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
return return
} }
chatID := message.Chat.ID
userID := message.From.ID
// Extract businessConnectionID if available // Extract businessConnectionID if available
var businessConnectionID string var businessConnectionID string
if update.BusinessConnection != nil { if update.BusinessConnection != nil {
@@ -33,6 +30,18 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
businessConnectionID = message.BusinessConnectionID businessConnectionID = message.BusinessConnectionID
} }
chatID := message.Chat.ID
userID := message.From.ID
username := message.From.Username
text := message.Text
// Pass the incoming message through the centralized screen for storage
_, err := b.screenIncomingMessage(message)
if err != nil {
log.Printf("Error storing user message: %v", err)
return
}
// Check if the message is a command // Check if the message is a command
if message.Entities != nil { if message.Entities != nil {
for _, entity := range message.Entities { for _, entity := range message.Entities {
@@ -40,7 +49,10 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
command := strings.TrimSpace(message.Text[entity.Offset : entity.Offset+entity.Length]) command := strings.TrimSpace(message.Text[entity.Offset : entity.Offset+entity.Length])
switch command { switch command {
case "/stats": case "/stats":
b.sendStats(ctx, chatID, businessConnectionID) b.sendStats(ctx, chatID, userID, username, businessConnectionID)
return
case "/whoami":
b.sendWhoAmI(ctx, chatID, userID, username, businessConnectionID)
return return
} }
} }
@@ -53,59 +65,71 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
return return
} }
// Existing rate limit and message handling // Rate limit check
if !b.checkRateLimits(userID) { if !b.checkRateLimits(userID) {
b.sendRateLimitExceededMessage(ctx, chatID, businessConnectionID) b.sendRateLimitExceededMessage(ctx, chatID, businessConnectionID)
return return
} }
username := message.From.Username
text := message.Text
// Proceed only if the message contains text // Proceed only if the message contains text
if text == "" { if text == "" {
// Optionally, handle other message types or ignore
log.Printf("Received a non-text message from user %d in chat %d", userID, chatID) log.Printf("Received a non-text message from user %d in chat %d", userID, chatID)
return return
} }
user, err := b.getOrCreateUser(userID, username) // Determine if the user is the owner
var isOwner bool
err = b.db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", userID, b.botID, true).First(&User{}).Error
if err == nil {
isOwner = true
}
user, err := b.getOrCreateUser(userID, username, isOwner)
if err != nil { if err != nil {
log.Printf("Error getting or creating user: %v", err) log.Printf("Error getting or creating user: %v", err)
return return
} }
userMessage := b.createMessage(chatID, userID, username, user.Role.Name, text, true) // Update the username if it's empty or has changed
userMessage.UserRole = string(anthropic.RoleUser) // Convert to string if user.Username != username {
b.storeMessage(userMessage) user.Username = username
if err := b.db.Save(&user).Error; err != nil {
log.Printf("Error updating user username: %v", err)
}
}
// Determine if the text contains only emojis
isEmojiOnly := isOnlyEmojis(text)
// Prepare context messages for Anthropic
chatMemory := b.getOrCreateChatMemory(chatID) chatMemory := b.getOrCreateChatMemory(chatID)
b.addMessageToChatMemory(chatMemory, userMessage) b.addMessageToChatMemory(chatMemory, b.createMessage(chatID, userID, username, user.Role.Name, text, true))
contextMessages := b.prepareContextMessages(chatMemory) contextMessages := b.prepareContextMessages(chatMemory)
isEmojiOnly := isOnlyEmojis(text) // Ensure you have this variable defined // Get response from Anthropic
response, err := b.getAnthropicResponse(ctx, contextMessages, b.isNewChat(chatID), b.isAdminOrOwner(userID), isEmojiOnly) response, err := b.getAnthropicResponse(ctx, contextMessages, b.isNewChat(chatID), isOwner, isEmojiOnly)
if err != nil { if err != nil {
log.Printf("Error getting Anthropic response: %v", err) log.Printf("Error getting Anthropic response: %v", err)
response = "I'm sorry, I'm having trouble processing your request right now." response = "I'm sorry, I'm having trouble processing your request right now."
} }
b.sendResponse(ctx, chatID, response, businessConnectionID) // Send the response through the centralized screen
if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil {
assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false) log.Printf("Error sending response: %v", err)
b.storeMessage(assistantMessage) return
b.addMessageToChatMemory(chatMemory, assistantMessage) }
} }
func (b *Bot) sendRateLimitExceededMessage(ctx context.Context, chatID int64, businessConnectionID string) { func (b *Bot) sendRateLimitExceededMessage(ctx context.Context, chatID int64, businessConnectionID string) {
b.sendResponse(ctx, chatID, "Rate limit exceeded. Please try again later.", businessConnectionID) if err := b.sendResponse(ctx, chatID, "Rate limit exceeded. Please try again later.", businessConnectionID); err != nil {
log.Printf("Error sending rate limit exceeded message: %v", err)
}
} }
func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, message *models.Message, businessConnectionID string) { func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, message *models.Message, businessConnectionID string) {
username := message.From.Username username := message.From.Username
// Create and store the sticker message // Create the user message (without storing it manually)
userMessage := b.createMessage(chatID, userID, username, "user", "Sent a sticker.", true) userMessage := b.createMessage(chatID, userID, username, "user", "Sent a sticker.", true)
userMessage.StickerFileID = message.Sticker.FileID userMessage.StickerFileID = message.Sticker.FileID
@@ -114,9 +138,7 @@ func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, me
userMessage.StickerPNGFile = message.Sticker.Thumbnail.FileID userMessage.StickerPNGFile = message.Sticker.Thumbnail.FileID
} }
b.storeMessage(userMessage) // Update chat memory with the user message
// Update chat memory
chatMemory := b.getOrCreateChatMemory(chatID) chatMemory := b.getOrCreateChatMemory(chatID)
b.addMessageToChatMemory(chatMemory, userMessage) b.addMessageToChatMemory(chatMemory, userMessage)
@@ -134,11 +156,11 @@ func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, me
} }
} }
b.sendResponse(ctx, chatID, response, businessConnectionID) // Send the response through the centralized screen
if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil {
assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false) log.Printf("Error sending response: %v", err)
b.storeMessage(assistantMessage) return
b.addMessageToChatMemory(chatMemory, assistantMessage) }
} }
func (b *Bot) generateStickerResponse(ctx context.Context, message Message) (string, error) { func (b *Bot) generateStickerResponse(ctx context.Context, message Message) (string, error) {

26
main.go Executable file → Normal file
View File

@@ -24,9 +24,6 @@ func main() {
log.Printf("Error loading .env file: %v", err) log.Printf("Error loading .env file: %v", err)
} }
// Check for required environment variables
checkRequiredEnvVars()
// Initialize database // Initialize database
db, err := initDB() db, err := initDB()
if err != nil { if err != nil {
@@ -52,14 +49,24 @@ func main() {
go func(cfg BotConfig) { go func(cfg BotConfig) {
defer wg.Done() defer wg.Done()
// Create Bot instance with RealClock // Create Bot instance without TelegramClient initially
realClock := RealClock{} realClock := RealClock{}
bot, err := NewBot(db, cfg, realClock) bot, err := NewBot(db, cfg, realClock, nil)
if err != nil { if err != nil {
log.Printf("Error creating bot %s: %v", cfg.ID, err) log.Printf("Error creating bot %s: %v", cfg.ID, err)
return return
} }
// Initialize TelegramClient with the bot's handleUpdate method
tgClient, err := initTelegramBot(cfg.TelegramToken, bot.handleUpdate)
if err != nil {
log.Printf("Error initializing Telegram client for bot %s: %v", cfg.ID, err)
return
}
// Assign the TelegramClient to the bot
bot.tgBot = tgClient
// Start the bot // Start the bot
log.Printf("Starting bot %s...", cfg.ID) log.Printf("Starting bot %s...", cfg.ID)
bot.Start(ctx) bot.Start(ctx)
@@ -79,12 +86,3 @@ func initLogger() (*os.File, error) {
log.SetOutput(mw) log.SetOutput(mw)
return logFile, nil return logFile, nil
} }
func checkRequiredEnvVars() {
requiredEnvVars := []string{"ANTHROPIC_API_KEY"}
for _, envVar := range requiredEnvVars {
if os.Getenv(envVar) == "" {
log.Fatalf("%s environment variable is not set", envVar)
}
}
}

13
models.go Executable file → Normal file
View File

@@ -11,7 +11,7 @@ type BotModel struct {
Identifier string `gorm:"uniqueIndex"` // Renamed from ID to Identifier Identifier string `gorm:"uniqueIndex"` // Renamed from ID to Identifier
Name string Name string
Configs []ConfigModel `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"` Configs []ConfigModel `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"`
Users []User `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"` // Added foreign key Users []User `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"` // Associated users
Messages []Message `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"` Messages []Message `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"`
} }
@@ -24,6 +24,7 @@ type ConfigModel struct {
TempBanDuration string `json:"temp_ban_duration"` TempBanDuration string `json:"temp_ban_duration"`
SystemPrompts string `json:"system_prompts"` // Consider JSON string or separate table SystemPrompts string `json:"system_prompts"` // Consider JSON string or separate table
TelegramToken string `json:"telegram_token"` TelegramToken string `json:"telegram_token"`
Active bool `json:"active"`
} }
type Message struct { type Message struct {
@@ -53,9 +54,15 @@ type Role struct {
type User struct { type User struct {
gorm.Model gorm.Model
BotID uint `gorm:"index"` // Added foreign key to BotModel BotID uint `gorm:"index"` // Foreign key to BotModel
TelegramID int64 `gorm:"uniqueIndex"` // Consider composite unique index if TelegramID is unique per Bot TelegramID int64 `gorm:"uniqueIndex;not null"` // Unique per user
Username string Username string
RoleID uint RoleID uint
Role Role `gorm:"foreignKey:RoleID"` Role Role `gorm:"foreignKey:RoleID"`
IsOwner bool `gorm:"default:false"` // Indicates if the user is the owner
}
// Compound unique index to ensure only one owner per bot
func (User) TableName() string {
return "users"
} }

0
rate_limiter.go Executable file → Normal file
View File

104
rate_limiter_test.go Executable file → Normal file
View File

@@ -1,8 +1,15 @@
package main package main
import ( import (
"context"
"fmt"
"testing" "testing"
"time" "time"
"github.com/go-telegram/bot"
"github.com/go-telegram/bot/models"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
) )
// TestCheckRateLimits tests the checkRateLimits method of the Bot. // TestCheckRateLimits tests the checkRateLimits method of the Bot.
@@ -22,6 +29,7 @@ func TestCheckRateLimits(t *testing.T) {
TempBanDuration: "1m", // Temporary ban duration of 1 minute for testing TempBanDuration: "1m", // Temporary ban duration of 1 minute for testing
SystemPrompts: make(map[string]string), SystemPrompts: make(map[string]string),
TelegramToken: "YOUR_TELEGRAM_BOT_TOKEN", TelegramToken: "YOUR_TELEGRAM_BOT_TOKEN",
OwnerTelegramID: 123456789,
} }
// Initialize the Bot with mock data and MockClock // Initialize the Bot with mock data and MockClock
@@ -79,6 +87,102 @@ func TestCheckRateLimits(t *testing.T) {
} }
} }
func TestOwnerAssignment(t *testing.T) {
// Initialize in-memory database for testing
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("Failed to open in-memory database: %v", err)
}
// Migrate the schema
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{})
if err != nil {
t.Fatalf("Failed to migrate database schema: %v", err)
}
// Create default roles
err = createDefaultRoles(db)
if err != nil {
t.Fatalf("Failed to create default roles: %v", err)
}
// Create a bot configuration
config := BotConfig{
ID: "test_bot",
TelegramToken: "TEST_TELEGRAM_TOKEN",
MemorySize: 10,
MessagePerHour: 5,
MessagePerDay: 10,
TempBanDuration: "1m",
SystemPrompts: make(map[string]string),
Active: true,
OwnerTelegramID: 111111111,
}
// Initialize MockClock
mockClock := &MockClock{
currentTime: time.Now(),
}
// Initialize MockTelegramClient
mockTGClient := &MockTelegramClient{
SendMessageFunc: func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) {
chatID, ok := params.ChatID.(int64)
if !ok {
return nil, fmt.Errorf("ChatID is not of type int64")
}
// Simulate successful message sending
return &models.Message{ID: 1, Chat: models.Chat{ID: chatID}}, nil
},
}
// Create the bot with the mock Telegram client
bot, err := NewBot(db, config, mockClock, mockTGClient)
if err != nil {
t.Fatalf("Failed to create bot: %v", err)
}
// Verify that the owner exists
var owner User
err = db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", config.OwnerTelegramID, bot.botID, true).First(&owner).Error
if err != nil {
t.Fatalf("Owner was not created: %v", err)
}
// Attempt to create another owner for the same bot
_, err = bot.getOrCreateUser(222222222, "AnotherOwner", true)
if err == nil {
t.Fatalf("Expected error when creating a second owner, but got none")
}
// Verify that the error message is appropriate
expectedErrorMsg := "an owner already exists for this bot"
if err.Error() != expectedErrorMsg {
t.Fatalf("Unexpected error message: %v", err)
}
// Assign admin role to a new user
adminUser, err := bot.getOrCreateUser(333333333, "AdminUser", false)
if err != nil {
t.Fatalf("Failed to create admin user: %v", err)
}
if adminUser.Role.Name != "admin" {
t.Fatalf("Expected role 'admin', got '%s'", adminUser.Role.Name)
}
// Attempt to change an existing user to owner
_, err = bot.getOrCreateUser(333333333, "AdminUser", true)
if err == nil {
t.Fatalf("Expected error when changing existing user to owner, but got none")
}
expectedErrorMsg = "cannot change existing user to owner"
if err.Error() != expectedErrorMsg {
t.Fatalf("Unexpected error message: %v", err)
}
}
// To ensure thread safety and avoid race conditions during testing, // To ensure thread safety and avoid race conditions during testing,
// you can run the tests with the `-race` flag: // you can run the tests with the `-race` flag:
// go test -race -v // go test -race -v

16
telegram_client.go Normal file
View File

@@ -0,0 +1,16 @@
// telegram_client.go
package main
import (
"context"
"github.com/go-telegram/bot"
"github.com/go-telegram/bot/models"
)
// TelegramClient defines the methods required from the Telegram bot.
type TelegramClient interface {
SendMessage(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error)
Start(ctx context.Context)
// Add other methods if needed.
}

35
telegram_client_mock.go Normal file
View File

@@ -0,0 +1,35 @@
// telegram_client_mock.go
package main
import (
"context"
"github.com/go-telegram/bot"
"github.com/go-telegram/bot/models"
)
// MockTelegramClient is a mock implementation of TelegramClient for testing.
type MockTelegramClient struct {
// You can add fields to keep track of calls if needed.
SendMessageFunc func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error)
StartFunc func(ctx context.Context) // Optional: track Start calls
}
// SendMessage mocks sending a message.
func (m *MockTelegramClient) SendMessage(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) {
if m.SendMessageFunc != nil {
return m.SendMessageFunc(ctx, params)
}
// Default behavior: return an empty message without error.
return &models.Message{}, nil
}
// Start mocks starting the Telegram client.
func (m *MockTelegramClient) Start(ctx context.Context) {
if m.StartFunc != nil {
m.StartFunc(ctx)
}
// Default behavior: do nothing.
}
// Add other mocked methods if your Bot uses more TelegramClient methods.