Note
Go to the end to download the full example code.
RAG-GUI#
A cookbook demonstrating how to run a RAG app on streamlit.
import os
import sys
from pathlib import Path
import streamlit as st
from grag.components.multivec_retriever import Retriever
from grag.components.utils import get_config
from grag.components.vectordb.deeplake_client import DeepLakeClient
from grag.rag.basic_rag import BasicRAG
sys.path.insert(1, str(Path(os.getcwd()).parents[1]))
st.set_page_config(
page_title="GRAG",
menu_items={
"Get Help": "https://github.com/arjbingly/Capstone_5",
"About": "This is a simple GUI for GRAG",
},
)
def spinner(text):
"""Decorator that displays a loading spinner with a custom text message during the execution of a function.
This decorator wraps any function to show a spinner using Streamlit's st.spinner during the function call,
indicating that an operation is in progress. The spinner is displayed with a user-defined text message.
Args:
text (str): The message to display next to the spinner.
Returns:
function: A decorator that takes a function and wraps it in a spinner context.
"""
def _spinner(func):
"""A decorator function that takes another function and wraps it to show a spinner during its execution.
Args:
func (function): The function to wrap.
Returns:
function: The wrapped function with a spinner displayed during its execution.
"""
def wrapper_func(*args, **kwargs):
"""The wrapper function that actually executes the wrapped function within the spinner context.
Args:
*args: Positional arguments passed to the wrapped function.
**kwargs: Keyword arguments passed to the wrapped function.
"""
with st.spinner(text=text):
func(*args, **kwargs)
return wrapper_func
return _spinner
@st.cache_data
def load_config():
"""Loads config."""
return get_config()
conf = load_config()
class RAGApp:
"""Application class to manage a Retrieval-Augmented Generation (RAG) model interface.
Attributes:
app: The main application or server instance hosting the RAG model.
conf: Configuration settings or parameters for the application.
"""
def __init__(self, app, conf):
"""Initializes the RAGApp with a given application and configuration.
Args:
app: The main application or framework instance that this class will interact with.
conf: A configuration object or dictionary containing settings for the application.
"""
self.app = app
self.conf = conf
def render_sidebar(self):
"""Renders the sidebar in the application interface with model selection and parameters."""
with st.sidebar:
st.title("GRAG")
st.subheader("Models and parameters")
st.sidebar.selectbox(
"Choose a model",
[
"Llama-2-13b-chat",
"Llama-2-7b-chat",
"Mixtral-8x7B-Instruct-v0.1",
"gemma-7b-it",
],
key="selected_model",
)
st.sidebar.slider(
"Temperature",
min_value=0.1,
max_value=1.0,
value=0.1,
step=0.1,
key="temperature",
)
st.sidebar.slider(
"Top-k", min_value=1, max_value=5, value=3, step=1, key="top_k"
)
st.button("Load Model", on_click=self.load_rag)
st.checkbox("Show sources", key="show_sources")
@spinner(text="Loading model...")
def load_rag(self):
"""Loads the specified RAG model based on the user's selection and settings in the sidebar."""
if "rag" in st.session_state:
del st.session_state["rag"]
llm_kwargs = {
"temperature": st.session_state["temperature"],
}
if st.session_state["selected_model"] == "Mixtral-8x7B-Instruct-v0.1":
llm_kwargs["n_gpu_layers"] = 16
llm_kwargs["quantization"] = "Q4_K_M"
elif st.session_state["selected_model"] == "gemma-7b-it":
llm_kwargs["n_gpu_layers"] = 18
llm_kwargs["quantization"] = "f16"
retriever_kwargs = {
"client_kwargs": {
"read_only": True,
},
"top_k": st.session_state["top_k"],
}
client = DeepLakeClient(collection_name="usc", read_only=True)
retriever = Retriever(vectordb=client)
st.session_state["rag"] = BasicRAG(
model_name=st.session_state["selected_model"],
stream=True,
llm_kwargs=llm_kwargs,
retriever=retriever,
retriever_kwargs=retriever_kwargs,
)
st.success(
f"""Model Loaded !!!
Model Name: {st.session_state['selected_model']}
Temperature: {st.session_state['temperature']}
Top-k : {st.session_state['top_k']}"""
)
def clear_cache(self):
"""Clears the cached data within the application."""
st.cache_data.clear()
def render_main(self):
"""Renders the main chat interface for user interaction with the loaded RAG model."""
st.title(":us: US Constitution Expert! :mortar_board:")
if "rag" not in st.session_state:
st.warning("You have not loaded any model")
else:
user_input = st.chat_input("Ask me anything about the US Constitution.")
if user_input:
with st.chat_message("user"):
st.write(user_input)
with st.chat_message("assistant"):
_ = st.write_stream(st.session_state["rag"](user_input)[0])
if st.session_state["show_sources"]:
retrieved_docs = st.session_state["rag"].retriever.get_chunk(
user_input
)
for index, doc in enumerate(retrieved_docs):
with st.expander(f"Source {index + 1}"):
st.markdown(
f"**{index + 1}. {doc.metadata['source']}**"
)
# if st.session_state['show_content']:
st.text(f"**{doc.page_content}**")
def render(self):
"""Orchestrates the rendering of both main and sidebar components of the application."""
self.render_main()
self.render_sidebar()
if __name__ == "__main__":
app = RAGApp(st, conf)
app.render()