ParamManagerScripts/backend/script_groups/ragex/x2.py

127 lines
3.7 KiB
Python

"""
Este script realiza la consulta usando RAGEX a la base de datos de documentos.
"""
import os
import sys
from pathlib import Path
import json
from langchain_community.vectorstores import Chroma
from langchain_openai import (
OpenAIEmbeddings,
) # Cambiado de HuggingFaceEmbeddings a OpenAIEmbeddings
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from rich.console import Console
from rich.markdown import Markdown
import os
import argparse
from openai_api_key import openai_api_key
console = Console()
class CitationTracker:
def __init__(self):
self.citations = []
def add_citation(self, text, metadata):
self.citations.append({"text": text, "metadata": metadata})
def get_formatted_citations(self):
result = "\n## Fuentes\n\n"
for i, citation in enumerate(self.citations, 1):
source = citation["metadata"]["source"]
result += f"{i}. [{os.path.basename(source)}]({source}) - Fragmento {citation['metadata']['chunk_id']}\n"
return result
def search_with_citation(query, db_directory, model="gpt-3.5-turbo"):
# Cargar embeddings y base de datos
embeddings = OpenAIEmbeddings(
model="text-embedding-3-small"
) # Usar OpenAI Embeddings igual que en x1.py
db = Chroma(persist_directory=db_directory, embedding_function=embeddings)
api_key = openai_api_key()
os.environ["OPENAI_API_KEY"] = api_key
# Configurar el LLM de OpenAI
llm = ChatOpenAI(model_name=model)
# Rastreador de citas
citation_tracker = CitationTracker()
# Recuperar documentos relevantes
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
# Plantilla para el prompt
template = """
Responde a la siguiente pregunta basándote exclusivamente en la información proporcionada.
Incluye referencias a las fuentes originales para cada afirmación importante usando [Fuente N].
Si la información no es suficiente, indícalo claramente.
Contexto:
{context}
Pregunta: {question}
Respuesta (incluye [Fuente N] para citar):
"""
prompt = ChatPromptTemplate.from_template(template)
# Función para formatear el contexto
def format_docs(docs):
formatted_context = ""
for i, doc in enumerate(docs, 1):
citation_tracker.add_citation(doc.page_content, doc.metadata)
formatted_context += f"[Fuente {i}]: {doc.page_content}\n\n"
return formatted_context
# Cadena RAG
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
# Ejecutar búsqueda
response = rag_chain.invoke(query)
# Agregar citas al final
full_response = response + "\n\n" + citation_tracker.get_formatted_citations()
return full_response
def main():
# Cargar configuraciones del entorno
configs = json.loads(os.environ.get("SCRIPT_CONFIGS", "{}"))
# Obtener working directory
working_directory = configs.get("working_directory", ".")
# Obtener configuraciones de nivel 2 (grupo)
group_config = configs.get("level2", {})
work_config = configs.get("level3", {})
in_dir = work_config.get("in_dir", ".")
docs_directory = os.path.join(working_directory, in_dir)
model = work_config.get("model", "gpt-3.5-turbo")
query = work_config.get("query", "")
db_directory = os.path.join(working_directory, "chroma_db")
result = search_with_citation(query, db_directory, model)
console.print(Markdown(result))
if __name__ == "__main__":
main()