mirror of
https://github.com/HugeFrog24/go-telegram-bot.git
synced 2026-05-01 07:42:18 +00:00
Compare commits
3 Commits
v1.0.0
..
e1a9261699
| Author | SHA1 | Date | |
|---|---|---|---|
| e1a9261699 | |||
| 6e2d2fce2f | |||
| 37d6242c06 |
@@ -123,7 +123,6 @@ journalctl -u telegram-bot -f
|
||||
| `/clear_hard` | All users | Permanently delete your own chat history |
|
||||
| `/clear_hard <user_id>` | Admin/Owner | Permanently delete all messages for a user across every chat |
|
||||
| `/clear_hard <user_id> <chat_id>` | Admin/Owner | Permanently delete a user's messages in a specific chat |
|
||||
| `/set_model <model-id>` | Admin/Owner | Switch the AI model live without restarting |
|
||||
|
||||
> **Note:** In private DMs each user's `chat_id` equals their `user_id`. The scoped `<chat_id>` form is mainly useful for group chat moderation.
|
||||
|
||||
|
||||
+2
-12
@@ -2,7 +2,6 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -10,12 +9,7 @@ import (
|
||||
"github.com/liushuangls/go-anthropic/v2"
|
||||
)
|
||||
|
||||
// ErrModelNotFound is returned when the configured Anthropic model is no longer available
|
||||
// (deprecated or removed). Callers can use errors.Is to detect this and surface an
|
||||
// actionable message to admins/owners while keeping the response vague for regular users.
|
||||
var ErrModelNotFound = errors.New("model not found or deprecated")
|
||||
|
||||
func (b *Bot) getAnthropicResponse(ctx context.Context, messages []anthropic.Message, isNewChat, isOwner, isEmojiOnly bool, username string, firstName string, lastName string, isPremium bool, languageCode string, messageTime int) (string, error) {
|
||||
func (b *Bot) getAnthropicResponse(ctx context.Context, messages []anthropic.Message, isNewChat, isAdminOrOwner, isEmojiOnly bool, username string, firstName string, lastName string, isPremium bool, languageCode string, messageTime int) (string, error) {
|
||||
// Use prompts from config
|
||||
var systemMessage string
|
||||
if isNewChat {
|
||||
@@ -77,7 +71,7 @@ func (b *Bot) getAnthropicResponse(ctx context.Context, messages []anthropic.Mes
|
||||
}
|
||||
systemMessage = strings.ReplaceAll(systemMessage, "{time_context}", timeContext)
|
||||
|
||||
if !isOwner {
|
||||
if !isAdminOrOwner {
|
||||
systemMessage += " " + b.config.SystemPrompts["avoid_sensitive"]
|
||||
}
|
||||
|
||||
@@ -125,10 +119,6 @@ func (b *Bot) getAnthropicResponse(ctx context.Context, messages []anthropic.Mes
|
||||
|
||||
resp, err := b.anthropicClient.CreateMessages(ctx, request)
|
||||
if err != nil {
|
||||
var apiErr *anthropic.APIError
|
||||
if errors.As(err, &apiErr) && apiErr.IsNotFoundErr() {
|
||||
return "", fmt.Errorf("%w: %s", ErrModelNotFound, b.config.Model)
|
||||
}
|
||||
return "", fmt.Errorf("error creating Anthropic message: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -221,18 +221,13 @@ func (b *Bot) getOrCreateChatMemory(chatID int64) *ChatMemory {
|
||||
if !isNewChat {
|
||||
// Fetch existing messages only if it's not a new chat
|
||||
err := b.db.Where("chat_id = ? AND bot_id = ?", chatID, b.botID).
|
||||
Order("timestamp desc").
|
||||
Order("timestamp asc").
|
||||
Limit(b.memorySize * 2).
|
||||
Find(&messages).Error
|
||||
|
||||
if err != nil {
|
||||
ErrorLogger.Printf("Error fetching messages from database: %v", err)
|
||||
messages = []Message{} // Initialize an empty slice on error
|
||||
} else {
|
||||
// Reverse from newest-first to chronological order for conversation context.
|
||||
for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 {
|
||||
messages[i], messages[j] = messages[j], messages[i]
|
||||
}
|
||||
}
|
||||
} else {
|
||||
messages = []Message{} // Ensure messages is initialized for new chats
|
||||
@@ -308,82 +303,13 @@ func (b *Bot) isNewChat(chatID int64) bool {
|
||||
return count == 0 // Only consider a chat new if it has 0 messages
|
||||
}
|
||||
|
||||
// roleHasScope reports whether role (with pre-loaded Scopes) contains the given scope name.
|
||||
func roleHasScope(role Role, scope string) bool {
|
||||
for _, s := range role.Scopes {
|
||||
if s.Name == scope {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// hasScope reports whether the user identified by userID holds the given scope for this bot.
|
||||
// Owners implicitly hold all scopes regardless of their assigned role.
|
||||
func (b *Bot) hasScope(userID int64, scope string) bool {
|
||||
func (b *Bot) isAdminOrOwner(userID int64) bool {
|
||||
var user User
|
||||
if err := b.db.Preload("Role.Scopes").
|
||||
Where("telegram_id = ? AND bot_id = ?", userID, b.botID).
|
||||
First(&user).Error; err != nil {
|
||||
err := b.db.Preload("Role").Where("telegram_id = ? AND bot_id = ?", userID, b.botID).First(&user).Error
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if user.IsOwner {
|
||||
return true
|
||||
}
|
||||
return roleHasScope(user.Role, scope)
|
||||
}
|
||||
|
||||
// publicBotCommands are shown to every user in the Telegram command palette.
|
||||
var publicBotCommands = []models.BotCommand{
|
||||
{Command: "stats", Description: "Get bot statistics. Usage: /stats or /stats user [user_id]"},
|
||||
{Command: "whoami", Description: "Get your user information"},
|
||||
{Command: "clear", Description: "Clear chat history (soft delete). Admins: /clear [user_id]"},
|
||||
}
|
||||
|
||||
// adminBotCommands are shown only in admin/owner chats via BotCommandScopeChatMember.
|
||||
var adminBotCommands = []models.BotCommand{
|
||||
{Command: "clear_hard", Description: "Clear chat history (permanently delete). Admins: /clear_hard [user_id]"},
|
||||
{Command: "set_model", Description: "Switch the AI model (admin/owner only). Usage: /set_model <model-id>"},
|
||||
}
|
||||
|
||||
// registerAdminCommandsForUser scopes the full command palette to a specific user's private chat.
|
||||
// In Telegram private chats, chat_id == user_id, so both fields carry the same value.
|
||||
// Errors are logged but treated as non-fatal: the user retains access via permission checks.
|
||||
func (b *Bot) registerAdminCommandsForUser(ctx context.Context, telegramID int64) {
|
||||
allCommands := make([]models.BotCommand, 0, len(publicBotCommands)+len(adminBotCommands))
|
||||
allCommands = append(allCommands, publicBotCommands...)
|
||||
allCommands = append(allCommands, adminBotCommands...)
|
||||
_, err := b.tgBot.SetMyCommands(ctx, &bot.SetMyCommandsParams{
|
||||
Commands: allCommands,
|
||||
Scope: &models.BotCommandScopeChatMember{ChatID: telegramID, UserID: telegramID},
|
||||
})
|
||||
if err != nil {
|
||||
ErrorLogger.Printf("Failed to register admin commands for user %d: %v", telegramID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// setElevatedCommands registers the full command palette (public + admin) for every user
|
||||
// whose role carries the model:set scope, or who is the bot owner. Called once at startup
|
||||
// and uses the freshly created tgBot directly (b.tgBot is not yet assigned at that point).
|
||||
func setElevatedCommands(tgBot TelegramClient, users []User) {
|
||||
allCommands := make([]models.BotCommand, 0, len(publicBotCommands)+len(adminBotCommands))
|
||||
allCommands = append(allCommands, publicBotCommands...)
|
||||
allCommands = append(allCommands, adminBotCommands...)
|
||||
for _, u := range users {
|
||||
if u.TelegramID == 0 {
|
||||
continue // skip placeholder users not yet seen in a chat
|
||||
}
|
||||
if !u.IsOwner && !roleHasScope(u.Role, ScopeModelSet) {
|
||||
continue
|
||||
}
|
||||
_, err := tgBot.SetMyCommands(context.Background(), &bot.SetMyCommandsParams{
|
||||
Commands: allCommands,
|
||||
Scope: &models.BotCommandScopeChatMember{ChatID: u.TelegramID, UserID: u.TelegramID},
|
||||
})
|
||||
if err != nil {
|
||||
ErrorLogger.Printf("Warning: could not set admin commands for user %d: %v", u.TelegramID, err)
|
||||
}
|
||||
}
|
||||
return user.Role.Name == "admin" || user.Role.Name == "owner"
|
||||
}
|
||||
|
||||
func initTelegramBot(token string, b *Bot) (TelegramClient, error) {
|
||||
@@ -396,25 +322,33 @@ func initTelegramBot(token string, b *Bot) (TelegramClient, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Register public commands for all users.
|
||||
_, err = tgBot.SetMyCommands(context.Background(), &bot.SetMyCommandsParams{
|
||||
Commands: publicBotCommands,
|
||||
Scope: &models.BotCommandScopeDefault{},
|
||||
})
|
||||
if err != nil {
|
||||
ErrorLogger.Printf("Error setting default bot commands: %v", err)
|
||||
return nil, err
|
||||
// Define bot commands
|
||||
commands := []models.BotCommand{
|
||||
{
|
||||
Command: "stats",
|
||||
Description: "Get bot statistics. Usage: /stats or /stats user [user_id]",
|
||||
},
|
||||
{
|
||||
Command: "whoami",
|
||||
Description: "Get your user information",
|
||||
},
|
||||
{
|
||||
Command: "clear",
|
||||
Description: "Clear chat history (soft delete). Admins: /clear [user_id]",
|
||||
},
|
||||
{
|
||||
Command: "clear_hard",
|
||||
Description: "Clear chat history (permanently delete). Admins: /clear_hard [user_id]",
|
||||
},
|
||||
}
|
||||
|
||||
// Register full command palette (public + admin) scoped to each known elevated user.
|
||||
// BotCommandScopeChatMember targets the user's private DM with the bot (chat_id == user_id).
|
||||
// Elevation is determined by scope rather than role name, so renaming roles requires no code change.
|
||||
// This is best-effort: failures are logged but do not prevent the bot from starting.
|
||||
var allUsers []User
|
||||
if err := b.db.Preload("Role.Scopes").Where("bot_id = ?", b.botID).Find(&allUsers).Error; err != nil {
|
||||
ErrorLogger.Printf("Warning: could not query users for command scoping: %v", err)
|
||||
} else {
|
||||
setElevatedCommands(tgBot, allUsers)
|
||||
// Set bot commands
|
||||
_, err = tgBot.SetMyCommands(context.Background(), &bot.SetMyCommandsParams{
|
||||
Commands: commands,
|
||||
})
|
||||
if err != nil {
|
||||
ErrorLogger.Printf("Error setting bot commands: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tgBot, nil
|
||||
@@ -480,7 +414,7 @@ func (b *Bot) sendStats(ctx context.Context, chatID int64, userID int64, targetU
|
||||
// If targetUserID is not 0, show user-specific stats
|
||||
// Check permissions if the user is trying to view someone else's stats
|
||||
if targetUserID != userID {
|
||||
if !b.hasScope(userID, ScopeStatsViewAny) {
|
||||
if !b.isAdminOrOwner(userID) {
|
||||
InfoLogger.Printf("User %d attempted to view stats for user %d without permission", userID, targetUserID)
|
||||
if err := b.sendResponse(ctx, chatID, "Permission denied. Only admins and owners can view other users' statistics.", businessConnectionID); err != nil {
|
||||
ErrorLogger.Printf("Error sending response: %v", err)
|
||||
@@ -554,7 +488,7 @@ func (b *Bot) getUserStats(userID int64) (string, int64, int64, int64, error) {
|
||||
|
||||
// Count responses to the user (OUT)
|
||||
var messagesOut int64
|
||||
if err := b.db.Model(&Message{}).Where("chat_id IN (SELECT DISTINCT chat_id FROM messages WHERE user_id = ? AND bot_id = ? AND deleted_at IS NULL) AND bot_id = ? AND is_user = ?",
|
||||
if err := b.db.Model(&Message{}).Where("chat_id IN (SELECT DISTINCT chat_id FROM messages WHERE user_id = ? AND bot_id = ?) AND bot_id = ? AND is_user = ?",
|
||||
userID, b.botID, b.botID, false).Count(&messagesOut).Error; err != nil {
|
||||
return "", 0, 0, 0, err
|
||||
}
|
||||
@@ -715,8 +649,8 @@ func (b *Bot) screenOutgoingMessage(chatID int64, response string) (Message, err
|
||||
}
|
||||
|
||||
func (b *Bot) promoteUserToAdmin(promoterID, userToPromoteID int64) error {
|
||||
// Check if the promoter has the user:promote scope
|
||||
if !b.hasScope(promoterID, ScopeUserPromote) {
|
||||
// Check if the promoter is an owner or admin
|
||||
if !b.isAdminOrOwner(promoterID) {
|
||||
return errors.New("only admins or owners can promote users to admin")
|
||||
}
|
||||
|
||||
@@ -735,11 +669,5 @@ func (b *Bot) promoteUserToAdmin(promoterID, userToPromoteID int64) error {
|
||||
// Update the user's role
|
||||
userToPromote.RoleID = adminRole.ID
|
||||
userToPromote.Role = adminRole
|
||||
if err := b.db.Save(&userToPromote).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Surface admin commands in the newly promoted user's private chat.
|
||||
b.registerAdminCommandsForUser(context.Background(), userToPromoteID)
|
||||
return nil
|
||||
return b.db.Save(&userToPromote).Error
|
||||
}
|
||||
|
||||
@@ -24,7 +24,6 @@ type BotConfig struct {
|
||||
OwnerTelegramID int64 `json:"owner_telegram_id"`
|
||||
AnthropicAPIKey string `json:"anthropic_api_key"`
|
||||
DebugScreening bool `json:"debug_screening"` // Enable detailed screening logs
|
||||
ConfigFilePath string `json:"-"` // Set at load time; not serialized
|
||||
}
|
||||
|
||||
// Custom unmarshalling to handle anthropic.Model
|
||||
@@ -109,7 +108,6 @@ func loadAllConfigs(dir string) ([]BotConfig, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
config.ConfigFilePath = validPath
|
||||
configs = append(configs, config)
|
||||
}
|
||||
}
|
||||
@@ -202,35 +200,3 @@ func (c *BotConfig) Reload(configDir, filename string) error {
|
||||
c.Model = anthropic.Model(c.Model)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PersistModel updates the model field in memory and writes it back to the config file on disk.
|
||||
// Only the "model" key is changed; all other fields are preserved verbatim.
|
||||
func (c *BotConfig) PersistModel(newModel string) error {
|
||||
if c.ConfigFilePath == "" {
|
||||
return fmt.Errorf("config file path not set; cannot persist model")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(c.ConfigFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config for update: %w", err)
|
||||
}
|
||||
|
||||
var raw map[string]any
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return fmt.Errorf("failed to parse config for update: %w", err)
|
||||
}
|
||||
|
||||
raw["model"] = newModel
|
||||
|
||||
updated, err := json.MarshalIndent(raw, "", "\t")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to re-encode config: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(c.ConfigFilePath, updated, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write config: %w", err)
|
||||
}
|
||||
|
||||
c.Model = anthropic.Model(newModel)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -752,67 +752,3 @@ func TestTemperatureConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
// Additional tests can be added here to cover more scenarios
|
||||
|
||||
// TestBotConfig_PersistModel verifies that PersistModel updates the model both in memory
|
||||
// and on disk while leaving all other config fields unchanged.
|
||||
func TestBotConfig_PersistModel(t *testing.T) { //NOSONAR go:S100 -- underscore separation is idiomatic in Go test names
|
||||
tempDir, err := os.MkdirTemp("", "persist_model_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := os.RemoveAll(tempDir); err != nil {
|
||||
t.Errorf("Failed to remove temp directory: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
initialJSON := `{
|
||||
"id": "bot1",
|
||||
"telegram_token": "token1",
|
||||
"model": "claude-v1",
|
||||
"messages_per_hour": 10,
|
||||
"messages_per_day": 100
|
||||
}`
|
||||
configPath := filepath.Join(tempDir, "config.json")
|
||||
if err := os.WriteFile(configPath, []byte(initialJSON), 0600); err != nil {
|
||||
t.Fatalf("Failed to write config file: %v", err)
|
||||
}
|
||||
|
||||
config := BotConfig{
|
||||
ID: "bot1",
|
||||
Model: "claude-v1",
|
||||
ConfigFilePath: configPath,
|
||||
}
|
||||
|
||||
// Successful model update
|
||||
if err := config.PersistModel("claude-sonnet-4-6"); err != nil {
|
||||
t.Fatalf("PersistModel() unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// In-memory model must be updated immediately
|
||||
if string(config.Model) != "claude-sonnet-4-6" {
|
||||
t.Errorf("in-memory model: got %q, want %q", config.Model, "claude-sonnet-4-6")
|
||||
}
|
||||
|
||||
// On-disk model must be updated; other fields must be preserved
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read updated config file: %v", err)
|
||||
}
|
||||
var raw map[string]any
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
t.Fatalf("Failed to unmarshal updated config: %v", err)
|
||||
}
|
||||
if raw["model"] != "claude-sonnet-4-6" {
|
||||
t.Errorf("on-disk model: got %v, want %q", raw["model"], "claude-sonnet-4-6")
|
||||
}
|
||||
if raw["id"] != "bot1" {
|
||||
t.Errorf("on-disk id should be preserved: got %v, want %q", raw["id"], "bot1")
|
||||
}
|
||||
|
||||
// Error case: empty ConfigFilePath must return an error
|
||||
noPath := BotConfig{Model: "claude-v1"}
|
||||
if err := noPath.PersistModel("claude-sonnet-4-6"); err == nil {
|
||||
t.Error("PersistModel with empty ConfigFilePath: expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
+2
-52
@@ -25,7 +25,7 @@ func initDB() (*gorm.DB, error) {
|
||||
},
|
||||
)
|
||||
|
||||
db, err := gorm.Open(sqlite.Open("data/bot.db?_journal_mode=WAL&_busy_timeout=5000&_foreign_keys=on"), &gorm.Config{
|
||||
db, err := gorm.Open(sqlite.Open("data/bot.db?_journal_mode=WAL&_busy_timeout=5000"), &gorm.Config{
|
||||
Logger: newLogger,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -39,7 +39,7 @@ func initDB() (*gorm.DB, error) {
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
|
||||
// AutoMigrate the models
|
||||
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{}, &Scope{})
|
||||
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to migrate database schema: %w", err)
|
||||
}
|
||||
@@ -59,59 +59,9 @@ func initDB() (*gorm.DB, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := createDefaultScopes(db); err != nil {
|
||||
return nil, fmt.Errorf("createDefaultScopes: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func createDefaultScopes(db *gorm.DB) error {
|
||||
all := []string{
|
||||
ScopeStatsViewOwn, ScopeStatsViewAny,
|
||||
ScopeHistoryClearOwn, ScopeHistoryClearAny,
|
||||
ScopeHistoryClearHardOwn, ScopeHistoryClearHardAny,
|
||||
ScopeModelSet, ScopeUserPromote,
|
||||
}
|
||||
for _, name := range all {
|
||||
if err := db.FirstOrCreate(&Scope{}, Scope{Name: name}).Error; err != nil {
|
||||
return fmt.Errorf("failed to create scope %s: %w", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
userScopes := []string{
|
||||
ScopeStatsViewOwn,
|
||||
ScopeHistoryClearOwn,
|
||||
ScopeHistoryClearHardOwn,
|
||||
}
|
||||
elevatedScopes := []string{
|
||||
ScopeStatsViewOwn, ScopeStatsViewAny,
|
||||
ScopeHistoryClearOwn, ScopeHistoryClearAny,
|
||||
ScopeHistoryClearHardOwn, ScopeHistoryClearHardAny,
|
||||
ScopeModelSet, ScopeUserPromote,
|
||||
}
|
||||
assignments := map[string][]string{
|
||||
"user": userScopes,
|
||||
"admin": elevatedScopes,
|
||||
// owner gets the same scopes as admin; owner uniqueness is enforced by the IsOwner flag
|
||||
"owner": elevatedScopes,
|
||||
}
|
||||
for roleName, scopes := range assignments {
|
||||
var role Role
|
||||
if err := db.Where("name = ?", roleName).First(&role).Error; err != nil {
|
||||
return fmt.Errorf("role %s not found: %w", roleName, err)
|
||||
}
|
||||
var scopeModels []Scope
|
||||
if err := db.Where("name IN ?", scopes).Find(&scopeModels).Error; err != nil {
|
||||
return fmt.Errorf("failed to find scopes for %s: %w", roleName, err)
|
||||
}
|
||||
if err := db.Model(&role).Association("Scopes").Replace(scopeModels); err != nil {
|
||||
return fmt.Errorf("failed to assign scopes to %s: %w", roleName, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func createDefaultRoles(db *gorm.DB) error {
|
||||
roles := []string{"user", "admin", "owner"}
|
||||
for _, roleName := range roles {
|
||||
|
||||
+67
-87
@@ -2,7 +2,6 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -12,20 +11,6 @@ import (
|
||||
"github.com/liushuangls/go-anthropic/v2"
|
||||
)
|
||||
|
||||
// anthropicErrorResponse returns the message to send back to the user when getAnthropicResponse
|
||||
// fails. Admins and owners receive an actionable hint when the model is deprecated; regular users
|
||||
// always get the generic fallback to avoid leaking internal details.
|
||||
func (b *Bot) anthropicErrorResponse(err error, userID int64) string {
|
||||
if errors.Is(err, ErrModelNotFound) && b.hasScope(userID, ScopeModelSet) {
|
||||
return fmt.Sprintf(
|
||||
"⚠️ Model `%s` is no longer available (deprecated or removed by Anthropic).\n"+
|
||||
"Use /set_model <model-id> to switch. Current models: https://platform.claude.com/docs/en/about-claude/models/overview",
|
||||
b.config.Model,
|
||||
)
|
||||
}
|
||||
return "I'm sorry, I'm having trouble processing your request right now."
|
||||
}
|
||||
|
||||
func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.Update) {
|
||||
var message *models.Message
|
||||
|
||||
@@ -62,7 +47,7 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
|
||||
messageTime := message.Date
|
||||
text := message.Text
|
||||
|
||||
// Check if it's a new chat (before storing the message so the flag is accurate).
|
||||
// Check if it's a new chat
|
||||
isNewChatFlag := b.isNewChat(chatID)
|
||||
|
||||
// Screen incoming message (store to DB + add to chat memory)
|
||||
@@ -79,14 +64,33 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
|
||||
isOwner = true
|
||||
}
|
||||
|
||||
// Always create/get the user record — on the very first message and on all subsequent ones.
|
||||
// Get the chat memory which now contains the user's message
|
||||
chatMemory := b.getOrCreateChatMemory(chatID)
|
||||
contextMessages := b.prepareContextMessages(chatMemory)
|
||||
|
||||
if isNewChatFlag {
|
||||
|
||||
// Get response from Anthropic using the context messages
|
||||
response, err := b.getAnthropicResponse(ctx, contextMessages, true, isOwner, false, username, firstName, lastName, isPremium, languageCode, messageTime)
|
||||
if err != nil {
|
||||
ErrorLogger.Printf("Error getting Anthropic response: %v", err)
|
||||
// Use the same error message as in the non-new chat case
|
||||
response = "I'm sorry, I'm having trouble processing your request right now."
|
||||
}
|
||||
|
||||
// Send the AI-generated response or error message
|
||||
if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil {
|
||||
ErrorLogger.Printf("Error sending response: %v", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
user, err := b.getOrCreateUser(userID, username, isOwner)
|
||||
if err != nil {
|
||||
ErrorLogger.Printf("Error getting or creating user: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update the username if it has changed
|
||||
// Update the username if it's empty or has changed
|
||||
if user.Username != username {
|
||||
user.Username = username
|
||||
if err := b.db.Save(&user).Error; err != nil {
|
||||
@@ -94,7 +98,7 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the message is a command — applies on every message, including the very first.
|
||||
// Check if the message is a command
|
||||
if message.Entities != nil {
|
||||
for _, entity := range message.Entities {
|
||||
if entity.Type == "bot_command" {
|
||||
@@ -166,38 +170,6 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
|
||||
}
|
||||
b.clearChatHistory(ctx, chatID, userID, targetUserID, targetChatID, businessConnectionID, false)
|
||||
return
|
||||
case "/set_model":
|
||||
if !b.hasScope(userID, ScopeModelSet) {
|
||||
if err := b.sendResponse(ctx, chatID, "Permission denied. Only admins and owners can change the model.", businessConnectionID); err != nil {
|
||||
ErrorLogger.Printf("Error sending response: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
parts := strings.Fields(message.Text)
|
||||
if len(parts) < 2 || strings.TrimSpace(parts[1]) == "" {
|
||||
if err := b.sendResponse(ctx, chatID, "Usage: /set_model <model-id>", businessConnectionID); err != nil {
|
||||
ErrorLogger.Printf("Error sending response: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
newModel := strings.TrimSpace(parts[1])
|
||||
// No upfront model validation:
|
||||
// - The go-anthropic library constants are not enumerable at runtime (Go has no const reflection).
|
||||
// - A live /v1/models probe would add a network round-trip and show in the API audit log.
|
||||
// - An invalid model ID will produce a 404 on the next real message, which routes through
|
||||
// anthropicErrorResponse and already delivers an actionable admin-facing hint.
|
||||
if err := b.config.PersistModel(newModel); err != nil {
|
||||
ErrorLogger.Printf("Failed to persist model change: %v", err)
|
||||
if err := b.sendResponse(ctx, chatID, fmt.Sprintf("Model updated in memory to `%s`, but failed to save to config file: %v", newModel, err), businessConnectionID); err != nil {
|
||||
ErrorLogger.Printf("Error sending response: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
InfoLogger.Printf("Model changed to %s by user %d", newModel, userID)
|
||||
if err := b.sendResponse(ctx, chatID, fmt.Sprintf("✅ Model updated to `%s` and saved to config.", newModel), businessConnectionID); err != nil {
|
||||
ErrorLogger.Printf("Error sending response: %v", err)
|
||||
}
|
||||
return
|
||||
case "/clear_hard":
|
||||
parts := strings.Fields(message.Text)
|
||||
var targetUserID, targetChatID int64
|
||||
@@ -230,19 +202,15 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
|
||||
}
|
||||
}
|
||||
|
||||
// Rate limit check applies to all message types including stickers.
|
||||
// Rate limit check applies to all message types including stickers
|
||||
if !b.checkRateLimits(userID) {
|
||||
b.sendRateLimitExceededMessage(ctx, chatID, businessConnectionID)
|
||||
return
|
||||
}
|
||||
|
||||
// Build context once — shared by the sticker and text response paths.
|
||||
chatMemory := b.getOrCreateChatMemory(chatID)
|
||||
contextMessages := b.prepareContextMessages(chatMemory)
|
||||
|
||||
// Check if the message contains a sticker
|
||||
if message.Sticker != nil {
|
||||
b.handleStickerMessage(ctx, chatID, userMsg, message, contextMessages, businessConnectionID)
|
||||
b.handleStickerMessage(ctx, chatID, userMsg, message, businessConnectionID)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -255,11 +223,15 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
|
||||
// Determine if the text contains only emojis
|
||||
isEmojiOnly := isOnlyEmojis(text)
|
||||
|
||||
// Prepare context messages for Anthropic
|
||||
chatMemory := b.getOrCreateChatMemory(chatID)
|
||||
contextMessages := b.prepareContextMessages(chatMemory)
|
||||
|
||||
// Get response from Anthropic
|
||||
response, err := b.getAnthropicResponse(ctx, contextMessages, isNewChatFlag, isOwner, isEmojiOnly, username, firstName, lastName, isPremium, languageCode, messageTime)
|
||||
response, err := b.getAnthropicResponse(ctx, contextMessages, false, isOwner, isEmojiOnly, username, firstName, lastName, isPremium, languageCode, messageTime) // isNewChat is false here
|
||||
if err != nil {
|
||||
ErrorLogger.Printf("Error getting Anthropic response: %v", err)
|
||||
response = b.anthropicErrorResponse(err, userID)
|
||||
response = "I'm sorry, I'm having trouble processing your request right now."
|
||||
}
|
||||
|
||||
// Send the response
|
||||
@@ -267,6 +239,7 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
|
||||
ErrorLogger.Printf("Error sending response: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Bot) sendRateLimitExceededMessage(ctx context.Context, chatID int64, businessConnectionID string) {
|
||||
@@ -275,11 +248,11 @@ func (b *Bot) sendRateLimitExceededMessage(ctx context.Context, chatID int64, bu
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Bot) handleStickerMessage(ctx context.Context, chatID int64, userMessage Message, message *models.Message, contextMessages []anthropic.Message, businessConnectionID string) {
|
||||
func (b *Bot) handleStickerMessage(ctx context.Context, chatID int64, userMessage Message, message *models.Message, businessConnectionID string) {
|
||||
// userMessage was already screened (stored + added to memory) by handleUpdate — do not call screenIncomingMessage again.
|
||||
|
||||
// Generate AI response about the sticker
|
||||
response, err := b.generateStickerResponse(ctx, userMessage, contextMessages)
|
||||
response, err := b.generateStickerResponse(ctx, userMessage)
|
||||
if err != nil {
|
||||
ErrorLogger.Printf("Error generating sticker response: %v", err)
|
||||
// Provide a fallback dynamic response based on sticker type
|
||||
@@ -299,15 +272,35 @@ func (b *Bot) handleStickerMessage(ctx context.Context, chatID int64, userMessag
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Bot) generateStickerResponse(ctx context.Context, message Message, contextMessages []anthropic.Message) (string, error) {
|
||||
// contextMessages already contains the sticker turn (added by screenIncomingMessage as
|
||||
// "Sent a sticker: <emoji>"), so the full conversation history is preserved.
|
||||
func (b *Bot) generateStickerResponse(ctx context.Context, message Message) (string, error) {
|
||||
// Example: Use the sticker type to generate a response
|
||||
if message.StickerFileID != "" {
|
||||
// Create message content with emoji information if available
|
||||
var messageContent string
|
||||
if message.StickerEmoji != "" {
|
||||
messageContent = fmt.Sprintf("User sent a sticker: %s", message.StickerEmoji)
|
||||
} else {
|
||||
messageContent = "User sent a sticker."
|
||||
}
|
||||
|
||||
// Prepare context with information about the sticker
|
||||
contextMessages := []anthropic.Message{
|
||||
{
|
||||
Role: anthropic.RoleUser,
|
||||
Content: []anthropic.MessageContent{
|
||||
anthropic.NewTextMessageContent(messageContent),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Treat sticker messages like emoji messages to get emoji responses
|
||||
// Convert the timestamp to Unix time for the messageTime parameter
|
||||
messageTime := int(message.Timestamp.Unix())
|
||||
response, err := b.getAnthropicResponse(ctx, contextMessages, false, false, true, message.Username, "", "", false, "", messageTime)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
@@ -317,11 +310,8 @@ func (b *Bot) generateStickerResponse(ctx context.Context, message Message, cont
|
||||
func (b *Bot) clearChatHistory(ctx context.Context, chatID int64, currentUserID int64, targetUserID int64, targetChatID int64, businessConnectionID string, hardDelete bool) {
|
||||
// If targetUserID is provided and different from currentUserID, check permissions
|
||||
if targetUserID != 0 && targetUserID != currentUserID {
|
||||
requiredScope := ScopeHistoryClearAny
|
||||
if hardDelete {
|
||||
requiredScope = ScopeHistoryClearHardAny
|
||||
}
|
||||
if !b.hasScope(currentUserID, requiredScope) {
|
||||
// Check if the current user is an admin or owner
|
||||
if !b.isAdminOrOwner(currentUserID) {
|
||||
InfoLogger.Printf("User %d attempted to clear history for user %d without permission", currentUserID, targetUserID)
|
||||
if err := b.sendResponse(ctx, chatID, "Permission denied. Only admins and owners can clear other users' histories.", businessConnectionID); err != nil {
|
||||
ErrorLogger.Printf("Error sending response: %v", err)
|
||||
@@ -360,42 +350,32 @@ func (b *Bot) clearChatHistory(ctx context.Context, chatID int64, currentUserID
|
||||
if hardDelete {
|
||||
// Permanently delete messages
|
||||
if targetUserID == currentUserID {
|
||||
// Own history — delete ALL messages (user + assistant) in the current chat.
|
||||
err = b.db.Unscoped().Where("chat_id = ? AND bot_id = ?", chatID, b.botID).Delete(&Message{}).Error
|
||||
// Deleting own messages — scope to the current chat only.
|
||||
err = b.db.Unscoped().Where("chat_id = ? AND bot_id = ? AND user_id = ?", chatID, b.botID, targetUserID).Delete(&Message{}).Error
|
||||
InfoLogger.Printf("User %d permanently deleted their own chat history in chat %d", currentUserID, chatID)
|
||||
} else {
|
||||
// Deleting another user's messages — scope bot-wide by default; chat-scoped if targetChatID given (see above).
|
||||
if targetChatID != 0 {
|
||||
// Chat-scoped: delete ALL messages (user + assistant) in the specified chat.
|
||||
err = b.db.Unscoped().Where("chat_id = ? AND bot_id = ?", targetChatID, b.botID).Delete(&Message{}).Error
|
||||
err = b.db.Unscoped().Where("chat_id = ? AND bot_id = ? AND user_id = ?", targetChatID, b.botID, targetUserID).Delete(&Message{}).Error
|
||||
InfoLogger.Printf("Admin/owner %d permanently deleted chat history for user %d in chat %d", currentUserID, targetUserID, targetChatID)
|
||||
} else {
|
||||
// Bot-wide: delete all of the user's own messages across every chat, then delete
|
||||
// assistant messages from their DM chat (where chat_id == user_id by Telegram convention).
|
||||
err = b.db.Unscoped().Where("bot_id = ? AND user_id = ?", b.botID, targetUserID).Delete(&Message{}).Error
|
||||
if err == nil {
|
||||
err = b.db.Unscoped().Where("chat_id = ? AND bot_id = ? AND is_user = ?", targetUserID, b.botID, false).Delete(&Message{}).Error
|
||||
}
|
||||
InfoLogger.Printf("Admin/owner %d permanently deleted all chat history for user %d", currentUserID, targetUserID)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Soft delete messages
|
||||
if targetUserID == currentUserID {
|
||||
// Own history — delete ALL messages (user + assistant) in the current chat.
|
||||
err = b.db.Where("chat_id = ? AND bot_id = ?", chatID, b.botID).Delete(&Message{}).Error
|
||||
// Deleting own messages — scope to the current chat only.
|
||||
err = b.db.Where("chat_id = ? AND bot_id = ? AND user_id = ?", chatID, b.botID, targetUserID).Delete(&Message{}).Error
|
||||
InfoLogger.Printf("User %d soft deleted their own chat history in chat %d", currentUserID, chatID)
|
||||
} else {
|
||||
// Deleting another user's messages — scope bot-wide by default; chat-scoped if targetChatID given (see above).
|
||||
if targetChatID != 0 {
|
||||
// Chat-scoped: delete ALL messages (user + assistant) in the specified chat.
|
||||
err = b.db.Where("chat_id = ? AND bot_id = ?", targetChatID, b.botID).Delete(&Message{}).Error
|
||||
err = b.db.Where("chat_id = ? AND bot_id = ? AND user_id = ?", targetChatID, b.botID, targetUserID).Delete(&Message{}).Error
|
||||
InfoLogger.Printf("Admin/owner %d soft deleted chat history for user %d in chat %d", currentUserID, targetUserID, targetChatID)
|
||||
} else {
|
||||
// Bot-wide: delete all of the user's own messages across every chat, then delete
|
||||
// assistant messages from their DM chat (where chat_id == user_id by Telegram convention).
|
||||
err = b.db.Where("bot_id = ? AND user_id = ?", b.botID, targetUserID).Delete(&Message{}).Error
|
||||
if err == nil {
|
||||
err = b.db.Where("chat_id = ? AND bot_id = ? AND is_user = ?", targetUserID, b.botID, false).Delete(&Message{}).Error
|
||||
}
|
||||
InfoLogger.Printf("Admin/owner %d soft deleted all chat history for user %d", currentUserID, targetUserID)
|
||||
}
|
||||
}
|
||||
|
||||
+2
-267
@@ -2,10 +2,6 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -615,277 +611,16 @@ func setupTestDB(t *testing.T) *gorm.DB {
|
||||
}
|
||||
|
||||
// AutoMigrate the models
|
||||
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{}, &Scope{})
|
||||
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to migrate database schema: %v", err)
|
||||
}
|
||||
|
||||
// Create default roles and scopes
|
||||
// Create default roles
|
||||
err = createDefaultRoles(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create default roles: %v", err)
|
||||
}
|
||||
if err := createDefaultScopes(db); err != nil {
|
||||
t.Fatalf("Failed to create default scopes: %v", err)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// setupBotForTest creates a minimal Bot instance backed by an in-memory DB.
|
||||
// It follows the same pattern as the existing handler tests to avoid duplication.
|
||||
func setupBotForTest(t *testing.T, ownerID int64) (*Bot, *MockTelegramClient) {
|
||||
t.Helper()
|
||||
db := setupTestDB(t)
|
||||
mockClock := &MockClock{currentTime: time.Now()}
|
||||
config := BotConfig{
|
||||
ID: "test_bot",
|
||||
OwnerTelegramID: ownerID,
|
||||
TelegramToken: "test_token",
|
||||
MemorySize: 10,
|
||||
MessagePerHour: 5,
|
||||
MessagePerDay: 10,
|
||||
TempBanDuration: "1h",
|
||||
Model: "claude-3-5-haiku-latest",
|
||||
SystemPrompts: make(map[string]string),
|
||||
Active: true,
|
||||
}
|
||||
mockTgClient := &MockTelegramClient{}
|
||||
botModel := &BotModel{Identifier: config.ID, Name: config.ID}
|
||||
assert.NoError(t, db.Create(botModel).Error)
|
||||
assert.NoError(t, db.Create(&ConfigModel{
|
||||
BotID: botModel.ID,
|
||||
MemorySize: config.MemorySize,
|
||||
MessagePerHour: config.MessagePerHour,
|
||||
MessagePerDay: config.MessagePerDay,
|
||||
TempBanDuration: config.TempBanDuration,
|
||||
SystemPrompts: "{}",
|
||||
TelegramToken: config.TelegramToken,
|
||||
Active: config.Active,
|
||||
}).Error)
|
||||
b, err := NewBot(db, config, mockClock, mockTgClient)
|
||||
assert.NoError(t, err)
|
||||
return b, mockTgClient
|
||||
}
|
||||
|
||||
// TestAnthropicErrorResponse verifies that model-deprecation errors surface actionable
|
||||
// details only to admin/owner, and that regular users and non-model errors always get
|
||||
// the generic fallback.
|
||||
func TestAnthropicErrorResponse(t *testing.T) { //NOSONAR go:S100 -- underscore separation is idiomatic in Go test names
|
||||
b, _ := setupBotForTest(t, 123)
|
||||
|
||||
// Create admin user
|
||||
adminRole, err := b.getRoleByName("admin")
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, b.db.Create(&User{
|
||||
BotID: b.botID, TelegramID: 456, Username: "admin",
|
||||
RoleID: adminRole.ID, Role: adminRole,
|
||||
}).Error)
|
||||
|
||||
// Create regular user
|
||||
userRole, err := b.getRoleByName("user")
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, b.db.Create(&User{
|
||||
BotID: b.botID, TelegramID: 789, Username: "regular",
|
||||
RoleID: userRole.ID, Role: userRole,
|
||||
}).Error)
|
||||
|
||||
modelErr := fmt.Errorf("%w: claude-3-5-haiku-latest", ErrModelNotFound)
|
||||
otherErr := errors.New("network error")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
userID int64
|
||||
wantSubstr string
|
||||
wantMissing string
|
||||
}{
|
||||
{
|
||||
name: "owner receives actionable model-not-found message",
|
||||
err: modelErr,
|
||||
userID: 123,
|
||||
wantSubstr: "/set_model",
|
||||
},
|
||||
{
|
||||
name: "admin receives actionable model-not-found message",
|
||||
err: modelErr,
|
||||
userID: 456,
|
||||
wantSubstr: "/set_model",
|
||||
},
|
||||
{
|
||||
name: "regular user receives generic message for model-not-found",
|
||||
err: modelErr,
|
||||
userID: 789,
|
||||
wantSubstr: "I'm sorry",
|
||||
wantMissing: "/set_model",
|
||||
},
|
||||
{
|
||||
name: "owner receives generic message for non-model error",
|
||||
err: otherErr,
|
||||
userID: 123,
|
||||
wantSubstr: "I'm sorry",
|
||||
wantMissing: "/set_model",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp := b.anthropicErrorResponse(tc.err, tc.userID)
|
||||
assert.Contains(t, resp, tc.wantSubstr)
|
||||
if tc.wantMissing != "" {
|
||||
assert.NotContains(t, resp, tc.wantMissing)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetModelCommand verifies that /set_model enforces permissions, validates input,
|
||||
// updates the model in memory, and persists the change to the config file on disk.
|
||||
func TestSetModelCommand(t *testing.T) { //NOSONAR go:S100 -- underscore separation is idiomatic in Go test names
|
||||
b, mockTgClient := setupBotForTest(t, 123)
|
||||
|
||||
// Point the config at a temporary file so PersistModel can write to disk.
|
||||
tempDir, err := os.MkdirTemp("", "set_model_cmd_test")
|
||||
assert.NoError(t, err)
|
||||
defer func() { _ = os.RemoveAll(tempDir) }()
|
||||
|
||||
configPath := filepath.Join(tempDir, "config.json")
|
||||
initialJSON := `{"id":"test_bot","telegram_token":"test_token","model":"claude-3-5-haiku-latest","messages_per_hour":5,"messages_per_day":10}`
|
||||
assert.NoError(t, os.WriteFile(configPath, []byte(initialJSON), 0600))
|
||||
b.config.ConfigFilePath = configPath
|
||||
|
||||
// Create admin and regular users
|
||||
adminRole, err := b.getRoleByName("admin")
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, b.db.Create(&User{
|
||||
BotID: b.botID, TelegramID: 456, Username: "admin",
|
||||
RoleID: adminRole.ID, Role: adminRole,
|
||||
}).Error)
|
||||
userRole, err := b.getRoleByName("user")
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, b.db.Create(&User{
|
||||
BotID: b.botID, TelegramID: 789, Username: "regular",
|
||||
RoleID: userRole.ID, Role: userRole,
|
||||
}).Error)
|
||||
|
||||
chatID := int64(1000)
|
||||
|
||||
// Seed chat 1000 with a prior message so isNewChatFlag is false for all subtests.
|
||||
// Commands are only processed in the non-new-chat branch of handleUpdate.
|
||||
assert.NoError(t, b.db.Create(&Message{
|
||||
BotID: b.botID, ChatID: chatID, UserID: 789, Username: "regular",
|
||||
UserRole: "user", Text: "hello", IsUser: true,
|
||||
}).Error)
|
||||
|
||||
makeUpdate := func(userID int64, text string, cmdLen int) *models.Update {
|
||||
return &models.Update{
|
||||
Message: &models.Message{
|
||||
Chat: models.Chat{ID: chatID},
|
||||
From: &models.User{ID: userID, Username: getUsernameByID(userID)},
|
||||
Text: text,
|
||||
Entities: []models.MessageEntity{
|
||||
{Type: "bot_command", Offset: 0, Length: cmdLen},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
text string
|
||||
wantSubstr string
|
||||
}{
|
||||
{
|
||||
name: "regular user is denied",
|
||||
userID: 789,
|
||||
text: "/set_model claude-sonnet-4-6",
|
||||
wantSubstr: "Permission denied",
|
||||
},
|
||||
{
|
||||
name: "admin missing argument shows usage",
|
||||
userID: 456,
|
||||
text: "/set_model",
|
||||
wantSubstr: "Usage:",
|
||||
},
|
||||
{
|
||||
name: "owner missing argument shows usage",
|
||||
userID: 123,
|
||||
text: "/set_model",
|
||||
wantSubstr: "Usage:",
|
||||
},
|
||||
{
|
||||
name: "admin sets model successfully",
|
||||
userID: 456,
|
||||
text: "/set_model claude-sonnet-4-6",
|
||||
wantSubstr: "✅",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var sentMessage string
|
||||
mockTgClient.SendMessageFunc = func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) {
|
||||
sentMessage = params.Text
|
||||
return &models.Message{}, nil
|
||||
}
|
||||
b.handleUpdate(context.Background(), nil, makeUpdate(tc.userID, tc.text, 10))
|
||||
assert.Contains(t, sentMessage, tc.wantSubstr)
|
||||
})
|
||||
}
|
||||
|
||||
// Verify the successful update took effect in memory and on disk.
|
||||
t.Run("model change persisted in memory and on disk", func(t *testing.T) {
|
||||
assert.Equal(t, "claude-sonnet-4-6", string(b.config.Model))
|
||||
data, err := os.ReadFile(configPath)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"claude-sonnet-4-6"`)
|
||||
})
|
||||
}
|
||||
|
||||
// TestHasScope verifies that scope checks honour role assignments and the owner bypass.
|
||||
func TestHasScope(t *testing.T) { //NOSONAR go:S100 -- underscore separation is idiomatic in Go test names
|
||||
const ownerID int64 = 100
|
||||
b, _ := setupBotForTest(t, ownerID)
|
||||
|
||||
// Admin user
|
||||
adminRole, err := b.getRoleByName("admin")
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, b.db.Create(&User{
|
||||
BotID: b.botID, TelegramID: 200, Username: "admin_user",
|
||||
RoleID: adminRole.ID, Role: adminRole,
|
||||
}).Error)
|
||||
|
||||
// Regular user
|
||||
userRole, err := b.getRoleByName("user")
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, b.db.Create(&User{
|
||||
BotID: b.botID, TelegramID: 300, Username: "regular_user",
|
||||
RoleID: userRole.ID, Role: userRole,
|
||||
}).Error)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
scope string
|
||||
want bool
|
||||
}{
|
||||
{"owner bypass: model:set", ownerID, ScopeModelSet, true},
|
||||
{"owner bypass: stats:view:any", ownerID, ScopeStatsViewAny, true},
|
||||
{"admin: model:set", 200, ScopeModelSet, true},
|
||||
{"admin: stats:view:any", 200, ScopeStatsViewAny, true},
|
||||
{"admin: history:clear:any", 200, ScopeHistoryClearAny, true},
|
||||
{"user: model:set denied", 300, ScopeModelSet, false},
|
||||
{"user: stats:view:any denied", 300, ScopeStatsViewAny, false},
|
||||
{"user: history:clear:any denied", 300, ScopeHistoryClearAny, false},
|
||||
{"user: stats:view:own allowed", 300, ScopeStatsViewOwn, true},
|
||||
{"user: history:clear:own allowed", 300, ScopeHistoryClearOwn, true},
|
||||
{"unknown telegram_id", 999, ScopeModelSet, false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, b.hasScope(tc.userID, tc.scope))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,27 +50,9 @@ type ChatMemory struct {
|
||||
BusinessConnectionID string // New field to store the business connection ID
|
||||
}
|
||||
|
||||
// Scope name constants — used in DB seeding, hasScope checks, and tests.
|
||||
const (
|
||||
ScopeStatsViewOwn = "stats:view:own"
|
||||
ScopeStatsViewAny = "stats:view:any"
|
||||
ScopeHistoryClearOwn = "history:clear:own"
|
||||
ScopeHistoryClearAny = "history:clear:any"
|
||||
ScopeHistoryClearHardOwn = "history:clear_hard:own"
|
||||
ScopeHistoryClearHardAny = "history:clear_hard:any"
|
||||
ScopeModelSet = "model:set"
|
||||
ScopeUserPromote = "user:promote"
|
||||
)
|
||||
|
||||
type Scope struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"uniqueIndex"`
|
||||
}
|
||||
|
||||
type Role struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"uniqueIndex"`
|
||||
Scopes []Scope `gorm:"many2many:role_scopes;"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
|
||||
+1
-1
@@ -11,6 +11,6 @@ import (
|
||||
// TelegramClient defines the methods required from the Telegram bot.
|
||||
type TelegramClient interface {
|
||||
SendMessage(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error)
|
||||
SetMyCommands(ctx context.Context, params *bot.SetMyCommandsParams) (bool, error)
|
||||
Start(ctx context.Context)
|
||||
// Add other methods if needed.
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
type MockTelegramClient struct {
|
||||
mock.Mock
|
||||
SendMessageFunc func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error)
|
||||
SetMyCommandsFunc func(ctx context.Context, params *bot.SetMyCommandsParams) (bool, error)
|
||||
StartFunc func(ctx context.Context)
|
||||
}
|
||||
|
||||
@@ -29,14 +28,6 @@ func (m *MockTelegramClient) SendMessage(ctx context.Context, params *bot.SendMe
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
// SetMyCommands mocks registering bot commands.
|
||||
func (m *MockTelegramClient) SetMyCommands(ctx context.Context, params *bot.SetMyCommandsParams) (bool, error) {
|
||||
if m.SetMyCommandsFunc != nil {
|
||||
return m.SetMyCommandsFunc(ctx, params)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Start mocks starting the Telegram client.
|
||||
func (m *MockTelegramClient) Start(ctx context.Context) {
|
||||
if m.StartFunc != nil {
|
||||
|
||||
+21
-39
@@ -12,38 +12,26 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
errOpenDB = "Failed to open in-memory database: %v"
|
||||
errMigrateSchema = "Failed to migrate database schema: %v"
|
||||
errCreateRoles = "Failed to create default roles: %v"
|
||||
errCreateScopes = "Failed to create default scopes: %v"
|
||||
errCreateBot = "Failed to create bot: %v"
|
||||
memoryDSN = ":memory:"
|
||||
)
|
||||
|
||||
func TestOwnerAssignment(t *testing.T) {
|
||||
// Initialize loggers
|
||||
initLoggers()
|
||||
|
||||
// Initialize in-memory database for testing
|
||||
db, err := gorm.Open(sqlite.Open(memoryDSN), &gorm.Config{})
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf(errOpenDB, err)
|
||||
t.Fatalf("Failed to open in-memory database: %v", err)
|
||||
}
|
||||
|
||||
// Migrate the schema
|
||||
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{}, &Scope{})
|
||||
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{})
|
||||
if err != nil {
|
||||
t.Fatalf(errMigrateSchema, err)
|
||||
t.Fatalf("Failed to migrate database schema: %v", err)
|
||||
}
|
||||
|
||||
// Create default roles and scopes
|
||||
// Create default roles
|
||||
err = createDefaultRoles(db)
|
||||
if err != nil {
|
||||
t.Fatalf(errCreateRoles, err)
|
||||
}
|
||||
if err := createDefaultScopes(db); err != nil {
|
||||
t.Fatalf(errCreateScopes, err)
|
||||
t.Fatalf("Failed to create default roles: %v", err)
|
||||
}
|
||||
|
||||
// Create a bot configuration
|
||||
@@ -79,7 +67,7 @@ func TestOwnerAssignment(t *testing.T) {
|
||||
// Create the bot with the mock Telegram client
|
||||
bot, err := NewBot(db, config, mockClock, mockTGClient)
|
||||
if err != nil {
|
||||
t.Fatalf(errCreateBot, err)
|
||||
t.Fatalf("Failed to create bot: %v", err)
|
||||
}
|
||||
|
||||
// Verify that the owner exists
|
||||
@@ -131,24 +119,21 @@ func TestPromoteUserToAdmin(t *testing.T) {
|
||||
initLoggers()
|
||||
|
||||
// Initialize in-memory database for testing
|
||||
db, err := gorm.Open(sqlite.Open(memoryDSN), &gorm.Config{})
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf(errOpenDB, err)
|
||||
t.Fatalf("Failed to open in-memory database: %v", err)
|
||||
}
|
||||
|
||||
// Migrate the schema
|
||||
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{}, &Scope{})
|
||||
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{})
|
||||
if err != nil {
|
||||
t.Fatalf(errMigrateSchema, err)
|
||||
t.Fatalf("Failed to migrate database schema: %v", err)
|
||||
}
|
||||
|
||||
// Create default roles and scopes
|
||||
// Create default roles
|
||||
err = createDefaultRoles(db)
|
||||
if err != nil {
|
||||
t.Fatalf(errCreateRoles, err)
|
||||
}
|
||||
if err := createDefaultScopes(db); err != nil {
|
||||
t.Fatalf(errCreateScopes, err)
|
||||
t.Fatalf("Failed to create default roles: %v", err)
|
||||
}
|
||||
|
||||
config := BotConfig{
|
||||
@@ -168,7 +153,7 @@ func TestPromoteUserToAdmin(t *testing.T) {
|
||||
|
||||
bot, err := NewBot(db, config, mockClock, mockTGClient)
|
||||
if err != nil {
|
||||
t.Fatalf(errCreateBot, err)
|
||||
t.Fatalf("Failed to create bot: %v", err)
|
||||
}
|
||||
|
||||
// Create an owner
|
||||
@@ -207,24 +192,21 @@ func TestGetOrCreateUser(t *testing.T) {
|
||||
initLoggers()
|
||||
|
||||
// Initialize in-memory database for testing
|
||||
db, err := gorm.Open(sqlite.Open(memoryDSN), &gorm.Config{})
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf(errOpenDB, err)
|
||||
t.Fatalf("Failed to open in-memory database: %v", err)
|
||||
}
|
||||
|
||||
// Migrate the schema
|
||||
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{}, &Scope{})
|
||||
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{})
|
||||
if err != nil {
|
||||
t.Fatalf(errMigrateSchema, err)
|
||||
t.Fatalf("Failed to migrate database schema: %v", err)
|
||||
}
|
||||
|
||||
// Create default roles and scopes
|
||||
// Create default roles
|
||||
err = createDefaultRoles(db)
|
||||
if err != nil {
|
||||
t.Fatalf(errCreateRoles, err)
|
||||
}
|
||||
if err := createDefaultScopes(db); err != nil {
|
||||
t.Fatalf(errCreateScopes, err)
|
||||
t.Fatalf("Failed to create default roles: %v", err)
|
||||
}
|
||||
|
||||
// Create a mock clock starting at a fixed time
|
||||
@@ -259,7 +241,7 @@ func TestGetOrCreateUser(t *testing.T) {
|
||||
// Create the bot with the mock Telegram client
|
||||
bot, err := NewBot(db, config, mockClock, mockTGClient)
|
||||
if err != nil {
|
||||
t.Fatalf(errCreateBot, err)
|
||||
t.Fatalf("Failed to create bot: %v", err)
|
||||
}
|
||||
|
||||
// Verify that the owner exists
|
||||
|
||||
Reference in New Issue
Block a user