mirror of
https://github.com/HugeFrog24/go-telegram-bot.git
synced 2026-03-02 00:14:34 +00:00
MVP
md formatting doesnt work yet Started implementing owner feature Add .gitattributes to enforce LF line endings Temporary commit before merge Updated owner management Updated json and gitignore Proceed with role management Again, CI Fix some lint errors Implemented screening Per-bot API keys implemented Use getRoleByName func Fix unused imports Upgrade actions rm unused function Upgrade action Fix unaddressed errors
This commit is contained in:
13
.gitattributes
vendored
Normal file
13
.gitattributes
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Enforce LF line endings for all files
|
||||||
|
* text eol=lf
|
||||||
|
|
||||||
|
# Specific file types that should always have LF line endings
|
||||||
|
*.go text eol=lf
|
||||||
|
*.json text eol=lf
|
||||||
|
*.sh text eol=lf
|
||||||
|
*.md text eol=lf
|
||||||
|
|
||||||
|
# Example: Binary files should not be modified
|
||||||
|
*.jpg binary
|
||||||
|
*.png binary
|
||||||
|
*.gif binary
|
||||||
54
.github/workflows/go-ci.yaml
vendored
Normal file
54
.github/workflows/go-ci.yaml
vendored
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
# Checkout the repository
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
# Set up Go environment
|
||||||
|
- name: Set up Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: '1.23' # Specify the Go version you are using
|
||||||
|
|
||||||
|
# Cache Go modules
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.cache/go-build
|
||||||
|
~/go/pkg/mod
|
||||||
|
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-go-
|
||||||
|
|
||||||
|
# Install Dependencies
|
||||||
|
- name: Install Dependencies
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
# Run Linters using golangci-lint
|
||||||
|
- name: Lint Code
|
||||||
|
uses: golangci/golangci-lint-action@v6
|
||||||
|
with:
|
||||||
|
version: v1.60 # Specify the version of golangci-lint
|
||||||
|
args: --timeout 5m
|
||||||
|
|
||||||
|
# Run Tests
|
||||||
|
- name: Run Tests
|
||||||
|
run: go test ./... -v
|
||||||
|
|
||||||
|
# Security Analysis using gosec
|
||||||
|
- name: Security Scan
|
||||||
|
uses: securego/gosec@master
|
||||||
|
with:
|
||||||
|
args: ./...
|
||||||
4
.gitignore
vendored
Executable file → Normal file
4
.gitignore
vendored
Executable file → Normal file
@@ -4,8 +4,8 @@ vendor/
|
|||||||
# Environment variables
|
# Environment variables
|
||||||
.env
|
.env
|
||||||
|
|
||||||
# Log file
|
# Any log files
|
||||||
bot.log
|
*.log
|
||||||
|
|
||||||
# Database file
|
# Database file
|
||||||
bot.db
|
bot.db
|
||||||
|
|||||||
0
anthropic.go
Executable file → Normal file
0
anthropic.go
Executable file → Normal file
221
bot.go
Executable file → Normal file
221
bot.go
Executable file → Normal file
@@ -5,7 +5,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -17,7 +16,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Bot struct {
|
type Bot struct {
|
||||||
tgBot *bot.Bot
|
tgBot TelegramClient
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
anthropicClient *anthropic.Client
|
anthropicClient *anthropic.Client
|
||||||
chatMemories map[int64]*ChatMemory
|
chatMemories map[int64]*ChatMemory
|
||||||
@@ -30,7 +29,8 @@ type Bot struct {
|
|||||||
botID uint // Reference to BotModel.ID
|
botID uint // Reference to BotModel.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) {
|
// NewBot initializes and returns a new Bot instance.
|
||||||
|
func NewBot(db *gorm.DB, config BotConfig, clock Clock, tgClient TelegramClient) (*Bot, error) {
|
||||||
// Retrieve or create Bot entry in the database
|
// Retrieve or create Bot entry in the database
|
||||||
var botEntry BotModel
|
var botEntry BotModel
|
||||||
err := db.Where("identifier = ?", config.ID).First(&botEntry).Error
|
err := db.Where("identifier = ?", config.ID).First(&botEntry).Error
|
||||||
@@ -43,7 +43,38 @@ func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
anthropicClient := anthropic.NewClient(os.Getenv("ANTHROPIC_API_KEY"))
|
// Ensure the owner exists in the Users table
|
||||||
|
var owner User
|
||||||
|
err = db.Where("telegram_id = ? AND bot_id = ?", config.OwnerTelegramID, botEntry.ID).First(&owner).Error
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
// Assign the "owner" role
|
||||||
|
var ownerRole Role
|
||||||
|
err := db.Where("name = ?", "owner").First(&ownerRole).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("owner role not found: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
owner = User{
|
||||||
|
BotID: botEntry.ID,
|
||||||
|
TelegramID: config.OwnerTelegramID,
|
||||||
|
Username: "", // Initialize as empty; will be updated upon interaction
|
||||||
|
RoleID: ownerRole.ID,
|
||||||
|
IsOwner: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Create(&owner).Error; err != nil {
|
||||||
|
// If unique constraint is violated, another owner already exists
|
||||||
|
if strings.Contains(err.Error(), "unique index") {
|
||||||
|
return nil, fmt.Errorf("an owner already exists for this bot")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to create owner user: %w", err)
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the per-bot Anthropic API key
|
||||||
|
anthropicClient := anthropic.NewClient(config.AnthropicAPIKey)
|
||||||
|
|
||||||
b := &Bot{
|
b := &Bot{
|
||||||
db: db,
|
db: db,
|
||||||
@@ -54,41 +85,80 @@ func NewBot(db *gorm.DB, config BotConfig, clock Clock) (*Bot, error) {
|
|||||||
userLimiters: make(map[int64]*userLimiter),
|
userLimiters: make(map[int64]*userLimiter),
|
||||||
clock: clock,
|
clock: clock,
|
||||||
botID: botEntry.ID, // Ensure BotModel has ID field
|
botID: botEntry.ID, // Ensure BotModel has ID field
|
||||||
|
tgBot: tgClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
tgBot, err := initTelegramBot(config.TelegramToken, b.handleUpdate)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
b.tgBot = tgBot
|
|
||||||
|
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Start begins the bot's operation.
|
||||||
func (b *Bot) Start(ctx context.Context) {
|
func (b *Bot) Start(ctx context.Context) {
|
||||||
b.tgBot.Start(ctx)
|
b.tgBot.Start(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bot) getOrCreateUser(userID int64, username string) (User, error) {
|
func (b *Bot) getOrCreateUser(userID int64, username string, isOwner bool) (User, error) {
|
||||||
var user User
|
var user User
|
||||||
err := b.db.Preload("Role").Where("telegram_id = ?", userID).First(&user).Error
|
err := b.db.Preload("Role").Where("telegram_id = ? AND bot_id = ?", userID, b.botID).First(&user).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
var defaultRole Role
|
// Check if an owner already exists for this bot
|
||||||
if err := b.db.Where("name = ?", "user").First(&defaultRole).Error; err != nil {
|
if isOwner {
|
||||||
return User{}, err
|
var existingOwner User
|
||||||
|
err := b.db.Where("bot_id = ? AND is_owner = ?", b.botID, true).First(&existingOwner).Error
|
||||||
|
if err == nil {
|
||||||
|
return User{}, fmt.Errorf("an owner already exists for this bot")
|
||||||
|
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return User{}, fmt.Errorf("failed to check existing owner: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
user = User{TelegramID: userID, Username: username, RoleID: defaultRole.ID}
|
|
||||||
|
var role Role
|
||||||
|
var roleName string
|
||||||
|
if isOwner {
|
||||||
|
roleName = "owner"
|
||||||
|
} else {
|
||||||
|
roleName = "user" // Assign "user" role to non-owner users
|
||||||
|
}
|
||||||
|
|
||||||
|
err := b.db.Where("name = ?", roleName).First(&role).Error
|
||||||
|
if err != nil {
|
||||||
|
return User{}, fmt.Errorf("failed to get role: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user = User{
|
||||||
|
BotID: b.botID,
|
||||||
|
TelegramID: userID,
|
||||||
|
Username: username,
|
||||||
|
RoleID: role.ID,
|
||||||
|
Role: role,
|
||||||
|
IsOwner: isOwner,
|
||||||
|
}
|
||||||
|
|
||||||
if err := b.db.Create(&user).Error; err != nil {
|
if err := b.db.Create(&user).Error; err != nil {
|
||||||
return User{}, err
|
// If unique constraint is violated, another owner already exists
|
||||||
|
if strings.Contains(err.Error(), "unique index") {
|
||||||
|
return User{}, fmt.Errorf("an owner already exists for this bot")
|
||||||
|
}
|
||||||
|
return User{}, fmt.Errorf("failed to create user: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return User{}, err
|
return User{}, err
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
if isOwner && !user.IsOwner {
|
||||||
|
return User{}, fmt.Errorf("cannot change existing user to owner")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *Bot) getRoleByName(roleName string) (Role, error) {
|
||||||
|
var role Role
|
||||||
|
err := b.db.Where("name = ?", roleName).First(&role).Error
|
||||||
|
return role, err
|
||||||
|
}
|
||||||
|
|
||||||
func (b *Bot) createMessage(chatID, userID int64, username, userRole, text string, isUser bool) Message {
|
func (b *Bot) createMessage(chatID, userID int64, username, userRole, text string, isUser bool) Message {
|
||||||
message := Message{
|
message := Message{
|
||||||
ChatID: chatID,
|
ChatID: chatID,
|
||||||
@@ -195,17 +265,28 @@ func (b *Bot) isAdminOrOwner(userID int64) bool {
|
|||||||
return user.Role.Name == "admin" || user.Role.Name == "owner"
|
return user.Role.Name == "admin" || user.Role.Name == "owner"
|
||||||
}
|
}
|
||||||
|
|
||||||
func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot *bot.Bot, update *models.Update)) (*bot.Bot, error) {
|
func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot *bot.Bot, update *models.Update)) (TelegramClient, error) {
|
||||||
opts := []bot.Option{
|
opts := []bot.Option{
|
||||||
bot.WithDefaultHandler(handleUpdate),
|
bot.WithDefaultHandler(handleUpdate),
|
||||||
}
|
}
|
||||||
|
|
||||||
return bot.New(token, opts...)
|
tgBot, err := bot.New(token, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tgBot, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendResponse sends a message to the specified chat.
|
|
||||||
// Returns an error if sending the message fails.
|
|
||||||
func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, businessConnectionID string) error {
|
func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, businessConnectionID string) error {
|
||||||
|
// Pass the outgoing message through the centralized screen for storage
|
||||||
|
_, err := b.screenOutgoingMessage(chatID, text, businessConnectionID)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error storing assistant message: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare message parameters
|
||||||
params := &bot.SendMessageParams{
|
params := &bot.SendMessageParams{
|
||||||
ChatID: chatID,
|
ChatID: chatID,
|
||||||
Text: text,
|
Text: text,
|
||||||
@@ -215,7 +296,8 @@ func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, busin
|
|||||||
params.BusinessConnectionID = businessConnectionID
|
params.BusinessConnectionID = businessConnectionID
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := b.tgBot.SendMessage(ctx, params)
|
// Send the message via Telegram client
|
||||||
|
_, err = b.tgBot.SendMessage(ctx, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[%s] [ERROR] Error sending message to chat %d with BusinessConnectionID %s: %v",
|
log.Printf("[%s] [ERROR] Error sending message to chat %d with BusinessConnectionID %s: %v",
|
||||||
b.config.ID, chatID, businessConnectionID, err)
|
b.config.ID, chatID, businessConnectionID, err)
|
||||||
@@ -225,16 +307,29 @@ func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, busin
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sendStats sends the bot statistics to the specified chat.
|
// sendStats sends the bot statistics to the specified chat.
|
||||||
func (b *Bot) sendStats(ctx context.Context, chatID int64, businessConnectionID string) {
|
func (b *Bot) sendStats(ctx context.Context, chatID int64, userID int64, username string, businessConnectionID string) {
|
||||||
totalUsers, totalMessages, err := b.getStats()
|
totalUsers, totalMessages, err := b.getStats()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Error fetching stats: %v\n", err)
|
fmt.Printf("Error fetching stats: %v\n", err)
|
||||||
b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve the stats at this time.", businessConnectionID)
|
if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve the stats at this time.", businessConnectionID); err != nil {
|
||||||
|
log.Printf("Error sending response: %v", err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
statsMessage := fmt.Sprintf("📊 **Bot Statistics:**\n\n- Total Users: %d\n- Total Messages: %d", totalUsers, totalMessages)
|
// Do NOT manually escape hyphens here
|
||||||
b.sendResponse(ctx, chatID, statsMessage, businessConnectionID)
|
statsMessage := fmt.Sprintf(
|
||||||
|
"📊 Bot Statistics:\n\n"+
|
||||||
|
"- Total Users: %d\n"+
|
||||||
|
"- Total Messages: %d",
|
||||||
|
totalUsers,
|
||||||
|
totalMessages,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Send the response through the centralized screen
|
||||||
|
if err := b.sendResponse(ctx, chatID, statsMessage, businessConnectionID); err != nil {
|
||||||
|
log.Printf("Error sending stats message: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getStats retrieves the total number of users and messages from the database.
|
// getStats retrieves the total number of users and messages from the database.
|
||||||
@@ -271,3 +366,77 @@ func isEmoji(r rune) bool {
|
|||||||
(r >= 0x2600 && r <= 0x26FF) || // Misc symbols
|
(r >= 0x2600 && r <= 0x26FF) || // Misc symbols
|
||||||
(r >= 0x2700 && r <= 0x27BF) // Dingbats
|
(r >= 0x2700 && r <= 0x27BF) // Dingbats
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *Bot) sendWhoAmI(ctx context.Context, chatID int64, userID int64, username string, businessConnectionID string) {
|
||||||
|
user, err := b.getOrCreateUser(userID, username, false)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error getting or creating user: %v", err)
|
||||||
|
if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve your information.", businessConnectionID); err != nil {
|
||||||
|
log.Printf("Error sending response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
role, err := b.getRoleByName(user.Role.Name)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error getting role by name: %v", err)
|
||||||
|
if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve your role information.", businessConnectionID); err != nil {
|
||||||
|
log.Printf("Error sending response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
whoAmIMessage := fmt.Sprintf(
|
||||||
|
"👤 Your Information:\n\n"+
|
||||||
|
"- Username: %s\n"+
|
||||||
|
"- Role: %s",
|
||||||
|
user.Username,
|
||||||
|
role.Name,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Send the response through the centralized screen
|
||||||
|
if err := b.sendResponse(ctx, chatID, whoAmIMessage, businessConnectionID); err != nil {
|
||||||
|
log.Printf("Error sending /whoami message: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// screenIncomingMessage handles storing of incoming messages.
|
||||||
|
func (b *Bot) screenIncomingMessage(message *models.Message) (Message, error) {
|
||||||
|
userRole := string(anthropic.RoleUser) // Convert RoleUser to string
|
||||||
|
userMessage := b.createMessage(message.Chat.ID, message.From.ID, message.From.Username, userRole, message.Text, true)
|
||||||
|
|
||||||
|
// If the message contains a sticker, include its details.
|
||||||
|
if message.Sticker != nil {
|
||||||
|
userMessage.StickerFileID = message.Sticker.FileID
|
||||||
|
if message.Sticker.Thumbnail != nil {
|
||||||
|
userMessage.StickerPNGFile = message.Sticker.Thumbnail.FileID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the message.
|
||||||
|
if err := b.storeMessage(userMessage); err != nil {
|
||||||
|
return Message{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update chat memory.
|
||||||
|
chatMemory := b.getOrCreateChatMemory(message.Chat.ID)
|
||||||
|
b.addMessageToChatMemory(chatMemory, userMessage)
|
||||||
|
|
||||||
|
return userMessage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// screenOutgoingMessage handles storing of outgoing messages.
|
||||||
|
func (b *Bot) screenOutgoingMessage(chatID int64, response string, businessConnectionID string) (Message, error) {
|
||||||
|
assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false)
|
||||||
|
|
||||||
|
// Store the message.
|
||||||
|
if err := b.storeMessage(assistantMessage); err != nil {
|
||||||
|
return Message{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update chat memory.
|
||||||
|
chatMemory := b.getOrCreateChatMemory(chatID)
|
||||||
|
b.addMessageToChatMemory(chatMemory, assistantMessage)
|
||||||
|
|
||||||
|
return assistantMessage, nil
|
||||||
|
}
|
||||||
|
|||||||
9
config.go
Executable file → Normal file
9
config.go
Executable file → Normal file
@@ -18,6 +18,9 @@ type BotConfig struct {
|
|||||||
TempBanDuration string `json:"temp_ban_duration"`
|
TempBanDuration string `json:"temp_ban_duration"`
|
||||||
Model anthropic.Model `json:"model"` // Changed from string to anthropic.Model
|
Model anthropic.Model `json:"model"` // Changed from string to anthropic.Model
|
||||||
SystemPrompts map[string]string `json:"system_prompts"`
|
SystemPrompts map[string]string `json:"system_prompts"`
|
||||||
|
Active bool `json:"active"` // New field to control bot activity
|
||||||
|
OwnerTelegramID int64 `json:"owner_telegram_id"`
|
||||||
|
AnthropicAPIKey string `json:"anthropic_api_key"` // Add this line
|
||||||
}
|
}
|
||||||
|
|
||||||
// Custom unmarshalling to handle anthropic.Model
|
// Custom unmarshalling to handle anthropic.Model
|
||||||
@@ -56,6 +59,12 @@ func loadAllConfigs(dir string) ([]BotConfig, error) {
|
|||||||
return nil, fmt.Errorf("failed to load config %s: %w", configPath, err)
|
return nil, fmt.Errorf("failed to load config %s: %w", configPath, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip inactive bots
|
||||||
|
if !config.Active {
|
||||||
|
fmt.Printf("Skipping inactive bot: %s\n", config.ID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Validate that ID is present
|
// Validate that ID is present
|
||||||
if config.ID == "" {
|
if config.ID == "" {
|
||||||
return nil, fmt.Errorf("config %s is missing 'id' field", configPath)
|
return nil, fmt.Errorf("config %s is missing 'id' field", configPath)
|
||||||
|
|||||||
3
config/default.json
Executable file → Normal file
3
config/default.json
Executable file → Normal file
@@ -1,6 +1,9 @@
|
|||||||
{
|
{
|
||||||
"id": "default_bot",
|
"id": "default_bot",
|
||||||
|
"active": false,
|
||||||
"telegram_token": "YOUR_TELEGRAM_BOT_TOKEN",
|
"telegram_token": "YOUR_TELEGRAM_BOT_TOKEN",
|
||||||
|
"owner_telegram_id": 111111111,
|
||||||
|
"anthropic_api_key": "YOUR_SPECIFIC_ANTHROPIC_API_KEY",
|
||||||
"memory_size": 10,
|
"memory_size": 10,
|
||||||
"messages_per_hour": 20,
|
"messages_per_hour": 20,
|
||||||
"messages_per_day": 100,
|
"messages_per_day": 100,
|
||||||
|
|||||||
11
database.go
Executable file → Normal file
11
database.go
Executable file → Normal file
@@ -27,11 +27,22 @@ func initDB() (*gorm.DB, error) {
|
|||||||
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AutoMigrate the models
|
||||||
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{})
|
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to migrate database schema: %w", err)
|
return nil, fmt.Errorf("failed to migrate database schema: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Enforce unique owner per bot using raw SQL
|
||||||
|
// Note: SQLite doesn't support partial indexes, but we can simulate it by making a unique index on (BotID, IsOwner)
|
||||||
|
// and ensuring that IsOwner can only be true for one user per BotID.
|
||||||
|
// This approach allows multiple users with IsOwner=false for the same BotID,
|
||||||
|
// but only one user can have IsOwner=true per BotID.
|
||||||
|
err = db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_bot_owner ON users (bot_id, is_owner) WHERE is_owner = 1;`).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create unique index for bot owners: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
err = createDefaultRoles(db)
|
err = createDefaultRoles(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
4
go.mod
Executable file → Normal file
4
go.mod
Executable file → Normal file
@@ -1,6 +1,6 @@
|
|||||||
module github.com/HugeFrog24/thatsky-telegram-bot
|
module github.com/HugeFrog24/go-telegram-bot
|
||||||
|
|
||||||
go 1.23.2
|
go 1.23
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/go-telegram/bot v1.9.0
|
github.com/go-telegram/bot v1.9.0
|
||||||
|
|||||||
86
handlers.go
Executable file → Normal file
86
handlers.go
Executable file → Normal file
@@ -22,9 +22,6 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
chatID := message.Chat.ID
|
|
||||||
userID := message.From.ID
|
|
||||||
|
|
||||||
// Extract businessConnectionID if available
|
// Extract businessConnectionID if available
|
||||||
var businessConnectionID string
|
var businessConnectionID string
|
||||||
if update.BusinessConnection != nil {
|
if update.BusinessConnection != nil {
|
||||||
@@ -33,6 +30,18 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
|
|||||||
businessConnectionID = message.BusinessConnectionID
|
businessConnectionID = message.BusinessConnectionID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
chatID := message.Chat.ID
|
||||||
|
userID := message.From.ID
|
||||||
|
username := message.From.Username
|
||||||
|
text := message.Text
|
||||||
|
|
||||||
|
// Pass the incoming message through the centralized screen for storage
|
||||||
|
_, err := b.screenIncomingMessage(message)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error storing user message: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Check if the message is a command
|
// Check if the message is a command
|
||||||
if message.Entities != nil {
|
if message.Entities != nil {
|
||||||
for _, entity := range message.Entities {
|
for _, entity := range message.Entities {
|
||||||
@@ -40,7 +49,10 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
|
|||||||
command := strings.TrimSpace(message.Text[entity.Offset : entity.Offset+entity.Length])
|
command := strings.TrimSpace(message.Text[entity.Offset : entity.Offset+entity.Length])
|
||||||
switch command {
|
switch command {
|
||||||
case "/stats":
|
case "/stats":
|
||||||
b.sendStats(ctx, chatID, businessConnectionID)
|
b.sendStats(ctx, chatID, userID, username, businessConnectionID)
|
||||||
|
return
|
||||||
|
case "/whoami":
|
||||||
|
b.sendWhoAmI(ctx, chatID, userID, username, businessConnectionID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -53,59 +65,71 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Existing rate limit and message handling
|
// Rate limit check
|
||||||
if !b.checkRateLimits(userID) {
|
if !b.checkRateLimits(userID) {
|
||||||
b.sendRateLimitExceededMessage(ctx, chatID, businessConnectionID)
|
b.sendRateLimitExceededMessage(ctx, chatID, businessConnectionID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
username := message.From.Username
|
|
||||||
text := message.Text
|
|
||||||
|
|
||||||
// Proceed only if the message contains text
|
// Proceed only if the message contains text
|
||||||
if text == "" {
|
if text == "" {
|
||||||
// Optionally, handle other message types or ignore
|
|
||||||
log.Printf("Received a non-text message from user %d in chat %d", userID, chatID)
|
log.Printf("Received a non-text message from user %d in chat %d", userID, chatID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := b.getOrCreateUser(userID, username)
|
// Determine if the user is the owner
|
||||||
|
var isOwner bool
|
||||||
|
err = b.db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", userID, b.botID, true).First(&User{}).Error
|
||||||
|
if err == nil {
|
||||||
|
isOwner = true
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := b.getOrCreateUser(userID, username, isOwner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error getting or creating user: %v", err)
|
log.Printf("Error getting or creating user: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userMessage := b.createMessage(chatID, userID, username, user.Role.Name, text, true)
|
// Update the username if it's empty or has changed
|
||||||
userMessage.UserRole = string(anthropic.RoleUser) // Convert to string
|
if user.Username != username {
|
||||||
b.storeMessage(userMessage)
|
user.Username = username
|
||||||
|
if err := b.db.Save(&user).Error; err != nil {
|
||||||
|
log.Printf("Error updating user username: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine if the text contains only emojis
|
||||||
|
isEmojiOnly := isOnlyEmojis(text)
|
||||||
|
|
||||||
|
// Prepare context messages for Anthropic
|
||||||
chatMemory := b.getOrCreateChatMemory(chatID)
|
chatMemory := b.getOrCreateChatMemory(chatID)
|
||||||
b.addMessageToChatMemory(chatMemory, userMessage)
|
b.addMessageToChatMemory(chatMemory, b.createMessage(chatID, userID, username, user.Role.Name, text, true))
|
||||||
|
|
||||||
contextMessages := b.prepareContextMessages(chatMemory)
|
contextMessages := b.prepareContextMessages(chatMemory)
|
||||||
|
|
||||||
isEmojiOnly := isOnlyEmojis(text) // Ensure you have this variable defined
|
// Get response from Anthropic
|
||||||
response, err := b.getAnthropicResponse(ctx, contextMessages, b.isNewChat(chatID), b.isAdminOrOwner(userID), isEmojiOnly)
|
response, err := b.getAnthropicResponse(ctx, contextMessages, b.isNewChat(chatID), isOwner, isEmojiOnly)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error getting Anthropic response: %v", err)
|
log.Printf("Error getting Anthropic response: %v", err)
|
||||||
response = "I'm sorry, I'm having trouble processing your request right now."
|
response = "I'm sorry, I'm having trouble processing your request right now."
|
||||||
}
|
}
|
||||||
|
|
||||||
b.sendResponse(ctx, chatID, response, businessConnectionID)
|
// Send the response through the centralized screen
|
||||||
|
if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil {
|
||||||
assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false)
|
log.Printf("Error sending response: %v", err)
|
||||||
b.storeMessage(assistantMessage)
|
return
|
||||||
b.addMessageToChatMemory(chatMemory, assistantMessage)
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bot) sendRateLimitExceededMessage(ctx context.Context, chatID int64, businessConnectionID string) {
|
func (b *Bot) sendRateLimitExceededMessage(ctx context.Context, chatID int64, businessConnectionID string) {
|
||||||
b.sendResponse(ctx, chatID, "Rate limit exceeded. Please try again later.", businessConnectionID)
|
if err := b.sendResponse(ctx, chatID, "Rate limit exceeded. Please try again later.", businessConnectionID); err != nil {
|
||||||
|
log.Printf("Error sending rate limit exceeded message: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, message *models.Message, businessConnectionID string) {
|
func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, message *models.Message, businessConnectionID string) {
|
||||||
username := message.From.Username
|
username := message.From.Username
|
||||||
|
|
||||||
// Create and store the sticker message
|
// Create the user message (without storing it manually)
|
||||||
userMessage := b.createMessage(chatID, userID, username, "user", "Sent a sticker.", true)
|
userMessage := b.createMessage(chatID, userID, username, "user", "Sent a sticker.", true)
|
||||||
userMessage.StickerFileID = message.Sticker.FileID
|
userMessage.StickerFileID = message.Sticker.FileID
|
||||||
|
|
||||||
@@ -114,9 +138,7 @@ func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, me
|
|||||||
userMessage.StickerPNGFile = message.Sticker.Thumbnail.FileID
|
userMessage.StickerPNGFile = message.Sticker.Thumbnail.FileID
|
||||||
}
|
}
|
||||||
|
|
||||||
b.storeMessage(userMessage)
|
// Update chat memory with the user message
|
||||||
|
|
||||||
// Update chat memory
|
|
||||||
chatMemory := b.getOrCreateChatMemory(chatID)
|
chatMemory := b.getOrCreateChatMemory(chatID)
|
||||||
b.addMessageToChatMemory(chatMemory, userMessage)
|
b.addMessageToChatMemory(chatMemory, userMessage)
|
||||||
|
|
||||||
@@ -134,11 +156,11 @@ func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, me
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
b.sendResponse(ctx, chatID, response, businessConnectionID)
|
// Send the response through the centralized screen
|
||||||
|
if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil {
|
||||||
assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false)
|
log.Printf("Error sending response: %v", err)
|
||||||
b.storeMessage(assistantMessage)
|
return
|
||||||
b.addMessageToChatMemory(chatMemory, assistantMessage)
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bot) generateStickerResponse(ctx context.Context, message Message) (string, error) {
|
func (b *Bot) generateStickerResponse(ctx context.Context, message Message) (string, error) {
|
||||||
|
|||||||
26
main.go
Executable file → Normal file
26
main.go
Executable file → Normal file
@@ -24,9 +24,6 @@ func main() {
|
|||||||
log.Printf("Error loading .env file: %v", err)
|
log.Printf("Error loading .env file: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for required environment variables
|
|
||||||
checkRequiredEnvVars()
|
|
||||||
|
|
||||||
// Initialize database
|
// Initialize database
|
||||||
db, err := initDB()
|
db, err := initDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -52,14 +49,24 @@ func main() {
|
|||||||
go func(cfg BotConfig) {
|
go func(cfg BotConfig) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
// Create Bot instance with RealClock
|
// Create Bot instance without TelegramClient initially
|
||||||
realClock := RealClock{}
|
realClock := RealClock{}
|
||||||
bot, err := NewBot(db, cfg, realClock)
|
bot, err := NewBot(db, cfg, realClock, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error creating bot %s: %v", cfg.ID, err)
|
log.Printf("Error creating bot %s: %v", cfg.ID, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize TelegramClient with the bot's handleUpdate method
|
||||||
|
tgClient, err := initTelegramBot(cfg.TelegramToken, bot.handleUpdate)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error initializing Telegram client for bot %s: %v", cfg.ID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assign the TelegramClient to the bot
|
||||||
|
bot.tgBot = tgClient
|
||||||
|
|
||||||
// Start the bot
|
// Start the bot
|
||||||
log.Printf("Starting bot %s...", cfg.ID)
|
log.Printf("Starting bot %s...", cfg.ID)
|
||||||
bot.Start(ctx)
|
bot.Start(ctx)
|
||||||
@@ -79,12 +86,3 @@ func initLogger() (*os.File, error) {
|
|||||||
log.SetOutput(mw)
|
log.SetOutput(mw)
|
||||||
return logFile, nil
|
return logFile, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkRequiredEnvVars() {
|
|
||||||
requiredEnvVars := []string{"ANTHROPIC_API_KEY"}
|
|
||||||
for _, envVar := range requiredEnvVars {
|
|
||||||
if os.Getenv(envVar) == "" {
|
|
||||||
log.Fatalf("%s environment variable is not set", envVar)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
13
models.go
Executable file → Normal file
13
models.go
Executable file → Normal file
@@ -11,7 +11,7 @@ type BotModel struct {
|
|||||||
Identifier string `gorm:"uniqueIndex"` // Renamed from ID to Identifier
|
Identifier string `gorm:"uniqueIndex"` // Renamed from ID to Identifier
|
||||||
Name string
|
Name string
|
||||||
Configs []ConfigModel `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"`
|
Configs []ConfigModel `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"`
|
||||||
Users []User `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"` // Added foreign key
|
Users []User `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"` // Associated users
|
||||||
Messages []Message `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"`
|
Messages []Message `gorm:"foreignKey:BotID;constraint:OnDelete:CASCADE"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -24,6 +24,7 @@ type ConfigModel struct {
|
|||||||
TempBanDuration string `json:"temp_ban_duration"`
|
TempBanDuration string `json:"temp_ban_duration"`
|
||||||
SystemPrompts string `json:"system_prompts"` // Consider JSON string or separate table
|
SystemPrompts string `json:"system_prompts"` // Consider JSON string or separate table
|
||||||
TelegramToken string `json:"telegram_token"`
|
TelegramToken string `json:"telegram_token"`
|
||||||
|
Active bool `json:"active"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
@@ -53,9 +54,15 @@ type Role struct {
|
|||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
BotID uint `gorm:"index"` // Added foreign key to BotModel
|
BotID uint `gorm:"index"` // Foreign key to BotModel
|
||||||
TelegramID int64 `gorm:"uniqueIndex"` // Consider composite unique index if TelegramID is unique per Bot
|
TelegramID int64 `gorm:"uniqueIndex;not null"` // Unique per user
|
||||||
Username string
|
Username string
|
||||||
RoleID uint
|
RoleID uint
|
||||||
Role Role `gorm:"foreignKey:RoleID"`
|
Role Role `gorm:"foreignKey:RoleID"`
|
||||||
|
IsOwner bool `gorm:"default:false"` // Indicates if the user is the owner
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compound unique index to ensure only one owner per bot
|
||||||
|
func (User) TableName() string {
|
||||||
|
return "users"
|
||||||
}
|
}
|
||||||
|
|||||||
0
rate_limiter.go
Executable file → Normal file
0
rate_limiter.go
Executable file → Normal file
104
rate_limiter_test.go
Executable file → Normal file
104
rate_limiter_test.go
Executable file → Normal file
@@ -1,8 +1,15 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-telegram/bot"
|
||||||
|
"github.com/go-telegram/bot/models"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestCheckRateLimits tests the checkRateLimits method of the Bot.
|
// TestCheckRateLimits tests the checkRateLimits method of the Bot.
|
||||||
@@ -22,6 +29,7 @@ func TestCheckRateLimits(t *testing.T) {
|
|||||||
TempBanDuration: "1m", // Temporary ban duration of 1 minute for testing
|
TempBanDuration: "1m", // Temporary ban duration of 1 minute for testing
|
||||||
SystemPrompts: make(map[string]string),
|
SystemPrompts: make(map[string]string),
|
||||||
TelegramToken: "YOUR_TELEGRAM_BOT_TOKEN",
|
TelegramToken: "YOUR_TELEGRAM_BOT_TOKEN",
|
||||||
|
OwnerTelegramID: 123456789,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the Bot with mock data and MockClock
|
// Initialize the Bot with mock data and MockClock
|
||||||
@@ -79,6 +87,102 @@ func TestCheckRateLimits(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOwnerAssignment(t *testing.T) {
|
||||||
|
// Initialize in-memory database for testing
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open in-memory database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Migrate the schema
|
||||||
|
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to migrate database schema: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create default roles
|
||||||
|
err = createDefaultRoles(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create default roles: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a bot configuration
|
||||||
|
config := BotConfig{
|
||||||
|
ID: "test_bot",
|
||||||
|
TelegramToken: "TEST_TELEGRAM_TOKEN",
|
||||||
|
MemorySize: 10,
|
||||||
|
MessagePerHour: 5,
|
||||||
|
MessagePerDay: 10,
|
||||||
|
TempBanDuration: "1m",
|
||||||
|
SystemPrompts: make(map[string]string),
|
||||||
|
Active: true,
|
||||||
|
OwnerTelegramID: 111111111,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize MockClock
|
||||||
|
mockClock := &MockClock{
|
||||||
|
currentTime: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize MockTelegramClient
|
||||||
|
mockTGClient := &MockTelegramClient{
|
||||||
|
SendMessageFunc: func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) {
|
||||||
|
chatID, ok := params.ChatID.(int64)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("ChatID is not of type int64")
|
||||||
|
}
|
||||||
|
// Simulate successful message sending
|
||||||
|
return &models.Message{ID: 1, Chat: models.Chat{ID: chatID}}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the bot with the mock Telegram client
|
||||||
|
bot, err := NewBot(db, config, mockClock, mockTGClient)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create bot: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the owner exists
|
||||||
|
var owner User
|
||||||
|
err = db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", config.OwnerTelegramID, bot.botID, true).First(&owner).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Owner was not created: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to create another owner for the same bot
|
||||||
|
_, err = bot.getOrCreateUser(222222222, "AnotherOwner", true)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Expected error when creating a second owner, but got none")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the error message is appropriate
|
||||||
|
expectedErrorMsg := "an owner already exists for this bot"
|
||||||
|
if err.Error() != expectedErrorMsg {
|
||||||
|
t.Fatalf("Unexpected error message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assign admin role to a new user
|
||||||
|
adminUser, err := bot.getOrCreateUser(333333333, "AdminUser", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create admin user: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if adminUser.Role.Name != "admin" {
|
||||||
|
t.Fatalf("Expected role 'admin', got '%s'", adminUser.Role.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to change an existing user to owner
|
||||||
|
_, err = bot.getOrCreateUser(333333333, "AdminUser", true)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Expected error when changing existing user to owner, but got none")
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedErrorMsg = "cannot change existing user to owner"
|
||||||
|
if err.Error() != expectedErrorMsg {
|
||||||
|
t.Fatalf("Unexpected error message: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// To ensure thread safety and avoid race conditions during testing,
|
// To ensure thread safety and avoid race conditions during testing,
|
||||||
// you can run the tests with the `-race` flag:
|
// you can run the tests with the `-race` flag:
|
||||||
// go test -race -v
|
// go test -race -v
|
||||||
|
|||||||
16
telegram_client.go
Normal file
16
telegram_client.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
// telegram_client.go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/go-telegram/bot"
|
||||||
|
"github.com/go-telegram/bot/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TelegramClient defines the methods required from the Telegram bot.
|
||||||
|
type TelegramClient interface {
|
||||||
|
SendMessage(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error)
|
||||||
|
Start(ctx context.Context)
|
||||||
|
// Add other methods if needed.
|
||||||
|
}
|
||||||
35
telegram_client_mock.go
Normal file
35
telegram_client_mock.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
// telegram_client_mock.go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/go-telegram/bot"
|
||||||
|
"github.com/go-telegram/bot/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockTelegramClient is a mock implementation of TelegramClient for testing.
|
||||||
|
type MockTelegramClient struct {
|
||||||
|
// You can add fields to keep track of calls if needed.
|
||||||
|
SendMessageFunc func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error)
|
||||||
|
StartFunc func(ctx context.Context) // Optional: track Start calls
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendMessage mocks sending a message.
|
||||||
|
func (m *MockTelegramClient) SendMessage(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) {
|
||||||
|
if m.SendMessageFunc != nil {
|
||||||
|
return m.SendMessageFunc(ctx, params)
|
||||||
|
}
|
||||||
|
// Default behavior: return an empty message without error.
|
||||||
|
return &models.Message{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start mocks starting the Telegram client.
|
||||||
|
func (m *MockTelegramClient) Start(ctx context.Context) {
|
||||||
|
if m.StartFunc != nil {
|
||||||
|
m.StartFunc(ctx)
|
||||||
|
}
|
||||||
|
// Default behavior: do nothing.
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add other mocked methods if your Bot uses more TelegramClient methods.
|
||||||
Reference in New Issue
Block a user