108 lines
3.6 KiB
Python
108 lines
3.6 KiB
Python
"""
|
|
Ollama Client - Wrapper for local Ollama API
|
|
"""
|
|
import requests
|
|
import json
|
|
from typing import List, Dict, Optional
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class OllamaClient:
|
|
"""Client for interacting with local Ollama models"""
|
|
|
|
def __init__(self, base_url: str = "http://localhost:11434", model: str = "llama3.2"):
|
|
self.base_url = base_url
|
|
self.model = model
|
|
self.api_url = f"{base_url}/api"
|
|
|
|
def generate(self, prompt: str, system: Optional[str] = None,
|
|
temperature: float = 0.7, max_tokens: int = 2048) -> str:
|
|
"""Generate text from a prompt"""
|
|
try:
|
|
payload = {
|
|
"model": self.model,
|
|
"prompt": prompt,
|
|
"stream": False,
|
|
"options": {
|
|
"temperature": temperature,
|
|
"num_predict": max_tokens
|
|
}
|
|
}
|
|
|
|
if system:
|
|
payload["system"] = system
|
|
|
|
response = requests.post(
|
|
f"{self.api_url}/generate",
|
|
json=payload,
|
|
timeout=1200
|
|
)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
return result.get("response", "").strip()
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
logger.error(f"Ollama API error: {e}")
|
|
return f"Error: Unable to connect to Ollama. Is it running? ({str(e)})"
|
|
|
|
def chat(self, messages: List[Dict[str, str]],
|
|
temperature: float = 0.7, max_tokens: int = 2048) -> str:
|
|
"""Chat completion with message history"""
|
|
try:
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"stream": False,
|
|
"options": {
|
|
"temperature": temperature,
|
|
"num_predict": max_tokens
|
|
}
|
|
}
|
|
|
|
response = requests.post(
|
|
f"{self.api_url}/chat",
|
|
json=payload,
|
|
timeout=1200
|
|
)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
return result.get("message", {}).get("content", "").strip()
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
logger.error(f"Ollama API error: {e}")
|
|
return f"Error: Unable to connect to Ollama. Is it running? ({str(e)})"
|
|
|
|
def check_connection(self) -> bool:
|
|
"""Check if Ollama is running and model is available"""
|
|
try:
|
|
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
|
|
response.raise_for_status()
|
|
|
|
models = response.json().get("models", [])
|
|
model_names = [m["name"] for m in models]
|
|
|
|
if self.model not in model_names:
|
|
logger.warning(f"Model {self.model} not found. Available: {model_names}")
|
|
return False
|
|
|
|
return True
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
logger.error(f"Cannot connect to Ollama: {e}")
|
|
return False
|
|
|
|
def list_models(self) -> List[str]:
|
|
"""List available Ollama models"""
|
|
try:
|
|
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
|
|
response.raise_for_status()
|
|
|
|
models = response.json().get("models", [])
|
|
return [m["name"] for m in models]
|
|
|
|
except requests.exceptions.RequestException:
|
|
return []
|