80 lines
2.7 KiB
Python
80 lines
2.7 KiB
Python
"""LLM configuration settings."""
|
|
from dataclasses import dataclass
|
|
from typing import Optional, Dict, Any
|
|
|
|
@dataclass
|
|
class LLMConfig:
|
|
provider: str
|
|
model_name: str
|
|
api_key: str
|
|
api_base: Optional[str] = None
|
|
additional_params: Optional[Dict[str, Any]] = None
|
|
|
|
class LLMProviderSettings:
|
|
"""Settings for different LLM providers."""
|
|
|
|
OPENAI_SETTINGS = {
|
|
'gpt-3.5-turbo-16k': {
|
|
'provider': 'openai',
|
|
'model_name': 'gpt-3.5-turbo-16k',
|
|
'max_tokens': 16000,
|
|
'temperature': 0.7,
|
|
},
|
|
'gpt-4': {
|
|
'provider': 'openai',
|
|
'model_name': 'gpt-4',
|
|
'max_tokens': 8000,
|
|
'temperature': 0.7,
|
|
}
|
|
}
|
|
|
|
DEEPSEEK_SETTINGS = {
|
|
'deepseek-chat': {
|
|
'provider': 'deepseek',
|
|
'model_name': 'deepseek-chat',
|
|
'max_tokens': 8000,
|
|
'temperature': 0.7,
|
|
'api_base': 'https://api.deepseek.com/v1', # Example API base, replace with actual
|
|
}
|
|
}
|
|
|
|
@classmethod
|
|
def get_config(cls, provider: str, model_name: str, api_key: str) -> LLMConfig:
|
|
"""Get LLM configuration for a specific provider and model."""
|
|
if provider == 'openai':
|
|
if model_name in cls.OPENAI_SETTINGS:
|
|
settings = cls.OPENAI_SETTINGS[model_name]
|
|
return LLMConfig(
|
|
provider=settings['provider'],
|
|
model_name=settings['model_name'],
|
|
api_key=api_key,
|
|
additional_params={
|
|
'max_tokens': settings['max_tokens'],
|
|
'temperature': settings['temperature']
|
|
}
|
|
)
|
|
elif provider == 'deepseek':
|
|
if model_name in cls.DEEPSEEK_SETTINGS:
|
|
settings = cls.DEEPSEEK_SETTINGS[model_name]
|
|
return LLMConfig(
|
|
provider=settings['provider'],
|
|
model_name=settings['model_name'],
|
|
api_key=api_key,
|
|
api_base=settings['api_base'],
|
|
additional_params={
|
|
'max_tokens': settings['max_tokens'],
|
|
'temperature': settings['temperature']
|
|
}
|
|
)
|
|
|
|
raise ValueError(f"Unsupported provider '{provider}' or model '{model_name}'")
|
|
|
|
@classmethod
|
|
def list_available_models(cls):
|
|
"""List all available models and their providers."""
|
|
models = {
|
|
'openai': list(cls.OPENAI_SETTINGS.keys()),
|
|
'deepseek': list(cls.DEEPSEEK_SETTINGS.keys())
|
|
}
|
|
return models
|