md formatting doesnt work yet

Started implementing owner feature

Add .gitattributes to enforce LF line endings

Temporary commit before merge

Updated owner management

Updated json and gitignore

Proceed with role management

Again, CI

Fix some lint errors

Implemented screening

Per-bot API keys implemented

Use getRoleByName func

Fix unused imports

Upgrade actions

rm unused function

Upgrade action

Fix unaddressed errors
This commit is contained in:
HugeFrog24
2024-10-20 17:17:21 +02:00
parent e5532df7f9
commit c8af457af1
18 changed files with 520 additions and 79 deletions

13
.gitattributes vendored Normal file
View File

@@ -0,0 +1,13 @@
# Enforce LF line endings for all files
* text eol=lf
# Specific file types that should always have LF line endings
*.go text eol=lf
*.json text eol=lf
*.sh text eol=lf
*.md text eol=lf
# Example: Binary files should not be modified
*.jpg binary
*.png binary
*.gif binary

54
.github/workflows/go-ci.yaml vendored Normal file
View File

@@ -0,0 +1,54 @@
name: CI
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build:
runs-on: ubuntu-latest
steps:
# Checkout the repository
- name: Checkout code
uses: actions/checkout@v4
# Set up Go environment
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: '1.23' # Specify the Go version you are using
# Cache Go modules
- name: Cache Go modules
uses: actions/cache@v4
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
# Install Dependencies
- name: Install Dependencies
run: go mod tidy
# Run Linters using golangci-lint
- name: Lint Code
uses: golangci/golangci-lint-action@v6
with:
version: v1.60 # Specify the version of golangci-lint
args: --timeout 5m
# Run Tests
- name: Run Tests
run: go test ./... -v
# Security Analysis using gosec
- name: Security Scan
uses: securego/gosec@master
with:
args: ./...

4
.gitignore vendored Executable file → Normal file
View File

@@ -4,8 +4,8 @@ vendor/
# Environment variables
.env
# Log file
bot.log
# Any log files
*.log
# Database file
bot.db

0
anthropic.go Executable file → Normal file
View File

221
bot.go Executable file → Normal file
View File

