mirror of
https://github.com/HugeFrog24/go-telegram-bot.git
synced 2026-04-30 23:32:19 +00:00
Style and security
This commit is contained in:
@@ -0,0 +1,14 @@
|
|||||||
|
---
|
||||||
|
alwaysApply: true
|
||||||
|
always_on: true
|
||||||
|
trigger: always_on
|
||||||
|
applyTo: "**"
|
||||||
|
description: Snyk Security At Inception
|
||||||
|
---
|
||||||
|
|
||||||
|
# Project security best practices
|
||||||
|
|
||||||
|
- Always run snyk_code_scan tool for new first party code that is generated in a Snyk-supported language.
|
||||||
|
- If any security issues are found based on newly introduced or modified code or dependencies, attempt to fix the issues using the results context from Snyk.
|
||||||
|
- Rescan the code after fixing the issues to ensure that the issues were fixed and that there are no newly introduced issues.
|
||||||
|
- Repeat this process until no new issues are found.
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
# Git files
|
||||||
|
.git
|
||||||
|
.gitignore
|
||||||
|
.gitattributes
|
||||||
|
|
||||||
|
# Documentation
|
||||||
|
README.md
|
||||||
|
*.md
|
||||||
|
|
||||||
|
# Docker files
|
||||||
|
Dockerfile
|
||||||
|
docker-compose.yml
|
||||||
|
.dockerignore
|
||||||
|
|
||||||
|
# Environment files
|
||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
|
||||||
|
# Log files
|
||||||
|
*.log
|
||||||
|
logs/
|
||||||
|
|
||||||
|
# Database files
|
||||||
|
*.db
|
||||||
|
*.sqlite
|
||||||
|
*.sqlite3
|
||||||
|
bot.db
|
||||||
|
|
||||||
|
# Config files (except default template)
|
||||||
|
config/*
|
||||||
|
!config/default.json
|
||||||
|
|
||||||
|
# Test files
|
||||||
|
*_test.go
|
||||||
|
test/
|
||||||
|
tests/
|
||||||
|
|
||||||
|
# Build artifacts
|
||||||
|
telegram-bot
|
||||||
|
*.exe
|
||||||
|
*.dll
|
||||||
|
*.so
|
||||||
|
*.dylib
|
||||||
|
|
||||||
|
# IDE files
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
# OS files
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# Go specific
|
||||||
|
vendor/
|
||||||
|
*.mod.backup
|
||||||
|
*.sum.backup
|
||||||
|
|
||||||
|
# Temporary files
|
||||||
|
tmp/
|
||||||
|
temp/
|
||||||
|
*.tmp
|
||||||
|
|
||||||
|
# Coverage files
|
||||||
|
*.out
|
||||||
|
coverage.html
|
||||||
|
|
||||||
|
# CI/CD files
|
||||||
|
.github/
|
||||||
|
.gitlab-ci.yml
|
||||||
|
.travis.yml
|
||||||
|
|
||||||
|
# Examples and documentation
|
||||||
|
examples/
|
||||||
|
docs/
|
||||||
Executable → Regular
Executable → Regular
+32
-28
@@ -7,23 +7,15 @@ on:
|
|||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
# Common setup job that other jobs can depend on
|
||||||
|
setup:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
# Checkout the repository
|
- uses: actions/checkout@v4
|
||||||
- name: Checkout code
|
- uses: actions/setup-go@v5
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
# Set up Go environment
|
|
||||||
- name: Set up Go
|
|
||||||
uses: actions/setup-go@v5
|
|
||||||
with:
|
with:
|
||||||
go-version: '1.23' # Specify the Go version you are using
|
go-version: '1.26.0'
|
||||||
|
- uses: actions/cache@v4
|
||||||
# Cache Go modules
|
|
||||||
- name: Cache Go modules
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/.cache/go-build
|
~/.cache/go-build
|
||||||
@@ -31,24 +23,36 @@ jobs:
|
|||||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-go-
|
${{ runner.os }}-go-
|
||||||
|
- run: go mod tidy
|
||||||
|
|
||||||
# Install Dependencies
|
# Lint job
|
||||||
- name: Install Dependencies
|
lint:
|
||||||
run: go mod tidy
|
needs: setup
|
||||||
|
runs-on: ubuntu-latest
|
||||||
# Run Linters using golangci-lint
|
steps:
|
||||||
- name: Lint Code
|
- uses: actions/checkout@v6
|
||||||
uses: golangci/golangci-lint-action@v6
|
- uses: golangci/golangci-lint-action@v9
|
||||||
with:
|
with:
|
||||||
version: v1.60 # Specify the version of golangci-lint
|
version: v2.10
|
||||||
args: --timeout 5m
|
args: --timeout 5m
|
||||||
|
|
||||||
# Run Tests
|
# Test job
|
||||||
- name: Run Tests
|
test:
|
||||||
run: go test ./... -v
|
needs: setup
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: actions/setup-go@v6
|
||||||
|
with:
|
||||||
|
go-version: '1.26.0'
|
||||||
|
- run: go test ./... -v
|
||||||
|
|
||||||
# Security Analysis using gosec
|
# Security scan job
|
||||||
- name: Security Scan
|
security:
|
||||||
uses: securego/gosec@master
|
needs: setup
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: securego/gosec@master
|
||||||
with:
|
with:
|
||||||
args: ./...
|
args: ./...
|
||||||
|
|||||||
Executable → Regular
+3
@@ -1,3 +1,6 @@
|
|||||||
|
# Local IDE config & user settings
|
||||||
|
.vscode/
|
||||||
|
|
||||||
# Go vendor directory
|
# Go vendor directory
|
||||||
vendor/
|
vendor/
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"mcpServers": {}
|
||||||
|
}
|
||||||
+55
@@ -0,0 +1,55 @@
|
|||||||
|
# Multi-stage build for Go Telegram Bot
|
||||||
|
# Build stage
|
||||||
|
FROM golang:1.26-alpine AS builder
|
||||||
|
|
||||||
|
# Install build dependencies including C compiler for CGO
|
||||||
|
RUN apk add --no-cache git ca-certificates tzdata gcc musl-dev
|
||||||
|
|
||||||
|
# Set working directory
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
# Copy go mod files first for better caching
|
||||||
|
COPY go.mod go.sum ./
|
||||||
|
|
||||||
|
# Download dependencies
|
||||||
|
RUN go mod download
|
||||||
|
|
||||||
|
# Copy source code
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
# Build the application
|
||||||
|
RUN CGO_ENABLED=1 GOOS=linux go build -a -installsuffix cgo -o telegram-bot .
|
||||||
|
|
||||||
|
# Runtime stage
|
||||||
|
FROM alpine:latest
|
||||||
|
|
||||||
|
# Merged into a single RUN to minimise image layers (docker:S7031).
|
||||||
|
# Order matters: packages must be installed before adduser/addgroup,
|
||||||
|
# and the app directory must exist before chown runs.
|
||||||
|
RUN apk --no-cache add ca-certificates tzdata sqlite && \
|
||||||
|
addgroup -g 1001 -S appgroup && \
|
||||||
|
adduser -u 1001 -S appuser -G appgroup && \
|
||||||
|
mkdir -p /app/config /app/data /app/logs && \
|
||||||
|
chown -R appuser:appgroup /app
|
||||||
|
|
||||||
|
# Set working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy binary from builder stage
|
||||||
|
COPY --from=builder /build/telegram-bot /app/telegram-bot
|
||||||
|
|
||||||
|
# Copy default config as template
|
||||||
|
COPY --chown=appuser:appgroup config/default.json /app/config/
|
||||||
|
|
||||||
|
# Switch to non-root user
|
||||||
|
USER appuser
|
||||||
|
|
||||||
|
# Expose any ports if needed (not required for this bot)
|
||||||
|
# EXPOSE 8080
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||||
|
CMD pgrep telegram-bot || exit 1
|
||||||
|
|
||||||
|
# Run the application
|
||||||
|
CMD ["/app/telegram-bot"]
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
A scalable, multi-bot solution for Telegram using Go, GORM, and the Anthropic API.
|
A scalable, multi-bot solution for Telegram using Go, GORM, and the Anthropic API.
|
||||||
|
|
||||||
## Design Considerations
|
## Design Considerations
|
||||||
|
|
||||||
- AI-powered
|
- AI-powered
|
||||||
- Supports multiple bot profiles
|
- Supports multiple bot profiles
|
||||||
- Uses SQLite for persistence
|
- Uses SQLite for persistence
|
||||||
@@ -12,41 +13,42 @@ A scalable, multi-bot solution for Telegram using Go, GORM, and the Anthropic AP
|
|||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
1. Clone the repository or install using `go get`:
|
### Docker Deployment (Recommended)
|
||||||
- Option 1: Clone the repository
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/HugeFrog24/go-telegram-bot.git
|
|
||||||
```
|
|
||||||
|
|
||||||
- Option 2: Install using go get
|
|
||||||
```bash
|
|
||||||
go get -u github.com/HugeFrog24/go-telegram-bot
|
|
||||||
```
|
|
||||||
|
|
||||||
- Navigate to the project directory:
|
1. Clone the repository:
|
||||||
```bash
|
|
||||||
cd go-telegram-bot
|
```bash
|
||||||
```
|
git clone https://github.com/HugeFrog24/go-telegram-bot.git
|
||||||
|
cd go-telegram-bot
|
||||||
|
```
|
||||||
|
|
||||||
2. Copy the default config template and edit it:
|
2. Copy the default config template and edit it:
|
||||||
```bash
|
```bash
|
||||||
cp config/default.json config/config-mybot.json
|
cp config/default.json config/mybot.json
|
||||||
|
nano config/mybot.json
|
||||||
```
|
```
|
||||||
|
|
||||||
Replace `config-mybot.json` with the name of your bot.
|
> [!IMPORTANT]
|
||||||
|
|
||||||
```bash
|
|
||||||
nano config/config-mybot.json
|
|
||||||
```
|
|
||||||
|
|
||||||
You can set up as many bots as you want. Just copy the template and edit the parameters.
|
|
||||||
|
|
||||||
> [!IMPORTANT]
|
|
||||||
> Keep your config files secret and do not commit them to version control.
|
> Keep your config files secret and do not commit them to version control.
|
||||||
|
|
||||||
3. Build the application:
|
3. Create data directory and run:
|
||||||
```bash
|
```bash
|
||||||
go build -o telegram-multibot
|
mkdir -p data
|
||||||
|
docker-compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
### Native Deployment
|
||||||
|
|
||||||
|
1. Install using `go get`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go get -u github.com/HugeFrog24/go-telegram-bot
|
||||||
|
cd go-telegram-bot
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Configure as above, then build:
|
||||||
|
```bash
|
||||||
|
go build -o telegram-bot
|
||||||
```
|
```
|
||||||
|
|
||||||
## Systemd Unit Setup
|
## Systemd Unit Setup
|
||||||
@@ -60,30 +62,32 @@ To enable the bot to start automatically on system boot and run in the backgroun
|
|||||||
```
|
```
|
||||||
|
|
||||||
Edit the service file:
|
Edit the service file:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
nano /etc/systemd/system/telegram-bot.service
|
sudo nano /etc/systemd/system/telegram-bot.service
|
||||||
```
|
```
|
||||||
|
|
||||||
Adjust the following parameters:
|
Adjust the following parameters:
|
||||||
- `WorkingDirectory`
|
|
||||||
- `ExecStart`
|
|
||||||
- `User`
|
|
||||||
|
|
||||||
3. Enable and start the service:
|
- WorkingDirectory
|
||||||
|
- ExecStart
|
||||||
|
- User
|
||||||
|
|
||||||
|
2. Enable and start the service:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sudo systemctl daemon-reload
|
sudo systemctl daemon-reload
|
||||||
```
|
```
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sudo systemctl enable telegram-bot.service
|
sudo systemctl enable telegram-bot
|
||||||
```
|
```
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sudo systemctl start telegram-bot.service
|
sudo systemctl start telegram-bot
|
||||||
```
|
```
|
||||||
|
|
||||||
4. Check the status:
|
3. Check the status:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sudo systemctl status telegram-bot
|
sudo systemctl status telegram-bot
|
||||||
@@ -93,18 +97,52 @@ For more details on the systemd setup, refer to the [demo service file](examples
|
|||||||
|
|
||||||
## Logs
|
## Logs
|
||||||
|
|
||||||
View logs using journalctl:
|
### Docker
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
journalctl -u telegram-bot
|
docker-compose logs -f telegram-bot
|
||||||
```
|
```
|
||||||
|
|
||||||
Follow logs:
|
### Systemd
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
journalctl -u telegram-bot -f
|
journalctl -u telegram-bot -f
|
||||||
```
|
```
|
||||||
|
|
||||||
View errors:
|
## Commands
|
||||||
|
|
||||||
|
| Command | Access | Description |
|
||||||
|
| --------------------------------- | ----------- | ------------------------------------------------------------ |
|
||||||
|
| `/stats` | All users | Show global bot statistics (total users and messages) |
|
||||||
|
| `/stats user` | All users | Show your own message statistics |
|
||||||
|
| `/stats user <user_id>` | Admin/Owner | Show statistics for a specific user |
|
||||||
|
| `/whoami` | All users | Show your Telegram ID, username, and role |
|
||||||
|
| `/clear` | All users | Soft-delete your own chat history |
|
||||||
|
| `/clear <user_id>` | Admin/Owner | Soft-delete all messages for a user across every chat |
|
||||||
|
| `/clear <user_id> <chat_id>` | Admin/Owner | Soft-delete a user's messages in a specific chat |
|
||||||
|
| `/clear_hard` | All users | Permanently delete your own chat history |
|
||||||
|
| `/clear_hard <user_id>` | Admin/Owner | Permanently delete all messages for a user across every chat |
|
||||||
|
| `/clear_hard <user_id> <chat_id>` | Admin/Owner | Permanently delete a user's messages in a specific chat |
|
||||||
|
| `/set_model <model-id>` | Admin/Owner | Switch the AI model live without restarting |
|
||||||
|
|
||||||
|
> **Note:** In private DMs each user's `chat_id` equals their `user_id`. The scoped `<chat_id>` form is mainly useful for group chat moderation.
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
The GitHub actions workflow already runs tests on every commit:
|
||||||
|
|
||||||
|
> [](https://github.com/HugeFrog24/go-telegram-bot/actions/workflows/go-ci.yaml)
|
||||||
|
|
||||||
|
However, you can run the tests locally using:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
journalctl -u telegram-bot -p err
|
go test -race -v ./...
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Storage
|
||||||
|
|
||||||
|
At the moment, a SQLite database (`./data/bot.db`) is used for persistent storage.
|
||||||
|
|
||||||
|
Remember to back it up regularly.
|
||||||
|
|
||||||
|
Future versions will support more robust storage backends.
|
||||||
|
|||||||
Executable → Regular
+84
-4
@@ -2,12 +2,20 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/liushuangls/go-anthropic/v2"
|
"github.com/liushuangls/go-anthropic/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (b *Bot) getAnthropicResponse(ctx context.Context, messages []anthropic.Message, isNewChat, isAdminOrOwner, isEmojiOnly bool) (string, error) {
|
// ErrModelNotFound is returned when the configured Anthropic model is no longer available
|
||||||
|
// (deprecated or removed). Callers can use errors.Is to detect this and surface an
|
||||||
|
// actionable message to admins/owners while keeping the response vague for regular users.
|
||||||
|
var ErrModelNotFound = errors.New("model not found or deprecated")
|
||||||
|
|
||||||
|
func (b *Bot) getAnthropicResponse(ctx context.Context, messages []anthropic.Message, isNewChat, isOwner, isEmojiOnly bool, username string, firstName string, lastName string, isPremium bool, languageCode string, messageTime int) (string, error) {
|
||||||
// Use prompts from config
|
// Use prompts from config
|
||||||
var systemMessage string
|
var systemMessage string
|
||||||
if isNewChat {
|
if isNewChat {
|
||||||
@@ -19,7 +27,57 @@ func (b *Bot) getAnthropicResponse(ctx context.Context, messages []anthropic.Mes
|
|||||||
// Combine default prompt with custom instructions
|
// Combine default prompt with custom instructions
|
||||||
systemMessage = b.config.SystemPrompts["default"] + " " + b.config.SystemPrompts["custom_instructions"] + " " + systemMessage
|
systemMessage = b.config.SystemPrompts["default"] + " " + b.config.SystemPrompts["custom_instructions"] + " " + systemMessage
|
||||||
|
|
||||||
if !isAdminOrOwner {
|
// Handle username placeholder
|
||||||
|
usernameValue := username
|
||||||
|
if username == "" {
|
||||||
|
usernameValue = "unknown" // Use "unknown" when username is not available
|
||||||
|
}
|
||||||
|
systemMessage = strings.ReplaceAll(systemMessage, "{username}", usernameValue)
|
||||||
|
|
||||||
|
// Handle firstname placeholder
|
||||||
|
firstnameValue := firstName
|
||||||
|
if firstName == "" {
|
||||||
|
firstnameValue = "unknown" // Use "unknown" when first name is not available
|
||||||
|
}
|
||||||
|
systemMessage = strings.ReplaceAll(systemMessage, "{firstname}", firstnameValue)
|
||||||
|
|
||||||
|
// Handle lastname placeholder
|
||||||
|
lastnameValue := lastName
|
||||||
|
if lastName == "" {
|
||||||
|
lastnameValue = "" // Empty string when last name is not available
|
||||||
|
}
|
||||||
|
systemMessage = strings.ReplaceAll(systemMessage, "{lastname}", lastnameValue)
|
||||||
|
|
||||||
|
// Handle language code placeholder
|
||||||
|
langValue := languageCode
|
||||||
|
if languageCode == "" {
|
||||||
|
langValue = "en" // Default to English when language code is not available
|
||||||
|
}
|
||||||
|
systemMessage = strings.ReplaceAll(systemMessage, "{language}", langValue)
|
||||||
|
|
||||||
|
// Handle premium status
|
||||||
|
premiumStatus := "regular user"
|
||||||
|
if isPremium {
|
||||||
|
premiumStatus = "premium user"
|
||||||
|
}
|
||||||
|
systemMessage = strings.ReplaceAll(systemMessage, "{premium_status}", premiumStatus)
|
||||||
|
|
||||||
|
// Handle time awareness
|
||||||
|
timeObj := time.Unix(int64(messageTime), 0)
|
||||||
|
hour := timeObj.Hour()
|
||||||
|
var timeContext string
|
||||||
|
if hour >= 5 && hour < 12 {
|
||||||
|
timeContext = "morning"
|
||||||
|
} else if hour >= 12 && hour < 18 {
|
||||||
|
timeContext = "afternoon"
|
||||||
|
} else if hour >= 18 && hour < 22 {
|
||||||
|
timeContext = "evening"
|
||||||
|
} else {
|
||||||
|
timeContext = "night"
|
||||||
|
}
|
||||||
|
systemMessage = strings.ReplaceAll(systemMessage, "{time_context}", timeContext)
|
||||||
|
|
||||||
|
if !isOwner {
|
||||||
systemMessage += " " + b.config.SystemPrompts["avoid_sensitive"]
|
systemMessage += " " + b.config.SystemPrompts["avoid_sensitive"]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -27,6 +85,16 @@ func (b *Bot) getAnthropicResponse(ctx context.Context, messages []anthropic.Mes
|
|||||||
systemMessage += " " + b.config.SystemPrompts["respond_with_emojis"]
|
systemMessage += " " + b.config.SystemPrompts["respond_with_emojis"]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Debug logging
|
||||||
|
InfoLogger.Printf("Sending %d messages to Anthropic", len(messages))
|
||||||
|
for i, msg := range messages {
|
||||||
|
for _, content := range msg.Content {
|
||||||
|
if content.Type == anthropic.MessagesContentTypeText {
|
||||||
|
InfoLogger.Printf("Message %d: Role=%v, Text=%v", i, msg.Role, content.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure the roles are correct
|
// Ensure the roles are correct
|
||||||
for i := range messages {
|
for i := range messages {
|
||||||
switch messages[i].Role {
|
switch messages[i].Role {
|
||||||
@@ -42,13 +110,25 @@ func (b *Bot) getAnthropicResponse(ctx context.Context, messages []anthropic.Mes
|
|||||||
|
|
||||||
model := anthropic.Model(b.config.Model)
|
model := anthropic.Model(b.config.Model)
|
||||||
|
|
||||||
resp, err := b.anthropicClient.CreateMessages(ctx, anthropic.MessagesRequest{
|
// Create the request
|
||||||
|
request := anthropic.MessagesRequest{
|
||||||
Model: model, // Now `model` is of type anthropic.Model
|
Model: model, // Now `model` is of type anthropic.Model
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
System: systemMessage,
|
System: systemMessage,
|
||||||
MaxTokens: 1000,
|
MaxTokens: 1000,
|
||||||
})
|
}
|
||||||
|
|
||||||
|
// Apply temperature if set in config
|
||||||
|
if b.config.Temperature != nil {
|
||||||
|
request.Temperature = b.config.Temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := b.anthropicClient.CreateMessages(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
var apiErr *anthropic.APIError
|
||||||
|
if errors.As(err, &apiErr) && apiErr.IsNotFoundErr() {
|
||||||
|
return "", fmt.Errorf("%w: %s", ErrModelNotFound, b.config.Model)
|
||||||
|
}
|
||||||
return "", fmt.Errorf("error creating Anthropic message: %w", err)
|
return "", fmt.Errorf("error creating Anthropic message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,197 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestLanguageCodeReplacement tests that language code is properly handled and replaced
|
||||||
|
func TestLanguageCodeReplacement(t *testing.T) {
|
||||||
|
// Test with provided language code
|
||||||
|
systemMessage := "User's language preference: '{language}'"
|
||||||
|
|
||||||
|
// Test with a specific language code
|
||||||
|
langValue := "fr"
|
||||||
|
result := strings.ReplaceAll(systemMessage, "{language}", langValue)
|
||||||
|
|
||||||
|
if !strings.Contains(result, "User's language preference: 'fr'") {
|
||||||
|
t.Errorf("Expected language code 'fr' to be replaced, got: %s", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with empty language code (should default to "en")
|
||||||
|
langValue = ""
|
||||||
|
if langValue == "" {
|
||||||
|
langValue = "en" // Default to English when language code is not available
|
||||||
|
}
|
||||||
|
result = strings.ReplaceAll(systemMessage, "{language}", langValue)
|
||||||
|
|
||||||
|
if !strings.Contains(result, "User's language preference: 'en'") {
|
||||||
|
t.Errorf("Expected default language code 'en' to be used, got: %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPremiumStatusReplacement tests that premium status is properly handled and replaced
|
||||||
|
func TestPremiumStatusReplacement(t *testing.T) {
|
||||||
|
systemMessage := "User is a {premium_status}"
|
||||||
|
|
||||||
|
// Test with premium user
|
||||||
|
isPremium := true
|
||||||
|
premiumStatus := "regular user"
|
||||||
|
if isPremium {
|
||||||
|
premiumStatus = "premium user"
|
||||||
|
}
|
||||||
|
result := strings.ReplaceAll(systemMessage, "{premium_status}", premiumStatus)
|
||||||
|
|
||||||
|
if !strings.Contains(result, "User is a premium user") {
|
||||||
|
t.Errorf("Expected premium status to be replaced with 'premium user', got: %s", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with regular user
|
||||||
|
isPremium = false
|
||||||
|
premiumStatus = "regular user"
|
||||||
|
if isPremium {
|
||||||
|
premiumStatus = "premium user"
|
||||||
|
}
|
||||||
|
result = strings.ReplaceAll(systemMessage, "{premium_status}", premiumStatus)
|
||||||
|
|
||||||
|
if !strings.Contains(result, "User is a regular user") {
|
||||||
|
t.Errorf("Expected premium status to be replaced with 'regular user', got: %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTimeContextCalculation tests that time context is correctly calculated for different hours
|
||||||
|
func TestTimeContextCalculation(t *testing.T) {
|
||||||
|
// Test cases for different hours
|
||||||
|
testCases := []struct {
|
||||||
|
hour int
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{3, "night"}, // Night: hours < 5 or hours >= 22
|
||||||
|
{5, "morning"}, // Morning: 5 <= hours < 12
|
||||||
|
{12, "afternoon"}, // Afternoon: 12 <= hours < 18
|
||||||
|
{17, "afternoon"}, // Afternoon: 12 <= hours < 18
|
||||||
|
{18, "evening"}, // Evening: 18 <= hours < 22
|
||||||
|
{21, "evening"}, // Evening: 18 <= hours < 22
|
||||||
|
{22, "night"}, // Night: hours < 5 or hours >= 22
|
||||||
|
{23, "night"}, // Night: hours < 5 or hours >= 22
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(fmt.Sprintf("Hour_%d", tc.hour), func(t *testing.T) {
|
||||||
|
// Create a timestamp for the specified hour
|
||||||
|
testTime := time.Date(2025, 5, 15, tc.hour, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
// Get the hour directly from the test time to ensure it's what we expect
|
||||||
|
actualHour := testTime.Hour()
|
||||||
|
if actualHour != tc.hour {
|
||||||
|
t.Fatalf("Test setup error: expected hour %d, got %d", tc.hour, actualHour)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate time context using the same logic as in anthropic.go
|
||||||
|
var timeContext string
|
||||||
|
if actualHour >= 5 && actualHour < 12 {
|
||||||
|
timeContext = "morning"
|
||||||
|
} else if actualHour >= 12 && actualHour < 18 {
|
||||||
|
timeContext = "afternoon"
|
||||||
|
} else if actualHour >= 18 && actualHour < 22 {
|
||||||
|
timeContext = "evening"
|
||||||
|
} else {
|
||||||
|
timeContext = "night"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the calculated time context matches the expected value
|
||||||
|
if timeContext != tc.expected {
|
||||||
|
t.Errorf("For hour %d: expected time context '%s', got '%s'",
|
||||||
|
actualHour, tc.expected, timeContext)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSystemMessagePlaceholderReplacement tests that all placeholders are correctly replaced
|
||||||
|
func TestSystemMessagePlaceholderReplacement(t *testing.T) {
|
||||||
|
systemMessage := "The user you're talking to has username '{username}' and display name '{firstname} {lastname}'.\n" +
|
||||||
|
"User's language preference: '{language}'\n" +
|
||||||
|
"User is a {premium_status}\n" +
|
||||||
|
"It's currently {time_context} in your timezone"
|
||||||
|
|
||||||
|
// Set up test data
|
||||||
|
username := "testuser"
|
||||||
|
firstName := "Test"
|
||||||
|
lastName := "User"
|
||||||
|
isPremium := true
|
||||||
|
languageCode := "de"
|
||||||
|
|
||||||
|
// Create a timestamp for a specific hour (e.g., 14:00 = afternoon)
|
||||||
|
testTime := time.Date(2025, 5, 15, 14, 0, 0, 0, time.UTC)
|
||||||
|
messageTime := int(testTime.Unix())
|
||||||
|
|
||||||
|
// Handle username placeholder
|
||||||
|
usernameValue := username
|
||||||
|
if username == "" {
|
||||||
|
usernameValue = "unknown"
|
||||||
|
}
|
||||||
|
systemMessage = strings.ReplaceAll(systemMessage, "{username}", usernameValue)
|
||||||
|
|
||||||
|
// Handle firstname placeholder
|
||||||
|
firstnameValue := firstName
|
||||||
|
if firstName == "" {
|
||||||
|
firstnameValue = "unknown"
|
||||||
|
}
|
||||||
|
systemMessage = strings.ReplaceAll(systemMessage, "{firstname}", firstnameValue)
|
||||||
|
|
||||||
|
// Handle lastname placeholder
|
||||||
|
lastnameValue := lastName
|
||||||
|
if lastName == "" {
|
||||||
|
lastnameValue = ""
|
||||||
|
}
|
||||||
|
systemMessage = strings.ReplaceAll(systemMessage, "{lastname}", lastnameValue)
|
||||||
|
|
||||||
|
// Handle language code placeholder
|
||||||
|
langValue := languageCode
|
||||||
|
if languageCode == "" {
|
||||||
|
langValue = "en"
|
||||||
|
}
|
||||||
|
systemMessage = strings.ReplaceAll(systemMessage, "{language}", langValue)
|
||||||
|
|
||||||
|
// Handle premium status
|
||||||
|
premiumStatus := "regular user"
|
||||||
|
if isPremium {
|
||||||
|
premiumStatus = "premium user"
|
||||||
|
}
|
||||||
|
systemMessage = strings.ReplaceAll(systemMessage, "{premium_status}", premiumStatus)
|
||||||
|
|
||||||
|
// Handle time awareness
|
||||||
|
timeObj := time.Unix(int64(messageTime), 0)
|
||||||
|
hour := timeObj.Hour()
|
||||||
|
var timeContext string
|
||||||
|
if hour >= 5 && hour < 12 {
|
||||||
|
timeContext = "morning"
|
||||||
|
} else if hour >= 12 && hour < 18 {
|
||||||
|
timeContext = "afternoon"
|
||||||
|
} else if hour >= 18 && hour < 22 {
|
||||||
|
timeContext = "evening"
|
||||||
|
} else {
|
||||||
|
timeContext = "night"
|
||||||
|
}
|
||||||
|
systemMessage = strings.ReplaceAll(systemMessage, "{time_context}", timeContext)
|
||||||
|
|
||||||
|
// Check that all placeholders were replaced correctly
|
||||||
|
if !strings.Contains(systemMessage, "username 'testuser'") {
|
||||||
|
t.Errorf("Username not replaced correctly, got: %s", systemMessage)
|
||||||
|
}
|
||||||
|
if !strings.Contains(systemMessage, "display name 'Test User'") {
|
||||||
|
t.Errorf("Display name not replaced correctly, got: %s", systemMessage)
|
||||||
|
}
|
||||||
|
if !strings.Contains(systemMessage, "language preference: 'de'") {
|
||||||
|
t.Errorf("Language preference not replaced correctly, got: %s", systemMessage)
|
||||||
|
}
|
||||||
|
if !strings.Contains(systemMessage, "User is a premium user") {
|
||||||
|
t.Errorf("Premium status not replaced correctly, got: %s", systemMessage)
|
||||||
|
}
|
||||||
|
if !strings.Contains(systemMessage, "It's currently afternoon in your timezone") {
|
||||||
|
t.Errorf("Time context not replaced correctly, got: %s", systemMessage)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -28,6 +28,14 @@ type Bot struct {
|
|||||||
botID uint // Reference to BotModel.ID
|
botID uint // Reference to BotModel.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper function to determine message type
|
||||||
|
func messageType(msg *models.Message) string {
|
||||||
|
if msg.Sticker != nil {
|
||||||
|
return "sticker"
|
||||||
|
}
|
||||||
|
return "text"
|
||||||
|
}
|
||||||
|
|
||||||
// NewBot initializes and returns a new Bot instance.
|
// NewBot initializes and returns a new Bot instance.
|
||||||
func NewBot(db *gorm.DB, config BotConfig, clock Clock, tgClient TelegramClient) (*Bot, error) {
|
func NewBot(db *gorm.DB, config BotConfig, clock Clock, tgClient TelegramClient) (*Bot, error) {
|
||||||
// Retrieve or create Bot entry in the database
|
// Retrieve or create Bot entry in the database
|
||||||
@@ -87,6 +95,15 @@ func NewBot(db *gorm.DB, config BotConfig, clock Clock, tgClient TelegramClient)
|
|||||||
tgBot: tgClient,
|
tgBot: tgClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if tgClient == nil {
|
||||||
|
var err error
|
||||||
|
tgClient, err = initTelegramBot(config.TelegramToken, b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to initialize Telegram bot: %w", err)
|
||||||
|
}
|
||||||
|
b.tgBot = tgClient
|
||||||
|
}
|
||||||
|
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -178,9 +195,10 @@ func (b *Bot) createMessage(chatID, userID int64, username, userRole, text strin
|
|||||||
return message
|
return message
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bot) storeMessage(message Message) error {
|
// storeMessage stores a message in the database and updates its ID
|
||||||
message.BotID = b.botID // Associate the message with the correct bot
|
func (b *Bot) storeMessage(message *Message) error {
|
||||||
return b.db.Create(&message).Error
|
message.BotID = b.botID // Associate the message with the correct bot
|
||||||
|
return b.db.Create(message).Error // This will update the message with its new ID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bot) getOrCreateChatMemory(chatID int64) *ChatMemory {
|
func (b *Bot) getOrCreateChatMemory(chatID int64) *ChatMemory {
|
||||||
@@ -190,14 +208,35 @@ func (b *Bot) getOrCreateChatMemory(chatID int64) *ChatMemory {
|
|||||||
|
|
||||||
if !exists {
|
if !exists {
|
||||||
b.chatMemoriesMu.Lock()
|
b.chatMemoriesMu.Lock()
|
||||||
// Double-check to prevent race condition
|
defer b.chatMemoriesMu.Unlock()
|
||||||
|
|
||||||
chatMemory, exists = b.chatMemories[chatID]
|
chatMemory, exists = b.chatMemories[chatID]
|
||||||
if !exists {
|
if !exists {
|
||||||
|
// Check if this is a new chat by querying the database
|
||||||
|
var count int64
|
||||||
|
b.db.Model(&Message{}).Where("chat_id = ? AND bot_id = ?", chatID, b.botID).Count(&count)
|
||||||
|
isNewChat := count == 0 // Truly new chat if no messages exist
|
||||||
|
|
||||||
var messages []Message
|
var messages []Message
|
||||||
b.db.Where("chat_id = ? AND bot_id = ?", chatID, b.botID).
|
if !isNewChat {
|
||||||
Order("timestamp asc").
|
// Fetch existing messages only if it's not a new chat
|
||||||
Limit(b.memorySize * 2).
|
err := b.db.Where("chat_id = ? AND bot_id = ?", chatID, b.botID).
|
||||||
Find(&messages)
|
Order("timestamp desc").
|
||||||
|
Limit(b.memorySize * 2).
|
||||||
|
Find(&messages).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
ErrorLogger.Printf("Error fetching messages from database: %v", err)
|
||||||
|
messages = []Message{} // Initialize an empty slice on error
|
||||||
|
} else {
|
||||||
|
// Reverse from newest-first to chronological order for conversation context.
|
||||||
|
for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 {
|
||||||
|
messages[i], messages[j] = messages[j], messages[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
messages = []Message{} // Ensure messages is initialized for new chats
|
||||||
|
}
|
||||||
|
|
||||||
chatMemory = &ChatMemory{
|
chatMemory = &ChatMemory{
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
@@ -206,19 +245,22 @@ func (b *Bot) getOrCreateChatMemory(chatID int64) *ChatMemory {
|
|||||||
|
|
||||||
b.chatMemories[chatID] = chatMemory
|
b.chatMemories[chatID] = chatMemory
|
||||||
}
|
}
|
||||||
b.chatMemoriesMu.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return chatMemory
|
return chatMemory
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addMessageToChatMemory adds a new message to the chat memory, ensuring the memory size is maintained.
|
||||||
func (b *Bot) addMessageToChatMemory(chatMemory *ChatMemory, message Message) {
|
func (b *Bot) addMessageToChatMemory(chatMemory *ChatMemory, message Message) {
|
||||||
b.chatMemoriesMu.Lock()
|
b.chatMemoriesMu.Lock()
|
||||||
defer b.chatMemoriesMu.Unlock()
|
defer b.chatMemoriesMu.Unlock()
|
||||||
|
|
||||||
|
// Add the new message
|
||||||
chatMemory.Messages = append(chatMemory.Messages, message)
|
chatMemory.Messages = append(chatMemory.Messages, message)
|
||||||
|
|
||||||
|
// Maintain the memory size
|
||||||
if len(chatMemory.Messages) > chatMemory.Size {
|
if len(chatMemory.Messages) > chatMemory.Size {
|
||||||
chatMemory.Messages = chatMemory.Messages[2:]
|
chatMemory.Messages = chatMemory.Messages[len(chatMemory.Messages)-chatMemory.Size:]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -226,6 +268,17 @@ func (b *Bot) prepareContextMessages(chatMemory *ChatMemory) []anthropic.Message
|
|||||||
b.chatMemoriesMu.RLock()
|
b.chatMemoriesMu.RLock()
|
||||||
defer b.chatMemoriesMu.RUnlock()
|
defer b.chatMemoriesMu.RUnlock()
|
||||||
|
|
||||||
|
// Debug logging
|
||||||
|
InfoLogger.Printf("Chat memory contains %d messages", len(chatMemory.Messages))
|
||||||
|
for i, msg := range chatMemory.Messages {
|
||||||
|
InfoLogger.Printf("Message %d: IsUser=%v, Text=%q", i, msg.IsUser, msg.Text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: consecutive messages with the same role are permitted.
|
||||||
|
// The Anthropic API automatically merges them into a single turn rather than
|
||||||
|
// returning an error. This can happen after a /clear (which only deletes user
|
||||||
|
// messages, leaving assistant messages in the DB) followed by a restart.
|
||||||
|
// See: https://platform.claude.com/docs/en/api/messages
|
||||||
var contextMessages []anthropic.Message
|
var contextMessages []anthropic.Message
|
||||||
for _, msg := range chatMemory.Messages {
|
for _, msg := range chatMemory.Messages {
|
||||||
role := anthropic.RoleUser
|
role := anthropic.RoleUser
|
||||||
@@ -252,21 +305,90 @@ func (b *Bot) prepareContextMessages(chatMemory *ChatMemory) []anthropic.Message
|
|||||||
func (b *Bot) isNewChat(chatID int64) bool {
|
func (b *Bot) isNewChat(chatID int64) bool {
|
||||||
var count int64
|
var count int64
|
||||||
b.db.Model(&Message{}).Where("chat_id = ? AND bot_id = ?", chatID, b.botID).Count(&count)
|
b.db.Model(&Message{}).Where("chat_id = ? AND bot_id = ?", chatID, b.botID).Count(&count)
|
||||||
return count == 1
|
return count == 0 // Only consider a chat new if it has 0 messages
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bot) isAdminOrOwner(userID int64) bool {
|
// roleHasScope reports whether role (with pre-loaded Scopes) contains the given scope name.
|
||||||
|
func roleHasScope(role Role, scope string) bool {
|
||||||
|
for _, s := range role.Scopes {
|
||||||
|
if s.Name == scope {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasScope reports whether the user identified by userID holds the given scope for this bot.
|
||||||
|
// Owners implicitly hold all scopes regardless of their assigned role.
|
||||||
|
func (b *Bot) hasScope(userID int64, scope string) bool {
|
||||||
var user User
|
var user User
|
||||||
err := b.db.Preload("Role").Where("telegram_id = ?", userID).First(&user).Error
|
if err := b.db.Preload("Role.Scopes").
|
||||||
if err != nil {
|
Where("telegram_id = ? AND bot_id = ?", userID, b.botID).
|
||||||
|
First(&user).Error; err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return user.Role.Name == "admin" || user.Role.Name == "owner"
|
if user.IsOwner {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return roleHasScope(user.Role, scope)
|
||||||
}
|
}
|
||||||
|
|
||||||
func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot *bot.Bot, update *models.Update)) (TelegramClient, error) {
|
// publicBotCommands are shown to every user in the Telegram command palette.
|
||||||
|
var publicBotCommands = []models.BotCommand{
|
||||||
|
{Command: "stats", Description: "Get bot statistics. Usage: /stats or /stats user [user_id]"},
|
||||||
|
{Command: "whoami", Description: "Get your user information"},
|
||||||
|
{Command: "clear", Description: "Clear chat history (soft delete). Admins: /clear [user_id]"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// adminBotCommands are shown only in admin/owner chats via BotCommandScopeChatMember.
|
||||||
|
var adminBotCommands = []models.BotCommand{
|
||||||
|
{Command: "clear_hard", Description: "Clear chat history (permanently delete). Admins: /clear_hard [user_id]"},
|
||||||
|
{Command: "set_model", Description: "Switch the AI model (admin/owner only). Usage: /set_model <model-id>"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerAdminCommandsForUser scopes the full command palette to a specific user's private chat.
|
||||||
|
// In Telegram private chats, chat_id == user_id, so both fields carry the same value.
|
||||||
|
// Errors are logged but treated as non-fatal: the user retains access via permission checks.
|
||||||
|
func (b *Bot) registerAdminCommandsForUser(ctx context.Context, telegramID int64) {
|
||||||
|
allCommands := make([]models.BotCommand, 0, len(publicBotCommands)+len(adminBotCommands))
|
||||||
|
allCommands = append(allCommands, publicBotCommands...)
|
||||||
|
allCommands = append(allCommands, adminBotCommands...)
|
||||||
|
_, err := b.tgBot.SetMyCommands(ctx, &bot.SetMyCommandsParams{
|
||||||
|
Commands: allCommands,
|
||||||
|
Scope: &models.BotCommandScopeChatMember{ChatID: telegramID, UserID: telegramID},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
ErrorLogger.Printf("Failed to register admin commands for user %d: %v", telegramID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setElevatedCommands registers the full command palette (public + admin) for every user
|
||||||
|
// whose role carries the model:set scope, or who is the bot owner. Called once at startup
|
||||||
|
// and uses the freshly created tgBot directly (b.tgBot is not yet assigned at that point).
|
||||||
|
func setElevatedCommands(tgBot TelegramClient, users []User) {
|
||||||
|
allCommands := make([]models.BotCommand, 0, len(publicBotCommands)+len(adminBotCommands))
|
||||||
|
allCommands = append(allCommands, publicBotCommands...)
|
||||||
|
allCommands = append(allCommands, adminBotCommands...)
|
||||||
|
for _, u := range users {
|
||||||
|
if u.TelegramID == 0 {
|
||||||
|
continue // skip placeholder users not yet seen in a chat
|
||||||
|
}
|
||||||
|
if !u.IsOwner && !roleHasScope(u.Role, ScopeModelSet) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, err := tgBot.SetMyCommands(context.Background(), &bot.SetMyCommandsParams{
|
||||||
|
Commands: allCommands,
|
||||||
|
Scope: &models.BotCommandScopeChatMember{ChatID: u.TelegramID, UserID: u.TelegramID},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
ErrorLogger.Printf("Warning: could not set admin commands for user %d: %v", u.TelegramID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func initTelegramBot(token string, b *Bot) (TelegramClient, error) {
|
||||||
opts := []bot.Option{
|
opts := []bot.Option{
|
||||||
bot.WithDefaultHandler(handleUpdate),
|
bot.WithDefaultHandler(b.handleUpdate),
|
||||||
}
|
}
|
||||||
|
|
||||||
tgBot, err := bot.New(token, opts...)
|
tgBot, err := bot.New(token, opts...)
|
||||||
@@ -274,12 +396,33 @@ func initTelegramBot(token string, handleUpdate func(ctx context.Context, tgBot
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Register public commands for all users.
|
||||||
|
_, err = tgBot.SetMyCommands(context.Background(), &bot.SetMyCommandsParams{
|
||||||
|
Commands: publicBotCommands,
|
||||||
|
Scope: &models.BotCommandScopeDefault{},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
ErrorLogger.Printf("Error setting default bot commands: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register full command palette (public + admin) scoped to each known elevated user.
|
||||||
|
// BotCommandScopeChatMember targets the user's private DM with the bot (chat_id == user_id).
|
||||||
|
// Elevation is determined by scope rather than role name, so renaming roles requires no code change.
|
||||||
|
// This is best-effort: failures are logged but do not prevent the bot from starting.
|
||||||
|
var allUsers []User
|
||||||
|
if err := b.db.Preload("Role.Scopes").Where("bot_id = ?", b.botID).Find(&allUsers).Error; err != nil {
|
||||||
|
ErrorLogger.Printf("Warning: could not query users for command scoping: %v", err)
|
||||||
|
} else {
|
||||||
|
setElevatedCommands(tgBot, allUsers)
|
||||||
|
}
|
||||||
|
|
||||||
return tgBot, nil
|
return tgBot, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, businessConnectionID string) error {
|
func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, businessConnectionID string) error {
|
||||||
// Pass the outgoing message through the centralized screen for storage
|
// Pass the outgoing message through the centralized screen for storage and chat memory update
|
||||||
_, err := b.screenOutgoingMessage(chatID, text, businessConnectionID)
|
_, err := b.screenOutgoingMessage(chatID, text)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorLogger.Printf("Error storing assistant message: %v", err)
|
ErrorLogger.Printf("Error storing assistant message: %v", err)
|
||||||
return err
|
return err
|
||||||
@@ -306,46 +449,122 @@ func (b *Bot) sendResponse(ctx context.Context, chatID int64, text string, busin
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sendStats sends the bot statistics to the specified chat.
|
// sendStats sends the bot statistics to the specified chat.
|
||||||
func (b *Bot) sendStats(ctx context.Context, chatID int64, userID int64, username string, businessConnectionID string) {
|
func (b *Bot) sendStats(ctx context.Context, chatID int64, userID int64, targetUserID int64, businessConnectionID string) {
|
||||||
totalUsers, totalMessages, err := b.getStats()
|
// If targetUserID is 0, show global stats
|
||||||
|
if targetUserID == 0 {
|
||||||
|
totalUsers, totalMessages, err := b.getStats()
|
||||||
|
if err != nil {
|
||||||
|
ErrorLogger.Printf("Error fetching stats: %v\n", err)
|
||||||
|
if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve the stats at this time.", businessConnectionID); err != nil {
|
||||||
|
ErrorLogger.Printf("Error sending response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do NOT manually escape hyphens here
|
||||||
|
statsMessage := fmt.Sprintf(
|
||||||
|
"📊 Bot Statistics:\n\n"+
|
||||||
|
"- Total Users: %d\n"+
|
||||||
|
"- Total Messages: %d",
|
||||||
|
totalUsers,
|
||||||
|
totalMessages,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Send the response through the centralized screen
|
||||||
|
if err := b.sendResponse(ctx, chatID, statsMessage, businessConnectionID); err != nil {
|
||||||
|
ErrorLogger.Printf("Error sending stats message: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If targetUserID is not 0, show user-specific stats
|
||||||
|
// Check permissions if the user is trying to view someone else's stats
|
||||||
|
if targetUserID != userID {
|
||||||
|
if !b.hasScope(userID, ScopeStatsViewAny) {
|
||||||
|
InfoLogger.Printf("User %d attempted to view stats for user %d without permission", userID, targetUserID)
|
||||||
|
if err := b.sendResponse(ctx, chatID, "Permission denied. Only admins and owners can view other users' statistics.", businessConnectionID); err != nil {
|
||||||
|
ErrorLogger.Printf("Error sending response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get user stats
|
||||||
|
username, messagesIn, messagesOut, totalMessages, err := b.getUserStats(targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorLogger.Printf("Error fetching stats: %v\n", err)
|
ErrorLogger.Printf("Error fetching user stats: %v\n", err)
|
||||||
if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't retrieve the stats at this time.", businessConnectionID); err != nil {
|
if err := b.sendResponse(ctx, chatID, fmt.Sprintf("Sorry, I couldn't retrieve statistics for user ID %d.", targetUserID), businessConnectionID); err != nil {
|
||||||
ErrorLogger.Printf("Error sending response: %v", err)
|
ErrorLogger.Printf("Error sending response: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do NOT manually escape hyphens here
|
// Build the user stats message
|
||||||
|
userInfo := fmt.Sprintf("@%s (ID: %d)", username, targetUserID)
|
||||||
|
if username == "" {
|
||||||
|
userInfo = fmt.Sprintf("User ID: %d", targetUserID)
|
||||||
|
}
|
||||||
|
|
||||||
statsMessage := fmt.Sprintf(
|
statsMessage := fmt.Sprintf(
|
||||||
"📊 Bot Statistics:\n\n"+
|
"👤 User Statistics for %s:\n\n"+
|
||||||
"- Total Users: %d\n"+
|
"- Messages Sent: %d\n"+
|
||||||
|
"- Messages Received: %d\n"+
|
||||||
"- Total Messages: %d",
|
"- Total Messages: %d",
|
||||||
totalUsers,
|
userInfo,
|
||||||
|
messagesIn,
|
||||||
|
messagesOut,
|
||||||
totalMessages,
|
totalMessages,
|
||||||
)
|
)
|
||||||
|
|
||||||
// Send the response through the centralized screen
|
|
||||||
if err := b.sendResponse(ctx, chatID, statsMessage, businessConnectionID); err != nil {
|
if err := b.sendResponse(ctx, chatID, statsMessage, businessConnectionID); err != nil {
|
||||||
ErrorLogger.Printf("Error sending stats message: %v", err)
|
ErrorLogger.Printf("Error sending user stats message: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getStats retrieves the total number of users and messages from the database.
|
// getStats retrieves the total number of users and messages from the database.
|
||||||
func (b *Bot) getStats() (int64, int64, error) {
|
func (b *Bot) getStats() (int64, int64, error) {
|
||||||
var totalUsers int64
|
var totalUsers int64
|
||||||
if err := b.db.Model(&User{}).Count(&totalUsers).Error; err != nil {
|
if err := b.db.Model(&User{}).Where("bot_id = ?", b.botID).Count(&totalUsers).Error; err != nil {
|
||||||
return 0, 0, err
|
return 0, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var totalMessages int64
|
var totalMessages int64
|
||||||
if err := b.db.Model(&Message{}).Count(&totalMessages).Error; err != nil {
|
if err := b.db.Model(&Message{}).Where("bot_id = ?", b.botID).Count(&totalMessages).Error; err != nil {
|
||||||
return 0, 0, err
|
return 0, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return totalUsers, totalMessages, nil
|
return totalUsers, totalMessages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getUserStats retrieves statistics for a specific user
|
||||||
|
func (b *Bot) getUserStats(userID int64) (string, int64, int64, int64, error) {
|
||||||
|
// Get user information from database
|
||||||
|
var user User
|
||||||
|
err := b.db.Where("telegram_id = ? AND bot_id = ?", userID, b.botID).First(&user).Error
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, 0, 0, fmt.Errorf("user not found: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count messages sent by the user (IN)
|
||||||
|
var messagesIn int64
|
||||||
|
if err := b.db.Model(&Message{}).Where("user_id = ? AND bot_id = ? AND is_user = ?",
|
||||||
|
userID, b.botID, true).Count(&messagesIn).Error; err != nil {
|
||||||
|
return "", 0, 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count responses to the user (OUT)
|
||||||
|
var messagesOut int64
|
||||||
|
if err := b.db.Model(&Message{}).Where("chat_id IN (SELECT DISTINCT chat_id FROM messages WHERE user_id = ? AND bot_id = ? AND deleted_at IS NULL) AND bot_id = ? AND is_user = ?",
|
||||||
|
userID, b.botID, b.botID, false).Count(&messagesOut).Error; err != nil {
|
||||||
|
return "", 0, 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Total messages is the sum
|
||||||
|
totalMessages := messagesIn + messagesOut
|
||||||
|
|
||||||
|
return user.Username, messagesIn, messagesOut, totalMessages, nil
|
||||||
|
}
|
||||||
|
|
||||||
// isOnlyEmojis checks if the string consists solely of emojis.
|
// isOnlyEmojis checks if the string consists solely of emojis.
|
||||||
func isOnlyEmojis(s string) bool {
|
func isOnlyEmojis(s string) bool {
|
||||||
for _, r := range s {
|
for _, r := range s {
|
||||||
@@ -399,41 +618,96 @@ func (b *Bot) sendWhoAmI(ctx context.Context, chatID int64, userID int64, userna
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// screenIncomingMessage handles storing of incoming messages.
|
// screenIncomingMessage centralizes all incoming message processing: storing messages and updating chat memory.
|
||||||
func (b *Bot) screenIncomingMessage(message *models.Message) (Message, error) {
|
func (b *Bot) screenIncomingMessage(message *models.Message) (Message, error) {
|
||||||
userRole := string(anthropic.RoleUser) // Convert RoleUser to string
|
if b.config.DebugScreening {
|
||||||
userMessage := b.createMessage(message.Chat.ID, message.From.ID, message.From.Username, userRole, message.Text, true)
|
start := time.Now()
|
||||||
|
defer func() {
|
||||||
|
InfoLogger.Printf(
|
||||||
|
"[Screen] Incoming: chat=%d user=%d type=%s memory_size=%d duration=%v",
|
||||||
|
message.Chat.ID,
|
||||||
|
message.From.ID,
|
||||||
|
messageType(message),
|
||||||
|
len(b.getOrCreateChatMemory(message.Chat.ID).Messages),
|
||||||
|
time.Since(start),
|
||||||
|
)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
// If the message contains a sticker, include its details.
|
userRole := string(anthropic.RoleUser)
|
||||||
|
|
||||||
|
// Determine message text based on message type
|
||||||
|
messageText := message.Text
|
||||||
|
if message.Sticker != nil {
|
||||||
|
if message.Sticker.Emoji != "" {
|
||||||
|
messageText = fmt.Sprintf("Sent a sticker: %s", message.Sticker.Emoji)
|
||||||
|
} else {
|
||||||
|
messageText = "Sent a sticker."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
userMessage := b.createMessage(message.Chat.ID, message.From.ID, message.From.Username, userRole, messageText, true)
|
||||||
|
|
||||||
|
// Handle sticker-specific details if present
|
||||||
if message.Sticker != nil {
|
if message.Sticker != nil {
|
||||||
userMessage.StickerFileID = message.Sticker.FileID
|
userMessage.StickerFileID = message.Sticker.FileID
|
||||||
|
userMessage.StickerEmoji = message.Sticker.Emoji // Store the sticker emoji
|
||||||
if message.Sticker.Thumbnail != nil {
|
if message.Sticker.Thumbnail != nil {
|
||||||
userMessage.StickerPNGFile = message.Sticker.Thumbnail.FileID
|
userMessage.StickerPNGFile = message.Sticker.Thumbnail.FileID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store the message.
|
// Get the chat memory before storing the message
|
||||||
if err := b.storeMessage(userMessage); err != nil {
|
chatMemory := b.getOrCreateChatMemory(message.Chat.ID)
|
||||||
|
|
||||||
|
// Store the message and get its ID
|
||||||
|
if err := b.storeMessage(&userMessage); err != nil {
|
||||||
return Message{}, err
|
return Message{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update chat memory.
|
// Add the message to the chat memory
|
||||||
chatMemory := b.getOrCreateChatMemory(message.Chat.ID)
|
|
||||||
b.addMessageToChatMemory(chatMemory, userMessage)
|
b.addMessageToChatMemory(chatMemory, userMessage)
|
||||||
|
|
||||||
return userMessage, nil
|
return userMessage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// screenOutgoingMessage handles storing of outgoing messages.
|
// screenOutgoingMessage handles storing of outgoing messages and updating chat memory.
|
||||||
func (b *Bot) screenOutgoingMessage(chatID int64, response string, businessConnectionID string) (Message, error) {
|
// It also marks the most recent unanswered user message as answered.
|
||||||
assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false)
|
func (b *Bot) screenOutgoingMessage(chatID int64, response string) (Message, error) {
|
||||||
|
if b.config.DebugScreening {
|
||||||
|
start := time.Now()
|
||||||
|
defer func() {
|
||||||
|
InfoLogger.Printf(
|
||||||
|
"[Screen] Outgoing: chat=%d len=%d memory_size=%d duration=%v",
|
||||||
|
chatID,
|
||||||
|
len(response),
|
||||||
|
len(b.getOrCreateChatMemory(chatID).Messages),
|
||||||
|
time.Since(start),
|
||||||
|
)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
// Store the message.
|
// Create and store the assistant message
|
||||||
if err := b.storeMessage(assistantMessage); err != nil {
|
assistantMessage := b.createMessage(chatID, 0, "", string(anthropic.RoleAssistant), response, false)
|
||||||
|
if err := b.storeMessage(&assistantMessage); err != nil {
|
||||||
return Message{}, err
|
return Message{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update chat memory.
|
// Find and mark the most recent unanswered user message as answered
|
||||||
|
now := time.Now()
|
||||||
|
err := b.db.Model(&Message{}).
|
||||||
|
Where("chat_id = ? AND bot_id = ? AND is_user = ? AND answered_on IS NULL",
|
||||||
|
chatID, b.botID, true).
|
||||||
|
Order("timestamp DESC").
|
||||||
|
Limit(1).
|
||||||
|
Update("answered_on", now).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
ErrorLogger.Printf("Error marking user message as answered: %v", err)
|
||||||
|
// Continue even if there's an error updating the user message
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update chat memory with the message that now has an ID
|
||||||
chatMemory := b.getOrCreateChatMemory(chatID)
|
chatMemory := b.getOrCreateChatMemory(chatID)
|
||||||
b.addMessageToChatMemory(chatMemory, assistantMessage)
|
b.addMessageToChatMemory(chatMemory, assistantMessage)
|
||||||
|
|
||||||
@@ -441,8 +715,8 @@ func (b *Bot) screenOutgoingMessage(chatID int64, response string, businessConne
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bot) promoteUserToAdmin(promoterID, userToPromoteID int64) error {
|
func (b *Bot) promoteUserToAdmin(promoterID, userToPromoteID int64) error {
|
||||||
// Check if the promoter is an owner or admin
|
// Check if the promoter has the user:promote scope
|
||||||
if !b.isAdminOrOwner(promoterID) {
|
if !b.hasScope(promoterID, ScopeUserPromote) {
|
||||||
return errors.New("only admins or owners can promote users to admin")
|
return errors.New("only admins or owners can promote users to admin")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -461,5 +735,11 @@ func (b *Bot) promoteUserToAdmin(promoterID, userToPromoteID int64) error {
|
|||||||
// Update the user's role
|
// Update the user's role
|
||||||
userToPromote.RoleID = adminRole.ID
|
userToPromote.RoleID = adminRole.ID
|
||||||
userToPromote.Role = adminRole
|
userToPromote.Role = adminRole
|
||||||
return b.db.Save(&userToPromote).Error
|
if err := b.db.Save(&userToPromote).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Surface admin commands in the newly promoted user's private chat.
|
||||||
|
b.registerAdminCommandsForUser(context.Background(), userToPromoteID)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,22 +5,26 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/liushuangls/go-anthropic/v2"
|
"github.com/liushuangls/go-anthropic/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type BotConfig struct {
|
type BotConfig struct {
|
||||||
ID string `json:"id"` // Unique identifier for the bot
|
ID string `json:"id"`
|
||||||
TelegramToken string `json:"telegram_token"` // Telegram Bot Token
|
TelegramToken string `json:"telegram_token"`
|
||||||
MemorySize int `json:"memory_size"`
|
MemorySize int `json:"memory_size"`
|
||||||
MessagePerHour int `json:"messages_per_hour"`
|
MessagePerHour int `json:"messages_per_hour"`
|
||||||
MessagePerDay int `json:"messages_per_day"`
|
MessagePerDay int `json:"messages_per_day"`
|
||||||
TempBanDuration string `json:"temp_ban_duration"`
|
TempBanDuration string `json:"temp_ban_duration"`
|
||||||
Model anthropic.Model `json:"model"` // Changed from string to anthropic.Model
|
Model anthropic.Model `json:"model"`
|
||||||
|
Temperature *float32 `json:"temperature,omitempty"` // Controls creativity vs determinism (0.0-1.0)
|
||||||
SystemPrompts map[string]string `json:"system_prompts"`
|
SystemPrompts map[string]string `json:"system_prompts"`
|
||||||
Active bool `json:"active"` // New field to control bot activity
|
Active bool `json:"active"`
|
||||||
OwnerTelegramID int64 `json:"owner_telegram_id"`
|
OwnerTelegramID int64 `json:"owner_telegram_id"`
|
||||||
AnthropicAPIKey string `json:"anthropic_api_key"` // Add this line
|
AnthropicAPIKey string `json:"anthropic_api_key"`
|
||||||
|
DebugScreening bool `json:"debug_screening"` // Enable detailed screening logs
|
||||||
|
ConfigFilePath string `json:"-"` // Set at load time; not serialized
|
||||||
}
|
}
|
||||||
|
|
||||||
// Custom unmarshalling to handle anthropic.Model
|
// Custom unmarshalling to handle anthropic.Model
|
||||||
@@ -32,15 +36,45 @@ func (c *BotConfig) UnmarshalJSON(data []byte) error {
|
|||||||
}{
|
}{
|
||||||
Alias: (*Alias)(c),
|
Alias: (*Alias)(c),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(data, &aux); err != nil {
|
if err := json.Unmarshal(data, &aux); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Model = anthropic.Model(aux.Model)
|
c.Model = anthropic.Model(aux.Model)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateConfigPath ensures the file path is within the allowed directory
|
||||||
|
func validateConfigPath(configDir, filename string) (string, error) {
|
||||||
|
// Clean the paths to remove any . or .. components
|
||||||
|
configDir = filepath.Clean(configDir)
|
||||||
|
filename = filepath.Clean(filename)
|
||||||
|
|
||||||
|
// Get absolute paths
|
||||||
|
absConfigDir, err := filepath.Abs(configDir)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get absolute path for config directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fullPath := filepath.Join(absConfigDir, filename)
|
||||||
|
absPath, err := filepath.Abs(fullPath)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get absolute path for config file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use filepath.Rel to check if the path is within the config directory
|
||||||
|
rel, err := filepath.Rel(absConfigDir, absPath)
|
||||||
|
if err != nil || strings.HasPrefix(rel, "..") || strings.Contains(rel, "..") {
|
||||||
|
return "", fmt.Errorf("invalid config path: file must be within the config directory")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify file extension
|
||||||
|
if filepath.Ext(absPath) != ".json" {
|
||||||
|
return "", fmt.Errorf("invalid file extension: must be .json")
|
||||||
|
}
|
||||||
|
|
||||||
|
return absPath, nil
|
||||||
|
}
|
||||||
|
|
||||||
func loadAllConfigs(dir string) ([]BotConfig, error) {
|
func loadAllConfigs(dir string) ([]BotConfig, error) {
|
||||||
var configs []BotConfig
|
var configs []BotConfig
|
||||||
ids := make(map[string]bool)
|
ids := make(map[string]bool)
|
||||||
@@ -53,59 +87,84 @@ func loadAllConfigs(dir string) ([]BotConfig, error) {
|
|||||||
|
|
||||||
for _, file := range files {
|
for _, file := range files {
|
||||||
if filepath.Ext(file.Name()) == ".json" {
|
if filepath.Ext(file.Name()) == ".json" {
|
||||||
configPath := filepath.Join(dir, file.Name())
|
validPath, err := validateConfigPath(dir, file.Name())
|
||||||
config, err := loadConfig(configPath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to load config %s: %w", configPath, err)
|
InfoLogger.Printf("Invalid config path for %s: %v", file.Name(), err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := loadConfig(validPath)
|
||||||
|
if err != nil {
|
||||||
|
InfoLogger.Printf("Failed to load config %s: %v", validPath, err)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip inactive bots
|
|
||||||
if !config.Active {
|
if !config.Active {
|
||||||
InfoLogger.Printf("Skipping inactive bot: %s", config.ID)
|
InfoLogger.Printf("Skipping inactive bot: %s", config.ID)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate that ID is present
|
if err := validateConfig(&config, ids, tokens); err != nil {
|
||||||
if config.ID == "" {
|
InfoLogger.Printf("Config validation failed for %s: %v", validPath, err)
|
||||||
return nil, fmt.Errorf("config %s is missing 'id' field", configPath)
|
continue
|
||||||
}
|
|
||||||
|
|
||||||
// Check for unique ID
|
|
||||||
if _, exists := ids[config.ID]; exists {
|
|
||||||
return nil, fmt.Errorf("duplicate bot id '%s' found in %s", config.ID, configPath)
|
|
||||||
}
|
|
||||||
ids[config.ID] = true
|
|
||||||
|
|
||||||
// Validate Telegram Token
|
|
||||||
if config.TelegramToken == "" {
|
|
||||||
return nil, fmt.Errorf("config %s is missing 'telegram_token' field", configPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for unique Telegram Token
|
|
||||||
if _, exists := tokens[config.TelegramToken]; exists {
|
|
||||||
return nil, fmt.Errorf("duplicate telegram_token '%s' found in %s", config.TelegramToken, configPath)
|
|
||||||
}
|
|
||||||
tokens[config.TelegramToken] = true
|
|
||||||
|
|
||||||
// Validate Model
|
|
||||||
if config.Model == "" {
|
|
||||||
return nil, fmt.Errorf("config %s is missing 'model' field", configPath)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
config.ConfigFilePath = validPath
|
||||||
configs = append(configs, config)
|
configs = append(configs, config)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(configs) == 0 {
|
||||||
|
return nil, fmt.Errorf("no valid configs found")
|
||||||
|
}
|
||||||
|
|
||||||
return configs, nil
|
return configs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateConfig(config *BotConfig, ids, tokens map[string]bool) error {
|
||||||
|
if config.ID == "" {
|
||||||
|
return fmt.Errorf("missing 'id' field")
|
||||||
|
}
|
||||||
|
if _, exists := ids[config.ID]; exists {
|
||||||
|
return fmt.Errorf("duplicate bot id '%s'", config.ID)
|
||||||
|
}
|
||||||
|
ids[config.ID] = true
|
||||||
|
|
||||||
|
if config.TelegramToken == "" {
|
||||||
|
return fmt.Errorf("missing 'telegram_token' field")
|
||||||
|
}
|
||||||
|
if _, exists := tokens[config.TelegramToken]; exists {
|
||||||
|
return fmt.Errorf("duplicate telegram_token")
|
||||||
|
}
|
||||||
|
tokens[config.TelegramToken] = true
|
||||||
|
|
||||||
|
if config.Model == "" {
|
||||||
|
return fmt.Errorf("missing 'model' field")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.MessagePerHour <= 0 {
|
||||||
|
return fmt.Errorf("'messages_per_hour' must be greater than 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.MessagePerDay <= 0 {
|
||||||
|
return fmt.Errorf("'messages_per_day' must be greater than 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func loadConfig(filename string) (BotConfig, error) {
|
func loadConfig(filename string) (BotConfig, error) {
|
||||||
var config BotConfig
|
var config BotConfig
|
||||||
file, err := os.Open(filename)
|
// Use filepath.Clean before opening the file
|
||||||
|
file, err := os.OpenFile(filepath.Clean(filename), os.O_RDONLY, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return config, fmt.Errorf("failed to open config file %s: %w", filename, err)
|
return config, fmt.Errorf("failed to open config file %s: %w", filename, err)
|
||||||
}
|
}
|
||||||
defer file.Close()
|
defer func() {
|
||||||
|
if err := file.Close(); err != nil {
|
||||||
|
InfoLogger.Printf("Failed to close config file: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
decoder := json.NewDecoder(file)
|
decoder := json.NewDecoder(file)
|
||||||
if err := decoder.Decode(&config); err != nil {
|
if err := decoder.Decode(&config); err != nil {
|
||||||
@@ -115,20 +174,63 @@ func loadConfig(filename string) (BotConfig, error) {
|
|||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *BotConfig) Reload(filename string) error {
|
// Reload reloads the BotConfig from the specified filename within the given config directory
|
||||||
file, err := os.Open(filename)
|
func (c *BotConfig) Reload(configDir, filename string) error {
|
||||||
|
// Validate the config path
|
||||||
|
validPath, err := validateConfigPath(configDir, filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to open config file %s: %w", filename, err)
|
return fmt.Errorf("invalid config path: %w", err)
|
||||||
}
|
}
|
||||||
defer file.Close()
|
|
||||||
|
// Use filepath.Clean before opening the file
|
||||||
|
cleanPath := filepath.Clean(validPath)
|
||||||
|
file, err := os.OpenFile(cleanPath, os.O_RDONLY, 0)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open config file %s: %w", cleanPath, err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := file.Close(); err != nil {
|
||||||
|
InfoLogger.Printf("Failed to close config file: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
decoder := json.NewDecoder(file)
|
decoder := json.NewDecoder(file)
|
||||||
if err := decoder.Decode(c); err != nil {
|
if err := decoder.Decode(c); err != nil {
|
||||||
return fmt.Errorf("failed to decode JSON from %s: %w", filename, err)
|
return fmt.Errorf("failed to decode JSON from %s: %w", validPath, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the Model is correctly casted
|
|
||||||
c.Model = anthropic.Model(c.Model)
|
c.Model = anthropic.Model(c.Model)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PersistModel updates the model field in memory and writes it back to the config file on disk.
|
||||||
|
// Only the "model" key is changed; all other fields are preserved verbatim.
|
||||||
|
func (c *BotConfig) PersistModel(newModel string) error {
|
||||||
|
if c.ConfigFilePath == "" {
|
||||||
|
return fmt.Errorf("config file path not set; cannot persist model")
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(c.ConfigFilePath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read config for update: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var raw map[string]any
|
||||||
|
if err := json.Unmarshal(data, &raw); err != nil {
|
||||||
|
return fmt.Errorf("failed to parse config for update: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
raw["model"] = newModel
|
||||||
|
|
||||||
|
updated, err := json.MarshalIndent(raw, "", "\t")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to re-encode config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(c.ConfigFilePath, updated, 0600); err != nil {
|
||||||
|
return fmt.Errorf("failed to write config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Model = anthropic.Model(newModel)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
Executable → Regular
+5
-3
@@ -8,12 +8,14 @@
|
|||||||
"messages_per_hour": 20,
|
"messages_per_hour": 20,
|
||||||
"messages_per_day": 100,
|
"messages_per_day": 100,
|
||||||
"temp_ban_duration": "24h",
|
"temp_ban_duration": "24h",
|
||||||
"model": "claude-3-5-sonnet-20240620",
|
"model": "claude-3-5-haiku-latest",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"debug_screening": false,
|
||||||
"system_prompts": {
|
"system_prompts": {
|
||||||
"default": "You are a helpful assistant.",
|
"default": "You are a helpful assistant.",
|
||||||
"custom_instructions": "Please follow these guidelines:\n- Your name is Atom.\n- If a user asks about buying apples, inform them that we don't sell apples.\n- When asked for a joke, tell a clean, family-friendly joke about programming or technology.\n- If someone inquires about our services, explain that we offer AI-powered chatbot solutions.\n- For any questions about pricing, direct users to contact our sales team at sales@example.com.\n- If asked about your capabilities, be honest about what you can and cannot do.\nAlways maintain a friendly and professional tone.",
|
"custom_instructions": "You are texting through a limited Telegram interface with 15-word maximum. Write like texting a friend - use shorthand, skip grammar, use slang/abbreviations. System cuts off anything longer than 15 words.\n\n- Your name is Atom.\n- The user you're talking to has username '{username}' and display name '{firstname} {lastname}'.\n- User's language preference: '{language}'\n- User is a {premium_status}\n- It's currently {time_context} in your timezone. Use appropriate time-based greetings and address the user by name.\n- If a user asks about buying apples, inform them that we don't sell apples.\n- When asked for a joke, tell a clean, family-friendly joke about programming or technology.\n- If someone inquires about our services, explain that we offer AI-powered chatbot solutions.\n- For any questions about pricing, direct users to contact our sales team at sales@example.com.\n- If asked about your capabilities, be honest about what you can and cannot do.\nAlways maintain a friendly and professional tone.",
|
||||||
"continue_conversation": "Continuing our conversation. Remember previous context if relevant.",
|
"continue_conversation": "Continuing our conversation. Remember previous context if relevant.",
|
||||||
"avoid_sensitive": "Avoid discussing sensitive topics or providing harmful information.",
|
"avoid_sensitive": "Avoid discussing sensitive topics or providing harmful information.",
|
||||||
"respond_with_emojis": "Since the user sent only emojis, respond using emojis only."
|
"respond_with_emojis": "Since the user sent only emojis, respond using emojis only."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+818
@@ -0,0 +1,818 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/liushuangls/go-anthropic/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Set up loggers
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
initLoggers()
|
||||||
|
os.Exit(m.Run())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBotConfig_UnmarshalJSON tests the custom unmarshalling of BotConfig
|
||||||
|
func TestBotConfig_UnmarshalJSON(t *testing.T) { //NOSONAR go:S100 -- underscore separation is idiomatic in Go test names
|
||||||
|
jsonData := `{
|
||||||
|
"id": "bot123",
|
||||||
|
"telegram_token": "token123",
|
||||||
|
"memory_size": 1024,
|
||||||
|
"messages_per_hour": 10,
|
||||||
|
"messages_per_day": 100,
|
||||||
|
"temp_ban_duration": "1h",
|
||||||
|
"model": "claude-v1",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"system_prompts": {"welcome": "Hello!"},
|
||||||
|
"active": true,
|
||||||
|
"owner_telegram_id": 123456789,
|
||||||
|
"anthropic_api_key": "api_key_123"
|
||||||
|
}`
|
||||||
|
|
||||||
|
var config BotConfig
|
||||||
|
if err := json.Unmarshal([]byte(jsonData), &config); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedModel := anthropic.Model("claude-v1")
|
||||||
|
if config.Model != expectedModel {
|
||||||
|
t.Errorf("Expected model %s, got %s", expectedModel, config.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedID := "bot123"
|
||||||
|
if config.ID != expectedID {
|
||||||
|
t.Errorf("Expected ID %s, got %s", expectedID, config.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add more field checks as necessary
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidateConfigPath tests the validateConfigPath function
|
||||||
|
func TestValidateConfigPath(t *testing.T) {
|
||||||
|
execDir, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get current directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
configDir string
|
||||||
|
filename string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid Path",
|
||||||
|
configDir: execDir,
|
||||||
|
filename: "config.json",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid Extension",
|
||||||
|
configDir: execDir,
|
||||||
|
filename: "config.yaml",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Path Traversal",
|
||||||
|
configDir: execDir,
|
||||||
|
filename: "../config.json",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Absolute Path Outside",
|
||||||
|
configDir: execDir,
|
||||||
|
filename: "/etc/passwd",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Nested Valid Path",
|
||||||
|
configDir: execDir,
|
||||||
|
filename: "subdir/config.json",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a subdirectory for testing
|
||||||
|
subDir := filepath.Join(execDir, "subdir")
|
||||||
|
if err := os.MkdirAll(subDir, 0755); err != nil {
|
||||||
|
t.Fatalf("Failed to create subdir: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := os.RemoveAll(subDir); err != nil {
|
||||||
|
t.Errorf("Failed to remove test subdirectory: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
configDir := tt.configDir
|
||||||
|
filename := tt.filename
|
||||||
|
if tt.name == "Nested Valid Path" {
|
||||||
|
configDir = subDir
|
||||||
|
}
|
||||||
|
_, err := validateConfigPath(configDir, filename)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("validateConfigPath() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLoadConfig tests the loadConfig function
|
||||||
|
func TestLoadConfig(t *testing.T) {
|
||||||
|
// Create a temporary directory
|
||||||
|
tempDir, err := os.MkdirTemp("", "config_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp dir: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := os.RemoveAll(tempDir); err != nil {
|
||||||
|
t.Errorf("Failed to remove temp directory: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Valid config JSON
|
||||||
|
validConfig := `{
|
||||||
|
"id": "bot123",
|
||||||
|
"telegram_token": "token123",
|
||||||
|
"memory_size": 1024,
|
||||||
|
"messages_per_hour": 10,
|
||||||
|
"messages_per_day": 100,
|
||||||
|
"temp_ban_duration": "1h",
|
||||||
|
"model": "claude-v1",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"system_prompts": {"welcome": "Hello!"},
|
||||||
|
"active": true,
|
||||||
|
"owner_telegram_id": 123456789,
|
||||||
|
"anthropic_api_key": "api_key_123"
|
||||||
|
}`
|
||||||
|
|
||||||
|
// Invalid config JSON
|
||||||
|
invalidConfig := `{
|
||||||
|
"id": "bot123",
|
||||||
|
"telegram_token": "token123",
|
||||||
|
"memory_size": "should be int",
|
||||||
|
"model": "claude-v1"
|
||||||
|
}`
|
||||||
|
|
||||||
|
// Write valid config file
|
||||||
|
validPath := filepath.Join(tempDir, "valid_config.json")
|
||||||
|
if err := os.WriteFile(validPath, []byte(validConfig), 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to write valid config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write invalid config file
|
||||||
|
invalidPath := filepath.Join(tempDir, "invalid_config.json")
|
||||||
|
if err := os.WriteFile(invalidPath, []byte(invalidConfig), 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to write invalid config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
filename string
|
||||||
|
wantErr bool
|
||||||
|
expectID string
|
||||||
|
expectErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Load Valid Config",
|
||||||
|
filename: validPath,
|
||||||
|
wantErr: false,
|
||||||
|
expectID: "bot123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Load Invalid Config",
|
||||||
|
filename: invalidPath,
|
||||||
|
wantErr: true,
|
||||||
|
expectErr: "failed to decode JSON",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non-existent File",
|
||||||
|
filename: filepath.Join(tempDir, "nonexistent.json"),
|
||||||
|
wantErr: true,
|
||||||
|
expectErr: "failed to open config file",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
config, err := loadConfig(tt.filename)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("loadConfig() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.wantErr && err != nil && tt.expectErr != "" {
|
||||||
|
if !contains(err.Error(), tt.expectErr) {
|
||||||
|
t.Errorf("loadConfig() error = %v, expected to contain %v", err, tt.expectErr)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if config.ID != tt.expectID {
|
||||||
|
t.Errorf("Expected ID %s, got %s", tt.expectID, config.ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidateConfig tests the validateConfig function
|
||||||
|
func TestValidateConfig(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config BotConfig
|
||||||
|
ids map[string]bool
|
||||||
|
tokens map[string]bool
|
||||||
|
wantErr bool
|
||||||
|
expectedError string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid Config",
|
||||||
|
config: BotConfig{
|
||||||
|
ID: "bot123",
|
||||||
|
TelegramToken: "token123",
|
||||||
|
Model: "claude-v1",
|
||||||
|
Active: true,
|
||||||
|
OwnerTelegramID: 123456789,
|
||||||
|
MessagePerHour: 10,
|
||||||
|
MessagePerDay: 100,
|
||||||
|
},
|
||||||
|
ids: make(map[string]bool),
|
||||||
|
tokens: make(map[string]bool),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Missing ID",
|
||||||
|
config: BotConfig{
|
||||||
|
TelegramToken: "token123",
|
||||||
|
Model: "claude-v1",
|
||||||
|
Active: true,
|
||||||
|
},
|
||||||
|
ids: make(map[string]bool),
|
||||||
|
tokens: make(map[string]bool),
|
||||||
|
wantErr: true,
|
||||||
|
expectedError: "missing 'id' field",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Duplicate ID",
|
||||||
|
config: BotConfig{
|
||||||
|
ID: "bot123",
|
||||||
|
TelegramToken: "token123",
|
||||||
|
Model: "claude-v1",
|
||||||
|
Active: true,
|
||||||
|
},
|
||||||
|
ids: map[string]bool{"bot123": true},
|
||||||
|
tokens: make(map[string]bool),
|
||||||
|
wantErr: true,
|
||||||
|
expectedError: "duplicate bot id",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Missing Telegram Token",
|
||||||
|
config: BotConfig{
|
||||||
|
ID: "bot123",
|
||||||
|
Model: "claude-v1",
|
||||||
|
Active: true,
|
||||||
|
},
|
||||||
|
ids: make(map[string]bool),
|
||||||
|
tokens: make(map[string]bool),
|
||||||
|
wantErr: true,
|
||||||
|
expectedError: "missing 'telegram_token' field",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Duplicate Telegram Token",
|
||||||
|
config: BotConfig{
|
||||||
|
ID: "bot123",
|
||||||
|
TelegramToken: "token123",
|
||||||
|
Model: "claude-v1",
|
||||||
|
Active: true,
|
||||||
|
},
|
||||||
|
ids: make(map[string]bool),
|
||||||
|
tokens: map[string]bool{"token123": true},
|
||||||
|
wantErr: true,
|
||||||
|
expectedError: "duplicate telegram_token",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Missing Model",
|
||||||
|
config: BotConfig{
|
||||||
|
ID: "bot123",
|
||||||
|
TelegramToken: "token123",
|
||||||
|
Active: true,
|
||||||
|
},
|
||||||
|
ids: make(map[string]bool),
|
||||||
|
tokens: make(map[string]bool),
|
||||||
|
wantErr: true,
|
||||||
|
expectedError: "missing 'model' field",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Zero MessagePerHour",
|
||||||
|
config: BotConfig{
|
||||||
|
ID: "bot123",
|
||||||
|
TelegramToken: "token123",
|
||||||
|
Model: "claude-v1",
|
||||||
|
MessagePerHour: 0,
|
||||||
|
MessagePerDay: 100,
|
||||||
|
},
|
||||||
|
ids: make(map[string]bool),
|
||||||
|
tokens: make(map[string]bool),
|
||||||
|
wantErr: true,
|
||||||
|
expectedError: "'messages_per_hour' must be greater than 0",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Zero MessagePerDay",
|
||||||
|
config: BotConfig{
|
||||||
|
ID: "bot123",
|
||||||
|
TelegramToken: "token123",
|
||||||
|
Model: "claude-v1",
|
||||||
|
MessagePerHour: 10,
|
||||||
|
MessagePerDay: 0,
|
||||||
|
},
|
||||||
|
ids: make(map[string]bool),
|
||||||
|
tokens: make(map[string]bool),
|
||||||
|
wantErr: true,
|
||||||
|
expectedError: "'messages_per_day' must be greater than 0",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := validateConfig(&tt.config, tt.ids, tt.tokens)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("validateConfig() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.wantErr && err != nil && tt.expectedError != "" {
|
||||||
|
if !contains(err.Error(), tt.expectedError) {
|
||||||
|
t.Errorf("validateConfig() error = %v, expected to contain %v", err, tt.expectedError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLoadAllConfigs tests the loadAllConfigs function
|
||||||
|
func TestLoadAllConfigs(t *testing.T) {
|
||||||
|
// Create a temporary directory
|
||||||
|
tempDir, err := os.MkdirTemp("", "load_all_configs_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp dir: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := os.RemoveAll(tempDir); err != nil {
|
||||||
|
t.Errorf("Failed to remove temp directory: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupFiles map[string]string // filename -> content
|
||||||
|
expectConfigs int
|
||||||
|
expectError bool
|
||||||
|
expectErrorMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Load All Valid Configs",
|
||||||
|
setupFiles: map[string]string{
|
||||||
|
"valid_config.json": `{
|
||||||
|
"id": "bot123",
|
||||||
|
"telegram_token": "token123",
|
||||||
|
"memory_size": 1024,
|
||||||
|
"messages_per_hour": 10,
|
||||||
|
"messages_per_day": 100,
|
||||||
|
"temp_ban_duration": "1h",
|
||||||
|
"model": "claude-v1",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"system_prompts": {"welcome": "Hello!"},
|
||||||
|
"active": true,
|
||||||
|
"owner_telegram_id": 123456789,
|
||||||
|
"anthropic_api_key": "api_key_123"
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
expectConfigs: 1,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Skip Inactive Config",
|
||||||
|
setupFiles: map[string]string{
|
||||||
|
"valid_config.json": `{
|
||||||
|
"id": "bot123",
|
||||||
|
"telegram_token": "token123",
|
||||||
|
"memory_size": 1024,
|
||||||
|
"messages_per_hour": 10,
|
||||||
|
"messages_per_day": 100,
|
||||||
|
"temp_ban_duration": "1h",
|
||||||
|
"model": "claude-v1",
|
||||||
|
"system_prompts": {"welcome": "Hello!"},
|
||||||
|
"active": true,
|
||||||
|
"owner_telegram_id": 123456789,
|
||||||
|
"anthropic_api_key": "api_key_123"
|
||||||
|
}`,
|
||||||
|
"inactive_config.json": `{
|
||||||
|
"id": "bot124",
|
||||||
|
"telegram_token": "token124",
|
||||||
|
"memory_size": 512,
|
||||||
|
"messages_per_hour": 5,
|
||||||
|
"messages_per_day": 50,
|
||||||
|
"temp_ban_duration": "30m",
|
||||||
|
"model": "claude-v2",
|
||||||
|
"temperature": 0.5,
|
||||||
|
"system_prompts": {"welcome": "Hi!"},
|
||||||
|
"active": false,
|
||||||
|
"owner_telegram_id": 987654321,
|
||||||
|
"anthropic_api_key": "api_key_124"
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
expectConfigs: 1,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Duplicate Bot ID",
|
||||||
|
setupFiles: map[string]string{
|
||||||
|
"valid_config.json": `{
|
||||||
|
"id": "bot123",
|
||||||
|
"telegram_token": "token123",
|
||||||
|
"memory_size": 1024,
|
||||||
|
"messages_per_hour": 10,
|
||||||
|
"messages_per_day": 100,
|
||||||
|
"temp_ban_duration": "1h",
|
||||||
|
"model": "claude-v1",
|
||||||
|
"system_prompts": {"welcome": "Hello!"},
|
||||||
|
"active": true,
|
||||||
|
"owner_telegram_id": 123456789,
|
||||||
|
"anthropic_api_key": "api_key_123"
|
||||||
|
}`,
|
||||||
|
"duplicate_id_config.json": `{
|
||||||
|
"id": "bot123",
|
||||||
|
"telegram_token": "token125",
|
||||||
|
"memory_size": 256,
|
||||||
|
"messages_per_hour": 2,
|
||||||
|
"messages_per_day": 20,
|
||||||
|
"temp_ban_duration": "15m",
|
||||||
|
"model": "claude-v3",
|
||||||
|
"temperature": 0.3,
|
||||||
|
"system_prompts": {"welcome": "Hey!"},
|
||||||
|
"active": true,
|
||||||
|
"owner_telegram_id": 1122334455,
|
||||||
|
"anthropic_api_key": "api_key_125"
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
expectConfigs: 1,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Duplicate Telegram Token",
|
||||||
|
setupFiles: map[string]string{
|
||||||
|
"valid_config.json": `{
|
||||||
|
"id": "bot123",
|
||||||
|
"telegram_token": "token123",
|
||||||
|
"memory_size": 1024,
|
||||||
|
"messages_per_hour": 10,
|
||||||
|
"messages_per_day": 100,
|
||||||
|
"temp_ban_duration": "1h",
|
||||||
|
"model": "claude-v1",
|
||||||
|
"system_prompts": {"welcome": "Hello!"},
|
||||||
|
"active": true,
|
||||||
|
"owner_telegram_id": 123456789,
|
||||||
|
"anthropic_api_key": "api_key_123"
|
||||||
|
}`,
|
||||||
|
"duplicate_token_config.json": `{
|
||||||
|
"id": "bot126",
|
||||||
|
"telegram_token": "token123",
|
||||||
|
"memory_size": 128,
|
||||||
|
"messages_per_hour": 1,
|
||||||
|
"messages_per_day": 10,
|
||||||
|
"temp_ban_duration": "5m",
|
||||||
|
"model": "claude-v4",
|
||||||
|
"temperature": 0.2,
|
||||||
|
"system_prompts": {"welcome": "Greetings!"},
|
||||||
|
"active": true,
|
||||||
|
"owner_telegram_id": 5566778899,
|
||||||
|
"anthropic_api_key": "api_key_126"
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
expectConfigs: 1,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid Config",
|
||||||
|
setupFiles: map[string]string{
|
||||||
|
"valid_config.json": `{
|
||||||
|
"id": "bot123",
|
||||||
|
"telegram_token": "token123",
|
||||||
|
"memory_size": 1024,
|
||||||
|
"messages_per_hour": 10,
|
||||||
|
"messages_per_day": 100,
|
||||||
|
"temp_ban_duration": "1h",
|
||||||
|
"model": "claude-v1",
|
||||||
|
"system_prompts": {"welcome": "Hello!"},
|
||||||
|
"active": true,
|
||||||
|
"owner_telegram_id": 123456789,
|
||||||
|
"anthropic_api_key": "api_key_123"
|
||||||
|
}`,
|
||||||
|
"invalid_config.json": `{
|
||||||
|
"id": "bot127",
|
||||||
|
"telegram_token": "token127",
|
||||||
|
"model": "",
|
||||||
|
"active": true
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
expectConfigs: 1,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Clear the tempDir before each test
|
||||||
|
if err := os.RemoveAll(tempDir); err != nil {
|
||||||
|
t.Fatalf("Failed to remove temp dir: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(tempDir, 0755); err != nil {
|
||||||
|
t.Fatalf("Failed to create temp dir: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the test files directly
|
||||||
|
for filename, content := range tt.setupFiles {
|
||||||
|
err := os.WriteFile(filepath.Join(tempDir, filename), []byte(content), 0644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to write file %s: %v", filename, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
configs, err := loadAllConfigs(tempDir)
|
||||||
|
if (err != nil) != tt.expectError {
|
||||||
|
t.Errorf("loadAllConfigs() error = %v, wantErr %v", err, tt.expectError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(configs) != tt.expectConfigs {
|
||||||
|
t.Errorf("Expected %d configs, got %d", tt.expectConfigs, len(configs))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBotConfig_Reload tests the Reload method of BotConfig
|
||||||
|
func TestBotConfig_Reload(t *testing.T) { //NOSONAR go:S100 -- underscore separation is idiomatic in Go test names
|
||||||
|
// Create a temporary directory
|
||||||
|
tempDir, err := os.MkdirTemp("", "reload_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp dir: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := os.RemoveAll(tempDir); err != nil {
|
||||||
|
t.Errorf("Failed to remove temp directory: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Create initial config file
|
||||||
|
config1 := `{
|
||||||
|
"id": "bot123",
|
||||||
|
"telegram_token": "token123",
|
||||||
|
"memory_size": 1024,
|
||||||
|
"messages_per_hour": 10,
|
||||||
|
"messages_per_day": 100,
|
||||||
|
"temp_ban_duration": "1h",
|
||||||
|
"model": "claude-v1",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"system_prompts": {"welcome": "Hello!"},
|
||||||
|
"active": true,
|
||||||
|
"owner_telegram_id": 123456789,
|
||||||
|
"anthropic_api_key": "api_key_123"
|
||||||
|
}`
|
||||||
|
configPath := filepath.Join(tempDir, "config.json")
|
||||||
|
if err := os.WriteFile(configPath, []byte(config1), 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to write initial config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize BotConfig
|
||||||
|
var config BotConfig
|
||||||
|
if err := config.Reload(tempDir, "config.json"); err != nil {
|
||||||
|
t.Fatalf("Failed to reload config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify initial load
|
||||||
|
if config.ID != "bot123" {
|
||||||
|
t.Errorf("Expected ID 'bot123', got '%s'", config.ID)
|
||||||
|
}
|
||||||
|
if config.Model != "claude-v1" {
|
||||||
|
t.Errorf("Expected Model 'claude-v1', got '%s'", config.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update config file
|
||||||
|
config2 := `{
|
||||||
|
"id": "bot123",
|
||||||
|
"telegram_token": "token123_updated",
|
||||||
|
"memory_size": 2048,
|
||||||
|
"messages_per_hour": 20,
|
||||||
|
"messages_per_day": 200,
|
||||||
|
"temp_ban_duration": "2h",
|
||||||
|
"model": "claude-v2",
|
||||||
|
"temperature": 0.3,
|
||||||
|
"system_prompts": {"welcome": "Hi there!"},
|
||||||
|
"active": true,
|
||||||
|
"owner_telegram_id": 987654321,
|
||||||
|
"anthropic_api_key": "api_key_456"
|
||||||
|
}`
|
||||||
|
if err := os.WriteFile(configPath, []byte(config2), 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to write updated config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload config
|
||||||
|
if err := config.Reload(tempDir, "config.json"); err != nil {
|
||||||
|
t.Fatalf("Failed to reload updated config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify updated config
|
||||||
|
if config.TelegramToken != "token123_updated" {
|
||||||
|
t.Errorf("Expected TelegramToken 'token123_updated', got '%s'", config.TelegramToken)
|
||||||
|
}
|
||||||
|
if config.MemorySize != 2048 {
|
||||||
|
t.Errorf("Expected MemorySize 2048, got %d", config.MemorySize)
|
||||||
|
}
|
||||||
|
if config.Model != "claude-v2" {
|
||||||
|
t.Errorf("Expected Model 'claude-v2', got '%s'", config.Model)
|
||||||
|
}
|
||||||
|
if config.OwnerTelegramID != 987654321 {
|
||||||
|
t.Errorf("Expected OwnerTelegramID 987654321, got %d", config.OwnerTelegramID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBotConfig_UnmarshalJSON_Invalid tests unmarshalling with invalid model
|
||||||
|
func TestBotConfig_UnmarshalJSON_Invalid(t *testing.T) { //NOSONAR go:S100 -- underscore separation is idiomatic in Go test names
|
||||||
|
jsonData := `{
|
||||||
|
"id": "bot123",
|
||||||
|
"telegram_token": "token123",
|
||||||
|
"memory_size": 1024,
|
||||||
|
"messages_per_hour": 10,
|
||||||
|
"messages_per_day": 100,
|
||||||
|
"temp_ban_duration": "1h",
|
||||||
|
"model": "",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"system_prompts": {"welcome": "Hello!"},
|
||||||
|
"active": true,
|
||||||
|
"owner_telegram_id": 123456789,
|
||||||
|
"anthropic_api_key": "api_key_123"
|
||||||
|
}`
|
||||||
|
|
||||||
|
var config BotConfig
|
||||||
|
err := json.Unmarshal([]byte(jsonData), &config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Model != "" {
|
||||||
|
t.Errorf("Expected empty model, got %s", config.Model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to check substring
|
||||||
|
func contains(s, substr string) bool {
|
||||||
|
return strings.Contains(s, substr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTemperatureConfig tests that the temperature value is correctly loaded
|
||||||
|
func TestTemperatureConfig(t *testing.T) {
|
||||||
|
// Create a temporary directory
|
||||||
|
tempDir, err := os.MkdirTemp("", "temperature_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp dir: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := os.RemoveAll(tempDir); err != nil {
|
||||||
|
t.Errorf("Failed to remove temp directory: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Create config with temperature
|
||||||
|
configWithTemp := `{
|
||||||
|
"id": "bot123",
|
||||||
|
"telegram_token": "token123",
|
||||||
|
"memory_size": 1024,
|
||||||
|
"messages_per_hour": 10,
|
||||||
|
"messages_per_day": 100,
|
||||||
|
"temp_ban_duration": "1h",
|
||||||
|
"model": "claude-v1",
|
||||||
|
"temperature": 0.42,
|
||||||
|
"system_prompts": {"welcome": "Hello!"},
|
||||||
|
"active": true,
|
||||||
|
"owner_telegram_id": 123456789,
|
||||||
|
"anthropic_api_key": "api_key_123"
|
||||||
|
}`
|
||||||
|
|
||||||
|
// Create config without temperature
|
||||||
|
configWithoutTemp := `{
|
||||||
|
"id": "bot124",
|
||||||
|
"telegram_token": "token124",
|
||||||
|
"memory_size": 1024,
|
||||||
|
"messages_per_hour": 10,
|
||||||
|
"messages_per_day": 100,
|
||||||
|
"temp_ban_duration": "1h",
|
||||||
|
"model": "claude-v1",
|
||||||
|
"system_prompts": {"welcome": "Hello!"},
|
||||||
|
"active": true,
|
||||||
|
"owner_telegram_id": 123456789,
|
||||||
|
"anthropic_api_key": "api_key_123"
|
||||||
|
}`
|
||||||
|
|
||||||
|
// Write config files
|
||||||
|
withTempPath := filepath.Join(tempDir, "with_temp.json")
|
||||||
|
if err := os.WriteFile(withTempPath, []byte(configWithTemp), 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to write config with temperature: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
withoutTempPath := filepath.Join(tempDir, "without_temp.json")
|
||||||
|
if err := os.WriteFile(withoutTempPath, []byte(configWithoutTemp), 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to write config without temperature: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test loading config with temperature
|
||||||
|
configWithTempObj, err := loadConfig(withTempPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load config with temperature: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify temperature is set correctly
|
||||||
|
if configWithTempObj.Temperature == nil {
|
||||||
|
t.Errorf("Expected Temperature to be set, got nil")
|
||||||
|
} else if *configWithTempObj.Temperature != 0.42 {
|
||||||
|
t.Errorf("Expected Temperature 0.42, got %f", *configWithTempObj.Temperature)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test loading config without temperature
|
||||||
|
configWithoutTempObj, err := loadConfig(withoutTempPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load config without temperature: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify temperature is nil when not specified
|
||||||
|
if configWithoutTempObj.Temperature != nil {
|
||||||
|
t.Errorf("Expected Temperature to be nil, got %f", *configWithoutTempObj.Temperature)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Additional tests can be added here to cover more scenarios
|
||||||
|
|
||||||
|
// TestBotConfig_PersistModel verifies that PersistModel updates the model both in memory
|
||||||
|
// and on disk while leaving all other config fields unchanged.
|
||||||
|
func TestBotConfig_PersistModel(t *testing.T) { //NOSONAR go:S100 -- underscore separation is idiomatic in Go test names
|
||||||
|
tempDir, err := os.MkdirTemp("", "persist_model_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp dir: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := os.RemoveAll(tempDir); err != nil {
|
||||||
|
t.Errorf("Failed to remove temp directory: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
initialJSON := `{
|
||||||
|
"id": "bot1",
|
||||||
|
"telegram_token": "token1",
|
||||||
|
"model": "claude-v1",
|
||||||
|
"messages_per_hour": 10,
|
||||||
|
"messages_per_day": 100
|
||||||
|
}`
|
||||||
|
configPath := filepath.Join(tempDir, "config.json")
|
||||||
|
if err := os.WriteFile(configPath, []byte(initialJSON), 0600); err != nil {
|
||||||
|
t.Fatalf("Failed to write config file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config := BotConfig{
|
||||||
|
ID: "bot1",
|
||||||
|
Model: "claude-v1",
|
||||||
|
ConfigFilePath: configPath,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Successful model update
|
||||||
|
if err := config.PersistModel("claude-sonnet-4-6"); err != nil {
|
||||||
|
t.Fatalf("PersistModel() unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// In-memory model must be updated immediately
|
||||||
|
if string(config.Model) != "claude-sonnet-4-6" {
|
||||||
|
t.Errorf("in-memory model: got %q, want %q", config.Model, "claude-sonnet-4-6")
|
||||||
|
}
|
||||||
|
|
||||||
|
// On-disk model must be updated; other fields must be preserved
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read updated config file: %v", err)
|
||||||
|
}
|
||||||
|
var raw map[string]any
|
||||||
|
if err := json.Unmarshal(data, &raw); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal updated config: %v", err)
|
||||||
|
}
|
||||||
|
if raw["model"] != "claude-sonnet-4-6" {
|
||||||
|
t.Errorf("on-disk model: got %v, want %q", raw["model"], "claude-sonnet-4-6")
|
||||||
|
}
|
||||||
|
if raw["id"] != "bot1" {
|
||||||
|
t.Errorf("on-disk id should be preserved: got %v, want %q", raw["id"], "bot1")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error case: empty ConfigFilePath must return an error
|
||||||
|
noPath := BotConfig{Model: "claude-v1"}
|
||||||
|
if err := noPath.PersistModel("claude-sonnet-4-6"); err == nil {
|
||||||
|
t.Error("PersistModel with empty ConfigFilePath: expected error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
Executable → Regular
+63
-2
@@ -3,6 +3,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/driver/sqlite"
|
"gorm.io/driver/sqlite"
|
||||||
@@ -11,6 +12,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func initDB() (*gorm.DB, error) {
|
func initDB() (*gorm.DB, error) {
|
||||||
|
if err := os.MkdirAll("data", 0750); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create data directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
newLogger := logger.New(
|
newLogger := logger.New(
|
||||||
log.New(log.Writer(), "\r\n", log.LstdFlags),
|
log.New(log.Writer(), "\r\n", log.LstdFlags),
|
||||||
logger.Config{
|
logger.Config{
|
||||||
@@ -20,15 +25,21 @@ func initDB() (*gorm.DB, error) {
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
db, err := gorm.Open(sqlite.Open("bot.db"), &gorm.Config{
|
db, err := gorm.Open(sqlite.Open("data/bot.db?_journal_mode=WAL&_busy_timeout=5000&_foreign_keys=on"), &gorm.Config{
|
||||||
Logger: newLogger,
|
Logger: newLogger,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get underlying sql.DB: %w", err)
|
||||||
|
}
|
||||||
|
sqlDB.SetMaxOpenConns(1)
|
||||||
|
|
||||||
// AutoMigrate the models
|
// AutoMigrate the models
|
||||||
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{})
|
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{}, &Scope{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to migrate database schema: %w", err)
|
return nil, fmt.Errorf("failed to migrate database schema: %w", err)
|
||||||
}
|
}
|
||||||
@@ -48,9 +59,59 @@ func initDB() (*gorm.DB, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := createDefaultScopes(db); err != nil {
|
||||||
|
return nil, fmt.Errorf("createDefaultScopes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func createDefaultScopes(db *gorm.DB) error {
|
||||||
|
all := []string{
|
||||||
|
ScopeStatsViewOwn, ScopeStatsViewAny,
|
||||||
|
ScopeHistoryClearOwn, ScopeHistoryClearAny,
|
||||||
|
ScopeHistoryClearHardOwn, ScopeHistoryClearHardAny,
|
||||||
|
ScopeModelSet, ScopeUserPromote,
|
||||||
|
}
|
||||||
|
for _, name := range all {
|
||||||
|
if err := db.FirstOrCreate(&Scope{}, Scope{Name: name}).Error; err != nil {
|
||||||
|
return fmt.Errorf("failed to create scope %s: %w", name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
userScopes := []string{
|
||||||
|
ScopeStatsViewOwn,
|
||||||
|
ScopeHistoryClearOwn,
|
||||||
|
ScopeHistoryClearHardOwn,
|
||||||
|
}
|
||||||
|
elevatedScopes := []string{
|
||||||
|
ScopeStatsViewOwn, ScopeStatsViewAny,
|
||||||
|
ScopeHistoryClearOwn, ScopeHistoryClearAny,
|
||||||
|
ScopeHistoryClearHardOwn, ScopeHistoryClearHardAny,
|
||||||
|
ScopeModelSet, ScopeUserPromote,
|
||||||
|
}
|
||||||
|
assignments := map[string][]string{
|
||||||
|
"user": userScopes,
|
||||||
|
"admin": elevatedScopes,
|
||||||
|
// owner gets the same scopes as admin; owner uniqueness is enforced by the IsOwner flag
|
||||||
|
"owner": elevatedScopes,
|
||||||
|
}
|
||||||
|
for roleName, scopes := range assignments {
|
||||||
|
var role Role
|
||||||
|
if err := db.Where("name = ?", roleName).First(&role).Error; err != nil {
|
||||||
|
return fmt.Errorf("role %s not found: %w", roleName, err)
|
||||||
|
}
|
||||||
|
var scopeModels []Scope
|
||||||
|
if err := db.Where("name IN ?", scopes).Find(&scopeModels).Error; err != nil {
|
||||||
|
return fmt.Errorf("failed to find scopes for %s: %w", roleName, err)
|
||||||
|
}
|
||||||
|
if err := db.Model(&role).Association("Scopes").Replace(scopeModels); err != nil {
|
||||||
|
return fmt.Errorf("failed to assign scopes to %s: %w", roleName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func createDefaultRoles(db *gorm.DB) error {
|
func createDefaultRoles(db *gorm.DB) error {
|
||||||
roles := []string{"user", "admin", "owner"}
|
roles := []string{"user", "admin", "owner"}
|
||||||
for _, roleName := range roles {
|
for _, roleName := range roles {
|
||||||
|
|||||||
@@ -0,0 +1,39 @@
|
|||||||
|
services:
|
||||||
|
telegram-bot:
|
||||||
|
image: bogerserge/go-telegram-bot:latest
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
platforms:
|
||||||
|
- linux/amd64
|
||||||
|
- linux/arm64
|
||||||
|
container_name: go-telegram-bot
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
# Optional: Environment variables (can be overridden with .env file)
|
||||||
|
# environment:
|
||||||
|
# - BOT_LOG_LEVEL=info
|
||||||
|
|
||||||
|
# Volume mounts
|
||||||
|
volumes:
|
||||||
|
# Bind mount config directory for live configuration updates
|
||||||
|
- ./config:/app/config:ro
|
||||||
|
# Named volume for persistent database storage
|
||||||
|
- ./data:/app/data
|
||||||
|
# Optional: Bind mount for log access (uncomment if needed)
|
||||||
|
# - ./logs:/app/logs
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "pgrep", "telegram-bot"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 3
|
||||||
|
start_period: 10s
|
||||||
|
|
||||||
|
# Logging configuration
|
||||||
|
logging:
|
||||||
|
driver: "json-file"
|
||||||
|
options:
|
||||||
|
max-size: "10m"
|
||||||
|
max-file: "3"
|
||||||
Executable → Regular
@@ -1,19 +1,23 @@
|
|||||||
module github.com/HugeFrog24/go-telegram-bot
|
module github.com/HugeFrog24/go-telegram-bot
|
||||||
|
|
||||||
go 1.23
|
go 1.26.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/go-telegram/bot v1.9.1
|
github.com/go-telegram/bot v1.19.0
|
||||||
github.com/joho/godotenv v1.5.1
|
github.com/liushuangls/go-anthropic/v2 v2.17.1
|
||||||
github.com/liushuangls/go-anthropic/v2 v2.8.1
|
github.com/stretchr/testify v1.11.1
|
||||||
golang.org/x/time v0.7.0
|
golang.org/x/time v0.14.0
|
||||||
gorm.io/driver/sqlite v1.5.6
|
gorm.io/driver/sqlite v1.6.0
|
||||||
gorm.io/gorm v1.25.12
|
gorm.io/gorm v1.31.1
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/jinzhu/now v1.1.5 // indirect
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
github.com/mattn/go-sqlite3 v1.14.24 // indirect
|
github.com/mattn/go-sqlite3 v1.14.34 // indirect
|
||||||
golang.org/x/text v0.19.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
github.com/stretchr/objx v0.5.3 // indirect
|
||||||
|
golang.org/x/text v0.34.0 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,20 +1,30 @@
|
|||||||
github.com/go-telegram/bot v1.9.1 h1:4vkNV6vDmEPZaYP7sZYaagOaJyV4GerfOPkjg/Ki5ic=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/go-telegram/bot v1.9.1/go.mod h1:i2TRs7fXWIeaceF3z7KzsMt/he0TwkVC680mvdTFYeM=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/go-telegram/bot v1.19.0 h1:tuvTQhgNietHFRN0HUDhuXsgfgkGSaO8WWwZQW3DMQg=
|
||||||
|
github.com/go-telegram/bot v1.19.0/go.mod h1:i2TRs7fXWIeaceF3z7KzsMt/he0TwkVC680mvdTFYeM=
|
||||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
github.com/liushuangls/go-anthropic/v2 v2.17.1 h1:ca3oFzgQHs9/mJr+xx2XFQIYcQLM2rDCqieUx0g+8p4=
|
||||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
github.com/liushuangls/go-anthropic/v2 v2.17.1/go.mod h1:a550cJXPoTG2FL3DvfKG2zzD5O2vjgvo4tHtoGPzFLU=
|
||||||
github.com/liushuangls/go-anthropic/v2 v2.8.1 h1:pxFl88IgkG7e8Z1XwOYu48LcmEN0+6UdO58HF9altw0=
|
github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk=
|
||||||
github.com/liushuangls/go-anthropic/v2 v2.8.1/go.mod h1:8BKv/fkeTaL5R9R9bGkaknYBueyw2WxY20o7bImbOek=
|
github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||||
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
|
github.com/stretchr/objx v0.5.3 h1:jmXUvGomnU1o3W/V5h2VEradbpJDwGrzugQQvL0POH4=
|
||||||
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
|
github.com/stretchr/objx v0.5.3/go.mod h1:rDQraq+vQZU7Fde9LOZLr8Tax6zZvy4kuNKF+QYS+U0=
|
||||||
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE=
|
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||||
gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4=
|
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
|
||||||
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||||
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
|
||||||
|
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
|
||||||
|
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
|
||||||
|
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
|
||||||
|
|||||||
Executable → Regular
+336
-78
@@ -2,6 +2,9 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/go-telegram/bot"
|
"github.com/go-telegram/bot"
|
||||||
@@ -9,6 +12,20 @@ import (
|
|||||||
"github.com/liushuangls/go-anthropic/v2"
|
"github.com/liushuangls/go-anthropic/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// anthropicErrorResponse returns the message to send back to the user when getAnthropicResponse
|
||||||
|
// fails. Admins and owners receive an actionable hint when the model is deprecated; regular users
|
||||||
|
// always get the generic fallback to avoid leaking internal details.
|
||||||
|
func (b *Bot) anthropicErrorResponse(err error, userID int64) string {
|
||||||
|
if errors.Is(err, ErrModelNotFound) && b.hasScope(userID, ScopeModelSet) {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"⚠️ Model `%s` is no longer available (deprecated or removed by Anthropic).\n"+
|
||||||
|
"Use /set_model <model-id> to switch. Current models: https://platform.claude.com/docs/en/about-claude/models/overview",
|
||||||
|
b.config.Model,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return "I'm sorry, I'm having trouble processing your request right now."
|
||||||
|
}
|
||||||
|
|
||||||
func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.Update) {
|
func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.Update) {
|
||||||
var message *models.Message
|
var message *models.Message
|
||||||
|
|
||||||
@@ -29,53 +46,32 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
|
|||||||
businessConnectionID = message.BusinessConnectionID
|
businessConnectionID = message.BusinessConnectionID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if message.From == nil {
|
||||||
|
// Channel posts and some automated messages have no sender — ignore them.
|
||||||
|
// see: https://core.telegram.org/bots/api#message
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
chatID := message.Chat.ID
|
chatID := message.Chat.ID
|
||||||
userID := message.From.ID
|
userID := message.From.ID
|
||||||
username := message.From.Username
|
username := message.From.Username
|
||||||
|
firstName := message.From.FirstName
|
||||||
|
lastName := message.From.LastName
|
||||||
|
languageCode := message.From.LanguageCode
|
||||||
|
isPremium := message.From.IsPremium
|
||||||
|
messageTime := message.Date
|
||||||
text := message.Text
|
text := message.Text
|
||||||
|
|
||||||
// Pass the incoming message through the centralized screen for storage
|
// Check if it's a new chat (before storing the message so the flag is accurate).
|
||||||
_, err := b.screenIncomingMessage(message)
|
isNewChatFlag := b.isNewChat(chatID)
|
||||||
|
|
||||||
|
// Screen incoming message (store to DB + add to chat memory)
|
||||||
|
userMsg, err := b.screenIncomingMessage(message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorLogger.Printf("Error storing user message: %v", err)
|
ErrorLogger.Printf("Error storing user message: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the message is a command
|
|
||||||
if message.Entities != nil {
|
|
||||||
for _, entity := range message.Entities {
|
|
||||||
if entity.Type == "bot_command" {
|
|
||||||
command := strings.TrimSpace(message.Text[entity.Offset : entity.Offset+entity.Length])
|
|
||||||
switch command {
|
|
||||||
case "/stats":
|
|
||||||
b.sendStats(ctx, chatID, userID, username, businessConnectionID)
|
|
||||||
return
|
|
||||||
case "/whoami":
|
|
||||||
b.sendWhoAmI(ctx, chatID, userID, username, businessConnectionID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the message contains a sticker
|
|
||||||
if message.Sticker != nil {
|
|
||||||
b.handleStickerMessage(ctx, chatID, userID, message, businessConnectionID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rate limit check
|
|
||||||
if !b.checkRateLimits(userID) {
|
|
||||||
b.sendRateLimitExceededMessage(ctx, chatID, businessConnectionID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Proceed only if the message contains text
|
|
||||||
if text == "" {
|
|
||||||
InfoLogger.Printf("Received a non-text message from user %d in chat %d", userID, chatID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine if the user is the owner
|
// Determine if the user is the owner
|
||||||
var isOwner bool
|
var isOwner bool
|
||||||
err = b.db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", userID, b.botID, true).First(&User{}).Error
|
err = b.db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", userID, b.botID, true).First(&User{}).Error
|
||||||
@@ -83,13 +79,14 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
|
|||||||
isOwner = true
|
isOwner = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Always create/get the user record — on the very first message and on all subsequent ones.
|
||||||
user, err := b.getOrCreateUser(userID, username, isOwner)
|
user, err := b.getOrCreateUser(userID, username, isOwner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorLogger.Printf("Error getting or creating user: %v", err)
|
ErrorLogger.Printf("Error getting or creating user: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the username if it's empty or has changed
|
// Update the username if it has changed
|
||||||
if user.Username != username {
|
if user.Username != username {
|
||||||
user.Username = username
|
user.Username = username
|
||||||
if err := b.db.Save(&user).Error; err != nil {
|
if err := b.db.Save(&user).Error; err != nil {
|
||||||
@@ -97,22 +94,175 @@ func (b *Bot) handleUpdate(ctx context.Context, tgBot *bot.Bot, update *models.U
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if the message is a command — applies on every message, including the very first.
|
||||||
|
if message.Entities != nil {
|
||||||
|
for _, entity := range message.Entities {
|
||||||
|
if entity.Type == "bot_command" {
|
||||||
|
command := strings.TrimSpace(message.Text[entity.Offset : entity.Offset+entity.Length])
|
||||||
|
switch command {
|
||||||
|
case "/stats":
|
||||||
|
// Parse command parameters
|
||||||
|
parts := strings.Fields(message.Text)
|
||||||
|
|
||||||
|
// Default: show global stats
|
||||||
|
if len(parts) == 1 {
|
||||||
|
b.sendStats(ctx, chatID, userID, 0, businessConnectionID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for "user" parameter
|
||||||
|
if len(parts) >= 2 && parts[1] == "user" {
|
||||||
|
targetUserID := userID // Default to current user
|
||||||
|
|
||||||
|
// If a user ID is provided, parse it
|
||||||
|
if len(parts) >= 3 {
|
||||||
|
var parseErr error
|
||||||
|
targetUserID, parseErr = strconv.ParseInt(parts[2], 10, 64)
|
||||||
|
if parseErr != nil {
|
||||||
|
InfoLogger.Printf("User %d provided invalid user ID format: %s", userID, parts[2])
|
||||||
|
if err := b.sendResponse(ctx, chatID, "Invalid user ID format. Usage: /stats user [user_id]", businessConnectionID); err != nil {
|
||||||
|
ErrorLogger.Printf("Error sending response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.sendStats(ctx, chatID, userID, targetUserID, businessConnectionID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid parameter
|
||||||
|
if err := b.sendResponse(ctx, chatID, "Invalid command format. Usage: /stats or /stats user [user_id]", businessConnectionID); err != nil {
|
||||||
|
ErrorLogger.Printf("Error sending response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case "/whoami":
|
||||||
|
b.sendWhoAmI(ctx, chatID, userID, username, businessConnectionID)
|
||||||
|
return
|
||||||
|
case "/clear":
|
||||||
|
parts := strings.Fields(message.Text)
|
||||||
|
var targetUserID, targetChatID int64
|
||||||
|
if len(parts) > 1 {
|
||||||
|
var parseErr error
|
||||||
|
targetUserID, parseErr = strconv.ParseInt(parts[1], 10, 64)
|
||||||
|
if parseErr != nil {
|
||||||
|
InfoLogger.Printf("User %d provided invalid user ID format: %s", userID, parts[1])
|
||||||
|
if err := b.sendResponse(ctx, chatID, "Invalid user ID format. Usage: /clear [user_id] [chat_id]", businessConnectionID); err != nil {
|
||||||
|
ErrorLogger.Printf("Error sending response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(parts) > 2 {
|
||||||
|
var parseErr error
|
||||||
|
targetChatID, parseErr = strconv.ParseInt(parts[2], 10, 64)
|
||||||
|
if parseErr != nil {
|
||||||
|
InfoLogger.Printf("User %d provided invalid chat ID format: %s", userID, parts[2])
|
||||||
|
if err := b.sendResponse(ctx, chatID, "Invalid chat ID format. Usage: /clear [user_id] [chat_id]", businessConnectionID); err != nil {
|
||||||
|
ErrorLogger.Printf("Error sending response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.clearChatHistory(ctx, chatID, userID, targetUserID, targetChatID, businessConnectionID, false)
|
||||||
|
return
|
||||||
|
case "/set_model":
|
||||||
|
if !b.hasScope(userID, ScopeModelSet) {
|
||||||
|
if err := b.sendResponse(ctx, chatID, "Permission denied. Only admins and owners can change the model.", businessConnectionID); err != nil {
|
||||||
|
ErrorLogger.Printf("Error sending response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
parts := strings.Fields(message.Text)
|
||||||
|
if len(parts) < 2 || strings.TrimSpace(parts[1]) == "" {
|
||||||
|
if err := b.sendResponse(ctx, chatID, "Usage: /set_model <model-id>", businessConnectionID); err != nil {
|
||||||
|
ErrorLogger.Printf("Error sending response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
newModel := strings.TrimSpace(parts[1])
|
||||||
|
// No upfront model validation:
|
||||||
|
// - The go-anthropic library constants are not enumerable at runtime (Go has no const reflection).
|
||||||
|
// - A live /v1/models probe would add a network round-trip and show in the API audit log.
|
||||||
|
// - An invalid model ID will produce a 404 on the next real message, which routes through
|
||||||
|
// anthropicErrorResponse and already delivers an actionable admin-facing hint.
|
||||||
|
if err := b.config.PersistModel(newModel); err != nil {
|
||||||
|
ErrorLogger.Printf("Failed to persist model change: %v", err)
|
||||||
|
if err := b.sendResponse(ctx, chatID, fmt.Sprintf("Model updated in memory to `%s`, but failed to save to config file: %v", newModel, err), businessConnectionID); err != nil {
|
||||||
|
ErrorLogger.Printf("Error sending response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
InfoLogger.Printf("Model changed to %s by user %d", newModel, userID)
|
||||||
|
if err := b.sendResponse(ctx, chatID, fmt.Sprintf("✅ Model updated to `%s` and saved to config.", newModel), businessConnectionID); err != nil {
|
||||||
|
ErrorLogger.Printf("Error sending response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case "/clear_hard":
|
||||||
|
parts := strings.Fields(message.Text)
|
||||||
|
var targetUserID, targetChatID int64
|
||||||
|
if len(parts) > 1 {
|
||||||
|
var parseErr error
|
||||||
|
targetUserID, parseErr = strconv.ParseInt(parts[1], 10, 64)
|
||||||
|
if parseErr != nil {
|
||||||
|
InfoLogger.Printf("User %d provided invalid user ID format: %s", userID, parts[1])
|
||||||
|
if err := b.sendResponse(ctx, chatID, "Invalid user ID format. Usage: /clear_hard [user_id] [chat_id]", businessConnectionID); err != nil {
|
||||||
|
ErrorLogger.Printf("Error sending response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(parts) > 2 {
|
||||||
|
var parseErr error
|
||||||
|
targetChatID, parseErr = strconv.ParseInt(parts[2], 10, 64)
|
||||||
|
if parseErr != nil {
|
||||||
|
InfoLogger.Printf("User %d provided invalid chat ID format: %s", userID, parts[2])
|
||||||
|
if err := b.sendResponse(ctx, chatID, "Invalid chat ID format. Usage: /clear_hard [user_id] [chat_id]", businessConnectionID); err != nil {
|
||||||
|
ErrorLogger.Printf("Error sending response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.clearChatHistory(ctx, chatID, userID, targetUserID, targetChatID, businessConnectionID, true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rate limit check applies to all message types including stickers.
|
||||||
|
if !b.checkRateLimits(userID) {
|
||||||
|
b.sendRateLimitExceededMessage(ctx, chatID, businessConnectionID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build context once — shared by the sticker and text response paths.
|
||||||
|
chatMemory := b.getOrCreateChatMemory(chatID)
|
||||||
|
contextMessages := b.prepareContextMessages(chatMemory)
|
||||||
|
|
||||||
|
// Check if the message contains a sticker
|
||||||
|
if message.Sticker != nil {
|
||||||
|
b.handleStickerMessage(ctx, chatID, userMsg, message, contextMessages, businessConnectionID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Proceed only if the message contains text
|
||||||
|
if text == "" {
|
||||||
|
InfoLogger.Printf("Received a non-text message from user %d in chat %d", userID, chatID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Determine if the text contains only emojis
|
// Determine if the text contains only emojis
|
||||||
isEmojiOnly := isOnlyEmojis(text)
|
isEmojiOnly := isOnlyEmojis(text)
|
||||||
|
|
||||||
// Prepare context messages for Anthropic
|
|
||||||
chatMemory := b.getOrCreateChatMemory(chatID)
|
|
||||||
b.addMessageToChatMemory(chatMemory, b.createMessage(chatID, userID, username, user.Role.Name, text, true))
|
|
||||||
contextMessages := b.prepareContextMessages(chatMemory)
|
|
||||||
|
|
||||||
// Get response from Anthropic
|
// Get response from Anthropic
|
||||||
response, err := b.getAnthropicResponse(ctx, contextMessages, b.isNewChat(chatID), isOwner, isEmojiOnly)
|
response, err := b.getAnthropicResponse(ctx, contextMessages, isNewChatFlag, isOwner, isEmojiOnly, username, firstName, lastName, isPremium, languageCode, messageTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorLogger.Printf("Error getting Anthropic response: %v", err)
|
ErrorLogger.Printf("Error getting Anthropic response: %v", err)
|
||||||
response = "I'm sorry, I'm having trouble processing your request right now."
|
response = b.anthropicErrorResponse(err, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send the response through the centralized screen
|
// Send the response
|
||||||
if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil {
|
if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil {
|
||||||
ErrorLogger.Printf("Error sending response: %v", err)
|
ErrorLogger.Printf("Error sending response: %v", err)
|
||||||
return
|
return
|
||||||
@@ -125,24 +275,11 @@ func (b *Bot) sendRateLimitExceededMessage(ctx context.Context, chatID int64, bu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, message *models.Message, businessConnectionID string) {
|
func (b *Bot) handleStickerMessage(ctx context.Context, chatID int64, userMessage Message, message *models.Message, contextMessages []anthropic.Message, businessConnectionID string) {
|
||||||
username := message.From.Username
|
// userMessage was already screened (stored + added to memory) by handleUpdate — do not call screenIncomingMessage again.
|
||||||
|
|
||||||
// Create the user message (without storing it manually)
|
|
||||||
userMessage := b.createMessage(chatID, userID, username, "user", "Sent a sticker.", true)
|
|
||||||
userMessage.StickerFileID = message.Sticker.FileID
|
|
||||||
|
|
||||||
// Safely store the Thumbnail's FileID if available
|
|
||||||
if message.Sticker.Thumbnail != nil {
|
|
||||||
userMessage.StickerPNGFile = message.Sticker.Thumbnail.FileID
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update chat memory with the user message
|
|
||||||
chatMemory := b.getOrCreateChatMemory(chatID)
|
|
||||||
b.addMessageToChatMemory(chatMemory, userMessage)
|
|
||||||
|
|
||||||
// Generate AI response about the sticker
|
// Generate AI response about the sticker
|
||||||
response, err := b.generateStickerResponse(ctx, userMessage)
|
response, err := b.generateStickerResponse(ctx, userMessage, contextMessages)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorLogger.Printf("Error generating sticker response: %v", err)
|
ErrorLogger.Printf("Error generating sticker response: %v", err)
|
||||||
// Provide a fallback dynamic response based on sticker type
|
// Provide a fallback dynamic response based on sticker type
|
||||||
@@ -155,34 +292,155 @@ func (b *Bot) handleStickerMessage(ctx context.Context, chatID, userID int64, me
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send the response through the centralized screen
|
// Send the response
|
||||||
if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil {
|
if err := b.sendResponse(ctx, chatID, response, businessConnectionID); err != nil {
|
||||||
ErrorLogger.Printf("Error sending response: %v", err)
|
ErrorLogger.Printf("Error sending response: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bot) generateStickerResponse(ctx context.Context, message Message) (string, error) {
|
func (b *Bot) generateStickerResponse(ctx context.Context, message Message, contextMessages []anthropic.Message) (string, error) {
|
||||||
// Example: Use the sticker type to generate a response
|
// contextMessages already contains the sticker turn (added by screenIncomingMessage as
|
||||||
|
// "Sent a sticker: <emoji>"), so the full conversation history is preserved.
|
||||||
if message.StickerFileID != "" {
|
if message.StickerFileID != "" {
|
||||||
// Prepare context with information about the sticker
|
messageTime := int(message.Timestamp.Unix())
|
||||||
contextMessages := []anthropic.Message{
|
response, err := b.getAnthropicResponse(ctx, contextMessages, false, false, true, message.Username, "", "", false, "", messageTime)
|
||||||
{
|
|
||||||
Role: anthropic.RoleUser,
|
|
||||||
Content: []anthropic.MessageContent{
|
|
||||||
anthropic.NewTextMessageContent("User sent a sticker."),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Since this is a sticker message, isEmojiOnly is false
|
|
||||||
response, err := b.getAnthropicResponse(ctx, contextMessages, false, false, false)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return "Hmm, that's interesting!", nil
|
return "Hmm, that's interesting!", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *Bot) clearChatHistory(ctx context.Context, chatID int64, currentUserID int64, targetUserID int64, targetChatID int64, businessConnectionID string, hardDelete bool) {
|
||||||
|
// If targetUserID is provided and different from currentUserID, check permissions
|
||||||
|
if targetUserID != 0 && targetUserID != currentUserID {
|
||||||
|
requiredScope := ScopeHistoryClearAny
|
||||||
|
if hardDelete {
|
||||||
|
requiredScope = ScopeHistoryClearHardAny
|
||||||
|
}
|
||||||
|
if !b.hasScope(currentUserID, requiredScope) {
|
||||||
|
InfoLogger.Printf("User %d attempted to clear history for user %d without permission", currentUserID, targetUserID)
|
||||||
|
if err := b.sendResponse(ctx, chatID, "Permission denied. Only admins and owners can clear other users' histories.", businessConnectionID); err != nil {
|
||||||
|
ErrorLogger.Printf("Error sending response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the target user exists
|
||||||
|
var targetUser User
|
||||||
|
err := b.db.Where("telegram_id = ? AND bot_id = ?", targetUserID, b.botID).First(&targetUser).Error
|
||||||
|
if err != nil {
|
||||||
|
ErrorLogger.Printf("Error finding target user %d: %v", targetUserID, err)
|
||||||
|
if err := b.sendResponse(ctx, chatID, fmt.Sprintf("User with ID %d not found.", targetUserID), businessConnectionID); err != nil {
|
||||||
|
ErrorLogger.Printf("Error sending response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If no targetUserID is provided, set it to currentUserID
|
||||||
|
targetUserID = currentUserID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete messages from the database
|
||||||
|
//
|
||||||
|
// Assumption: this bot is primarily used in private DMs, where each user's messages
|
||||||
|
// are stored with chat_id == their own user_id — not the caller's chat_id. Scoping
|
||||||
|
// a cross-user delete by the caller's chatID would therefore match 0 rows.
|
||||||
|
//
|
||||||
|
// When clearing another user's history the default (targetChatID == 0) deletes all
|
||||||
|
// of that user's messages across every chat for this bot — the natural meaning of
|
||||||
|
// "/clear <userID>" (wipe their entire history with the bot).
|
||||||
|
//
|
||||||
|
// When targetChatID != 0 the deletion is scoped to that specific chat, which is
|
||||||
|
// useful for group moderation ("/clear <userID> <chatID>").
|
||||||
|
var err error
|
||||||
|
if hardDelete {
|
||||||
|
// Permanently delete messages
|
||||||
|
if targetUserID == currentUserID {
|
||||||
|
// Own history — delete ALL messages (user + assistant) in the current chat.
|
||||||
|
err = b.db.Unscoped().Where("chat_id = ? AND bot_id = ?", chatID, b.botID).Delete(&Message{}).Error
|
||||||
|
InfoLogger.Printf("User %d permanently deleted their own chat history in chat %d", currentUserID, chatID)
|
||||||
|
} else {
|
||||||
|
if targetChatID != 0 {
|
||||||
|
// Chat-scoped: delete ALL messages (user + assistant) in the specified chat.
|
||||||
|
err = b.db.Unscoped().Where("chat_id = ? AND bot_id = ?", targetChatID, b.botID).Delete(&Message{}).Error
|
||||||
|
InfoLogger.Printf("Admin/owner %d permanently deleted chat history for user %d in chat %d", currentUserID, targetUserID, targetChatID)
|
||||||
|
} else {
|
||||||
|
// Bot-wide: delete all of the user's own messages across every chat, then delete
|
||||||
|
// assistant messages from their DM chat (where chat_id == user_id by Telegram convention).
|
||||||
|
err = b.db.Unscoped().Where("bot_id = ? AND user_id = ?", b.botID, targetUserID).Delete(&Message{}).Error
|
||||||
|
if err == nil {
|
||||||
|
err = b.db.Unscoped().Where("chat_id = ? AND bot_id = ? AND is_user = ?", targetUserID, b.botID, false).Delete(&Message{}).Error
|
||||||
|
}
|
||||||
|
InfoLogger.Printf("Admin/owner %d permanently deleted all chat history for user %d", currentUserID, targetUserID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Soft delete messages
|
||||||
|
if targetUserID == currentUserID {
|
||||||
|
// Own history — delete ALL messages (user + assistant) in the current chat.
|
||||||
|
err = b.db.Where("chat_id = ? AND bot_id = ?", chatID, b.botID).Delete(&Message{}).Error
|
||||||
|
InfoLogger.Printf("User %d soft deleted their own chat history in chat %d", currentUserID, chatID)
|
||||||
|
} else {
|
||||||
|
if targetChatID != 0 {
|
||||||
|
// Chat-scoped: delete ALL messages (user + assistant) in the specified chat.
|
||||||
|
err = b.db.Where("chat_id = ? AND bot_id = ?", targetChatID, b.botID).Delete(&Message{}).Error
|
||||||
|
InfoLogger.Printf("Admin/owner %d soft deleted chat history for user %d in chat %d", currentUserID, targetUserID, targetChatID)
|
||||||
|
} else {
|
||||||
|
// Bot-wide: delete all of the user's own messages across every chat, then delete
|
||||||
|
// assistant messages from their DM chat (where chat_id == user_id by Telegram convention).
|
||||||
|
err = b.db.Where("bot_id = ? AND user_id = ?", b.botID, targetUserID).Delete(&Message{}).Error
|
||||||
|
if err == nil {
|
||||||
|
err = b.db.Where("chat_id = ? AND bot_id = ? AND is_user = ?", targetUserID, b.botID, false).Delete(&Message{}).Error
|
||||||
|
}
|
||||||
|
InfoLogger.Printf("Admin/owner %d soft deleted all chat history for user %d", currentUserID, targetUserID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
ErrorLogger.Printf("Error clearing chat history: %v", err)
|
||||||
|
if err := b.sendResponse(ctx, chatID, "Sorry, I couldn't clear the chat history.", businessConnectionID); err != nil {
|
||||||
|
ErrorLogger.Printf("Error sending response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Evict the relevant in-memory cache entry so the next access rebuilds from
|
||||||
|
// the now-clean DB. Applies to all cases: own history, cross-user
|
||||||
|
// scoped to a specific chat, and bot-wide cross-user clear.
|
||||||
|
b.chatMemoriesMu.Lock()
|
||||||
|
if targetUserID == currentUserID {
|
||||||
|
// Own history is always scoped to the current chat.
|
||||||
|
delete(b.chatMemories, chatID)
|
||||||
|
} else if targetChatID != 0 {
|
||||||
|
// Admin cleared a specific chat — evict that chat's cache.
|
||||||
|
delete(b.chatMemories, targetChatID)
|
||||||
|
} else {
|
||||||
|
// Bot-wide clear: primary use-case is DMs where chatID == userID.
|
||||||
|
delete(b.chatMemories, targetUserID)
|
||||||
|
}
|
||||||
|
b.chatMemoriesMu.Unlock()
|
||||||
|
|
||||||
|
// Send a confirmation message
|
||||||
|
var confirmationMessage string
|
||||||
|
if targetUserID == currentUserID {
|
||||||
|
confirmationMessage = "Your chat history has been cleared."
|
||||||
|
} else {
|
||||||
|
// Get the username of the target user if available
|
||||||
|
var targetUser User
|
||||||
|
err := b.db.Where("telegram_id = ? AND bot_id = ?", targetUserID, b.botID).First(&targetUser).Error
|
||||||
|
if err == nil && targetUser.Username != "" {
|
||||||
|
confirmationMessage = fmt.Sprintf("Chat history for user @%s (ID: %d) has been cleared.", targetUser.Username, targetUserID)
|
||||||
|
} else {
|
||||||
|
confirmationMessage = fmt.Sprintf("Chat history for user with ID %d has been cleared.", targetUserID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := b.sendResponse(ctx, chatID, confirmationMessage, businessConnectionID); err != nil {
|
||||||
|
ErrorLogger.Printf("Error sending response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,891 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-telegram/bot"
|
||||||
|
"github.com/go-telegram/bot/models"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHandleUpdate_NewChat(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
db := setupTestDB(t)
|
||||||
|
mockClock := &MockClock{
|
||||||
|
currentTime: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
config := BotConfig{
|
||||||
|
ID: "test_bot",
|
||||||
|
OwnerTelegramID: 123, // owner's ID
|
||||||
|
TelegramToken: "test_token",
|
||||||
|
MemorySize: 10,
|
||||||
|
MessagePerHour: 5,
|
||||||
|
MessagePerDay: 10,
|
||||||
|
TempBanDuration: "1h",
|
||||||
|
SystemPrompts: make(map[string]string),
|
||||||
|
Active: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
mockTgClient := &MockTelegramClient{}
|
||||||
|
|
||||||
|
// Create bot model first
|
||||||
|
botModel := &BotModel{
|
||||||
|
Identifier: config.ID,
|
||||||
|
Name: config.ID,
|
||||||
|
}
|
||||||
|
err := db.Create(botModel).Error
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Create bot config
|
||||||
|
configModel := &ConfigModel{
|
||||||
|
BotID: botModel.ID,
|
||||||
|
MemorySize: config.MemorySize,
|
||||||
|
MessagePerHour: config.MessagePerHour,
|
||||||
|
MessagePerDay: config.MessagePerDay,
|
||||||
|
TempBanDuration: config.TempBanDuration,
|
||||||
|
SystemPrompts: "{}",
|
||||||
|
TelegramToken: config.TelegramToken,
|
||||||
|
Active: config.Active,
|
||||||
|
}
|
||||||
|
err = db.Create(configModel).Error
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Create bot instance
|
||||||
|
b, err := NewBot(db, config, mockClock, mockTgClient)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
userID int64
|
||||||
|
isOwner bool
|
||||||
|
wantResp string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Owner First Message",
|
||||||
|
userID: 123, // owner's ID
|
||||||
|
isOwner: true,
|
||||||
|
wantResp: "I'm sorry, I'm having trouble processing your request right now.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Regular User First Message",
|
||||||
|
userID: 456,
|
||||||
|
isOwner: false,
|
||||||
|
wantResp: "I'm sorry, I'm having trouble processing your request right now.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Setup mock response expectations for error case to test fallback messages
|
||||||
|
mockTgClient.SendMessageFunc = func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) {
|
||||||
|
assert.Equal(t, tc.userID, params.ChatID)
|
||||||
|
assert.Equal(t, tc.wantResp, params.Text)
|
||||||
|
return &models.Message{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create update with new message
|
||||||
|
update := &models.Update{
|
||||||
|
Message: &models.Message{
|
||||||
|
Chat: models.Chat{ID: tc.userID},
|
||||||
|
From: &models.User{
|
||||||
|
ID: tc.userID,
|
||||||
|
Username: "testuser",
|
||||||
|
},
|
||||||
|
Text: "Hello",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle the update
|
||||||
|
b.handleUpdate(context.Background(), nil, update)
|
||||||
|
|
||||||
|
// Verify message was stored
|
||||||
|
var storedMsg Message
|
||||||
|
err := db.Where("chat_id = ? AND user_id = ? AND text = ?", tc.userID, tc.userID, "Hello").First(&storedMsg).Error
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify response was stored
|
||||||
|
var respMsg Message
|
||||||
|
err = db.Where("chat_id = ? AND is_user = ? AND text = ?", tc.userID, false, tc.wantResp).First(&respMsg).Error
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClearChatHistory(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
db := setupTestDB(t)
|
||||||
|
mockClock := &MockClock{
|
||||||
|
currentTime: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
config := BotConfig{
|
||||||
|
ID: "test_bot",
|
||||||
|
OwnerTelegramID: 123, // owner's ID
|
||||||
|
TelegramToken: "test_token",
|
||||||
|
MemorySize: 10,
|
||||||
|
MessagePerHour: 5,
|
||||||
|
MessagePerDay: 10,
|
||||||
|
TempBanDuration: "1h",
|
||||||
|
SystemPrompts: make(map[string]string),
|
||||||
|
Active: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
mockTgClient := &MockTelegramClient{}
|
||||||
|
|
||||||
|
// Create bot model first
|
||||||
|
botModel := &BotModel{
|
||||||
|
Identifier: config.ID,
|
||||||
|
Name: config.ID,
|
||||||
|
}
|
||||||
|
err := db.Create(botModel).Error
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Create bot config
|
||||||
|
configModel := &ConfigModel{
|
||||||
|
BotID: botModel.ID,
|
||||||
|
MemorySize: config.MemorySize,
|
||||||
|
MessagePerHour: config.MessagePerHour,
|
||||||
|
MessagePerDay: config.MessagePerDay,
|
||||||
|
TempBanDuration: config.TempBanDuration,
|
||||||
|
SystemPrompts: "{}",
|
||||||
|
TelegramToken: config.TelegramToken,
|
||||||
|
Active: config.Active,
|
||||||
|
}
|
||||||
|
err = db.Create(configModel).Error
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Create bot instance
|
||||||
|
b, err := NewBot(db, config, mockClock, mockTgClient)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Create test users
|
||||||
|
ownerID := int64(123)
|
||||||
|
adminID := int64(456)
|
||||||
|
regularUserID := int64(789)
|
||||||
|
nonExistentUserID := int64(999)
|
||||||
|
chatID := int64(1000)
|
||||||
|
|
||||||
|
// Create admin role
|
||||||
|
adminRole, err := b.getRoleByName("admin")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Create admin user
|
||||||
|
adminUser := User{
|
||||||
|
BotID: b.botID,
|
||||||
|
TelegramID: adminID,
|
||||||
|
Username: "admin",
|
||||||
|
RoleID: adminRole.ID,
|
||||||
|
Role: adminRole,
|
||||||
|
IsOwner: false,
|
||||||
|
}
|
||||||
|
err = db.Create(&adminUser).Error
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Create regular user
|
||||||
|
regularRole, err := b.getRoleByName("user")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
regularUser := User{
|
||||||
|
BotID: b.botID,
|
||||||
|
TelegramID: regularUserID,
|
||||||
|
Username: "regular",
|
||||||
|
RoleID: regularRole.ID,
|
||||||
|
Role: regularRole,
|
||||||
|
IsOwner: false,
|
||||||
|
}
|
||||||
|
err = db.Create(®ularUser).Error
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Create test messages for each user.
|
||||||
|
// Each user's messages are stored with chat_id == their own user_id, mirroring
|
||||||
|
// how Telegram private DMs work (chat_id == user_id in 1-on-1 bot conversations).
|
||||||
|
// Using a shared artificial chatID here would mask the cross-user delete bug.
|
||||||
|
for _, userID := range []int64{ownerID, adminID, regularUserID} {
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
message := Message{
|
||||||
|
BotID: b.botID,
|
||||||
|
ChatID: userID, // per-user chat, not a shared chatID
|
||||||
|
UserID: userID,
|
||||||
|
Username: "test",
|
||||||
|
UserRole: "user",
|
||||||
|
Text: "Test message",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
IsUser: true,
|
||||||
|
}
|
||||||
|
err = db.Create(&message).Error
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test cases
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
currentUserID int64
|
||||||
|
targetUserID int64
|
||||||
|
hardDelete bool
|
||||||
|
expectedError bool
|
||||||
|
expectedCount int64
|
||||||
|
expectedMsg string
|
||||||
|
targetChatID int64
|
||||||
|
businessConnID string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Owner clears own history",
|
||||||
|
currentUserID: ownerID,
|
||||||
|
targetUserID: ownerID,
|
||||||
|
hardDelete: false,
|
||||||
|
expectedError: false,
|
||||||
|
expectedCount: 0,
|
||||||
|
expectedMsg: "Your chat history has been cleared.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Admin clears own history",
|
||||||
|
currentUserID: adminID,
|
||||||
|
targetUserID: adminID,
|
||||||
|
hardDelete: false,
|
||||||
|
expectedError: false,
|
||||||
|
expectedCount: 0,
|
||||||
|
expectedMsg: "Your chat history has been cleared.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Regular user clears own history",
|
||||||
|
currentUserID: regularUserID,
|
||||||
|
targetUserID: regularUserID,
|
||||||
|
hardDelete: false,
|
||||||
|
expectedError: false,
|
||||||
|
expectedCount: 0,
|
||||||
|
expectedMsg: "Your chat history has been cleared.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Owner clears admin's history",
|
||||||
|
currentUserID: ownerID,
|
||||||
|
targetUserID: adminID,
|
||||||
|
hardDelete: false,
|
||||||
|
expectedError: false,
|
||||||
|
expectedCount: 0,
|
||||||
|
expectedMsg: "Chat history for user @admin (ID: 456) has been cleared.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Admin clears regular user's history",
|
||||||
|
currentUserID: adminID,
|
||||||
|
targetUserID: regularUserID,
|
||||||
|
hardDelete: false,
|
||||||
|
expectedError: false,
|
||||||
|
expectedCount: 0,
|
||||||
|
expectedMsg: "Chat history for user @regular (ID: 789) has been cleared.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Regular user attempts to clear admin's history",
|
||||||
|
currentUserID: regularUserID,
|
||||||
|
targetUserID: adminID,
|
||||||
|
hardDelete: false,
|
||||||
|
expectedError: true,
|
||||||
|
expectedCount: 5, // Messages should remain
|
||||||
|
expectedMsg: "Permission denied. Only admins and owners can clear other users' histories.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Admin attempts to clear non-existent user's history",
|
||||||
|
currentUserID: adminID,
|
||||||
|
targetUserID: nonExistentUserID,
|
||||||
|
hardDelete: false,
|
||||||
|
expectedError: true,
|
||||||
|
expectedCount: 5, // Messages should remain for admin
|
||||||
|
expectedMsg: "User with ID 999 not found.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Owner hard deletes regular user's history",
|
||||||
|
currentUserID: ownerID,
|
||||||
|
targetUserID: regularUserID,
|
||||||
|
hardDelete: true,
|
||||||
|
expectedError: false,
|
||||||
|
expectedCount: 0,
|
||||||
|
expectedMsg: "Chat history for user @regular (ID: 789) has been cleared.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// targetChatID scopes the delete to a specific chat; messages in other chats survive.
|
||||||
|
// We seed messages with ChatID == userID (per-user DM), so targeting a different chatID
|
||||||
|
// should leave the user's messages untouched (expectedCount == 5).
|
||||||
|
name: "Admin clears regular user's history scoped to non-matching chat",
|
||||||
|
currentUserID: adminID,
|
||||||
|
targetUserID: regularUserID,
|
||||||
|
targetChatID: int64(9999), // a chat the user has no messages in
|
||||||
|
hardDelete: false,
|
||||||
|
expectedError: false,
|
||||||
|
expectedCount: 5, // messages in chat 789 are unaffected
|
||||||
|
expectedMsg: "Chat history for user @regular (ID: 789) has been cleared.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Reset messages for the test case
|
||||||
|
if tc.name != "Owner hard deletes regular user's history" {
|
||||||
|
// Delete all messages for the target user
|
||||||
|
err = db.Where("user_id = ?", tc.targetUserID).Delete(&Message{}).Error
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Recreate messages for the target user
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
message := Message{
|
||||||
|
BotID: b.botID,
|
||||||
|
ChatID: chatID,
|
||||||
|
UserID: tc.targetUserID,
|
||||||
|
Username: "test",
|
||||||
|
UserRole: "user",
|
||||||
|
Text: "Test message",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
IsUser: true,
|
||||||
|
}
|
||||||
|
err = db.Create(&message).Error
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup mock response expectations
|
||||||
|
var sentMessage string
|
||||||
|
mockTgClient.SendMessageFunc = func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) {
|
||||||
|
sentMessage = params.Text
|
||||||
|
return &models.Message{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the clearChatHistory method
|
||||||
|
b.clearChatHistory(context.Background(), chatID, tc.currentUserID, tc.targetUserID, tc.targetChatID, tc.businessConnID, tc.hardDelete)
|
||||||
|
|
||||||
|
// Verify the response message
|
||||||
|
assert.Equal(t, tc.expectedMsg, sentMessage)
|
||||||
|
|
||||||
|
// Count remaining messages for the target user
|
||||||
|
var count int64
|
||||||
|
if tc.hardDelete {
|
||||||
|
db.Unscoped().Model(&Message{}).Where("user_id = ? AND chat_id = ?", tc.targetUserID, chatID).Count(&count)
|
||||||
|
} else {
|
||||||
|
db.Model(&Message{}).Where("user_id = ? AND chat_id = ?", tc.targetUserID, chatID).Count(&count)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tc.expectedCount, count)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStatsCommand(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
db := setupTestDB(t)
|
||||||
|
mockClock := &MockClock{
|
||||||
|
currentTime: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
config := BotConfig{
|
||||||
|
ID: "test_bot",
|
||||||
|
OwnerTelegramID: 123, // owner's ID
|
||||||
|
TelegramToken: "test_token",
|
||||||
|
MemorySize: 10,
|
||||||
|
MessagePerHour: 5,
|
||||||
|
MessagePerDay: 10,
|
||||||
|
TempBanDuration: "1h",
|
||||||
|
SystemPrompts: make(map[string]string),
|
||||||
|
Active: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
mockTgClient := &MockTelegramClient{}
|
||||||
|
|
||||||
|
// Create bot model first
|
||||||
|
botModel := &BotModel{
|
||||||
|
Identifier: config.ID,
|
||||||
|
Name: config.ID,
|
||||||
|
}
|
||||||
|
err := db.Create(botModel).Error
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Create bot config
|
||||||
|
configModel := &ConfigModel{
|
||||||
|
BotID: botModel.ID,
|
||||||
|
MemorySize: config.MemorySize,
|
||||||
|
MessagePerHour: config.MessagePerHour,
|
||||||
|
MessagePerDay: config.MessagePerDay,
|
||||||
|
TempBanDuration: config.TempBanDuration,
|
||||||
|
SystemPrompts: "{}",
|
||||||
|
TelegramToken: config.TelegramToken,
|
||||||
|
Active: config.Active,
|
||||||
|
}
|
||||||
|
err = db.Create(configModel).Error
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Create bot instance
|
||||||
|
b, err := NewBot(db, config, mockClock, mockTgClient)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Create test users
|
||||||
|
ownerID := int64(123)
|
||||||
|
adminID := int64(456)
|
||||||
|
regularUserID := int64(789)
|
||||||
|
chatID := int64(1000)
|
||||||
|
|
||||||
|
// Create admin role
|
||||||
|
adminRole, err := b.getRoleByName("admin")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Create admin user
|
||||||
|
adminUser := User{
|
||||||
|
BotID: b.botID,
|
||||||
|
TelegramID: adminID,
|
||||||
|
Username: "admin",
|
||||||
|
RoleID: adminRole.ID,
|
||||||
|
Role: adminRole,
|
||||||
|
IsOwner: false,
|
||||||
|
}
|
||||||
|
err = db.Create(&adminUser).Error
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Create regular user
|
||||||
|
regularRole, err := b.getRoleByName("user")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
regularUser := User{
|
||||||
|
BotID: b.botID,
|
||||||
|
TelegramID: regularUserID,
|
||||||
|
Username: "regular",
|
||||||
|
RoleID: regularRole.ID,
|
||||||
|
Role: regularRole,
|
||||||
|
IsOwner: false,
|
||||||
|
}
|
||||||
|
err = db.Create(®ularUser).Error
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Create test messages for each user
|
||||||
|
for _, userID := range []int64{ownerID, adminID, regularUserID} {
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
// User message
|
||||||
|
userMessage := Message{
|
||||||
|
BotID: b.botID,
|
||||||
|
ChatID: chatID,
|
||||||
|
UserID: userID,
|
||||||
|
Username: "test",
|
||||||
|
UserRole: "user",
|
||||||
|
Text: "Test message",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
IsUser: true,
|
||||||
|
}
|
||||||
|
err = db.Create(&userMessage).Error
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Bot response
|
||||||
|
botMessage := Message{
|
||||||
|
BotID: b.botID,
|
||||||
|
ChatID: chatID,
|
||||||
|
UserID: 0,
|
||||||
|
Username: "AI Assistant",
|
||||||
|
UserRole: "assistant",
|
||||||
|
Text: "Test response",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
IsUser: false,
|
||||||
|
}
|
||||||
|
err = db.Create(&botMessage).Error
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test cases
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
command string
|
||||||
|
currentUserID int64
|
||||||
|
expectedError bool
|
||||||
|
expectedMsg string
|
||||||
|
businessConnID string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Global stats",
|
||||||
|
command: "/stats",
|
||||||
|
currentUserID: regularUserID,
|
||||||
|
expectedError: false,
|
||||||
|
expectedMsg: "📊 Bot Statistics:",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "User requests own stats",
|
||||||
|
command: "/stats user",
|
||||||
|
currentUserID: regularUserID,
|
||||||
|
expectedError: false,
|
||||||
|
expectedMsg: "👤 User Statistics for @regular (ID: 789):",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Admin requests another user's stats",
|
||||||
|
command: "/stats user 789",
|
||||||
|
currentUserID: adminID,
|
||||||
|
expectedError: false,
|
||||||
|
expectedMsg: "👤 User Statistics for @regular (ID: 789):",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Owner requests another user's stats",
|
||||||
|
command: "/stats user 456",
|
||||||
|
currentUserID: ownerID,
|
||||||
|
expectedError: false,
|
||||||
|
expectedMsg: "👤 User Statistics for @admin (ID: 456):",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Regular user attempts to request another user's stats",
|
||||||
|
command: "/stats user 456",
|
||||||
|
currentUserID: regularUserID,
|
||||||
|
expectedError: true,
|
||||||
|
expectedMsg: "Permission denied. Only admins and owners can view other users' statistics.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "User provides invalid user ID format",
|
||||||
|
command: "/stats user abc",
|
||||||
|
currentUserID: adminID,
|
||||||
|
expectedError: true,
|
||||||
|
expectedMsg: "Invalid user ID format. Usage: /stats user [user_id]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "User provides invalid command format",
|
||||||
|
command: "/stats invalid",
|
||||||
|
currentUserID: adminID,
|
||||||
|
expectedError: true,
|
||||||
|
expectedMsg: "Invalid command format. Usage: /stats or /stats user [user_id]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "User requests non-existent user's stats",
|
||||||
|
command: "/stats user 999",
|
||||||
|
currentUserID: adminID,
|
||||||
|
expectedError: true,
|
||||||
|
expectedMsg: "Sorry, I couldn't retrieve statistics for user ID 999.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Setup mock response expectations
|
||||||
|
var sentMessage string
|
||||||
|
mockTgClient.SendMessageFunc = func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) {
|
||||||
|
sentMessage = params.Text
|
||||||
|
return &models.Message{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create update with command
|
||||||
|
update := &models.Update{
|
||||||
|
Message: &models.Message{
|
||||||
|
Chat: models.Chat{ID: chatID},
|
||||||
|
From: &models.User{
|
||||||
|
ID: tc.currentUserID,
|
||||||
|
Username: getUsernameByID(tc.currentUserID),
|
||||||
|
},
|
||||||
|
Text: tc.command,
|
||||||
|
Entities: []models.MessageEntity{
|
||||||
|
{
|
||||||
|
Type: "bot_command",
|
||||||
|
Offset: 0,
|
||||||
|
Length: 6, // Length of "/stats"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle the update
|
||||||
|
b.handleUpdate(context.Background(), nil, update)
|
||||||
|
|
||||||
|
// Verify the response message contains the expected text
|
||||||
|
assert.Contains(t, sentMessage, tc.expectedMsg)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to get username by ID for test
|
||||||
|
func getUsernameByID(id int64) string {
|
||||||
|
switch id {
|
||||||
|
case 123:
|
||||||
|
return "owner"
|
||||||
|
case 456:
|
||||||
|
return "admin"
|
||||||
|
case 789:
|
||||||
|
return "regular"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTestDB(t *testing.T) *gorm.DB {
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open test database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AutoMigrate the models
|
||||||
|
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{}, &Scope{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to migrate database schema: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create default roles and scopes
|
||||||
|
err = createDefaultRoles(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create default roles: %v", err)
|
||||||
|
}
|
||||||
|
if err := createDefaultScopes(db); err != nil {
|
||||||
|
t.Fatalf("Failed to create default scopes: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupBotForTest creates a minimal Bot instance backed by an in-memory DB.
|
||||||
|
// It follows the same pattern as the existing handler tests to avoid duplication.
|
||||||
|
func setupBotForTest(t *testing.T, ownerID int64) (*Bot, *MockTelegramClient) {
|
||||||
|
t.Helper()
|
||||||
|
db := setupTestDB(t)
|
||||||
|
mockClock := &MockClock{currentTime: time.Now()}
|
||||||
|
config := BotConfig{
|
||||||
|
ID: "test_bot",
|
||||||
|
OwnerTelegramID: ownerID,
|
||||||
|
TelegramToken: "test_token",
|
||||||
|
MemorySize: 10,
|
||||||
|
MessagePerHour: 5,
|
||||||
|
MessagePerDay: 10,
|
||||||
|
TempBanDuration: "1h",
|
||||||
|
Model: "claude-3-5-haiku-latest",
|
||||||
|
SystemPrompts: make(map[string]string),
|
||||||
|
Active: true,
|
||||||
|
}
|
||||||
|
mockTgClient := &MockTelegramClient{}
|
||||||
|
botModel := &BotModel{Identifier: config.ID, Name: config.ID}
|
||||||
|
assert.NoError(t, db.Create(botModel).Error)
|
||||||
|
assert.NoError(t, db.Create(&ConfigModel{
|
||||||
|
BotID: botModel.ID,
|
||||||
|
MemorySize: config.MemorySize,
|
||||||
|
MessagePerHour: config.MessagePerHour,
|
||||||
|
MessagePerDay: config.MessagePerDay,
|
||||||
|
TempBanDuration: config.TempBanDuration,
|
||||||
|
SystemPrompts: "{}",
|
||||||
|
TelegramToken: config.TelegramToken,
|
||||||
|
Active: config.Active,
|
||||||
|
}).Error)
|
||||||
|
b, err := NewBot(db, config, mockClock, mockTgClient)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
return b, mockTgClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAnthropicErrorResponse verifies that model-deprecation errors surface actionable
|
||||||
|
// details only to admin/owner, and that regular users and non-model errors always get
|
||||||
|
// the generic fallback.
|
||||||
|
func TestAnthropicErrorResponse(t *testing.T) { //NOSONAR go:S100 -- underscore separation is idiomatic in Go test names
|
||||||
|
b, _ := setupBotForTest(t, 123)
|
||||||
|
|
||||||
|
// Create admin user
|
||||||
|
adminRole, err := b.getRoleByName("admin")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NoError(t, b.db.Create(&User{
|
||||||
|
BotID: b.botID, TelegramID: 456, Username: "admin",
|
||||||
|
RoleID: adminRole.ID, Role: adminRole,
|
||||||
|
}).Error)
|
||||||
|
|
||||||
|
// Create regular user
|
||||||
|
userRole, err := b.getRoleByName("user")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NoError(t, b.db.Create(&User{
|
||||||
|
BotID: b.botID, TelegramID: 789, Username: "regular",
|
||||||
|
RoleID: userRole.ID, Role: userRole,
|
||||||
|
}).Error)
|
||||||
|
|
||||||
|
modelErr := fmt.Errorf("%w: claude-3-5-haiku-latest", ErrModelNotFound)
|
||||||
|
otherErr := errors.New("network error")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
userID int64
|
||||||
|
wantSubstr string
|
||||||
|
wantMissing string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "owner receives actionable model-not-found message",
|
||||||
|
err: modelErr,
|
||||||
|
userID: 123,
|
||||||
|
wantSubstr: "/set_model",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "admin receives actionable model-not-found message",
|
||||||
|
err: modelErr,
|
||||||
|
userID: 456,
|
||||||
|
wantSubstr: "/set_model",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "regular user receives generic message for model-not-found",
|
||||||
|
err: modelErr,
|
||||||
|
userID: 789,
|
||||||
|
wantSubstr: "I'm sorry",
|
||||||
|
wantMissing: "/set_model",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "owner receives generic message for non-model error",
|
||||||
|
err: otherErr,
|
||||||
|
userID: 123,
|
||||||
|
wantSubstr: "I'm sorry",
|
||||||
|
wantMissing: "/set_model",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
resp := b.anthropicErrorResponse(tc.err, tc.userID)
|
||||||
|
assert.Contains(t, resp, tc.wantSubstr)
|
||||||
|
if tc.wantMissing != "" {
|
||||||
|
assert.NotContains(t, resp, tc.wantMissing)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSetModelCommand verifies that /set_model enforces permissions, validates input,
|
||||||
|
// updates the model in memory, and persists the change to the config file on disk.
|
||||||
|
func TestSetModelCommand(t *testing.T) { //NOSONAR go:S100 -- underscore separation is idiomatic in Go test names
|
||||||
|
b, mockTgClient := setupBotForTest(t, 123)
|
||||||
|
|
||||||
|
// Point the config at a temporary file so PersistModel can write to disk.
|
||||||
|
tempDir, err := os.MkdirTemp("", "set_model_cmd_test")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer func() { _ = os.RemoveAll(tempDir) }()
|
||||||
|
|
||||||
|
configPath := filepath.Join(tempDir, "config.json")
|
||||||
|
initialJSON := `{"id":"test_bot","telegram_token":"test_token","model":"claude-3-5-haiku-latest","messages_per_hour":5,"messages_per_day":10}`
|
||||||
|
assert.NoError(t, os.WriteFile(configPath, []byte(initialJSON), 0600))
|
||||||
|
b.config.ConfigFilePath = configPath
|
||||||
|
|
||||||
|
// Create admin and regular users
|
||||||
|
adminRole, err := b.getRoleByName("admin")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NoError(t, b.db.Create(&User{
|
||||||
|
BotID: b.botID, TelegramID: 456, Username: "admin",
|
||||||
|
RoleID: adminRole.ID, Role: adminRole,
|
||||||
|
}).Error)
|
||||||
|
userRole, err := b.getRoleByName("user")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NoError(t, b.db.Create(&User{
|
||||||
|
BotID: b.botID, TelegramID: 789, Username: "regular",
|
||||||
|
RoleID: userRole.ID, Role: userRole,
|
||||||
|
}).Error)
|
||||||
|
|
||||||
|
chatID := int64(1000)
|
||||||
|
|
||||||
|
// Seed chat 1000 with a prior message so isNewChatFlag is false for all subtests.
|
||||||
|
// Commands are only processed in the non-new-chat branch of handleUpdate.
|
||||||
|
assert.NoError(t, b.db.Create(&Message{
|
||||||
|
BotID: b.botID, ChatID: chatID, UserID: 789, Username: "regular",
|
||||||
|
UserRole: "user", Text: "hello", IsUser: true,
|
||||||
|
}).Error)
|
||||||
|
|
||||||
|
makeUpdate := func(userID int64, text string, cmdLen int) *models.Update {
|
||||||
|
return &models.Update{
|
||||||
|
Message: &models.Message{
|
||||||
|
Chat: models.Chat{ID: chatID},
|
||||||
|
From: &models.User{ID: userID, Username: getUsernameByID(userID)},
|
||||||
|
Text: text,
|
||||||
|
Entities: []models.MessageEntity{
|
||||||
|
{Type: "bot_command", Offset: 0, Length: cmdLen},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
userID int64
|
||||||
|
text string
|
||||||
|
wantSubstr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "regular user is denied",
|
||||||
|
userID: 789,
|
||||||
|
text: "/set_model claude-sonnet-4-6",
|
||||||
|
wantSubstr: "Permission denied",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "admin missing argument shows usage",
|
||||||
|
userID: 456,
|
||||||
|
text: "/set_model",
|
||||||
|
wantSubstr: "Usage:",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "owner missing argument shows usage",
|
||||||
|
userID: 123,
|
||||||
|
text: "/set_model",
|
||||||
|
wantSubstr: "Usage:",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "admin sets model successfully",
|
||||||
|
userID: 456,
|
||||||
|
text: "/set_model claude-sonnet-4-6",
|
||||||
|
wantSubstr: "✅",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var sentMessage string
|
||||||
|
mockTgClient.SendMessageFunc = func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) {
|
||||||
|
sentMessage = params.Text
|
||||||
|
return &models.Message{}, nil
|
||||||
|
}
|
||||||
|
b.handleUpdate(context.Background(), nil, makeUpdate(tc.userID, tc.text, 10))
|
||||||
|
assert.Contains(t, sentMessage, tc.wantSubstr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the successful update took effect in memory and on disk.
|
||||||
|
t.Run("model change persisted in memory and on disk", func(t *testing.T) {
|
||||||
|
assert.Equal(t, "claude-sonnet-4-6", string(b.config.Model))
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"claude-sonnet-4-6"`)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHasScope verifies that scope checks honour role assignments and the owner bypass.
|
||||||
|
func TestHasScope(t *testing.T) { //NOSONAR go:S100 -- underscore separation is idiomatic in Go test names
|
||||||
|
const ownerID int64 = 100
|
||||||
|
b, _ := setupBotForTest(t, ownerID)
|
||||||
|
|
||||||
|
// Admin user
|
||||||
|
adminRole, err := b.getRoleByName("admin")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NoError(t, b.db.Create(&User{
|
||||||
|
BotID: b.botID, TelegramID: 200, Username: "admin_user",
|
||||||
|
RoleID: adminRole.ID, Role: adminRole,
|
||||||
|
}).Error)
|
||||||
|
|
||||||
|
// Regular user
|
||||||
|
userRole, err := b.getRoleByName("user")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NoError(t, b.db.Create(&User{
|
||||||
|
BotID: b.botID, TelegramID: 300, Username: "regular_user",
|
||||||
|
RoleID: userRole.ID, Role: userRole,
|
||||||
|
}).Error)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
userID int64
|
||||||
|
scope string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"owner bypass: model:set", ownerID, ScopeModelSet, true},
|
||||||
|
{"owner bypass: stats:view:any", ownerID, ScopeStatsViewAny, true},
|
||||||
|
{"admin: model:set", 200, ScopeModelSet, true},
|
||||||
|
{"admin: stats:view:any", 200, ScopeStatsViewAny, true},
|
||||||
|
{"admin: history:clear:any", 200, ScopeHistoryClearAny, true},
|
||||||
|
{"user: model:set denied", 300, ScopeModelSet, false},
|
||||||
|
{"user: stats:view:any denied", 300, ScopeStatsViewAny, false},
|
||||||
|
{"user: history:clear:any denied", 300, ScopeHistoryClearAny, false},
|
||||||
|
{"user: stats:view:own allowed", 300, ScopeStatsViewOwn, true},
|
||||||
|
{"user: history:clear:own allowed", 300, ScopeHistoryClearOwn, true},
|
||||||
|
{"unknown telegram_id", 999, ScopeModelSet, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tc.want, b.hasScope(tc.userID, tc.scope))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -47,19 +47,13 @@ func main() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize TelegramClient with the bot's handleUpdate method
|
// Start the bot in a separate goroutine
|
||||||
tgClient, err := initTelegramBot(cfg.TelegramToken, bot.handleUpdate)
|
go bot.Start(ctx)
|
||||||
if err != nil {
|
|
||||||
ErrorLogger.Printf("Error initializing Telegram client for bot %s: %v", cfg.ID, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assign the TelegramClient to the bot
|
// Keep the bot running until the context is cancelled
|
||||||
bot.tgBot = tgClient
|
<-ctx.Done()
|
||||||
|
|
||||||
// Start the bot
|
InfoLogger.Printf("Bot %s stopped", cfg.ID)
|
||||||
InfoLogger.Printf("Starting bot %s...", cfg.ID)
|
|
||||||
bot.Start(ctx)
|
|
||||||
}(config)
|
}(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -29,16 +29,19 @@ type ConfigModel struct {
|
|||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
BotID uint
|
BotID uint `gorm:"index"`
|
||||||
ChatID int64
|
ChatID int64 `gorm:"index"`
|
||||||
UserID int64
|
UserID int64 `gorm:"index"`
|
||||||
Username string
|
Username string `gorm:"index"`
|
||||||
UserRole string
|
UserRole string // Store the role as a string
|
||||||
Text string
|
Text string `gorm:"type:text"`
|
||||||
StickerFileID string `json:"sticker_file_id,omitempty"` // New field to store Sticker File ID
|
Timestamp time.Time `gorm:"index"`
|
||||||
StickerPNGFile string `json:"sticker_png_file,omitempty"` // Optionally store PNG file ID if needed
|
|
||||||
Timestamp time.Time
|
|
||||||
IsUser bool
|
IsUser bool
|
||||||
|
StickerFileID string
|
||||||
|
StickerPNGFile string
|
||||||
|
StickerEmoji string // Store the emoji associated with the sticker
|
||||||
|
DeletedAt gorm.DeletedAt `gorm:"index"` // Add soft delete field
|
||||||
|
AnsweredOn *time.Time `gorm:"index"` // Tracks when a user message was answered (NULL for assistant messages and unanswered user messages)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatMemory struct {
|
type ChatMemory struct {
|
||||||
@@ -47,22 +50,41 @@ type ChatMemory struct {
|
|||||||
BusinessConnectionID string // New field to store the business connection ID
|
BusinessConnectionID string // New field to store the business connection ID
|
||||||
}
|
}
|
||||||
|
|
||||||
type Role struct {
|
// Scope name constants — used in DB seeding, hasScope checks, and tests.
|
||||||
|
const (
|
||||||
|
ScopeStatsViewOwn = "stats:view:own"
|
||||||
|
ScopeStatsViewAny = "stats:view:any"
|
||||||
|
ScopeHistoryClearOwn = "history:clear:own"
|
||||||
|
ScopeHistoryClearAny = "history:clear:any"
|
||||||
|
ScopeHistoryClearHardOwn = "history:clear_hard:own"
|
||||||
|
ScopeHistoryClearHardAny = "history:clear_hard:any"
|
||||||
|
ScopeModelSet = "model:set"
|
||||||
|
ScopeUserPromote = "user:promote"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Scope struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Name string `gorm:"uniqueIndex"`
|
Name string `gorm:"uniqueIndex"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Role struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string `gorm:"uniqueIndex"`
|
||||||
|
Scopes []Scope `gorm:"many2many:role_scopes;"`
|
||||||
|
}
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
BotID uint `gorm:"index"` // Foreign key to BotModel
|
BotID uint `gorm:"uniqueIndex:idx_user_bot;index"` // Foreign key to BotModel
|
||||||
TelegramID int64 `gorm:"uniqueIndex;not null"` // Unique per user
|
TelegramID int64 `gorm:"uniqueIndex:idx_user_bot;not null"` // Unique per (telegram_id, bot_id) pair
|
||||||
Username string
|
Username string
|
||||||
RoleID uint
|
RoleID uint
|
||||||
Role Role `gorm:"foreignKey:RoleID"`
|
Role Role `gorm:"foreignKey:RoleID"`
|
||||||
IsOwner bool `gorm:"default:false"` // Indicates if the user is the owner
|
IsOwner bool `gorm:"default:false"` // Indicates if the user is the owner
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compound unique index to ensure only one owner per bot
|
// idx_user_bot is a composite unique index on (bot_id, telegram_id),
|
||||||
|
// allowing the same Telegram user to be registered independently on each bot.
|
||||||
func (User) TableName() string {
|
func (User) TableName() string {
|
||||||
return "users"
|
return "users"
|
||||||
}
|
}
|
||||||
|
|||||||
Executable → Regular
+8
-2
@@ -50,8 +50,14 @@ func (b *Bot) checkRateLimits(userID int64) bool {
|
|||||||
limiter.lastDailyReset = now
|
limiter.lastDailyReset = now
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the message exceeds rate limits
|
// Check if the message exceeds rate limits.
|
||||||
if !limiter.hourlyLimiter.Allow() || !limiter.dailyLimiter.Allow() {
|
// Reserve from both limiters first, then cancel both if either is over budget.
|
||||||
|
// This prevents consuming a token from one limiter when the other rejects.
|
||||||
|
dailyRes := limiter.dailyLimiter.ReserveN(now, 1)
|
||||||
|
hourlyRes := limiter.hourlyLimiter.ReserveN(now, 1)
|
||||||
|
if dailyRes.DelayFrom(now) > 0 || hourlyRes.DelayFrom(now) > 0 {
|
||||||
|
dailyRes.CancelAt(now)
|
||||||
|
hourlyRes.CancelAt(now)
|
||||||
banDuration, err := time.ParseDuration(b.config.TempBanDuration)
|
banDuration, err := time.ParseDuration(b.config.TempBanDuration)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If parsing fails, default to a 24-hour ban
|
// If parsing fails, default to a 24-hour ban
|
||||||
|
|||||||
Executable → Regular
Executable → Regular
+1
-1
@@ -11,6 +11,6 @@ import (
|
|||||||
// TelegramClient defines the methods required from the Telegram bot.
|
// TelegramClient defines the methods required from the Telegram bot.
|
||||||
type TelegramClient interface {
|
type TelegramClient interface {
|
||||||
SendMessage(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error)
|
SendMessage(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error)
|
||||||
|
SetMyCommands(ctx context.Context, params *bot.SetMyCommandsParams) (bool, error)
|
||||||
Start(ctx context.Context)
|
Start(ctx context.Context)
|
||||||
// Add other methods if needed.
|
|
||||||
}
|
}
|
||||||
|
|||||||
Executable → Regular
+20
-8
@@ -6,13 +6,15 @@ import (
|
|||||||
|
|
||||||
"github.com/go-telegram/bot"
|
"github.com/go-telegram/bot"
|
||||||
"github.com/go-telegram/bot/models"
|
"github.com/go-telegram/bot/models"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockTelegramClient is a mock implementation of TelegramClient for testing.
|
// MockTelegramClient is a mock implementation of TelegramClient for testing.
|
||||||
type MockTelegramClient struct {
|
type MockTelegramClient struct {
|
||||||
// You can add fields to keep track of calls if needed.
|
mock.Mock
|
||||||
SendMessageFunc func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error)
|
SendMessageFunc func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error)
|
||||||
StartFunc func(ctx context.Context) // Optional: track Start calls
|
SetMyCommandsFunc func(ctx context.Context, params *bot.SetMyCommandsParams) (bool, error)
|
||||||
|
StartFunc func(ctx context.Context)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendMessage mocks sending a message.
|
// SendMessage mocks sending a message.
|
||||||
@@ -20,16 +22,26 @@ func (m *MockTelegramClient) SendMessage(ctx context.Context, params *bot.SendMe
|
|||||||
if m.SendMessageFunc != nil {
|
if m.SendMessageFunc != nil {
|
||||||
return m.SendMessageFunc(ctx, params)
|
return m.SendMessageFunc(ctx, params)
|
||||||
}
|
}
|
||||||
// Default behavior: return an empty message without error.
|
args := m.Called(ctx, params)
|
||||||
return &models.Message{}, nil
|
if msg, ok := args.Get(0).(*models.Message); ok {
|
||||||
|
return msg, args.Error(1)
|
||||||
|
}
|
||||||
|
return nil, args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMyCommands mocks registering bot commands.
|
||||||
|
func (m *MockTelegramClient) SetMyCommands(ctx context.Context, params *bot.SetMyCommandsParams) (bool, error) {
|
||||||
|
if m.SetMyCommandsFunc != nil {
|
||||||
|
return m.SetMyCommandsFunc(ctx, params)
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start mocks starting the Telegram client.
|
// Start mocks starting the Telegram client.
|
||||||
func (m *MockTelegramClient) Start(ctx context.Context) {
|
func (m *MockTelegramClient) Start(ctx context.Context) {
|
||||||
if m.StartFunc != nil {
|
if m.StartFunc != nil {
|
||||||
m.StartFunc(ctx)
|
m.StartFunc(ctx)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
// Default behavior: do nothing.
|
m.Called(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add other mocked methods if your Bot uses more TelegramClient methods.
|
|
||||||
|
|||||||
Executable → Regular
+130
-14
@@ -12,26 +12,38 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
errOpenDB = "Failed to open in-memory database: %v"
|
||||||
|
errMigrateSchema = "Failed to migrate database schema: %v"
|
||||||
|
errCreateRoles = "Failed to create default roles: %v"
|
||||||
|
errCreateScopes = "Failed to create default scopes: %v"
|
||||||
|
errCreateBot = "Failed to create bot: %v"
|
||||||
|
memoryDSN = ":memory:"
|
||||||
|
)
|
||||||
|
|
||||||
func TestOwnerAssignment(t *testing.T) {
|
func TestOwnerAssignment(t *testing.T) {
|
||||||
// Initialize loggers
|
// Initialize loggers
|
||||||
initLoggers()
|
initLoggers()
|
||||||
|
|
||||||
// Initialize in-memory database for testing
|
// Initialize in-memory database for testing
|
||||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
db, err := gorm.Open(sqlite.Open(memoryDSN), &gorm.Config{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to open in-memory database: %v", err)
|
t.Fatalf(errOpenDB, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Migrate the schema
|
// Migrate the schema
|
||||||
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{})
|
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{}, &Scope{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to migrate database schema: %v", err)
|
t.Fatalf(errMigrateSchema, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create default roles
|
// Create default roles and scopes
|
||||||
err = createDefaultRoles(db)
|
err = createDefaultRoles(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create default roles: %v", err)
|
t.Fatalf(errCreateRoles, err)
|
||||||
|
}
|
||||||
|
if err := createDefaultScopes(db); err != nil {
|
||||||
|
t.Fatalf(errCreateScopes, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a bot configuration
|
// Create a bot configuration
|
||||||
@@ -67,7 +79,7 @@ func TestOwnerAssignment(t *testing.T) {
|
|||||||
// Create the bot with the mock Telegram client
|
// Create the bot with the mock Telegram client
|
||||||
bot, err := NewBot(db, config, mockClock, mockTGClient)
|
bot, err := NewBot(db, config, mockClock, mockTGClient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create bot: %v", err)
|
t.Fatalf(errCreateBot, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify that the owner exists
|
// Verify that the owner exists
|
||||||
@@ -119,21 +131,24 @@ func TestPromoteUserToAdmin(t *testing.T) {
|
|||||||
initLoggers()
|
initLoggers()
|
||||||
|
|
||||||
// Initialize in-memory database for testing
|
// Initialize in-memory database for testing
|
||||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
db, err := gorm.Open(sqlite.Open(memoryDSN), &gorm.Config{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to open in-memory database: %v", err)
|
t.Fatalf(errOpenDB, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Migrate the schema
|
// Migrate the schema
|
||||||
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{})
|
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{}, &Scope{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to migrate database schema: %v", err)
|
t.Fatalf(errMigrateSchema, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create default roles
|
// Create default roles and scopes
|
||||||
err = createDefaultRoles(db)
|
err = createDefaultRoles(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create default roles: %v", err)
|
t.Fatalf(errCreateRoles, err)
|
||||||
|
}
|
||||||
|
if err := createDefaultScopes(db); err != nil {
|
||||||
|
t.Fatalf(errCreateScopes, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config := BotConfig{
|
config := BotConfig{
|
||||||
@@ -153,7 +168,7 @@ func TestPromoteUserToAdmin(t *testing.T) {
|
|||||||
|
|
||||||
bot, err := NewBot(db, config, mockClock, mockTGClient)
|
bot, err := NewBot(db, config, mockClock, mockTGClient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create bot: %v", err)
|
t.Fatalf(errCreateBot, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create an owner
|
// Create an owner
|
||||||
@@ -184,6 +199,107 @@ func TestPromoteUserToAdmin(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestGetOrCreateUser tests the getOrCreateUser method of the Bot.
|
||||||
|
// It verifies that a new user is created when one does not exist,
|
||||||
|
// and an existing user is returned when one does exist.
|
||||||
|
func TestGetOrCreateUser(t *testing.T) {
|
||||||
|
// Initialize loggers
|
||||||
|
initLoggers()
|
||||||
|
|
||||||
|
// Initialize in-memory database for testing
|
||||||
|
db, err := gorm.Open(sqlite.Open(memoryDSN), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf(errOpenDB, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Migrate the schema
|
||||||
|
err = db.AutoMigrate(&BotModel{}, &ConfigModel{}, &Message{}, &User{}, &Role{}, &Scope{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf(errMigrateSchema, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create default roles and scopes
|
||||||
|
err = createDefaultRoles(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf(errCreateRoles, err)
|
||||||
|
}
|
||||||
|
if err := createDefaultScopes(db); err != nil {
|
||||||
|
t.Fatalf(errCreateScopes, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a mock clock starting at a fixed time
|
||||||
|
mockClock := &MockClock{
|
||||||
|
currentTime: time.Date(2023, 10, 1, 0, 0, 0, 0, time.UTC),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a mock configuration
|
||||||
|
config := BotConfig{
|
||||||
|
ID: "bot1",
|
||||||
|
MemorySize: 10,
|
||||||
|
MessagePerHour: 5,
|
||||||
|
MessagePerDay: 10,
|
||||||
|
TempBanDuration: "1m",
|
||||||
|
SystemPrompts: make(map[string]string),
|
||||||
|
TelegramToken: "YOUR_TELEGRAM_BOT_TOKEN",
|
||||||
|
OwnerTelegramID: 123456789,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize MockTelegramClient
|
||||||
|
mockTGClient := &MockTelegramClient{
|
||||||
|
SendMessageFunc: func(ctx context.Context, params *bot.SendMessageParams) (*models.Message, error) {
|
||||||
|
chatID, ok := params.ChatID.(int64)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("ChatID is not of type int64")
|
||||||
|
}
|
||||||
|
// Simulate successful message sending
|
||||||
|
return &models.Message{ID: 1, Chat: models.Chat{ID: chatID}}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the bot with the mock Telegram client
|
||||||
|
bot, err := NewBot(db, config, mockClock, mockTGClient)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf(errCreateBot, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the owner exists
|
||||||
|
var owner User
|
||||||
|
err = db.Where("telegram_id = ? AND bot_id = ? AND is_owner = ?", config.OwnerTelegramID, bot.botID, true).First(&owner).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Owner was not created: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to create another owner for the same bot
|
||||||
|
_, err = bot.getOrCreateUser(222222222, "AnotherOwner", true)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Expected error when creating a second owner, but got none")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new user
|
||||||
|
newUser, err := bot.getOrCreateUser(987654321, "TestUser", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create a new user: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the new user was created
|
||||||
|
var userInDB User
|
||||||
|
err = db.Where("telegram_id = ?", newUser.TelegramID).First(&userInDB).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New user was not created in the database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the existing user
|
||||||
|
existingUser, err := bot.getOrCreateUser(987654321, "TestUser", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get existing user: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the existing user is the same as the new user
|
||||||
|
if existingUser.ID != userInDB.ID {
|
||||||
|
t.Fatalf("Expected to get the existing user, but got a different user")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// To ensure thread safety and avoid race conditions during testing,
|
// To ensure thread safety and avoid race conditions during testing,
|
||||||
// you can run the tests with the `-race` flag:
|
// you can run the tests with the `-race` flag:
|
||||||
// go test -race -v
|
// go test -race -v
|
||||||
|
|||||||
Reference in New Issue
Block a user