sach91 bootcamp week8 exercise
This commit is contained in:
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Ollama Client - Wrapper for local Ollama API
|
||||
"""
|
||||
import requests
|
||||
import json
|
||||
from typing import List, Dict, Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class OllamaClient:
|
||||
"""Client for interacting with local Ollama models"""
|
||||
|
||||
def __init__(self, base_url: str = "http://localhost:11434", model: str = "llama3.2"):
|
||||
self.base_url = base_url
|
||||
self.model = model
|
||||
self.api_url = f"{base_url}/api"
|
||||
|
||||
def generate(self, prompt: str, system: Optional[str] = None,
|
||||
temperature: float = 0.7, max_tokens: int = 2048) -> str:
|
||||
"""Generate text from a prompt"""
|
||||
try:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": temperature,
|
||||
"num_predict": max_tokens
|
||||
}
|
||||
}
|
||||
|
||||
if system:
|
||||
payload["system"] = system
|
||||
|
||||
response = requests.post(
|
||||
f"{self.api_url}/generate",
|
||||
json=payload,
|
||||
timeout=1200
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
return result.get("response", "").strip()
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Ollama API error: {e}")
|
||||
return f"Error: Unable to connect to Ollama. Is it running? ({str(e)})"
|
||||
|
||||
def chat(self, messages: List[Dict[str, str]],
|
||||
temperature: float = 0.7, max_tokens: int = 2048) -> str:
|
||||
"""Chat completion with message history"""
|
||||
try:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": temperature,
|
||||
"num_predict": max_tokens
|
||||
}
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{self.api_url}/chat",
|
||||
json=payload,
|
||||
timeout=1200
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
return result.get("message", {}).get("content", "").strip()
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Ollama API error: {e}")
|
||||
return f"Error: Unable to connect to Ollama. Is it running? ({str(e)})"
|
||||
|
||||
def check_connection(self) -> bool:
|
||||
"""Check if Ollama is running and model is available"""
|
||||
try:
|
||||
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
|
||||
response.raise_for_status()
|
||||
|
||||
models = response.json().get("models", [])
|
||||
model_names = [m["name"] for m in models]
|
||||
|
||||
if self.model not in model_names:
|
||||
logger.warning(f"Model {self.model} not found. Available: {model_names}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Cannot connect to Ollama: {e}")
|
||||
return False
|
||||
|
||||
def list_models(self) -> List[str]:
|
||||
"""List available Ollama models"""
|
||||
try:
|
||||
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
|
||||
response.raise_for_status()
|
||||
|
||||
models = response.json().get("models", [])
|
||||
return [m["name"] for m in models]
|
||||
|
||||
except requests.exceptions.RequestException:
|
||||
return []
|
||||
Reference in New Issue
Block a user