@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"log"
"os"
"strings"
"sync"
"time"
@@ -17,7 +16,7 @@ import (
)
type Bot struct {
tgBot *bot.Bot
tgBot TelegramClient
db *gorm.DB
anthropicClient *anthropic.Client
chatMemories map[int64]*ChatMemory
@@ -30,7 +29,8 @@ type Bot struct {
botID uint // Reference to BotModel.ID
}
func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) {
// NewBot initializes and returns a new Bot instance.
func NewBot(db *gorm.DB, config BotConfig, clock Clock, tgClient TelegramClient) (*Bot, error) {
// Retrieve or create Bot entry in the database
var botEntry BotModel
err := db.Where("identifier = ?", config.ID).First(&botEntry).Error
@@ -43,7 +43,38 @@ func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) {
return nil, err
}
anthropicClient := anthropic.NewClient(os.Getenv("ANTHROPIC_API_KEY"))
// Ensure the owner exists in the Users table
var owner User
err = db.Where("telegram_id = ? AND bot_id = ?", config.OwnerTelegramID, botEntry.ID).First(&owner).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
// Assign the "owner" role
var ownerRole Role
err := db.Where("name = ?", "owner").First(&ownerRole).Error
if err != nil {
return nil, fmt.Errorf("owner role not found: %w", err)
}
owner = User{
BotID: botEntry.ID,
TelegramID: config.OwnerTelegramID,
Username: "", // Initialize as empty; will be updated upon interaction
RoleID: ownerRole.ID,
IsOwner: true,
}
if err := db.Create(&owner).Error; err != nil {
// If unique constraint is violated, another owner already exists
if strings.Contains(err.Error(), "unique index") {
return nil, fmt.Errorf("an owner already exists for this bot")
}
return nil, fmt.Errorf("failed to create owner user: %w", err)
}
} else if err != nil {
return nil, err
}
// Use the per-bot Anthropic API key
anthropicClient := anthropic.NewClient(config.AnthropicAPIKey)
b := &Bot{
db: db,
@@ -54,41 +85,80 @@ func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) {
userLimiters: make(map[int64]*userLimiter),
clock: clock,
botID: botEntry.ID, // Ensure BotModel has ID field
tgBot: tgClient,
}
tgBot, err := initTelegramBot(config.TelegramToken, b.handleUpdate)
if err != nil {
return nil, err
}
b.tgBot = tgBot
return b, nil
}
// Start begins the bot's operation.
func (b *Bot) Start(ctx context.Context) {
b.tgBot.Start(ctx)
}
func (b *Bot) getOrCreateUser(userID int64, username string) (User, error) {
func (b *Bot) getOrCreateUser(userID int64, username string, isOwner bool) (User, error) {
var user User
err := b.db.Preload("Role").Where("telegram_id = ?", userID).First(&user).Error
err := b.db.Preload("Role").Where("telegram_id = ? AND bot_id = ?", userID, b.botID).First(&user).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
var defaultRole Role
if err := b.db.Where("name = ?", "user").First(&defaultRole).Error; err != nil {
return User{}, err
// Check if an owner already exists for this bot
if isOwner {
var existingOwner User
err := b.db.Where("bot_id = ? AND is_owner = ?", b.botID, true).First(&existingOwner).Error
if err == nil {
return User{}, fmt.Errorf("an owner already exists for this bot")
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
return User{}, fmt.Errorf("failed to check existing owner: %w", err)
}
}
user = User{TelegramID: userID, Username: username, RoleID: defaultRole.ID}
var role Role
var roleName string
if isOwner {
roleName = "owner"
} else {
roleName = "user" // Assign "user" role to non-owner users
}
err := b.db.Where("name = ?", roleName).First(&role).Error
if err != nil {
return User{}, fmt.Errorf("failed to get role: %w", err)
}
user = User{
BotID: b.botID,
TelegramID: userID,
Username: username,
RoleID: role.ID,
Role: role,
IsOwner: isOwner,
}
if err := b.db.Create(&user).Error; err != nil {
return User{}, err
// If unique constraint is violated, another owner already exists
if strings.Contains(err.Error(), "unique index") {
return User{}, fmt.Errorf("an owner already exists for this bot")
}
return User{}, fmt.Errorf("failed to create user: %w", err)
}
} else {
return User{}, err
}
} else {
if isOwner && !user.IsOwner {
return User{}, fmt.Errorf("cannot change existing user to owner")
}
}
return user, nil
}
func (b *Bot) getRoleByName(roleName string) (Role, error) {
var role Role
err := b.db.Where("name = ?", roleName).First(&role).Error
return role, err
}
func (b *Bot) createMessage(chatID, userID int64, username, userRole, text string, isUser bool) Message {
message := Message{
ChatID: chatID,
@@ -195,17 +265,28 @@ func (b *Bot) isAdminOrOwner(userID int64) bool {
return user.Role.Name == "admin" || user.Role.Name == "owner"
}
func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot *bot.Bot, update *models.Update)) (*bot.Bot, error) {
func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot *bot.Bot, update *models.Update)) (TelegramClient, error) {
opts := []bot.Option{
bot.WithDefaultHandler(handleUpdate),
}
return bot.New(token, opts...)
tgBot, err := bot.New(token, opts...)
if err != nil {
return nil, err
}
return tgBot, nil
}
// sendResponse sends a message to the specified chat.
// Returns an error if sending the message fails.
func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, businessConnectionID string) error {
// Pass the outgoing message through the centralized screen for storage
_, err := b.screenOutgoingMessage(chatID, text, businessConnectionID)
if err != nil {
log.Printf("Error storing assistant message: %v", err)
return err
}
// Prepare message parameters
params := &bot.SendMessageParams{
ChatID: chatID,
Text: text,
@@ -215,7 +296,8 @@ func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, busin
params.BusinessConnectionID = businessConnectionID
}
_, err := b.tgBot.SendMessage(ctx, params)
// Send the message via Telegram client
_, err = b.tgBot.SendMessage(ctx, params)
if err != nil {
log.Printf("[%s] [ERROR] Error sending message to chat %d with BusinessConnectionID %s: %v",
b.config.ID, chatID, businessConnectionID, err)
@@ -225,16 +307,29 @@ func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, busin
}
// sendStats sends the bot statistics to the specified chat.
func (b *Bot) sendStats(ctx context.Context, chatID int64, businessConnectionID string) {
func (b *Bot) sendStats(ctx context.Context, chatID int64, userID int64, username string, businessConnectionID string) {
totalUsers, totalMessages, err := b.getStats()
if err != nil {
fmt.Printf("Error fetching stats: %v\n", err)
b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve the stats at this time.", businessConnectionID)
if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve the stats at this time.", businessConnectionID); err != nil {
log.Printf("Error sending response: %v", err)
}
return
}
statsMessage := fmt.Sprintf("📊 **Bot Statistics:**\n\n- Total Users: %d\n- Total Messages: %d", totalUsers, totalMessages)
b.sendResponse(ctx, chatID, statsMessage, businessConnectionID)
// Do NOT manually escape hyphens here
statsMessage := fmt.Sprintf(
"📊 Bot Statistics:\n\n"+
"- Total Users: %d\n"+
"- Total Messages: %d",
totalUsers,
totalMessages,
)
// Send the response through the centralized screen
if err := b.sendResponse(ctx, chatID, statsMessage, businessConnectionID); err != nil {
log.Printf("Error sending stats message: %v", err)
}
}
// getStats retrieves the total number of users and messages from the database.
@@ -271,3 +366,77 @@ func isEmoji(r rune) bool {
(r >= 0x2600 && r <= 0x26FF) || // Misc symbols
(r >= 0x2700 && r <= 0x27BF) // Dingbats
}
func (b *Bot) sendWhoAmI(ctx context.Context, chatID int64, userID int64, username string, businessConnectionID string) {
user, err := b.getOrCreateUser(userID, username, false)
if err != nil {
log.Printf("Error getting or creating user: %v", err)
if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve your information.", businessConnectionID); err != nil {
log.Printf("Error sending response: %v", err)
}
return
}
role, err := b.getRoleByName(user.Role.Name)
if err != nil {
log.Printf("Error getting role by name: %v", err)
if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve your role information.", businessConnectionID); err != nil {
log.Printf("Error sending response: %v", err)
}
return
}
whoAmIMessage := fmt.Sprintf(
"👤 Your Information:\n\n"+
"- Username: %s\n"+
"- Role: %s",
user.Username,
role.Name,
)
// Send the response through the centralized screen
if err := b.sendResponse(ctx, chatID, whoAmIMessage, businessConnectionID); err != nil {
log.Printf("Error sending /whoami message: %v", err)
}
}
// screenIncomingMessage handles storing of incoming messages.
func (b *Bot) screenIncomingMessage(message *models.Message) (Message, error) {
userRole := string(anthropic.RoleUser) // Convert RoleUser to string
userMessage := b.createMessage(message.Chat.ID, message.From.ID, message.From.Username, userRole, message.Text, true)
// If the message contains a sticker, include its details.
if message.Sticker != nil {
userMessage.StickerFileID = message.Sticker.FileID
if message.Sticker.Thumbnail != nil {
userMessage.StickerPNGFile = message.Sticker.Thumbnail.FileID
}
}
// Store the message.
if err := b.storeMessage(userMessage); err != nil {
return Message{}, err
}
// Update chat memory.
chatMemory := b.getOrCreateChatMemory(message.Chat.ID)
b.addMessageToChatMemory(chatMemory, userMessage)
return userMessage, nil
}
// screenOutgoingMessage handles storing of outgoing messages.
func (b *Bot) screenOutgoingMessage(chatID int64, response string, businessConnectionID string) (Message, error) {
assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false)
// Store the message.
if err := b.storeMessage(assistantMessage); err != nil {
return Message{}, err
}
// Update chat memory.
chatMemory := b.getOrCreateChatMemory(chatID)
b.addMessageToChatMemory(chatMemory, assistantMessage)
return assistantMessage, nil
}

