This commit is contained in:
HugeFrog24
2024-10-23 23:19:08 +02:00
parent 6e2d2fce2f
commit 166200c473

140
config.go
View File

@@ -5,22 +5,23 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"github.com/liushuangls/go-anthropic/v2" "github.com/liushuangls/go-anthropic/v2"
) )
type BotConfig struct { type BotConfig struct {
ID string `json:"id"` // Unique identifier for the bot ID string `json:"id"`
TelegramToken string `json:"telegram_token"` // Telegram Bot Token TelegramToken string `json:"telegram_token"`
MemorySize int `json:"memory_size"` MemorySize int `json:"memory_size"`
MessagePerHour int `json:"messages_per_hour"` MessagePerHour int `json:"messages_per_hour"`
MessagePerDay int `json:"messages_per_day"` MessagePerDay int `json:"messages_per_day"`
TempBanDuration string `json:"temp_ban_duration"` TempBanDuration string `json:"temp_ban_duration"`
Model anthropic.Model `json:"model"` // Changed from string to anthropic.Model Model anthropic.Model `json:"model"`
SystemPrompts map[string]string `json:"system_prompts"` SystemPrompts map[string]string `json:"system_prompts"`
Active bool `json:"active"` // New field to control bot activity Active bool `json:"active"`
OwnerTelegramID int64 `json:"owner_telegram_id"` OwnerTelegramID int64 `json:"owner_telegram_id"`
AnthropicAPIKey string `json:"anthropic_api_key"` // Add this line AnthropicAPIKey string `json:"anthropic_api_key"`
} }
// Custom unmarshalling to handle anthropic.Model // Custom unmarshalling to handle anthropic.Model
@@ -32,15 +33,53 @@ func (c *BotConfig) UnmarshalJSON(data []byte) error {
}{ }{
Alias: (*Alias)(c), Alias: (*Alias)(c),
} }
if err := json.Unmarshal(data, &aux); err != nil { if err := json.Unmarshal(data, &aux); err != nil {
return err return err
} }
c.Model = anthropic.Model(aux.Model) c.Model = anthropic.Model(aux.Model)
return nil return nil
} }
// validateConfigPath ensures the file path is within the allowed directory
func validateConfigPath(configDir, filename string) (string, error) {
// Clean the paths to remove any . or .. components
configDir = filepath.Clean(configDir)
filename = filepath.Clean(filename)
// Get absolute paths
absConfigDir, err := filepath.Abs(configDir)
if err != nil {
return "", fmt.Errorf("failed to get absolute path for config directory: %w", err)
}
fullPath := filepath.Join(absConfigDir, filename)
absPath, err := filepath.Abs(fullPath)
if err != nil {
return "", fmt.Errorf("failed to get absolute path for config file: %w", err)
}
// Check if the file path is within the config directory
if !isSubPath(absConfigDir, absPath) {
return "", fmt.Errorf("invalid config path: file must be within the config directory")
}
// Verify file extension
if filepath.Ext(absPath) != ".json" {
return "", fmt.Errorf("invalid file extension: must be .json")
}
return absPath, nil
}
// isSubPath checks if childPath is a subdirectory of parentPath
func isSubPath(parentPath, childPath string) bool {
rel, err := filepath.Rel(parentPath, childPath)
if err != nil {
return false
}
return !strings.HasPrefix(rel, ".."+string(filepath.Separator)) && rel != ".."
}
func loadAllConfigs(dir string) ([]BotConfig, error) { func loadAllConfigs(dir string) ([]BotConfig, error) {
var configs []BotConfig var configs []BotConfig
ids := make(map[string]bool) ids := make(map[string]bool)
@@ -53,55 +92,59 @@ func loadAllConfigs(dir string) ([]BotConfig, error) {
for _, file := range files { for _, file := range files {
if filepath.Ext(file.Name()) == ".json" { if filepath.Ext(file.Name()) == ".json" {
configPath := filepath.Join(dir, file.Name()) validPath, err := validateConfigPath(dir, file.Name())
config, err := loadConfig(configPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load config %s: %w", configPath, err) return nil, fmt.Errorf("invalid config path: %w", err)
} }
// Skip inactive bots config, err := loadConfig(validPath)
if err != nil {
return nil, fmt.Errorf("failed to load config %s: %w", validPath, err)
}
// Validation checks...
if !config.Active { if !config.Active {
InfoLogger.Printf("Skipping inactive bot: %s", config.ID) InfoLogger.Printf("Skipping inactive bot: %s", config.ID)
continue continue
} }
// Validate that ID is present if err := validateConfig(&config, ids, tokens); err != nil {
if config.ID == "" { return nil, fmt.Errorf("config validation failed for %s: %w", validPath, err)
return nil, fmt.Errorf("config %s is missing 'id' field", configPath)
}
// Check for unique ID
if _, exists := ids[config.ID]; exists {
return nil, fmt.Errorf("duplicate bot id '%s' found in %s", config.ID, configPath)
}
ids[config.ID] = true
// Validate Telegram Token
if config.TelegramToken == "" {
return nil, fmt.Errorf("config %s is missing 'telegram_token' field", configPath)
}
// Check for unique Telegram Token
if _, exists := tokens[config.TelegramToken]; exists {
return nil, fmt.Errorf("duplicate telegram_token '%s' found in %s", config.TelegramToken, configPath)
}
tokens[config.TelegramToken] = true
// Validate Model
if config.Model == "" {
return nil, fmt.Errorf("config %s is missing 'model' field", configPath)
} }
configs = append(configs, config) configs = append(configs, config)
} }
} }
return configs, nil return configs, nil
} }
func validateConfig(config *BotConfig, ids, tokens map[string]bool) error {
if config.ID == "" {
return fmt.Errorf("missing 'id' field")
}
if _, exists := ids[config.ID]; exists {
return fmt.Errorf("duplicate bot id '%s'", config.ID)
}
ids[config.ID] = true
if config.TelegramToken == "" {
return fmt.Errorf("missing 'telegram_token' field")
}
if _, exists := tokens[config.TelegramToken]; exists {
return fmt.Errorf("duplicate telegram_token")
}
tokens[config.TelegramToken] = true
if config.Model == "" {
return fmt.Errorf("missing 'model' field")
}
return nil
}
func loadConfig(filename string) (BotConfig, error) { func loadConfig(filename string) (BotConfig, error) {
var config BotConfig var config BotConfig
file, err := os.Open(filename) file, err := os.OpenFile(filename, os.O_RDONLY, 0)
if err != nil { if err != nil {
return config, fmt.Errorf("failed to open config file %s: %w", filename, err) return config, fmt.Errorf("failed to open config file %s: %w", filename, err)
} }
@@ -116,19 +159,30 @@ func loadConfig(filename string) (BotConfig, error) {
} }
func (c *BotConfig) Reload(filename string) error { func (c *BotConfig) Reload(filename string) error {
file, err := os.Open(filename) // Get the directory of the current executable
execDir, err := os.Executable()
if err != nil { if err != nil {
return fmt.Errorf("failed to open config file %s: %w", filename, err) return fmt.Errorf("failed to get executable directory: %w", err)
}
configDir := filepath.Dir(execDir)
// Validate the config path
validPath, err := validateConfigPath(configDir, filename)
if err != nil {
return fmt.Errorf("invalid config path: %w", err)
}
file, err := os.OpenFile(validPath, os.O_RDONLY, 0)
if err != nil {
return fmt.Errorf("failed to open config file %s: %w", validPath, err)
} }
defer file.Close() defer file.Close()
decoder := json.NewDecoder(file) decoder := json.NewDecoder(file)
if err := decoder.Decode(c); err != nil { if err := decoder.Decode(c); err != nil {
return fmt.Errorf("failed to decode JSON from %s: %w", filename, err) return fmt.Errorf("failed to decode JSON from %s: %w", validPath, err)
} }
// Ensure the Model is correctly casted
c.Model = anthropic.Model(c.Model) c.Model = anthropic.Model(c.Model)
return nil return nil
} }