This commit is contained in:
HugeFrog24
2026-02-11 18:39:02 +01:00
parent e9fd36b22d
commit 547dc8ca1a
30 changed files with 1936 additions and 229 deletions

View File

@@ -0,0 +1,14 @@
---
alwaysApply: true
always_on: true
trigger: always_on
applyTo: "**"
description: Snyk Security At Inception
---
# Project security best practices
- Always run snyk_code_scan tool for new first party code that is generated in a Snyk-supported language.
- If any security issues are found based on newly introduced or modified code or dependencies, attempt to fix the issues using the results context from Snyk.
- Rescan the code after fixing the issues to ensure that the issues were fixed and that there are no newly introduced issues.
- Repeat this process until no new issues are found.

77
.dockerignore Normal file
View File

@@ -0,0 +1,77 @@
# Git files
.git
.gitignore
.gitattributes
# Documentation
README.md
*.md
# Docker files
Dockerfile
docker-compose.yml
.dockerignore
# Environment files
.env
.env.*
# Log files
*.log
logs/
# Database files
*.db
*.sqlite
*.sqlite3
bot.db
# Config files (except default template)
config/*
!config/default.json
# Test files
*_test.go
test/
tests/
# Build artifacts
telegram-bot
*.exe
*.dll
*.so
*.dylib
# IDE files
.vscode/
.idea/
*.swp
*.swo
*~
# OS files
.DS_Store
Thumbs.db
# Go specific
vendor/
*.mod.backup
*.sum.backup
# Temporary files
tmp/
temp/
*.tmp
# Coverage files
*.out
coverage.html
# CI/CD files
.github/
.gitlab-ci.yml
.travis.yml
# Examples and documentation
examples/
docs/

0
.gitattributes vendored Executable file → Normal file
View File

8
.github/workflows/go-ci.yaml vendored Executable file → Normal file
View File

@@ -14,7 +14,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: '1.23' go-version: '1.24.2'
- uses: actions/cache@v4 - uses: actions/cache@v4
with: with:
path: | path: |
@@ -31,9 +31,9 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: golangci/golangci-lint-action@v6 - uses: golangci/golangci-lint-action@v7
with: with:
version: v1.60 version: v2.0
args: --timeout 5m args: --timeout 5m
# Test job # Test job
@@ -44,7 +44,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: '1.23' go-version: '1.24.2'
- run: go test ./... -v - run: go test ./... -v
# Security scan job # Security scan job

0
.gitignore vendored Executable file → Normal file
View File

3
.roo/mcp.json Normal file
View File

@@ -0,0 +1,3 @@
{
"mcpServers": {}
}

57
Dockerfile Normal file
View File

@@ -0,0 +1,57 @@
# Multi-stage build for Go Telegram Bot
# Build stage
FROM golang:1.24-alpine AS builder
# Install build dependencies including C compiler for CGO
RUN apk add --no-cache git ca-certificates tzdata gcc musl-dev
# Set working directory
WORKDIR /build
# Copy go mod files first for better caching
COPY go.mod go.sum ./
# Download dependencies
RUN go mod download
# Copy source code
COPY . .
# Build the application
RUN CGO_ENABLED=1 GOOS=linux go build -a -installsuffix cgo -o telegram-bot .
# Runtime stage
FROM alpine:latest
# Install runtime dependencies
RUN apk --no-cache add ca-certificates tzdata sqlite
# Create non-root user
RUN addgroup -g 1001 -S appgroup && \
adduser -u 1001 -S appuser -G appgroup
# Set working directory
WORKDIR /app
# Create necessary directories
RUN mkdir -p /app/config /app/data /app/logs && \
chown -R appuser:appgroup /app
# Copy binary from builder stage
COPY --from=builder /build/telegram-bot /app/telegram-bot
# Copy default config as template
COPY --chown=appuser:appgroup config/default.json /app/config/
# Switch to non-root user
USER appuser
# Expose any ports if needed (not required for this bot)
# EXPOSE 8080
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD pgrep telegram-bot || exit 1
# Run the application
CMD ["/app/telegram-bot"]

65
README.md Executable file → Normal file
View File

@@ -12,39 +12,38 @@ A scalable, multi-bot solution for Telegram using Go, GORM, and the Anthropic AP
## Usage ## Usage
1. Clone the repository or install using `go get`: ### Docker Deployment (Recommended)
- Option 1: Clone the repository
1. Clone the repository:
```bash ```bash
git clone https://github.com/HugeFrog24/go-telegram-bot.git git clone https://github.com/HugeFrog24/go-telegram-bot.git
```
- Option 2: Install using go get
```bash
go get -u github.com/HugeFrog24/go-telegram-bot
```
- Navigate to the project directory:
```bash
cd go-telegram-bot cd go-telegram-bot
``` ```
2. Copy the default config template and edit it: 2. Copy the default config template and edit it:
```bash ```bash
cp config/default.json config/config-mybot.json cp config/default.json config/mybot.json
nano config/mybot.json
``` ```
Replace `config-mybot.json` with the name of your bot.
```bash
nano config/config-mybot.json
```
You can set up as many bots as you want. Just copy the template and edit the parameters.
> [!IMPORTANT] > [!IMPORTANT]
> Keep your config files secret and do not commit them to version control. > Keep your config files secret and do not commit them to version control.
3. Build the application: 3. Create data directory and run:
```bash
mkdir -p data
docker-compose up -d
```
### Native Deployment
1. Install using `go get`:
```bash
go get -u github.com/HugeFrog24/go-telegram-bot
cd go-telegram-bot
```
2. Configure as above, then build:
```bash ```bash
go build -o telegram-bot go build -o telegram-bot
``` ```
@@ -76,11 +75,11 @@ To enable the bot to start automatically on system boot and run in the backgroun
``` ```
```bash ```bash
sudo systemctl enable telegram-bot.service sudo systemctl enable telegram-bot
``` ```
```bash ```bash
sudo systemctl start telegram-bot.service sudo systemctl start telegram-bot
``` ```
4. Check the status: 4. Check the status:
@@ -93,22 +92,16 @@ For more details on the systemd setup, refer to the [demo service file](examples
## Logs ## Logs
View logs using journalctl: ### Docker
```bash ```bash
journalctl -u telegram-bot docker-compose logs -f telegram-bot
``` ```
Follow logs: ### Systemd
```bash ```bash
journalctl -u telegram-bot -f journalctl -u telegram-bot -f
``` ```
View errors:
```bash
journalctl -u telegram-bot -p err
```
## Testing ## Testing
The GitHub actions workflow already runs tests on every commit: The GitHub actions workflow already runs tests on every commit:
@@ -119,3 +112,11 @@ However, you can run the tests locally using:
```bash ```bash
go test -race -v ./... go test -race -v ./...
``` ```
## Storage
At the moment, a SQLite database (`./data/bot.db`) is used for persistent storage.
Remember to back it up regularly.
Future versions will support more robust storage backends.

76
anthropic.go Executable file → Normal file
View File

@@ -3,11 +3,13 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"time"
"github.com/liushuangls/go-anthropic/v2" "github.com/liushuangls/go-anthropic/v2"
) )
func (b *Bot) getAnthropicResponse(ctx context.Context, messages []anthropic.Message, isNewChat, isAdminOrOwner, isEmojiOnly bool) (string, error) { func (b *Bot) getAnthropicResponse(ctx context.Context, messages []anthropic.Message, isNewChat, isAdminOrOwner, isEmojiOnly bool, username string, firstName string, lastName string, isPremium bool, languageCode string, messageTime int) (string, error) {
// Use prompts from config // Use prompts from config
var systemMessage string var systemMessage string
if isNewChat { if isNewChat {
@@ -19,6 +21,56 @@ func (b *Bot) getAnthropicResponse(ctx context.Context, messages []anthropic.Mes
// Combine default prompt with custom instructions // Combine default prompt with custom instructions
systemMessage = b.config.SystemPrompts["default"] + " " + b.config.SystemPrompts["custom_instructions"] + " " + systemMessage systemMessage = b.config.SystemPrompts["default"] + " " + b.config.SystemPrompts["custom_instructions"] + " " + systemMessage
// Handle username placeholder
usernameValue := username
if username == "" {
usernameValue = "unknown" // Use "unknown" when username is not available
}
systemMessage = strings.ReplaceAll(systemMessage, "{username}", usernameValue)
// Handle firstname placeholder
firstnameValue := firstName
if firstName == "" {
firstnameValue = "unknown" // Use "unknown" when first name is not available
}
systemMessage = strings.ReplaceAll(systemMessage, "{firstname}", firstnameValue)
// Handle lastname placeholder
lastnameValue := lastName
if lastName == "" {
lastnameValue = "" // Empty string when last name is not available
}
systemMessage = strings.ReplaceAll(systemMessage, "{lastname}", lastnameValue)
// Handle language code placeholder
langValue := languageCode
if languageCode == "" {
langValue = "en" // Default to English when language code is not available
}
systemMessage = strings.ReplaceAll(systemMessage, "{language}", langValue)
// Handle premium status
premiumStatus := "regular user"
if isPremium {
premiumStatus = "premium user"
}
systemMessage = strings.ReplaceAll(systemMessage, "{premium_status}", premiumStatus)
// Handle time awareness
timeObj := time.Unix(int64(messageTime), 0)
hour := timeObj.Hour()
var timeContext string
if hour >= 5 && hour < 12 {
timeContext = "morning"
} else if hour >= 12 && hour < 18 {
timeContext = "afternoon"
} else if hour >= 18 && hour < 22 {
timeContext = "evening"
} else {
timeContext = "night"
}
systemMessage = strings.ReplaceAll(systemMessage, "{time_context}", timeContext)
if !isAdminOrOwner { if !isAdminOrOwner {
systemMessage += " " + b.config.SystemPrompts["avoid_sensitive"] systemMessage += " " + b.config.SystemPrompts["avoid_sensitive"]
} }
@@ -27,6 +79,16 @@ func (b *Bot) getAnthropicResponse(ctx context.Context, messages []anthropic.Mes
systemMessage += " " + b.config.SystemPrompts["respond_with_emojis"] systemMessage += " " + b.config.SystemPrompts["respond_with_emojis"]
} }
// Debug logging
InfoLogger.Printf("Sending %d messages to Anthropic", len(messages))
for i, msg := range messages {
for _, content := range msg.Content {
if content.Type == anthropic.MessagesContentTypeText {
InfoLogger.Printf("Message %d: Role=%v, Text=%v", i, msg.Role, content.Text)
}
}
}
// Ensure the roles are correct // Ensure the roles are correct
for i := range messages { for i := range messages {
switch messages[i].Role { switch messages[i].Role {
@@ -42,12 +104,20 @@ func (b *Bot) getAnthropicResponse(ctx context.Context, messages []anthropic.Mes
model := anthropic.Model(b.config.Model) model := anthropic.Model(b.config.Model)
resp, err := b.anthropicClient.CreateMessages(ctx, anthropic.MessagesRequest{ // Create the request
request := anthropic.MessagesRequest{
Model: model, // Now `model` is of type anthropic.Model Model: model, // Now `model` is of type anthropic.Model
Messages: messages, Messages: messages,
System: systemMessage, System: systemMessage,
MaxTokens: 1000, MaxTokens: 1000,
}) }
// Apply temperature if set in config
if b.config.Temperature != nil {
request.Temperature = b.config.Temperature
}
resp, err := b.anthropicClient.CreateMessages(ctx, request)
if err != nil { if err != nil {
return "", fmt.Errorf("error creating Anthropic message: %w", err) return "", fmt.Errorf("error creating Anthropic message: %w", err)
} }

197
anthropic_test.go Normal file
View File

@@ -0,0 +1,197 @@
package main
import (
"fmt"
"strings"
"testing"
"time"
)
// TestLanguageCodeReplacement tests that language code is properly handled and replaced
func TestLanguageCodeReplacement(t *testing.T) {
// Test with provided language code
systemMessage := "User's language preference: '{language}'"
// Test with a specific language code
langValue := "fr"
result := strings.ReplaceAll(systemMessage, "{language}", langValue)
if !strings.Contains(result, "User's language preference: 'fr'") {
t.Errorf("Expected language code 'fr' to be replaced, got: %s", result)
}
// Test with empty language code (should default to "en")
langValue = ""
if langValue == "" {
langValue = "en" // Default to English when language code is not available
}
result = strings.ReplaceAll(systemMessage, "{language}", langValue)
if !strings.Contains(result, "User's language preference: 'en'") {
t.Errorf("Expected default language code 'en' to be used, got: %s", result)
}
}
// TestPremiumStatusReplacement tests that premium status is properly handled and replaced
func TestPremiumStatusReplacement(t *testing.T) {
systemMessage := "User is a {premium_status}"
// Test with premium user
isPremium := true
premiumStatus := "regular user"
if isPremium {
premiumStatus = "premium user"
}
result := strings.ReplaceAll(systemMessage, "{premium_status}", premiumStatus)
if !strings.Contains(result, "User is a premium user") {
t.Errorf("Expected premium status to be replaced with 'premium user', got: %s", result)
}
// Test with regular user
isPremium = false
premiumStatus = "regular user"
if isPremium {
premiumStatus = "premium user"
}
result = strings.ReplaceAll(systemMessage, "{premium_status}", premiumStatus)
if !strings.Contains(result, "User is a regular user") {
t.Errorf("Expected premium status to be replaced with 'regular user', got: %s", result)
}
}
// TestTimeContextCalculation tests that time context is correctly calculated for different hours
func TestTimeContextCalculation(t *testing.T) {
// Test cases for different hours
testCases := []struct {
hour int
expected string
}{
{3, "night"}, // Night: hours < 5 or hours >= 22
{5, "morning"}, // Morning: 5 <= hours < 12
{12, "afternoon"}, // Afternoon: 12 <= hours < 18
{17, "afternoon"}, // Afternoon: 12 <= hours < 18
{18, "evening"}, // Evening: 18 <= hours < 22
{21, "evening"}, // Evening: 18 <= hours < 22
{22, "night"}, // Night: hours < 5 or hours >= 22
{23, "night"}, // Night: hours < 5 or hours >= 22
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("Hour_%d", tc.hour), func(t *testing.T) {
// Create a timestamp for the specified hour
testTime := time.Date(2025, 5, 15, tc.hour, 0, 0, 0, time.UTC)
// Get the hour directly from the test time to ensure it's what we expect
actualHour := testTime.Hour()
if actualHour != tc.hour {
t.Fatalf("Test setup error: expected hour %d, got %d", tc.hour, actualHour)
}
// Calculate time context using the same logic as in anthropic.go
var timeContext string
if actualHour >= 5 && actualHour < 12 {
timeContext = "morning"
} else if actualHour >= 12 && actualHour < 18 {
timeContext = "afternoon"
} else if actualHour >= 18 && actualHour < 22 {
timeContext = "evening"
} else {
timeContext = "night"
}
// Check if the calculated time context matches the expected value
if timeContext != tc.expected {
t.Errorf("For hour %d: expected time context '%s', got '%s'",
actualHour, tc.expected, timeContext)
}
})
}
}
// TestSystemMessagePlaceholderReplacement tests that all placeholders are correctly replaced
func TestSystemMessagePlaceholderReplacement(t *testing.T) {
systemMessage := "The user you're talking to has username '{username}' and display name '{firstname} {lastname}'.\n" +
"User's language preference: '{language}'\n" +
"User is a {premium_status}\n" +
"It's currently {time_context} in your timezone"
// Set up test data
username := "testuser"
firstName := "Test"
lastName := "User"
isPremium := true
languageCode := "de"
// Create a timestamp for a specific hour (e.g., 14:00 = afternoon)
testTime := time.Date(2025, 5, 15, 14, 0, 0, 0, time.UTC)
messageTime := int(testTime.Unix())
// Handle username placeholder
usernameValue := username
if username == "" {
usernameValue = "unknown"
}
systemMessage = strings.ReplaceAll(systemMessage, "{username}", usernameValue)
// Handle firstname placeholder
firstnameValue := firstName
if firstName == "" {
firstnameValue = "unknown"
}
systemMessage = strings.ReplaceAll(systemMessage, "{firstname}", firstnameValue)
// Handle lastname placeholder
lastnameValue := lastName
if lastName == "" {
lastnameValue = ""
}
systemMessage = strings.ReplaceAll(systemMessage, "{lastname}", lastnameValue)
// Handle language code placeholder
langValue := languageCode
if languageCode == "" {
langValue = "en"
}
systemMessage = strings.ReplaceAll(systemMessage, "{language}", langValue)
// Handle premium status
premiumStatus := "regular user"
if isPremium {
premiumStatus = "premium user"
}
systemMessage = strings.ReplaceAll(systemMessage, "{premium_status}", premiumStatus)
// Handle time awareness
timeObj := time.Unix(int64(messageTime), 0)
hour := timeObj.Hour()
var timeContext string
if hour >= 5 && hour < 12 {
timeContext = "morning"
} else if hour >= 12 && hour < 18 {
timeContext = "afternoon"
} else if hour >= 18 && hour < 22 {
timeContext = "evening"
} else {
timeContext = "night"
}
systemMessage = strings.ReplaceAll(systemMessage, "{time_context}", timeContext)
// Check that all placeholders were replaced correctly
if !strings.Contains(systemMessage, "username 'testuser'") {
t.Errorf("Username not replaced correctly, got: %s", systemMessage)
}
if !strings.Contains(systemMessage, "display name 'Test User'") {
t.Errorf("Display name not replaced correctly, got: %s", systemMessage)
}
if !strings.Contains(systemMessage, "language preference: 'de'") {
t.Errorf("Language preference not replaced correctly, got: %s", systemMessage)
}
if !strings.Contains(systemMessage, "User is a premium user") {
t.Errorf("Premium status not replaced correctly, got: %s", systemMessage)
}
if !strings.Contains(systemMessage, "It's currently afternoon in your timezone") {
t.Errorf("Time context not replaced correctly, got: %s", systemMessage)
}
}

253
bot.go Executable file → Normal file
View File

@@ -28,6 +28,14 @@ type Bot struct {
botID uint // Reference to BotModel.ID botID uint // Reference to BotModel.ID
} }
// Helper function to determine message type
func messageType(msg *models.Message) string {
if msg.Sticker != nil {
return "sticker"
}
return "text"
}
// NewBot initializes and returns a new Bot instance. // NewBot initializes and returns a new Bot instance.
func NewBot(db *gorm.DB, config BotConfig, clock Clock, tgClient TelegramClient) (*Bot, error) { func NewBot(db *gorm.DB, config BotConfig, clock Clock, tgClient TelegramClient) (*Bot, error) {
// Retrieve or create Bot entry in the database // Retrieve or create Bot entry in the database
@@ -87,6 +95,15 @@ func NewBot(db *gorm.DB, config BotConfig, clock Clock, tgClient TelegramClient)
tgBot: tgClient, tgBot: tgClient,
} }
if tgClient == nil {
var err error
tgClient, err = initTelegramBot(config.TelegramToken, b)
if err != nil {
return nil, fmt.Errorf("failed to initialize Telegram bot: %w", err)
}
b.tgBot = tgClient
}
return b, nil return b, nil
} }
@@ -178,9 +195,10 @@ func (b *Bot) createMessage(chatID, userID int64, username, userRole, text strin
return message return message
} }
func (b *Bot) storeMessage(message Message) error { // storeMessage stores a message in the database and updates its ID
func (b *Bot) storeMessage(message *Message) error {
message.BotID = b.botID // Associate the message with the correct bot message.BotID = b.botID // Associate the message with the correct bot
return b.db.Create(&message).Error return b.db.Create(message).Error // This will update the message with its new ID
} }
func (b *Bot) getOrCreateChatMemory(chatID int64) *ChatMemory { func (b *Bot) getOrCreateChatMemory(chatID int64) *ChatMemory {
@@ -190,14 +208,30 @@ func (b *Bot) getOrCreateChatMemory(chatID int64) *ChatMemory {
if !exists { if !exists {
b.chatMemoriesMu.Lock() b.chatMemoriesMu.Lock()
// Double-check to prevent race condition defer b.chatMemoriesMu.Unlock()
chatMemory, exists = b.chatMemories[chatID] chatMemory, exists = b.chatMemories[chatID]
if !exists { if !exists {
// Check if this is a new chat by querying the database
var count int64
b.db.Model(&Message{}).Where("chat_id = ? AND bot_id = ?", chatID, b.botID).Count(&count)
isNewChat := count == 0 // Truly new chat if no messages exist
var messages []Message var messages []Message
b.db.Where("chat_id = ? AND bot_id = ?", chatID, b.botID). if !isNewChat {
// Fetch existing messages only if it's not a new chat
err := b.db.Where("chat_id = ? AND bot_id = ?", chatID, b.botID).
Order("timestamp asc"). Order("timestamp asc").
Limit(b.memorySize * 2). Limit(b.memorySize * 2).
Find(&messages) Find(&messages).Error
if err != nil {
ErrorLogger.Printf("Error fetching messages from database: %v", err)
messages = []Message{} // Initialize an empty slice on error
}
} else {
messages = []Message{} // Ensure messages is initialized for new chats
}
chatMemory = &ChatMemory{ chatMemory = &ChatMemory{
Messages: messages, Messages: messages,
@@ -206,19 +240,22 @@ func (b *Bot) getOrCreateChatMemory(chatID int64) *ChatMemory {
b.chatMemories[chatID] = chatMemory b.chatMemories[chatID] = chatMemory
} }
b.chatMemoriesMu.Unlock()
} }
return chatMemory return chatMemory
} }
// addMessageToChatMemory adds a new message to the chat memory, ensuring the memory size is maintained.
func (b *Bot) addMessageToChatMemory(chatMemory *ChatMemory, message Message) { func (b *Bot) addMessageToChatMemory(chatMemory *ChatMemory, message Message) {
b.chatMemoriesMu.Lock() b.chatMemoriesMu.Lock()
defer b.chatMemoriesMu.Unlock() defer b.chatMemoriesMu.Unlock()
// Add the new message
chatMemory.Messages = append(chatMemory.Messages, message) chatMemory.Messages = append(chatMemory.Messages, message)
// Maintain the memory size
if len(chatMemory.Messages) > chatMemory.Size { if len(chatMemory.Messages) > chatMemory.Size {
chatMemory.Messages = chatMemory.Messages[2:] chatMemory.Messages = chatMemory.Messages[len(chatMemory.Messages)-chatMemory.Size:]
} }
} }
@@ -226,6 +263,12 @@ func (b *Bot) prepareContextMessages(chatMemory *ChatMemory) []anthropic.Message
b.chatMemoriesMu.RLock() b.chatMemoriesMu.RLock()
defer b.chatMemoriesMu.RUnlock() defer b.chatMemoriesMu.RUnlock()
// Debug logging
InfoLogger.Printf("Chat memory contains %d messages", len(chatMemory.Messages))
for i, msg := range chatMemory.Messages {
InfoLogger.Printf("Message %d: IsUser=%v, Text=%q", i, msg.IsUser, msg.Text)
}
var contextMessages []anthropic.Message var contextMessages []anthropic.Message
for _, msg := range chatMemory.Messages { for _, msg := range chatMemory.Messages {
role := anthropic.RoleUser role := anthropic.RoleUser
@@ -252,7 +295,7 @@ func (b *Bot) prepareContextMessages(chatMemory *ChatMemory) []anthropic.Message
func (b *Bot) isNewChat(chatID int64) bool { func (b *Bot) isNewChat(chatID int64) bool {
var count int64 var count int64
b.db.Model(&Message{}).Where("chat_id = ? AND bot_id = ?", chatID, b.botID).Count(&count) b.db.Model(&Message{}).Where("chat_id = ? AND bot_id = ?", chatID, b.botID).Count(&count)
return count == 1 return count == 0 // Only consider a chat new if it has 0 messages
} }
func (b *Bot) isAdminOrOwner(userID int64) bool { func (b *Bot) isAdminOrOwner(userID int64) bool {
@@ -264,9 +307,9 @@ func (b *Bot) isAdminOrOwner(userID int64) bool {
return user.Role.Name == "admin" || user.Role.Name == "owner" return user.Role.Name == "admin" || user.Role.Name == "owner"
} }
func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot *bot.Bot, update *models.Update)) (TelegramClient, error) { func initTelegramBot(token string, b *Bot) (TelegramClient, error) {
opts := []bot.Option{ opts := []bot.Option{
bot.WithDefaultHandler(handleUpdate), bot.WithDefaultHandler(b.handleUpdate),
} }
tgBot, err := bot.New(token, opts...) tgBot, err := bot.New(token, opts...)
@@ -274,11 +317,40 @@ func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot
return nil, err return nil, err
} }
// Define bot commands
commands := []models.BotCommand{
{
Command: "stats",
Description: "Get bot statistics. Usage: /stats or /stats user [user_id]",
},
{
Command: "whoami",
Description: "Get your user information",
},
{
Command: "clear",
Description: "Clear chat history (soft delete). Admins: /clear [user_id]",
},
{
Command: "clear_hard",
Description: "Clear chat history (permanently delete). Admins: /clear_hard [user_id]",
},
}
// Set bot commands
_, err = tgBot.SetMyCommands(context.Background(), &bot.SetMyCommandsParams{
Commands: commands,
})
if err != nil {
ErrorLogger.Printf("Error setting bot commands: %v", err)
return nil, err
}
return tgBot, nil return tgBot, nil
} }
func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, businessConnectionID string) error { func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, businessConnectionID string) error {
// Pass the outgoing message through the centralized screen for storage // Pass the outgoing message through the centralized screen for storage and chat memory update
_, err := b.screenOutgoingMessage(chatID, text) _, err := b.screenOutgoingMessage(chatID, text)
if err != nil { if err != nil {
ErrorLogger.Printf("Error storing assistant message: %v", err) ErrorLogger.Printf("Error storing assistant message: %v", err)
@@ -306,7 +378,9 @@ func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, busin
} }
// sendStats sends the bot statistics to the specified chat. // sendStats sends the bot statistics to the specified chat.
func (b *Bot) sendStats(ctx context.Context, chatID int64, businessConnectionID string) { func (b *Bot) sendStats(ctx context.Context, chatID int64, userID int64, targetUserID int64, businessConnectionID string) {
// If targetUserID is 0, show global stats
if targetUserID == 0 {
totalUsers, totalMessages, err := b.getStats() totalUsers, totalMessages, err := b.getStats()
if err != nil { if err != nil {
ErrorLogger.Printf("Error fetching stats: %v\n", err) ErrorLogger.Printf("Error fetching stats: %v\n", err)
@@ -329,6 +403,51 @@ func (b *Bot) sendStats(ctx context.Context, chatID int64, businessConnectionID
if err := b.sendResponse(ctx, chatID, statsMessage, businessConnectionID); err != nil { if err := b.sendResponse(ctx, chatID, statsMessage, businessConnectionID); err != nil {
ErrorLogger.Printf("Error sending stats message: %v", err) ErrorLogger.Printf("Error sending stats message: %v", err)
} }
return
}
// If targetUserID is not 0, show user-specific stats
// Check permissions if the user is trying to view someone else's stats
if targetUserID != userID {
if !b.isAdminOrOwner(userID) {
InfoLogger.Printf("User %d attempted to view stats for user %d without permission", userID, targetUserID)
if err := b.sendResponse(ctx, chatID, "Permission denied. Only admins and owners can view other users' statistics.", businessConnectionID); err != nil {
ErrorLogger.Printf("Error sending response: %v", err)
}
return
}
}
// Get user stats
username, messagesIn, messagesOut, totalMessages, err := b.getUserStats(targetUserID)
if err != nil {
ErrorLogger.Printf("Error fetching user stats: %v\n", err)
if err := b.sendResponse(ctx, chatID, fmt.Sprintf("Sorry, I couldn't retrieve statistics for user ID %d.", targetUserID), businessConnectionID); err != nil {
ErrorLogger.Printf("Error sending response: %v", err)
}
return
}
// Build the user stats message
userInfo := fmt.Sprintf("@%s (ID: %d)", username, targetUserID)
if username == "" {
userInfo = fmt.Sprintf("User ID: %d", targetUserID)
}
statsMessage := fmt.Sprintf(
"👤 User Statistics for %s:\n\n"+
"- Messages Sent: %d\n"+
"- Messages Received: %d\n"+
"- Total Messages: %d",
userInfo,
messagesIn,
messagesOut,
totalMessages,
)
if err := b.sendResponse(ctx, chatID, statsMessage, businessConnectionID); err != nil {
ErrorLogger.Printf("Error sending user stats message: %v", err)
}
} }
// getStats retrieves the total number of users and messages from the database. // getStats retrieves the total number of users and messages from the database.
@@ -346,6 +465,35 @@ func (b *Bot) getStats() (int64, int64, error) {
return totalUsers, totalMessages, nil return totalUsers, totalMessages, nil
} }
// getUserStats retrieves statistics for a specific user
func (b *Bot) getUserStats(userID int64) (string, int64, int64, int64, error) {
// Get user information from database
var user User
err := b.db.Where("telegram_id = ? AND bot_id = ?", userID, b.botID).First(&user).Error
if err != nil {
return "", 0, 0, 0, fmt.Errorf("user not found: %w", err)
}
// Count messages sent by the user (IN)
var messagesIn int64
if err := b.db.Model(&Message{}).Where("user_id = ? AND bot_id = ? AND is_user = ?",
userID, b.botID, true).Count(&messagesIn).Error; err != nil {
return "", 0, 0, 0, err
}
// Count responses to the user (OUT)
var messagesOut int64
if err := b.db.Model(&Message{}).Where("chat_id IN (SELECT DISTINCT chat_id FROM messages WHERE user_id = ? AND bot_id = ?) AND bot_id = ? AND is_user = ?",
userID, b.botID, b.botID, false).Count(&messagesOut).Error; err != nil {
return "", 0, 0, 0, err
}
// Total messages is the sum
totalMessages := messagesIn + messagesOut
return user.Username, messagesIn, messagesOut, totalMessages, nil
}
// isOnlyEmojis checks if the string consists solely of emojis. // isOnlyEmojis checks if the string consists solely of emojis.
func isOnlyEmojis(s string) bool { func isOnlyEmojis(s string) bool {
for _, r := range s { for _, r := range s {
@@ -399,41 +547,96 @@ func (b *Bot) sendWhoAmI(ctx context.Context, chatID int64, userID int64, userna
} }
} }
// screenIncomingMessage handles storing of incoming messages. // screenIncomingMessage centralizes all incoming message processing: storing messages and updating chat memory.
func (b *Bot) screenIncomingMessage(message *models.Message) (Message, error) { func (b *Bot) screenIncomingMessage(message *models.Message) (Message, error) {
userRole := string(anthropic.RoleUser) // Convert RoleUser to string if b.config.DebugScreening {
userMessage := b.createMessage(message.Chat.ID, message.From.ID, message.From.Username, userRole, message.Text, true) start := time.Now()
defer func() {
InfoLogger.Printf(
"[Screen] Incoming: chat=%d user=%d type=%s memory_size=%d duration=%v",
message.Chat.ID,
message.From.ID,
messageType(message),
len(b.getOrCreateChatMemory(message.Chat.ID).Messages),
time.Since(start),
)
}()
}
// If the message contains a sticker, include its details. userRole := string(anthropic.RoleUser)
// Determine message text based on message type
messageText := message.Text
if message.Sticker != nil {
if message.Sticker.Emoji != "" {
messageText = fmt.Sprintf("Sent a sticker: %s", message.Sticker.Emoji)
} else {
messageText = "Sent a sticker."
}
}
userMessage := b.createMessage(message.Chat.ID, message.From.ID, message.From.Username, userRole, messageText, true)
// Handle sticker-specific details if present
if message.Sticker != nil { if message.Sticker != nil {
userMessage.StickerFileID = message.Sticker.FileID userMessage.StickerFileID = message.Sticker.FileID
userMessage.StickerEmoji = message.Sticker.Emoji // Store the sticker emoji
if message.Sticker.Thumbnail != nil { if message.Sticker.Thumbnail != nil {
userMessage.StickerPNGFile = message.Sticker.Thumbnail.FileID userMessage.StickerPNGFile = message.Sticker.Thumbnail.FileID
} }
} }
// Store the message. // Get the chat memory before storing the message
if err := b.storeMessage(userMessage); err != nil { chatMemory := b.getOrCreateChatMemory(message.Chat.ID)
// Store the message and get its ID
if err := b.storeMessage(&userMessage); err != nil {
return Message{}, err return Message{}, err
} }
// Update chat memory. // Add the message to the chat memory
chatMemory := b.getOrCreateChatMemory(message.Chat.ID)
b.addMessageToChatMemory(chatMemory, userMessage) b.addMessageToChatMemory(chatMemory, userMessage)
return userMessage, nil return userMessage, nil
} }
// screenOutgoingMessage handles storing of outgoing messages. // screenOutgoingMessage handles storing of outgoing messages and updating chat memory.
// It also marks the most recent unanswered user message as answered.
func (b *Bot) screenOutgoingMessage(chatID int64, response string) (Message, error) { func (b *Bot) screenOutgoingMessage(chatID int64, response string) (Message, error) {
assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false) if b.config.DebugScreening {
start := time.Now()
defer func() {
InfoLogger.Printf(
"[Screen] Outgoing: chat=%d len=%d memory_size=%d duration=%v",
chatID,
len(response),
len(b.getOrCreateChatMemory(chatID).Messages),
time.Since(start),
)
}()
}
// Store the message. // Create and store the assistant message
if err := b.storeMessage(assistantMessage); err != nil { assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false)
if err := b.storeMessage(&assistantMessage); err != nil {
return Message{}, err return Message{}, err
} }
// Update chat memory. // Find and mark the most recent unanswered user message as answered
now := time.Now()
err := b.db.Model(&Message{}).
Where("chat_id = ? AND bot_id = ? AND is_user = ? AND answered_on IS NULL",
chatID, b.botID, true).
Order("timestamp DESC").
Limit(1).
Update("answered_on", now).Error
if err != nil {
ErrorLogger.Printf("Error marking user message as answered: %v", err)
// Continue even if there's an error updating the user message
}
// Update chat memory with the message that now has an ID
chatMemory := b.getOrCreateChatMemory(chatID) chatMemory := b.getOrCreateChatMemory(chatID)
b.addMessageToChatMemory(chatMemory, assistantMessage) b.addMessageToChatMemory(chatMemory, assistantMessage)

0
clock.go Executable file → Normal file
View File

19
config.go Executable file → Normal file
View File

@@ -18,10 +18,12 @@ type BotConfig struct {
MessagePerDay int `json:"messages_per_day"` MessagePerDay int `json:"messages_per_day"`
TempBanDuration string `json:"temp_ban_duration"` TempBanDuration string `json:"temp_ban_duration"`
Model anthropic.Model `json:"model"` Model anthropic.Model `json:"model"`
Temperature *float32 `json:"temperature,omitempty"` // Controls creativity vs determinism (0.0-1.0)
SystemPrompts map[string]string `json:"system_prompts"` SystemPrompts map[string]string `json:"system_prompts"`
Active bool `json:"active"` Active bool `json:"active"`
OwnerTelegramID int64 `json:"owner_telegram_id"` OwnerTelegramID int64 `json:"owner_telegram_id"`
AnthropicAPIKey string `json:"anthropic_api_key"` AnthropicAPIKey string `json:"anthropic_api_key"`
DebugScreening bool `json:"debug_screening"` // Enable detailed screening logs
} }
// Custom unmarshalling to handle anthropic.Model // Custom unmarshalling to handle anthropic.Model
@@ -144,12 +146,15 @@ func validateConfig(config *BotConfig, ids, tokens map[string]bool) error {
func loadConfig(filename string) (BotConfig, error) { func loadConfig(filename string) (BotConfig, error) {
var config BotConfig var config BotConfig
// Use filepath.Clean before opening the file // Use filepath.Clean before opening the file
cleanPath := filepath.Clean(filename) file, err := os.OpenFile(filepath.Clean(filename), os.O_RDONLY, 0)
file, err := os.OpenFile(cleanPath, os.O_RDONLY, 0)
if err != nil { if err != nil {
return config, fmt.Errorf("failed to open config file %s: %w", cleanPath, err) return config, fmt.Errorf("failed to open config file %s: %w", filename, err)
} }
defer file.Close() defer func() {
if err := file.Close(); err != nil {
InfoLogger.Printf("Failed to close config file: %v", err)
}
}()
decoder := json.NewDecoder(file) decoder := json.NewDecoder(file)
if err := decoder.Decode(&config); err != nil { if err := decoder.Decode(&config); err != nil {
@@ -173,7 +178,11 @@ func (c *BotConfig) Reload(configDir, filename string) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to open config file %s: %w", cleanPath, err) return fmt.Errorf("failed to open config file %s: %w", cleanPath, err)
} }
defer file.Close() defer func() {
if err := file.Close(); err != nil {
InfoLogger.Printf("Failed to close config file: %v", err)
}
}()
decoder := json.NewDecoder(file) decoder := json.NewDecoder(file)
if err := decoder.Decode(c); err != nil { if err := decoder.Decode(c); err != nil {

6
config/default.json Executable file → Normal file
View File

@@ -8,10 +8,12 @@
"messages_per_hour": 20, "messages_per_hour": 20,
"messages_per_day": 100, "messages_per_day": 100,
"temp_ban_duration": "24h", "temp_ban_duration": "24h",
"model": "claude-3-5-sonnet-20240620", "model": "claude-3-5-haiku-latest",
"temperature": 0.7,
"debug_screening": false,
"system_prompts": { "system_prompts": {
"default": "You are a helpful assistant.", "default": "You are a helpful assistant.",
"custom_instructions": "Please follow these guidelines:\n- Your name is Atom.\n- If a user asks about buying apples, inform them that we don't sell apples.\n- When asked for a joke, tell a clean, family-friendly joke about programming or technology.\n- If someone inquires about our services, explain that we offer AI-powered chatbot solutions.\n- For any questions about pricing, direct users to contact our sales team at sales@example.com.\n- If asked about your capabilities, be honest about what you can and cannot do.\nAlways maintain a friendly and professional tone.", "custom_instructions": "You are texting through a limited Telegram interface with 15-word maximum. Write like texting a friend - use shorthand, skip grammar, use slang/abbreviations. System cuts off anything longer than 15 words.\n\n- Your name is Atom.\n- The user you're talking to has username '{username}' and display name '{firstname} {lastname}'.\n- User's language preference: '{language}'\n- User is a {premium_status}\n- It's currently {time_context} in your timezone. Use appropriate time-based greetings and address the user by name.\n- If a user asks about buying apples, inform them that we don't sell apples.\n- When asked for a joke, tell a clean, family-friendly joke about programming or technology.\n- If someone inquires about our services, explain that we offer AI-powered chatbot solutions.\n- For any questions about pricing, direct users to contact our sales team at sales@example.com.\n- If asked about your capabilities, be honest about what you can and cannot do.\nAlways maintain a friendly and professional tone.",
"continue_conversation": "Continuing our conversation. Remember previous context if relevant.", "continue_conversation": "Continuing our conversation. Remember previous context if relevant.",
"avoid_sensitive": "Avoid discussing sensitive topics or providing harmful information.", "avoid_sensitive": "Avoid discussing sensitive topics or providing harmful information.",
"respond_with_emojis": "Since the user sent only emojis, respond using emojis only." "respond_with_emojis": "Since the user sent only emojis, respond using emojis only."

115
config_test.go Executable file → Normal file
View File

@@ -10,7 +10,7 @@ import (
"github.com/liushuangls/go-anthropic/v2" "github.com/liushuangls/go-anthropic/v2"
) )
// Add this at the beginning of the file, after the imports // Set up loggers
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
initLoggers() initLoggers()
os.Exit(m.Run()) os.Exit(m.Run())
@@ -26,6 +26,7 @@ func TestBotConfig_UnmarshalJSON(t *testing.T) {
"messages_per_day": 100, "messages_per_day": 100,
"temp_ban_duration": "1h", "temp_ban_duration": "1h",
"model": "claude-v1", "model": "claude-v1",
"temperature": 0.7,
"system_prompts": {"welcome": "Hello!"}, "system_prompts": {"welcome": "Hello!"},
"active": true, "active": true,
"owner_telegram_id": 123456789, "owner_telegram_id": 123456789,
@@ -100,7 +101,11 @@ func TestValidateConfigPath(t *testing.T) {
if err := os.MkdirAll(subDir, 0755); err != nil { if err := os.MkdirAll(subDir, 0755); err != nil {
t.Fatalf("Failed to create subdir: %v", err) t.Fatalf("Failed to create subdir: %v", err)
} }
defer os.RemoveAll(subDir) defer func() {
if err := os.RemoveAll(subDir); err != nil {
t.Errorf("Failed to remove test subdirectory: %v", err)
}
}()
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@@ -124,7 +129,11 @@ func TestLoadConfig(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to create temp dir: %v", err) t.Fatalf("Failed to create temp dir: %v", err)
} }
defer os.RemoveAll(tempDir) defer func() {
if err := os.RemoveAll(tempDir); err != nil {
t.Errorf("Failed to remove temp directory: %v", err)
}
}()
// Valid config JSON // Valid config JSON
validConfig := `{ validConfig := `{
@@ -135,6 +144,7 @@ func TestLoadConfig(t *testing.T) {
"messages_per_day": 100, "messages_per_day": 100,
"temp_ban_duration": "1h", "temp_ban_duration": "1h",
"model": "claude-v1", "model": "claude-v1",
"temperature": 0.7,
"system_prompts": {"welcome": "Hello!"}, "system_prompts": {"welcome": "Hello!"},
"active": true, "active": true,
"owner_telegram_id": 123456789, "owner_telegram_id": 123456789,
@@ -318,7 +328,11 @@ func TestLoadAllConfigs(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to create temp dir: %v", err) t.Fatalf("Failed to create temp dir: %v", err)
} }
defer os.RemoveAll(tempDir) defer func() {
if err := os.RemoveAll(tempDir); err != nil {
t.Errorf("Failed to remove temp directory: %v", err)
}
}()
tests := []struct { tests := []struct {
name string name string
@@ -338,6 +352,7 @@ func TestLoadAllConfigs(t *testing.T) {
"messages_per_day": 100, "messages_per_day": 100,
"temp_ban_duration": "1h", "temp_ban_duration": "1h",
"model": "claude-v1", "model": "claude-v1",
"temperature": 0.7,
"system_prompts": {"welcome": "Hello!"}, "system_prompts": {"welcome": "Hello!"},
"active": true, "active": true,
"owner_telegram_id": 123456789, "owner_telegram_id": 123456789,
@@ -371,6 +386,7 @@ func TestLoadAllConfigs(t *testing.T) {
"messages_per_day": 50, "messages_per_day": 50,
"temp_ban_duration": "30m", "temp_ban_duration": "30m",
"model": "claude-v2", "model": "claude-v2",
"temperature": 0.5,
"system_prompts": {"welcome": "Hi!"}, "system_prompts": {"welcome": "Hi!"},
"active": false, "active": false,
"owner_telegram_id": 987654321, "owner_telegram_id": 987654321,
@@ -404,6 +420,7 @@ func TestLoadAllConfigs(t *testing.T) {
"messages_per_day": 20, "messages_per_day": 20,
"temp_ban_duration": "15m", "temp_ban_duration": "15m",
"model": "claude-v3", "model": "claude-v3",
"temperature": 0.3,
"system_prompts": {"welcome": "Hey!"}, "system_prompts": {"welcome": "Hey!"},
"active": true, "active": true,
"owner_telegram_id": 1122334455, "owner_telegram_id": 1122334455,
@@ -437,6 +454,7 @@ func TestLoadAllConfigs(t *testing.T) {
"messages_per_day": 10, "messages_per_day": 10,
"temp_ban_duration": "5m", "temp_ban_duration": "5m",
"model": "claude-v4", "model": "claude-v4",
"temperature": 0.2,
"system_prompts": {"welcome": "Greetings!"}, "system_prompts": {"welcome": "Greetings!"},
"active": true, "active": true,
"owner_telegram_id": 5566778899, "owner_telegram_id": 5566778899,
@@ -511,7 +529,11 @@ func TestBotConfig_Reload(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to create temp dir: %v", err) t.Fatalf("Failed to create temp dir: %v", err)
} }
defer os.RemoveAll(tempDir) defer func() {
if err := os.RemoveAll(tempDir); err != nil {
t.Errorf("Failed to remove temp directory: %v", err)
}
}()
// Create initial config file // Create initial config file
config1 := `{ config1 := `{
@@ -522,6 +544,7 @@ func TestBotConfig_Reload(t *testing.T) {
"messages_per_day": 100, "messages_per_day": 100,
"temp_ban_duration": "1h", "temp_ban_duration": "1h",
"model": "claude-v1", "model": "claude-v1",
"temperature": 0.7,
"system_prompts": {"welcome": "Hello!"}, "system_prompts": {"welcome": "Hello!"},
"active": true, "active": true,
"owner_telegram_id": 123456789, "owner_telegram_id": 123456789,
@@ -555,6 +578,7 @@ func TestBotConfig_Reload(t *testing.T) {
"messages_per_day": 200, "messages_per_day": 200,
"temp_ban_duration": "2h", "temp_ban_duration": "2h",
"model": "claude-v2", "model": "claude-v2",
"temperature": 0.3,
"system_prompts": {"welcome": "Hi there!"}, "system_prompts": {"welcome": "Hi there!"},
"active": true, "active": true,
"owner_telegram_id": 987654321, "owner_telegram_id": 987654321,
@@ -594,6 +618,7 @@ func TestBotConfig_UnmarshalJSON_Invalid(t *testing.T) {
"messages_per_day": 100, "messages_per_day": 100,
"temp_ban_duration": "1h", "temp_ban_duration": "1h",
"model": "", "model": "",
"temperature": 0.7,
"system_prompts": {"welcome": "Hello!"}, "system_prompts": {"welcome": "Hello!"},
"active": true, "active": true,
"owner_telegram_id": 123456789, "owner_telegram_id": 123456789,
@@ -616,4 +641,84 @@ func contains(s, substr string) bool {
return strings.Contains(s, substr) return strings.Contains(s, substr)
} }
// TestTemperatureConfig tests that the temperature value is correctly loaded
func TestTemperatureConfig(t *testing.T) {
// Create a temporary directory
tempDir, err := os.MkdirTemp("", "temperature_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer func() {
if err := os.RemoveAll(tempDir); err != nil {
t.Errorf("Failed to remove temp directory: %v", err)
}
}()
// Create config with temperature
configWithTemp := `{
"id": "bot123",
"telegram_token": "token123",
"memory_size": 1024,
"messages_per_hour": 10,
"messages_per_day": 100,
"temp_ban_duration": "1h",
"model": "claude-v1",
"temperature": 0.42,
"system_prompts": {"welcome": "Hello!"},
"active": true,
"owner_telegram_id": 123456789,
"anthropic_api_key": "api_key_123"
}`
// Create config without temperature
configWithoutTemp := `{
"id": "bot124",
"telegram_token": "token124",
"memory_size": 1024,
"messages_per_hour": 10,
"messages_per_day": 100,
"temp_ban_duration": "1h",
"model": "claude-v1",
"system_prompts": {"welcome": "Hello!"},
"active": true,
"owner_telegram_id": 123456789,
"anthropic_api_key": "api_key_123"
}`
// Write config files
withTempPath := filepath.Join(tempDir, "with_temp.json")
if err := os.WriteFile(withTempPath, []byte(configWithTemp), 0644); err != nil {
t.Fatalf("Failed to write config with temperature: %v", err)
}
withoutTempPath := filepath.Join(tempDir, "without_temp.json")
if err := os.WriteFile(withoutTempPath, []byte(configWithoutTemp), 0644); err != nil {
t.Fatalf("Failed to write config without temperature: %v", err)
}
// Test loading config with temperature
configWithTempObj, err := loadConfig(withTempPath)
if err != nil {
t.Fatalf("Failed to load config with temperature: %v", err)
}
// Verify temperature is set correctly
if configWithTempObj.Temperature == nil {
t.Errorf("Expected Temperature to be set, got nil")
} else if *configWithTempObj.Temperature != 0.42 {
t.Errorf("Expected Temperature 0.42, got %f", *configWithTempObj.Temperature)
}
// Test loading config without temperature
configWithoutTempObj, err := loadConfig(withoutTempPath)
if err != nil {
t.Fatalf("Failed to load config without temperature: %v", err)
}
// Verify temperature is nil when not specified
if configWithoutTempObj.Temperature != nil {
t.Errorf("Expected Temperature to be nil, got %f", *configWithoutTempObj.Temperature)
}
}
// Additional tests can be added here to cover more scenarios // Additional tests can be added here to cover more scenarios

2
database.go Executable file → Normal file
View File

@@ -20,7 +20,7 @@ func initDB() (*gorm.DB, error) {
}, },
) )
db, err := gorm.Open(sqlite.Open("bot.db"), &gorm.Config{ db, err := gorm.Open(sqlite.Open("data/bot.db"), &gorm.Config{
Logger: newLogger, Logger: newLogger,
}) })
if err != nil { if err != nil {

39
docker-compose.yml Normal file
View File

@@ -0,0 +1,39 @@
services:
telegram-bot:
image: bogerserge/go-telegram-bot:latest
build:
context: .
dockerfile: Dockerfile
platforms:
- linux/amd64
- linux/arm64
container_name: go-telegram-bot
restart: unless-stopped
# Optional: Environment variables (can be overridden with .env file)
# environment:
# - BOT_LOG_LEVEL=info
# Volume mounts
volumes:
# Bind mount config directory for live configuration updates
- ./config:/app/config:ro
# Named volume for persistent database storage
- ./data:/app/data
# Optional: Bind mount for log access (uncomment if needed)
# - ./logs:/app/logs
# Health check
healthcheck:
test: ["CMD", "pgrep", "telegram-bot"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
# Logging configuration
logging:
driver: "json-file"
options:
max-size: "10m"
max-file: "3"

0
examples/systemd/telegram-bot.service Executable file → Normal file
View File

21
go.mod Executable file → Normal file
View File

@@ -1,18 +1,23 @@
module github.com/HugeFrog24/go-telegram-bot module github.com/HugeFrog24/go-telegram-bot
go 1.23 go 1.24.2
require ( require (
github.com/go-telegram/bot v1.9.1 github.com/go-telegram/bot v1.18.0
github.com/liushuangls/go-anthropic/v2 v2.8.2 github.com/liushuangls/go-anthropic/v2 v2.17.0
golang.org/x/time v0.7.0 github.com/stretchr/testify v1.11.1
gorm.io/driver/sqlite v1.5.6 golang.org/x/time v0.14.0
gorm.io/gorm v1.25.12 gorm.io/driver/sqlite v1.6.0
gorm.io/gorm v1.31.1
) )
require ( require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
github.com/mattn/go-sqlite3 v1.14.24 // indirect github.com/mattn/go-sqlite3 v1.14.34 // indirect
golang.org/x/text v0.19.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/objx v0.5.3 // indirect
golang.org/x/text v0.34.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
) )

54
go.sum Executable file → Normal file
View File

@@ -1,18 +1,44 @@
github.com/go-telegram/bot v1.9.1 h1:4vkNV6vDmEPZaYP7sZYaagOaJyV4GerfOPkjg/Ki5ic= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/go-telegram/bot v1.9.1/go.mod h1:i2TRs7fXWIeaceF3z7KzsMt/he0TwkVC680mvdTFYeM= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-telegram/bot v1.17.0 h1:Hs0kGxSj97QFqOQP0zxduY/4tSx8QDzvNI9uVRS+zmY=
github.com/go-telegram/bot v1.17.0/go.mod h1:i2TRs7fXWIeaceF3z7KzsMt/he0TwkVC680mvdTFYeM=
github.com/go-telegram/bot v1.18.0 h1:yQzv437DY42SYTPBY48RinAvwbmf1ox5QICskIYWCD8=
github.com/go-telegram/bot v1.18.0/go.mod h1:i2TRs7fXWIeaceF3z7KzsMt/he0TwkVC680mvdTFYeM=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/liushuangls/go-anthropic/v2 v2.8.2 h1:PbR9oQF3JDnU/hmbbQI+3tkCqNtdaw4K6S0YfzByl9I= github.com/liushuangls/go-anthropic/v2 v2.16.2 h1:eK2tdDTKlMiHEdTKhbSUf11dgY0K//PulXDFAj2EeHQ=
github.com/liushuangls/go-anthropic/v2 v2.8.2/go.mod h1:8BKv/fkeTaL5R9R9bGkaknYBueyw2WxY20o7bImbOek= github.com/liushuangls/go-anthropic/v2 v2.16.2/go.mod h1:a550cJXPoTG2FL3DvfKG2zzD5O2vjgvo4tHtoGPzFLU=
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= github.com/liushuangls/go-anthropic/v2 v2.17.0 h1:iBA6h7aghi1q86owEQ95XE2R2MF/0dQ7bCxtwTxOg4c=
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/liushuangls/go-anthropic/v2 v2.17.0/go.mod h1:a550cJXPoTG2FL3DvfKG2zzD5O2vjgvo4tHtoGPzFLU=
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE= github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk=
gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.5.3 h1:jmXUvGomnU1o3W/V5h2VEradbpJDwGrzugQQvL0POH4=
github.com/stretchr/objx v0.5.3/go.mod h1:rDQraq+vQZU7Fde9LOZLr8Tax6zZvy4kuNKF+QYS+U0=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
gorm.io/gorm v1.31.0 h1:0VlycGreVhK7RF/Bwt51Fk8v0xLiiiFdbGDPIZQ7mJY=
gorm.io/gorm v1.31.0/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=

286
handlers.go Executable file → Normal file
View File

@@ -2,6 +2,8 @@ package main
import ( import (
"context" "context"
"fmt"
"strconv"
"strings" "strings"
"github.com/go-telegram/bot" "github.com/go-telegram/bot"
@@ -32,15 +34,63 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
chatID := message.Chat.ID chatID := message.Chat.ID
userID := message.From.ID userID := message.From.ID
username := message.From.Username username := message.From.Username
firstName := message.From.FirstName
lastName := message.From.LastName
languageCode := message.From.LanguageCode
isPremium := message.From.IsPremium
messageTime := message.Date
text := message.Text text := message.Text
// Pass the incoming message through the centralized screen for storage // Check if it's a new chat
_, err := b.screenIncomingMessage(message) isNewChatFlag := b.isNewChat(chatID)
if err != nil {
// Screen incoming message
if _, err := b.screenIncomingMessage(message); err != nil {
ErrorLogger.Printf("Error storing user message: %v", err) ErrorLogger.Printf("Error storing user message: %v", err)
return return
} }
// Determine if the user is the owner
var isOwner bool
err := b.db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", userID, b.botID, true).First(&User{}).Error
if err == nil {
isOwner = true
}
// Get the chat memory which now contains the user's message
chatMemory := b.getOrCreateChatMemory(chatID)
contextMessages := b.prepareContextMessages(chatMemory)
if isNewChatFlag {
// Get response from Anthropic using the context messages
response, err := b.getAnthropicResponse(ctx, contextMessages, true, isOwner, false, username, firstName, lastName, isPremium, languageCode, messageTime)
if err != nil {
ErrorLogger.Printf("Error getting Anthropic response: %v", err)
// Use the same error message as in the non-new chat case
response = "I'm sorry, I'm having trouble processing your request right now."
}
// Send the AI-generated response or error message
if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil {
ErrorLogger.Printf("Error sending response: %v", err)
return
}
} else {
user, err := b.getOrCreateUser(userID, username, isOwner)
if err != nil {
ErrorLogger.Printf("Error getting or creating user: %v", err)
return
}
// Update the username if it's empty or has changed
if user.Username != username {
user.Username = username
if err := b.db.Save(&user).Error; err != nil {
ErrorLogger.Printf("Error updating user username: %v", err)
}
}
// Check if the message is a command // Check if the message is a command
if message.Entities != nil { if message.Entities != nil {
for _, entity := range message.Entities { for _, entity := range message.Entities {
@@ -48,11 +98,82 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
command := strings.TrimSpace(message.Text[entity.Offset : entity.Offset+entity.Length]) command := strings.TrimSpace(message.Text[entity.Offset : entity.Offset+entity.Length])
switch command { switch command {
case "/stats": case "/stats":
b.sendStats(ctx, chatID, businessConnectionID) // Parse command parameters
parts := strings.Fields(message.Text)
// Default: show global stats
if len(parts) == 1 {
b.sendStats(ctx, chatID, userID, 0, businessConnectionID)
return
}
// Check for "user" parameter
if len(parts) >= 2 && parts[1] == "user" {
targetUserID := userID // Default to current user
// If a user ID is provided, parse it
if len(parts) >= 3 {
var parseErr error
targetUserID, parseErr = strconv.ParseInt(parts[2], 10, 64)
if parseErr != nil {
InfoLogger.Printf("User %d provided invalid user ID format: %s", userID, parts[2])
if err := b.sendResponse(ctx, chatID, "Invalid user ID format. Usage: /stats user [user_id]", businessConnectionID); err != nil {
ErrorLogger.Printf("Error sending response: %v", err)
}
return
}
}
b.sendStats(ctx, chatID, userID, targetUserID, businessConnectionID)
return
}
// Invalid parameter
if err := b.sendResponse(ctx, chatID, "Invalid command format. Usage: /stats or /stats user [user_id]", businessConnectionID); err != nil {
ErrorLogger.Printf("Error sending response: %v", err)
}
return return
case "/whoami": case "/whoami":
b.sendWhoAmI(ctx, chatID, userID, username, businessConnectionID) b.sendWhoAmI(ctx, chatID, userID, username, businessConnectionID)
return return
case "/clear":
// Extract optional user ID parameter
parts := strings.Fields(message.Text)
var targetUserID int64 = 0
if len(parts) > 1 {
// Parse the user ID
var parseErr error
targetUserID, parseErr = strconv.ParseInt(parts[1], 10, 64)
if parseErr != nil {
// Invalid user ID format
InfoLogger.Printf("User %d provided invalid user ID format: %s", userID, parts[1])
if err := b.sendResponse(ctx, chatID, "Invalid user ID format. Usage: /clear [user_id]", businessConnectionID); err != nil {
ErrorLogger.Printf("Error sending response: %v", err)
}
return
}
}
b.clearChatHistory(ctx, chatID, userID, targetUserID, businessConnectionID, false)
return
case "/clear_hard":
// Extract optional user ID parameter
parts := strings.Fields(message.Text)
var targetUserID int64 = 0
if len(parts) > 1 {
// Parse the user ID
var parseErr error
targetUserID, parseErr = strconv.ParseInt(parts[1], 10, 64)
if parseErr != nil {
// Invalid user ID format
InfoLogger.Printf("User %d provided invalid user ID format: %s", userID, parts[1])
if err := b.sendResponse(ctx, chatID, "Invalid user ID format. Usage: /clear_hard [user_id]", businessConnectionID); err != nil {
ErrorLogger.Printf("Error sending response: %v", err)
}
return
}
}
b.clearChatHistory(ctx, chatID, userID, targetUserID, businessConnectionID, true)
return
} }
} }
} }
@@ -60,7 +181,7 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
// Check if the message contains a sticker // Check if the message contains a sticker
if message.Sticker != nil { if message.Sticker != nil {
b.handleStickerMessage(ctx, chatID, userID, message, businessConnectionID) b.handleStickerMessage(ctx, chatID, message, businessConnectionID)
return return
} }
@@ -76,48 +197,27 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
return return
} }
// Determine if the user is the owner
var isOwner bool
err = b.db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", userID, b.botID, true).First(&User{}).Error
if err == nil {
isOwner = true
}
user, err := b.getOrCreateUser(userID, username, isOwner)
if err != nil {
ErrorLogger.Printf("Error getting or creating user: %v", err)
return
}
// Update the username if it's empty or has changed
if user.Username != username {
user.Username = username
if err := b.db.Save(&user).Error; err != nil {
ErrorLogger.Printf("Error updating user username: %v", err)
}
}
// Determine if the text contains only emojis // Determine if the text contains only emojis
isEmojiOnly := isOnlyEmojis(text) isEmojiOnly := isOnlyEmojis(text)
// Prepare context messages for Anthropic // Prepare context messages for Anthropic
chatMemory := b.getOrCreateChatMemory(chatID) chatMemory := b.getOrCreateChatMemory(chatID)
b.addMessageToChatMemory(chatMemory, b.createMessage(chatID, userID, username, user.Role.Name, text, true))
contextMessages := b.prepareContextMessages(chatMemory) contextMessages := b.prepareContextMessages(chatMemory)
// Get response from Anthropic // Get response from Anthropic
response, err := b.getAnthropicResponse(ctx, contextMessages, b.isNewChat(chatID), isOwner, isEmojiOnly) response, err := b.getAnthropicResponse(ctx, contextMessages, false, isOwner, isEmojiOnly, username, firstName, lastName, isPremium, languageCode, messageTime) // isNewChat is false here
if err != nil { if err != nil {
ErrorLogger.Printf("Error getting Anthropic response: %v", err) ErrorLogger.Printf("Error getting Anthropic response: %v", err)
response = "I'm sorry, I'm having trouble processing your request right now." response = "I'm sorry, I'm having trouble processing your request right now."
} }
// Send the response through the centralized screen // Send the response
if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil { if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil {
ErrorLogger.Printf("Error sending response: %v", err) ErrorLogger.Printf("Error sending response: %v", err)
return return
} }
} }
}
func (b *Bot) sendRateLimitExceededMessage(ctx context.Context, chatID int64, businessConnectionID string) { func (b *Bot) sendRateLimitExceededMessage(ctx context.Context, chatID int64, businessConnectionID string) {
if err := b.sendResponse(ctx, chatID, "Rate limit exceeded. Please try again later.", businessConnectionID); err != nil { if err := b.sendResponse(ctx, chatID, "Rate limit exceeded. Please try again later.", businessConnectionID); err != nil {
@@ -125,22 +225,14 @@ func (b *Bot) sendRateLimitExceededMessage(ctx context.Context, chatID int64, bu
} }
} }
func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, message *models.Message, businessConnectionID string) { func (b *Bot) handleStickerMessage(ctx context.Context, chatID int64, message *models.Message, businessConnectionID string) {
username := message.From.Username // Process sticker through centralized screening
userMessage, err := b.screenIncomingMessage(message)
// Create the user message (without storing it manually) if err != nil {
userMessage := b.createMessage(chatID, userID, username, "user", "Sent a sticker.", true) ErrorLogger.Printf("Error processing sticker message: %v", err)
userMessage.StickerFileID = message.Sticker.FileID return
// Safely store the Thumbnail's FileID if available
if message.Sticker.Thumbnail != nil {
userMessage.StickerPNGFile = message.Sticker.Thumbnail.FileID
} }
// Update chat memory with the user message
chatMemory := b.getOrCreateChatMemory(chatID)
b.addMessageToChatMemory(chatMemory, userMessage)
// Generate AI response about the sticker // Generate AI response about the sticker
response, err := b.generateStickerResponse(ctx, userMessage) response, err := b.generateStickerResponse(ctx, userMessage)
if err != nil { if err != nil {
@@ -155,7 +247,7 @@ func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, me
} }
} }
// Send the response through the centralized screen // Send the response
if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil { if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil {
ErrorLogger.Printf("Error sending response: %v", err) ErrorLogger.Printf("Error sending response: %v", err)
return return
@@ -165,18 +257,28 @@ func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, me
func (b *Bot) generateStickerResponse(ctx context.Context, message Message) (string, error) { func (b *Bot) generateStickerResponse(ctx context.Context, message Message) (string, error) {
// Example: Use the sticker type to generate a response // Example: Use the sticker type to generate a response
if message.StickerFileID != "" { if message.StickerFileID != "" {
// Create message content with emoji information if available
var messageContent string
if message.StickerEmoji != "" {
messageContent = fmt.Sprintf("User sent a sticker: %s", message.StickerEmoji)
} else {
messageContent = "User sent a sticker."
}
// Prepare context with information about the sticker // Prepare context with information about the sticker
contextMessages := []anthropic.Message{ contextMessages := []anthropic.Message{
{ {
Role: anthropic.RoleUser, Role: anthropic.RoleUser,
Content: []anthropic.MessageContent{ Content: []anthropic.MessageContent{
anthropic.NewTextMessageContent("User sent a sticker."), anthropic.NewTextMessageContent(messageContent),
}, },
}, },
} }
// Since this is a sticker message, isEmojiOnly is false // Treat sticker messages like emoji messages to get emoji responses
response, err := b.getAnthropicResponse(ctx, contextMessages, false, false, false) // Convert the timestamp to Unix time for the messageTime parameter
messageTime := int(message.Timestamp.Unix())
response, err := b.getAnthropicResponse(ctx, contextMessages, false, false, true, message.Username, "", "", false, "", messageTime)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -186,3 +288,93 @@ func (b *Bot) generateStickerResponse(ctx context.Context, message Message) (str
return "Hmm, that's interesting!", nil return "Hmm, that's interesting!", nil
} }
func (b *Bot) clearChatHistory(ctx context.Context, chatID int64, currentUserID int64, targetUserID int64, businessConnectionID string, hardDelete bool) {
// If targetUserID is provided and different from currentUserID, check permissions
if targetUserID != 0 && targetUserID != currentUserID {
// Check if the current user is an admin or owner
if !b.isAdminOrOwner(currentUserID) {
InfoLogger.Printf("User %d attempted to clear history for user %d without permission", currentUserID, targetUserID)
if err := b.sendResponse(ctx, chatID, "Permission denied. Only admins and owners can clear other users' histories.", businessConnectionID); err != nil {
ErrorLogger.Printf("Error sending response: %v", err)
}
return
}
// Check if the target user exists
var targetUser User
err := b.db.Where("telegram_id = ? AND bot_id = ?", targetUserID, b.botID).First(&targetUser).Error
if err != nil {
ErrorLogger.Printf("Error finding target user %d: %v", targetUserID, err)
if err := b.sendResponse(ctx, chatID, fmt.Sprintf("User with ID %d not found.", targetUserID), businessConnectionID); err != nil {
ErrorLogger.Printf("Error sending response: %v", err)
}
return
}
} else {
// If no targetUserID is provided, set it to currentUserID
targetUserID = currentUserID
}
// Delete messages from the database
var err error
if hardDelete {
// Permanently delete messages
if targetUserID == currentUserID {
// Deleting own messages
err = b.db.Unscoped().Where("chat_id = ? AND bot_id = ? AND user_id = ?", chatID, b.botID, targetUserID).Delete(&Message{}).Error
InfoLogger.Printf("User %d permanently deleted their own chat history in chat %d", currentUserID, chatID)
} else {
// Deleting another user's messages
err = b.db.Unscoped().Where("chat_id = ? AND bot_id = ? AND user_id = ?", chatID, b.botID, targetUserID).Delete(&Message{}).Error
InfoLogger.Printf("Admin/owner %d permanently deleted chat history for user %d in chat %d", currentUserID, targetUserID, chatID)
}
} else {
// Soft delete messages
if targetUserID == currentUserID {
// Deleting own messages
err = b.db.Where("chat_id = ? AND bot_id = ? AND user_id = ?", chatID, b.botID, targetUserID).Delete(&Message{}).Error
InfoLogger.Printf("User %d soft deleted their own chat history in chat %d", currentUserID, chatID)
} else {
// Deleting another user's messages
err = b.db.Where("chat_id = ? AND bot_id = ? AND user_id = ?", chatID, b.botID, targetUserID).Delete(&Message{}).Error
InfoLogger.Printf("Admin/owner %d soft deleted chat history for user %d in chat %d", currentUserID, targetUserID, chatID)
}
}
if err != nil {
ErrorLogger.Printf("Error clearing chat history: %v", err)
if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't clear the chat history.", businessConnectionID); err != nil {
ErrorLogger.Printf("Error sending response: %v", err)
}
return
}
// Reset the chat memory if clearing own history
if targetUserID == currentUserID {
chatMemory := b.getOrCreateChatMemory(chatID)
chatMemory.Messages = []Message{} // Clear the messages
b.chatMemoriesMu.Lock()
b.chatMemories[chatID] = chatMemory
b.chatMemoriesMu.Unlock()
}
// Send a confirmation message
var confirmationMessage string
if targetUserID == currentUserID {
confirmationMessage = "Your chat history has been cleared."
} else {
// Get the username of the target user if available
var targetUser User
err := b.db.Where("telegram_id = ? AND bot_id = ?", targetUserID, b.botID).First(&targetUser).Error
if err == nil && targetUser.Username != "" {
confirmationMessage = fmt.Sprintf("Chat history for user @%s (ID: %d) has been cleared.", targetUser.Username, targetUserID)
} else {
confirmationMessage = fmt.Sprintf("Chat history for user with ID %d has been cleared.", targetUserID)
}
}
if err := b.sendResponse(ctx, chatID, confirmationMessage, businessConnectionID); err != nil {
ErrorLogger.Printf("Error sending response: %v", err)
}
}

609
handlers_test.go Normal file
View File

@@ -0,0 +1,609 @@
package main
import (
"context"
"testing"
"time"
"github.com/go-telegram/bot"
"github.com/go-telegram/bot/models"
"github.com/stretchr/testify/assert"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
func TestHandleUpdate_NewChat(t *testing.T) {
// Setup
db := setupTestDB(t)
mockClock := &MockClock{
currentTime: time.Now(),
}
config := BotConfig{
ID: "test_bot",
OwnerTelegramID: 123, // owner's ID
TelegramToken: "test_token",
MemorySize: 10,
MessagePerHour: 5,
MessagePerDay: 10,
TempBanDuration: "1h",
SystemPrompts: make(map[string]string),
Active: true,
}
mockTgClient := &MockTelegramClient{}
// Create bot model first
botModel := &BotModel{
Identifier: config.ID,
Name: config.ID,
}
err := db.Create(botModel).Error
assert.NoError(t, err)
// Create bot config
configModel := &ConfigModel{
BotID: botModel.ID,
MemorySize: config.MemorySize,
MessagePerHour: config.MessagePerHour,
MessagePerDay: config.MessagePerDay,
TempBanDuration: config.TempBanDuration,
SystemPrompts: "{}",
TelegramToken: config.TelegramToken,
Active: config.Active,
}
err = db.Create(configModel).Error
assert.NoError(t, err)
// Create bot instance
b, err := NewBot(db, config, mockClock, mockTgClient)
assert.NoError(t, err)
testCases := []struct {
name string
userID int64
isOwner bool
wantResp string
}{
{
name: "Owner First Message",
userID: 123, // owner's ID
isOwner: true,
wantResp: "I'm sorry, I'm having trouble processing your request right now.",
},
{
name: "Regular User First Message",
userID: 456,
isOwner: false,
wantResp: "I'm sorry, I'm having trouble processing your request right now.",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Setup mock response expectations for error case to test fallback messages
mockTgClient.SendMessageFunc = func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) {
assert.Equal(t, tc.userID, params.ChatID)
assert.Equal(t, tc.wantResp, params.Text)
return &models.Message{}, nil
}
// Create update with new message
update := &models.Update{
Message: &models.Message{
Chat: models.Chat{ID: tc.userID},
From: &models.User{
ID: tc.userID,
Username: "testuser",
},
Text: "Hello",
},
}
// Handle the update
b.handleUpdate(context.Background(), nil, update)
// Verify message was stored
var storedMsg Message
err := db.Where("chat_id = ? AND user_id = ? AND text = ?", tc.userID, tc.userID, "Hello").First(&storedMsg).Error
assert.NoError(t, err)
// Verify response was stored
var respMsg Message
err = db.Where("chat_id = ? AND is_user = ? AND text = ?", tc.userID, false, tc.wantResp).First(&respMsg).Error
assert.NoError(t, err)
})
}
}
func TestClearChatHistory(t *testing.T) {
// Setup
db := setupTestDB(t)
mockClock := &MockClock{
currentTime: time.Now(),
}
config := BotConfig{
ID: "test_bot",
OwnerTelegramID: 123, // owner's ID
TelegramToken: "test_token",
MemorySize: 10,
MessagePerHour: 5,
MessagePerDay: 10,
TempBanDuration: "1h",
SystemPrompts: make(map[string]string),
Active: true,
}
mockTgClient := &MockTelegramClient{}
// Create bot model first
botModel := &BotModel{
Identifier: config.ID,
Name: config.ID,
}
err := db.Create(botModel).Error
assert.NoError(t, err)
// Create bot config
configModel := &ConfigModel{
BotID: botModel.ID,
MemorySize: config.MemorySize,
MessagePerHour: config.MessagePerHour,
MessagePerDay: config.MessagePerDay,
TempBanDuration: config.TempBanDuration,
SystemPrompts: "{}",
TelegramToken: config.TelegramToken,
Active: config.Active,
}
err = db.Create(configModel).Error
assert.NoError(t, err)
// Create bot instance
b, err := NewBot(db, config, mockClock, mockTgClient)
assert.NoError(t, err)
// Create test users
ownerID := int64(123)
adminID := int64(456)
regularUserID := int64(789)
nonExistentUserID := int64(999)
chatID := int64(1000)
// Create admin role
adminRole, err := b.getRoleByName("admin")
assert.NoError(t, err)
// Create admin user
adminUser := User{
BotID: b.botID,
TelegramID: adminID,
Username: "admin",
RoleID: adminRole.ID,
Role: adminRole,
IsOwner: false,
}
err = db.Create(&adminUser).Error
assert.NoError(t, err)
// Create regular user
regularRole, err := b.getRoleByName("user")
assert.NoError(t, err)
regularUser := User{
BotID: b.botID,
TelegramID: regularUserID,
Username: "regular",
RoleID: regularRole.ID,
Role: regularRole,
IsOwner: false,
}
err = db.Create(&regularUser).Error
assert.NoError(t, err)
// Create test messages for each user
for _, userID := range []int64{ownerID, adminID, regularUserID} {
for i := 0; i < 5; i++ {
message := Message{
BotID: b.botID,
ChatID: chatID,
UserID: userID,
Username: "test",
UserRole: "user",
Text: "Test message",
Timestamp: time.Now(),
IsUser: true,
}
err = db.Create(&message).Error
assert.NoError(t, err)
}
}
// Test cases
testCases := []struct {
name string
currentUserID int64
targetUserID int64
hardDelete bool
expectedError bool
expectedCount int64
expectedMsg string
businessConnID string
}{
{
name: "Owner clears own history",
currentUserID: ownerID,
targetUserID: ownerID,
hardDelete: false,
expectedError: false,
expectedCount: 0,
expectedMsg: "Your chat history has been cleared.",
},
{
name: "Admin clears own history",
currentUserID: adminID,
targetUserID: adminID,
hardDelete: false,
expectedError: false,
expectedCount: 0,
expectedMsg: "Your chat history has been cleared.",
},
{
name: "Regular user clears own history",
currentUserID: regularUserID,
targetUserID: regularUserID,
hardDelete: false,
expectedError: false,
expectedCount: 0,
expectedMsg: "Your chat history has been cleared.",
},
{
name: "Owner clears admin's history",
currentUserID: ownerID,
targetUserID: adminID,
hardDelete: false,
expectedError: false,
expectedCount: 0,
expectedMsg: "Chat history for user @admin (ID: 456) has been cleared.",
},
{
name: "Admin clears regular user's history",
currentUserID: adminID,
targetUserID: regularUserID,
hardDelete: false,
expectedError: false,
expectedCount: 0,
expectedMsg: "Chat history for user @regular (ID: 789) has been cleared.",
},
{
name: "Regular user attempts to clear admin's history",
currentUserID: regularUserID,
targetUserID: adminID,
hardDelete: false,
expectedError: true,
expectedCount: 5, // Messages should remain
expectedMsg: "Permission denied. Only admins and owners can clear other users' histories.",
},
{
name: "Admin attempts to clear non-existent user's history",
currentUserID: adminID,
targetUserID: nonExistentUserID,
hardDelete: false,
expectedError: true,
expectedCount: 5, // Messages should remain for admin
expectedMsg: "User with ID 999 not found.",
},
{
name: "Owner hard deletes regular user's history",
currentUserID: ownerID,
targetUserID: regularUserID,
hardDelete: true,
expectedError: false,
expectedCount: 0,
expectedMsg: "Chat history for user @regular (ID: 789) has been cleared.",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Reset messages for the test case
if tc.name != "Owner hard deletes regular user's history" {
// Delete all messages for the target user
err = db.Where("user_id = ?", tc.targetUserID).Delete(&Message{}).Error
assert.NoError(t, err)
// Recreate messages for the target user
for i := 0; i < 5; i++ {
message := Message{
BotID: b.botID,
ChatID: chatID,
UserID: tc.targetUserID,
Username: "test",
UserRole: "user",
Text: "Test message",
Timestamp: time.Now(),
IsUser: true,
}
err = db.Create(&message).Error
assert.NoError(t, err)
}
}
// Setup mock response expectations
var sentMessage string
mockTgClient.SendMessageFunc = func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) {
sentMessage = params.Text
return &models.Message{}, nil
}
// Call the clearChatHistory method
b.clearChatHistory(context.Background(), chatID, tc.currentUserID, tc.targetUserID, tc.businessConnID, tc.hardDelete)
// Verify the response message
assert.Equal(t, tc.expectedMsg, sentMessage)
// Count remaining messages for the target user
var count int64
if tc.hardDelete {
db.Unscoped().Model(&Message{}).Where("user_id = ? AND chat_id = ?", tc.targetUserID, chatID).Count(&count)
} else {
db.Model(&Message{}).Where("user_id = ? AND chat_id = ?", tc.targetUserID, chatID).Count(&count)
}
assert.Equal(t, tc.expectedCount, count)
})
}
}
func TestStatsCommand(t *testing.T) {
// Setup
db := setupTestDB(t)
mockClock := &MockClock{
currentTime: time.Now(),
}
config := BotConfig{
ID: "test_bot",
OwnerTelegramID: 123, // owner's ID
TelegramToken: "test_token",
MemorySize: 10,
MessagePerHour: 5,
MessagePerDay: 10,
TempBanDuration: "1h",
SystemPrompts: make(map[string]string),
Active: true,
}
mockTgClient := &MockTelegramClient{}
// Create bot model first
botModel := &BotModel{
Identifier: config.ID,
Name: config.ID,
}
err := db.Create(botModel).Error
assert.NoError(t, err)
// Create bot config
configModel := &ConfigModel{
BotID: botModel.ID,
MemorySize: config.MemorySize,
MessagePerHour: config.MessagePerHour,
MessagePerDay: config.MessagePerDay,
TempBanDuration: config.TempBanDuration,
SystemPrompts: "{}",
TelegramToken: config.TelegramToken,
Active: config.Active,
}
err = db.Create(configModel).Error
assert.NoError(t, err)
// Create bot instance
b, err := NewBot(db, config, mockClock, mockTgClient)
assert.NoError(t, err)
// Create test users
ownerID := int64(123)
adminID := int64(456)
regularUserID := int64(789)
chatID := int64(1000)
// Create admin role
adminRole, err := b.getRoleByName("admin")
assert.NoError(t, err)
// Create admin user
adminUser := User{
BotID: b.botID,
TelegramID: adminID,
Username: "admin",
RoleID: adminRole.ID,
Role: adminRole,
IsOwner: false,
}
err = db.Create(&adminUser).Error
assert.NoError(t, err)
// Create regular user
regularRole, err := b.getRoleByName("user")
assert.NoError(t, err)
regularUser := User{
BotID: b.botID,
TelegramID: regularUserID,
Username: "regular",
RoleID: regularRole.ID,
Role: regularRole,
IsOwner: false,
}
err = db.Create(&regularUser).Error
assert.NoError(t, err)
// Create test messages for each user
for _, userID := range []int64{ownerID, adminID, regularUserID} {
for i := 0; i < 5; i++ {
// User message
userMessage := Message{
BotID: b.botID,
ChatID: chatID,
UserID: userID,
Username: "test",
UserRole: "user",
Text: "Test message",
Timestamp: time.Now(),
IsUser: true,
}
err = db.Create(&userMessage).Error
assert.NoError(t, err)
// Bot response
botMessage := Message{
BotID: b.botID,
ChatID: chatID,
UserID: 0,
Username: "AI Assistant",
UserRole: "assistant",
Text: "Test response",
Timestamp: time.Now(),
IsUser: false,
}
err = db.Create(&botMessage).Error
assert.NoError(t, err)
}
}
// Test cases
testCases := []struct {
name string
command string
currentUserID int64
expectedError bool
expectedMsg string
businessConnID string
}{
{
name: "Global stats",
command: "/stats",
currentUserID: regularUserID,
expectedError: false,
expectedMsg: "📊 Bot Statistics:",
},
{
name: "User requests own stats",
command: "/stats user",
currentUserID: regularUserID,
expectedError: false,
expectedMsg: "👤 User Statistics for @regular (ID: 789):",
},
{
name: "Admin requests another user's stats",
command: "/stats user 789",
currentUserID: adminID,
expectedError: false,
expectedMsg: "👤 User Statistics for @regular (ID: 789):",
},
{
name: "Owner requests another user's stats",
command: "/stats user 456",
currentUserID: ownerID,
expectedError: false,
expectedMsg: "👤 User Statistics for @admin (ID: 456):",
},
{
name: "Regular user attempts to request another user's stats",
command: "/stats user 456",
currentUserID: regularUserID,
expectedError: true,
expectedMsg: "Permission denied. Only admins and owners can view other users' statistics.",
},
{
name: "User provides invalid user ID format",
command: "/stats user abc",
currentUserID: adminID,
expectedError: true,
expectedMsg: "Invalid user ID format. Usage: /stats user [user_id]",
},
{
name: "User provides invalid command format",
command: "/stats invalid",
currentUserID: adminID,
expectedError: true,
expectedMsg: "Invalid command format. Usage: /stats or /stats user [user_id]",
},
{
name: "User requests non-existent user's stats",
command: "/stats user 999",
currentUserID: adminID,
expectedError: true,
expectedMsg: "Sorry, I couldn't retrieve statistics for user ID 999.",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Setup mock response expectations
var sentMessage string
mockTgClient.SendMessageFunc = func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) {
sentMessage = params.Text
return &models.Message{}, nil
}
// Create update with command
update := &models.Update{
Message: &models.Message{
Chat: models.Chat{ID: chatID},
From: &models.User{
ID: tc.currentUserID,
Username: getUsernameByID(tc.currentUserID),
},
Text: tc.command,
Entities: []models.MessageEntity{
{
Type: "bot_command",
Offset: 0,
Length: 6, // Length of "/stats"
},
},
},
}
// Handle the update
b.handleUpdate(context.Background(), nil, update)
// Verify the response message contains the expected text
assert.Contains(t, sentMessage, tc.expectedMsg)
})
}
}
// Helper function to get username by ID for test
func getUsernameByID(id int64) string {
switch id {
case 123:
return "owner"
case 456:
return "admin"
case 789:
return "regular"
default:
return "unknown"
}
}
func setupTestDB(t *testing.T) *gorm.DB {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("Failed to open test database: %v", err)
}
// AutoMigrate the models
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{})
if err != nil {
t.Fatalf("Failed to migrate database schema: %v", err)
}
// Create default roles
err = createDefaultRoles(db)
if err != nil {
t.Fatalf("Failed to create default roles: %v", err)
}
return db
}

0
logger.go Executable file → Normal file
View File

16
main.go Executable file → Normal file
View File

@@ -47,19 +47,13 @@ func main() {
return return
} }
// Initialize TelegramClient with the bot's handleUpdate method // Start the bot in a separate goroutine
tgClient, err := initTelegramBot(cfg.TelegramToken, bot.handleUpdate) go bot.Start(ctx)
if err != nil {
ErrorLogger.Printf("Error initializing Telegram client for bot %s: %v", cfg.ID, err)
return
}
// Assign the TelegramClient to the bot // Keep the bot running until the context is cancelled
bot.tgBot = tgClient <-ctx.Done()
// Start the bot InfoLogger.Printf("Bot %s stopped", cfg.ID)
InfoLogger.Printf("Starting bot %s...", cfg.ID)
bot.Start(ctx)
}(config) }(config)
} }

21
models.go Executable file → Normal file
View File

@@ -29,16 +29,19 @@ type ConfigModel struct {
type Message struct { type Message struct {
gorm.Model gorm.Model
BotID uint BotID uint `gorm:"index"`
ChatID int64 ChatID int64 `gorm:"index"`
UserID int64 UserID int64 `gorm:"index"`
Username string Username string `gorm:"index"`
UserRole string UserRole string // Store the role as a string
Text string Text string `gorm:"type:text"`
StickerFileID string `json:"sticker_file_id,omitempty"` // New field to store Sticker File ID Timestamp time.Time `gorm:"index"`
StickerPNGFile string `json:"sticker_png_file,omitempty"` // Optionally store PNG file ID if needed
Timestamp time.Time
IsUser bool IsUser bool
StickerFileID string
StickerPNGFile string
StickerEmoji string // Store the emoji associated with the sticker
DeletedAt gorm.DeletedAt `gorm:"index"` // Add soft delete field
AnsweredOn *time.Time `gorm:"index"` // Tracks when a user message was answered (NULL for assistant messages and unanswered user messages)
} }
type ChatMemory struct { type ChatMemory struct {

0
rate_limiter.go Executable file → Normal file
View File

0
rate_limiter_test.go Executable file → Normal file
View File

0
telegram_client.go Executable file → Normal file
View File

17
telegram_client_mock.go Executable file → Normal file
View File

@@ -6,13 +6,14 @@ import (
"github.com/go-telegram/bot" "github.com/go-telegram/bot"
"github.com/go-telegram/bot/models" "github.com/go-telegram/bot/models"
"github.com/stretchr/testify/mock"
) )
// MockTelegramClient is a mock implementation of TelegramClient for testing. // MockTelegramClient is a mock implementation of TelegramClient for testing.
type MockTelegramClient struct { type MockTelegramClient struct {
// You can add fields to keep track of calls if needed. mock.Mock
SendMessageFunc func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) SendMessageFunc func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error)
StartFunc func(ctx context.Context) // Optional: track Start calls StartFunc func(ctx context.Context)
} }
// SendMessage mocks sending a message. // SendMessage mocks sending a message.
@@ -20,16 +21,18 @@ func (m *MockTelegramClient) SendMessage(ctx context.Context, params *bot.SendMe
if m.SendMessageFunc != nil { if m.SendMessageFunc != nil {
return m.SendMessageFunc(ctx, params) return m.SendMessageFunc(ctx, params)
} }
// Default behavior: return an empty message without error. args := m.Called(ctx, params)
return &models.Message{}, nil if msg, ok := args.Get(0).(*models.Message); ok {
return msg, args.Error(1)
}
return nil, args.Error(1)
} }
// Start mocks starting the Telegram client. // Start mocks starting the Telegram client.
func (m *MockTelegramClient) Start(ctx context.Context) { func (m *MockTelegramClient) Start(ctx context.Context) {
if m.StartFunc != nil { if m.StartFunc != nil {
m.StartFunc(ctx) m.StartFunc(ctx)
return
} }
// Default behavior: do nothing. m.Called(ctx)
} }
// Add other mocked methods if your Bot uses more TelegramClient methods.

98
user_management_test.go Executable file → Normal file
View File

@@ -184,6 +184,104 @@ func TestPromoteUserToAdmin(t *testing.T) {
} }
} }
// TestGetOrCreateUser tests the getOrCreateUser method of the Bot.
// It verifies that a new user is created when one does not exist,
// and an existing user is returned when one does exist.
func TestGetOrCreateUser(t *testing.T) {
// Initialize loggers
initLoggers()
// Initialize in-memory database for testing
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("Failed to open in-memory database: %v", err)
}
// Migrate the schema
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{})
if err != nil {
t.Fatalf("Failed to migrate database schema: %v", err)
}
// Create default roles
err = createDefaultRoles(db)
if err != nil {
t.Fatalf("Failed to create default roles: %v", err)
}
// Create a mock clock starting at a fixed time
mockClock := &MockClock{
currentTime: time.Date(2023, 10, 1, 0, 0, 0, 0, time.UTC),
}
// Create a mock configuration
config := BotConfig{
ID: "bot1",
MemorySize: 10,
MessagePerHour: 5,
MessagePerDay: 10,
TempBanDuration: "1m",
SystemPrompts: make(map[string]string),
TelegramToken: "YOUR_TELEGRAM_BOT_TOKEN",
OwnerTelegramID: 123456789,
}
// Initialize MockTelegramClient
mockTGClient := &MockTelegramClient{
SendMessageFunc: func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) {
chatID, ok := params.ChatID.(int64)
if !ok {
return nil, fmt.Errorf("ChatID is not of type int64")
}
// Simulate successful message sending
return &models.Message{ID: 1, Chat: models.Chat{ID: chatID}}, nil
},
}
// Create the bot with the mock Telegram client
bot, err := NewBot(db, config, mockClock, mockTGClient)
if err != nil {
t.Fatalf("Failed to create bot: %v", err)
}
// Verify that the owner exists
var owner User
err = db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", config.OwnerTelegramID, bot.botID, true).First(&owner).Error
if err != nil {
t.Fatalf("Owner was not created: %v", err)
}
// Attempt to create another owner for the same bot
_, err = bot.getOrCreateUser(222222222, "AnotherOwner", true)
if err == nil {
t.Fatalf("Expected error when creating a second owner, but got none")
}
// Create a new user
newUser, err := bot.getOrCreateUser(987654321, "TestUser", false)
if err != nil {
t.Fatalf("Failed to create a new user: %v", err)
}
// Verify that the new user was created
var userInDB User
err = db.Where("telegram_id = ?", newUser.TelegramID).First(&userInDB).Error
if err != nil {
t.Fatalf("New user was not created in the database: %v", err)
}
// Get the existing user
existingUser, err := bot.getOrCreateUser(987654321, "TestUser", false)
if err != nil {
t.Fatalf("Failed to get existing user: %v", err)
}
// Verify that the existing user is the same as the new user
if existingUser.ID != userInDB.ID {
t.Fatalf("Expected to get the existing user, but got a different user")
}
}
// To ensure thread safety and avoid race conditions during testing, // To ensure thread safety and avoid race conditions during testing,
// you can run the tests with the `-race` flag: // you can run the tests with the `-race` flag:
// go test -race -v // go test -race -v