0
clock.go Executable file → Normal file
View File

9
config.go Executable file → Normal file
View File

@@ -18,6 +18,9 @@ type BotConfig struct {
TempBanDuration string `json:"temp_ban_duration"`
Model anthropic.Model `json:"model"` // Changed from string to anthropic.Model
SystemPrompts map[string]string `json:"system_prompts"`
Active bool `json:"active"` // New field to control bot activity
OwnerTelegramID int64 `json:"owner_telegram_id"`
AnthropicAPIKey string `json:"anthropic_api_key"` // Add this line
}
// Custom unmarshalling to handle anthropic.Model
@@ -56,6 +59,12 @@ func loadAllConfigs(dir string) ([]BotConfig, error) {
return nil, fmt.Errorf("failed to load config %s: %w", configPath, err)
}
// Skip inactive bots
if !config.Active {
fmt.Printf("Skipping inactive bot: %s\n", config.ID)
continue
}
// Validate that ID is present
if config.ID == "" {
return nil, fmt.Errorf("config %s is missing 'id' field", configPath)

3
config/default.json Executable file → Normal file
View File

@@ -1,6 +1,9 @@
{
"id": "default_bot",
"active": false,
"telegram_token": "YOUR_TELEGRAM_BOT_TOKEN",
"owner_telegram_id": 111111111,
"anthropic_api_key": "YOUR_SPECIFIC_ANTHROPIC_API_KEY",
"memory_size": 10,
"messages_per_hour": 20,
"messages_per_day": 100,

11
database.go Executable file → Normal file
View File

@@ -27,11 +27,22 @@ func initDB() (*gorm.DB, error) {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
// AutoMigrate the models
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{})
if err != nil {
return nil, fmt.Errorf("failed to migrate database schema: %w", err)
}
// Enforce unique owner per bot using raw SQL
// Note: SQLite doesn't support partial indexes, but we can simulate it by making a unique index on (BotID, IsOwner)
// and ensuring that IsOwner can only be true for one user per BotID.
// This approach allows multiple users with IsOwner=false for the same BotID,
// but only one user can have IsOwner=true per BotID.
err = db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_bot_owner ON users (bot_id, is_owner) WHERE is_owner = 1;`).Error
if err != nil {
return nil, fmt.Errorf("failed to create unique index for bot owners: %w", err)
}
err = createDefaultRoles(db)
if err != nil {
return nil, err

4
go.mod Executable file → Normal file
View File

@@ -1,6 +1,6 @@
module github.com/HugeFrog24/thatsky-telegram-bot
module github.com/HugeFrog24/go-telegram-bot
go 1.23.2
go 1.23
require (
github.com/go-telegram/bot v1.9.0

0
go.sum Executable file → Normal file
View File

86
handlers.go Executable file → Normal file
View File

@@ -22,9 +22,6 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
return
}
chatID := message.Chat.ID
userID := message.From.ID
// Extract businessConnectionID if available
var businessConnectionID string
if update.BusinessConnection != nil {
@@ -33,6 +30,18 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
businessConnectionID = message.BusinessConnectionID
}
chatID := message.Chat.ID
userID := message.From.ID
username := message.From.Username
text := message.Text
// Pass the incoming message through the centralized screen for storage
_, err := b.screenIncomingMessage(message)
if err != nil {
log.Printf("Error storing user message: %v", err)
return
}
// Check if the message is a command
if message.Entities != nil {
for _, entity := range message.Entities {
@@ -40,7 +49,10 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
command := strings.TrimSpace(message.Text[entity.Offset : entity.Offset+entity.Length])
switch command {
case "/stats":
b.sendStats(ctx, chatID, businessConnectionID)
b.sendStats(ctx, chatID, userID, username, businessConnectionID)
return
case "/whoami":
b.sendWhoAmI(ctx, chatID, userID, username, businessConnectionID)
return
}
}
@@ -53,59 +65,71 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
return
}
// Existing rate limit and message handling
// Rate limit check
if !b.checkRateLimits(userID) {
b.sendRateLimitExceededMessage(ctx, chatID, businessConnectionID)
return
}
username := message.From.Username
text := message.Text
// Proceed only if the message contains text
if text == "" {
// Optionally, handle other message types or ignore
log.Printf("Received a non-text message from user %d in chat %d", userID, chatID)
return
}
user, err := b.getOrCreateUser(userID, username)
// Determine if the user is the owner
var isOwner bool
err = b.db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", userID, b.botID, true).First(&User{}).Error
if err == nil {
isOwner = true
}
user, err := b.getOrCreateUser(userID, username, isOwner)
if err != nil {
log.Printf("Error getting or creating user: %v", err)
return
}
userMessage := b.createMessage(chatID, userID, username, user.Role.Name, text, true)
userMessage.UserRole = string(anthropic.RoleUser) // Convert to string
b.storeMessage(userMessage)
// Update the username if it's empty or has changed
if user.Username != username {
user.Username = username
if err := b.db.Save(&user).Error; err != nil {
log.Printf("Error updating user username: %v", err)
}
}
// Determine if the text contains only emojis
isEmojiOnly := isOnlyEmojis(text)
// Prepare context messages for Anthropic
chatMemory := b.getOrCreateChatMemory(chatID)
b.addMessageToChatMemory(chatMemory, userMessage)
b.addMessageToChatMemory(chatMemory, b.createMessage(chatID, userID, username, user.Role.Name, text, true))
contextMessages := b.prepareContextMessages(chatMemory)
isEmojiOnly := isOnlyEmojis(text) // Ensure you have this variable defined
response, err := b.getAnthropicResponse(ctx, contextMessages, b.isNewChat(chatID), b.isAdminOrOwner(userID), isEmojiOnly)
// Get response from Anthropic
response, err := b.getAnthropicResponse(ctx, contextMessages, b.isNewChat(chatID), isOwner, isEmojiOnly)
if err != nil {
log.Printf("Error getting Anthropic response: %v", err)
response = "I'm sorry, I'm having trouble processing your request right now."
}
b.sendResponse(ctx, chatID, response, businessConnectionID)
assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false)
b.storeMessage(assistantMessage)
b.addMessageToChatMemory(chatMemory, assistantMessage)
// Send the response through the centralized screen
if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil {
log.Printf("Error sending response: %v", err)
return
}
}
func (b *Bot) sendRateLimitExceededMessage(ctx context.Context, chatID int64, businessConnectionID string) {
b.sendResponse(ctx, chatID, "Rate limit exceeded. Please try again later.", businessConnectionID)
if err := b.sendResponse(ctx, chatID, "Rate limit exceeded. Please try again later.", businessConnectionID); err != nil {
log.Printf("Error sending rate limit exceeded message: %v", err)
}
}
func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, message *models.Message, businessConnectionID string) {
username := message.From.Username
// Create and store the sticker message
// Create the user message (without storing it manually)
userMessage := b.createMessage(chatID, userID, username, "user", "Sent a sticker.", true)
userMessage.StickerFileID = message.Sticker.FileID
@@ -114,9 +138,7 @@ func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, me
userMessage.StickerPNGFile = message.Sticker.Thumbnail.FileID
}
b.storeMessage(userMessage)
// Update chat memory
// Update chat memory with the user message
chatMemory := b.getOrCreateChatMemory(chatID)
b.addMessageToChatMemory(chatMemory, userMessage)
@@ -134,11 +156,11 @@ func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, me
}
}
b.sendResponse(ctx, chatID, response, businessConnectionID)
assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false)
b.storeMessage(assistantMessage)
b.addMessageToChatMemory(chatMemory, assistantMessage)
// Send the response through the centralized screen
if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil {
log.Printf("Error sending response: %v", err)
return
}
}
func (b *Bot) generateStickerResponse(ctx context.Context, message Message) (string, error) {

26
main.go Executable file → Normal file
View File

@@ -24,9 +24,6 @@ func main() {
log.Printf("Error loading .env file: %v", err)
}
// Check for required environment variables
checkRequiredEnvVars()
// Initialize database
db, err := initDB()
if err != nil {
@@ -52,14 +49,24 @@ func main() {
go func(cfg BotConfig) {
defer wg.Done()
// Create Bot instance with RealClock
// Create Bot instance without TelegramClient initially
realClock := RealClock{}
bot, err := NewBot(db, cfg, realClock)
bot, err := NewBot(db, cfg, realClock, nil)
if err != nil {
log.Printf("Error creating bot %s: %v", cfg.ID, err)
return
}
// Initialize TelegramClient with the bot's handleUpdate method
tgClient, err := initTelegramBot(cfg.TelegramToken, bot.handleUpdate)
if err != nil {
log.Printf("Error initializing Telegram client for bot %s: %v", cfg.ID, err)
return
}
// Assign the TelegramClient to the bot
bot.tgBot = tgClient
// Start the bot
log.Printf("Starting bot %s...", cfg.ID)
bot.Start(ctx)
@@ -79,12 +86,3 @@ func initLogger() (*os.File, error) {
log.SetOutput(mw)
return logFile, nil
}
func checkRequiredEnvVars() {
requiredEnvVars := []string{"ANTHROPIC_API_KEY"}
for _, envVar := range requiredEnvVars {
if os.Getenv(envVar) == "" {
log.Fatalf("%s environment variable is not set", envVar)
}
}
}

13
models.go Executable file → Normal file
View File

@@ -11,7 +11,7 @@ type BotModel struct {
Identifier string `gorm:"uniqueIndex"` // Renamed from ID to Identifier
Name string
Configs []ConfigModel `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"`
Users []User `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"` // Added foreign key
Users []User `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"` // Associated users
Messages []Message `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"`
}
@@ -24,6 +24,7 @@ type ConfigModel struct {
TempBanDuration string `json:"temp_ban_duration"`
SystemPrompts string `json:"system_prompts"` // Consider JSON string or separate table
TelegramToken string `json:"telegram_token"`
Active bool `json:"active"`
}
type Message struct {
@@ -53,9 +54,15 @@ type Role struct {
type User struct {
gorm.Model
BotID uint `gorm:"index"` // Added foreign key to BotModel
TelegramID int64 `gorm:"uniqueIndex"` // Consider composite unique index if TelegramID is unique per Bot
BotID uint `gorm:"index"` // Foreign key to BotModel
TelegramID int64 `gorm:"uniqueIndex;not null"` // Unique per user
Username string
RoleID uint
Role Role `gorm:"foreignKey:RoleID"`
IsOwner bool `gorm:"default:false"` // Indicates if the user is the owner
}
// Compound unique index to ensure only one owner per bot
func (User) TableName() string {
return "users"
}

0
rate_limiter.go Executable file → Normal file
View File

104
rate_limiter_test.go Executable file → Normal file
View File

@@ -1,8 +1,15 @@
package main
import (
"context"
"fmt"
"testing"
"time"
"github.com/go-telegram/bot"
"github.com/go-telegram/bot/models"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// TestCheckRateLimits tests the checkRateLimits method of the Bot.
@@ -22,6 +29,7 @@ func TestCheckRateLimits(t *testing.T) {
TempBanDuration: "1m", // Temporary ban duration of 1 minute for testing
SystemPrompts: make(map[string]string),
TelegramToken: "YOUR_TELEGRAM_BOT_TOKEN",
OwnerTelegramID: 123456789,
}
// Initialize the Bot with mock data and MockClock
@@ -79,6 +87,102 @@ func TestCheckRateLimits(t *testing.T) {
}
}
func TestOwnerAssignment(t *testing.T) {
// Initialize in-memory database for testing
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("Failed to open in-memory database: %v", err)
}
// Migrate the schema
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{})
if err != nil {
t.Fatalf("Failed to migrate database schema: %v", err)
}
// Create default roles
err = createDefaultRoles(db)
if err != nil {
t.Fatalf("Failed to create default roles: %v", err)
}
// Create a bot configuration
config := BotConfig{
ID: "test_bot",
TelegramToken: "TEST_TELEGRAM_TOKEN",
MemorySize: 10,
MessagePerHour: 5,
MessagePerDay: 10,
TempBanDuration: "1m",
SystemPrompts: make(map[string]string),
Active: true,
OwnerTelegramID: 111111111,
}
// Initialize MockClock
mockClock := &MockClock{
currentTime: time.Now(),
}
// Initialize MockTelegramClient
mockTGClient := &MockTelegramClient{
SendMessageFunc: func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) {
chatID, ok := params.ChatID.(int64)
if !ok {
return nil, fmt.Errorf("ChatID is not of type int64")
}
// Simulate successful message sending
return &models.Message{ID: 1, Chat: models.Chat{ID: chatID}}, nil
},
}
// Create the bot with the mock Telegram client
bot, err := NewBot(db, config, mockClock, mockTGClient)
if err != nil {
t.Fatalf("Failed to create bot: %v", err)
}
// Verify that the owner exists
var owner User
err = db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", config.OwnerTelegramID, bot.botID, true).First(&owner).Error
if err != nil {
t.Fatalf("Owner was not created: %v", err)
}
// Attempt to create another owner for the same bot
_, err = bot.getOrCreateUser(222222222, "AnotherOwner", true)
if err == nil {
t.Fatalf("Expected error when creating a second owner, but got none")
}
// Verify that the error message is appropriate
expectedErrorMsg := "an owner already exists for this bot"
if err.Error() != expectedErrorMsg {
t.Fatalf("Unexpected error message: %v", err)
}
// Assign admin role to a new user
adminUser, err := bot.getOrCreateUser(333333333, "AdminUser", false)
if err != nil {
t.Fatalf("Failed to create admin user: %v", err)
}
if adminUser.Role.Name != "admin" {
t.Fatalf("Expected role 'admin', got '%s'", adminUser.Role.Name)
}
// Attempt to change an existing user to owner
_, err = bot.getOrCreateUser(333333333, "AdminUser", true)
if err == nil {
t.Fatalf("Expected error when changing existing user to owner, but got none")
}
expectedErrorMsg = "cannot change existing user to owner"
if err.Error() != expectedErrorMsg {
t.Fatalf("Unexpected error message: %v", err)
}
}
// To ensure thread safety and avoid race conditions during testing,
// you can run the tests with the `-race` flag:
// go test -race -v

