107 lines
4.1 KiB
Python
107 lines
4.1 KiB
Python
# services/llm/grok_service.py
|
|
"""
|
|
Grok (xAI) service implementation
|
|
"""
|
|
import httpx
|
|
from typing import Dict, List, Optional
|
|
import json
|
|
from .base import LLMService
|
|
from config.api_keys import APIKeyManager
|
|
from utils.logger import setup_logger
|
|
|
|
|
|
class GrokService(LLMService):
|
|
def __init__(
|
|
self,
|
|
model: str = "grok-3-mini-fast",
|
|
temperature: float = 0.3,
|
|
max_tokens: int = 16000,
|
|
): # Debemos usar el modelo grok-3-mini-fast
|
|
api_key = APIKeyManager.get_grok_key()
|
|
if not api_key:
|
|
raise ValueError(
|
|
"Grok API key not found. Please set the GROK_API_KEY environment variable."
|
|
)
|
|
|
|
self.api_key = api_key
|
|
self.model = model
|
|
self.temperature = temperature
|
|
self.max_tokens = max_tokens
|
|
self.base_url = "https://api.x.ai/v1"
|
|
self.client = httpx.Client(
|
|
headers={
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
)
|
|
self.logger = setup_logger("grok_xai")
|
|
|
|
def _send_request(self, payload: Dict) -> Optional[Dict]:
|
|
"""Sends a request to the Grok API."""
|
|
try:
|
|
response = self.client.post(
|
|
f"{self.base_url}/chat/completions", json=payload, timeout=60
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
except httpx.HTTPStatusError as e:
|
|
self.logger.error(
|
|
f"Error in Grok API call: {e.response.status_code} - {e.response.text}"
|
|
)
|
|
print(
|
|
f"Error in Grok API call: {e.response.status_code} - {e.response.text}"
|
|
)
|
|
return None
|
|
except Exception as e:
|
|
self.logger.error(f"An unexpected error occurred: {e}")
|
|
print(f"An unexpected error occurred: {e}")
|
|
return None
|
|
|
|
def generate_text(self, prompt: str) -> str:
|
|
self.logger.info(f"--- PROMPT ---\n{prompt}")
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"temperature": self.temperature,
|
|
"max_tokens": self.max_tokens,
|
|
}
|
|
response_data = self._send_request(payload)
|
|
if response_data and response_data.get("choices"):
|
|
response_content = response_data["choices"][0]["message"]["content"]
|
|
self.logger.info(f"--- RESPONSE ---\n{response_content}")
|
|
return response_content
|
|
return "Failed to get a response from Grok."
|
|
|
|
def get_similarity_scores(self, texts_pairs: Dict[str, List[str]]) -> List[float]:
|
|
system_prompt = (
|
|
"You are an expert in semantic analysis. Evaluate the semantic similarity between the pairs of texts provided. "
|
|
"Return your response ONLY as a JSON object containing a single key 'similarity_scores' with a list of floats from 0.0 to 1.0. "
|
|
"Do not include any other text, explanation, or markdown formatting. The output must be a valid JSON."
|
|
)
|
|
request_payload = json.dumps(texts_pairs)
|
|
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": request_payload},
|
|
],
|
|
"temperature": self.temperature,
|
|
"max_tokens": self.max_tokens,
|
|
"response_format": {"type": "json_object"},
|
|
}
|
|
|
|
response_data = self._send_request(payload)
|
|
if response_data and response_data.get("choices"):
|
|
response_content = response_data["choices"][0]["message"]["content"]
|
|
try:
|
|
scores_data = json.loads(response_content)
|
|
if isinstance(scores_data, dict) and "similarity_scores" in scores_data:
|
|
return scores_data["similarity_scores"]
|
|
else:
|
|
raise ValueError("Unexpected JSON format from Grok.")
|
|
except (json.JSONDecodeError, ValueError) as e:
|
|
print(f"Error decoding Grok JSON response: {e}")
|
|
return None
|
|
return None
|