228 lines
6.9 KiB
Python
228 lines
6.9 KiB
Python
|
# services/llm/batch_processor.py
|
||
|
"""
|
||
|
Batch processing service for LLM operations
|
||
|
"""
|
||
|
from typing import List, Dict, Any, Optional, Callable
|
||
|
import json
|
||
|
from dataclasses import dataclass
|
||
|
import time
|
||
|
from .base import LLMService
|
||
|
from utils.progress_bar import ProgressBar
|
||
|
|
||
|
@dataclass
|
||
|
class BatchConfig:
|
||
|
"""Configuration for batch processing"""
|
||
|
batch_size: int = 20
|
||
|
max_retries: int = 3
|
||
|
retry_delay: int = 3
|
||
|
progress_callback: Optional[Callable[[str], None]] = None
|
||
|
|
||
|
class BatchProcessor:
|
||
|
"""
|
||
|
Handles batch processing for LLM operations
|
||
|
"""
|
||
|
def __init__(
|
||
|
self,
|
||
|
llm_service: LLMService,
|
||
|
config: Optional[BatchConfig] = None
|
||
|
):
|
||
|
self.llm_service = llm_service
|
||
|
self.config = config or BatchConfig()
|
||
|
|
||
|
def process_batch(
|
||
|
self,
|
||
|
items: List[Dict[str, Any]],
|
||
|
system_prompt: str,
|
||
|
template: str,
|
||
|
output_processor: Optional[Callable] = None
|
||
|
) -> List[Any]:
|
||
|
"""
|
||
|
Process items in batches with consistent context
|
||
|
|
||
|
Args:
|
||
|
items: List of dictionaries containing data to process
|
||
|
system_prompt: System prompt for context
|
||
|
template: Template string for formatting requests
|
||
|
output_processor: Optional function to process LLM responses
|
||
|
|
||
|
Returns:
|
||
|
List of processed results
|
||
|
"""
|
||
|
results = []
|
||
|
total_items = len(items)
|
||
|
|
||
|
# Setup progress tracking
|
||
|
progress = ProgressBar(
|
||
|
total_items,
|
||
|
"Processing batches:",
|
||
|
"Complete"
|
||
|
)
|
||
|
if self.config.progress_callback:
|
||
|
progress.set_output_callback(self.config.progress_callback)
|
||
|
progress.start()
|
||
|
|
||
|
# Process in batches
|
||
|
for start_idx in range(0, total_items, self.config.batch_size):
|
||
|
end_idx = min(start_idx + self.config.batch_size, total_items)
|
||
|
batch_items = items[start_idx:end_idx]
|
||
|
|
||
|
# Prepare batch request
|
||
|
batch_data = {
|
||
|
"items": batch_items,
|
||
|
"template": template
|
||
|
}
|
||
|
request_payload = json.dumps(batch_data)
|
||
|
|
||
|
# Process batch with retries
|
||
|
for attempt in range(self.config.max_retries):
|
||
|
try:
|
||
|
response = self.llm_service.generate_text(
|
||
|
system_prompt=system_prompt,
|
||
|
user_prompt=request_payload
|
||
|
)
|
||
|
|
||
|
# Parse and process response
|
||
|
batch_results = self._process_response(
|
||
|
response,
|
||
|
output_processor
|
||
|
)
|
||
|
|
||
|
if len(batch_results) != len(batch_items):
|
||
|
raise ValueError(
|
||
|
"Response count doesn't match input count"
|
||
|
)
|
||
|
|
||
|
results.extend(batch_results)
|
||
|
break
|
||
|
|
||
|
except Exception as e:
|
||
|
if attempt < self.config.max_retries - 1:
|
||
|
if self.config.progress_callback:
|
||
|
self.config.progress_callback(
|
||
|
f"Error in batch {start_idx}-{end_idx}: {e}. Retrying..."
|
||
|
)
|
||
|
time.sleep(self.config.retry_delay)
|
||
|
else:
|
||
|
if self.config.progress_callback:
|
||
|
self.config.progress_callback(
|
||
|
f"Error in batch {start_idx}-{end_idx}: {e}"
|
||
|
)
|
||
|
# On final retry failure, add None results
|
||
|
results.extend([None] * len(batch_items))
|
||
|
|
||
|
# Update progress
|
||
|
progress.update(end_idx)
|
||
|
|
||
|
progress.finish()
|
||
|
return results
|
||
|
|
||
|
def _process_response(
|
||
|
self,
|
||
|
response: str,
|
||
|
output_processor: Optional[Callable] = None
|
||
|
) -> List[Any]:
|
||
|
"""Process LLM response"""
|
||
|
try:
|
||
|
# Parse JSON response
|
||
|
parsed = json.loads(response)
|
||
|
|
||
|
# Apply custom processing if provided
|
||
|
if output_processor:
|
||
|
return [output_processor(item) for item in parsed]
|
||
|
return parsed
|
||
|
|
||
|
except json.JSONDecodeError:
|
||
|
raise ValueError("Failed to parse LLM response as JSON")
|
||
|
|
||
|
# Example specialized batch processor for translations
|
||
|
class TranslationBatchProcessor(BatchProcessor):
|
||
|
"""Specialized batch processor for translations"""
|
||
|
|
||
|
def translate_batch(
|
||
|
self,
|
||
|
texts: List[str],
|
||
|
source_lang: str,
|
||
|
target_lang: str
|
||
|
) -> List[str]:
|
||
|
"""
|
||
|
Translate a batch of texts
|
||
|
|
||
|
Args:
|
||
|
texts: List of texts to translate
|
||
|
source_lang: Source language code
|
||
|
target_lang: Target language code
|
||
|
|
||
|
Returns:
|
||
|
List of translated texts
|
||
|
"""
|
||
|
# Prepare items
|
||
|
items = [{"text": text} for text in texts]
|
||
|
|
||
|
# Setup prompts
|
||
|
system_prompt = (
|
||
|
"You are a translator. Translate the provided texts "
|
||
|
"maintaining special fields like <> and <#>."
|
||
|
)
|
||
|
|
||
|
template = (
|
||
|
"Translate the following texts from {source_lang} to {target_lang}. "
|
||
|
"Return translations as a JSON array of strings:"
|
||
|
"\n\n{text}"
|
||
|
)
|
||
|
|
||
|
# Process batch
|
||
|
results = self.process_batch(
|
||
|
items=items,
|
||
|
system_prompt=system_prompt,
|
||
|
template=template.format(
|
||
|
source_lang=source_lang,
|
||
|
target_lang=target_lang
|
||
|
)
|
||
|
)
|
||
|
|
||
|
return results
|
||
|
|
||
|
# Example usage:
|
||
|
"""
|
||
|
from services.llm.llm_factory import LLMFactory
|
||
|
from services.llm.batch_processor import BatchProcessor, BatchConfig, TranslationBatchProcessor
|
||
|
|
||
|
# Create LLM service
|
||
|
llm_service = LLMFactory.create_service("openai")
|
||
|
|
||
|
# Setup batch processor with progress callback
|
||
|
def progress_callback(message: str):
|
||
|
print(message)
|
||
|
|
||
|
config = BatchConfig(
|
||
|
batch_size=20,
|
||
|
progress_callback=progress_callback
|
||
|
)
|
||
|
|
||
|
# General batch processor
|
||
|
processor = BatchProcessor(llm_service, config)
|
||
|
|
||
|
# Example batch process for custom task
|
||
|
items = [
|
||
|
{"text": "Hello", "context": "greeting"},
|
||
|
{"text": "Goodbye", "context": "farewell"}
|
||
|
]
|
||
|
|
||
|
system_prompt = "You are a helpful assistant."
|
||
|
template = "Process these items considering their context: {items}"
|
||
|
|
||
|
results = processor.process_batch(
|
||
|
items=items,
|
||
|
system_prompt=system_prompt,
|
||
|
template=template
|
||
|
)
|
||
|
|
||
|
# Example translation batch
|
||
|
translator = TranslationBatchProcessor(llm_service, config)
|
||
|
texts = ["Hello world", "How are you?"]
|
||
|
translations = translator.translate_batch(
|
||
|
texts=texts,
|
||
|
source_lang="en",
|
||
|
target_lang="es"
|
||
|
)
|
||
|
"""
|