16
telegram_client.go Normal file
View File

@@ -0,0 +1,16 @@
// telegram_client.go
package main
import (
"context"
"github.com/go-telegram/bot"
"github.com/go-telegram/bot/models"
)
// TelegramClient defines the methods required from the Telegram bot.
type TelegramClient interface {
SendMessage(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error)
Start(ctx context.Context)
// Add other methods if needed.
}

35
telegram_client_mock.go Normal file
View File

@@ -0,0 +1,35 @@
// telegram_client_mock.go
package main
import (
"context"
"github.com/go-telegram/bot"
"github.com/go-telegram/bot/models"
)
// MockTelegramClient is a mock implementation of TelegramClient for testing.
type MockTelegramClient struct {
// You can add fields to keep track of calls if needed.
SendMessageFunc func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error)
StartFunc func(ctx context.Context) // Optional: track Start calls
}
// SendMessage mocks sending a message.
func (m *MockTelegramClient) SendMessage(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) {
if m.SendMessageFunc != nil {
return m.SendMessageFunc(ctx, params)
}
// Default behavior: return an empty message without error.
return &models.Message{}, nil
}
// Start mocks starting the Telegram client.
func (m *MockTelegramClient) Start(ctx context.Context) {
if m.StartFunc != nil {
m.StartFunc(ctx)
}
// Default behavior: do nothing.
}
// Add other mocked methods if your Bot uses more TelegramClient methods.