.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/RAG-GUI/app.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_RAG-GUI_app.py: RAG-GUI ======= A cookbook demonstrating how to run a RAG app on streamlit. .. GENERATED FROM PYTHON SOURCE LINES 7-201 .. code-block:: Python import os import sys from pathlib import Path import streamlit as st from grag.components.multivec_retriever import Retriever from grag.components.utils import get_config from grag.components.vectordb.deeplake_client import DeepLakeClient from grag.rag.basic_rag import BasicRAG sys.path.insert(1, str(Path(os.getcwd()).parents[1])) st.set_page_config( page_title="GRAG", menu_items={ "Get Help": "https://github.com/arjbingly/Capstone_5", "About": "This is a simple GUI for GRAG", }, ) def spinner(text): """Decorator that displays a loading spinner with a custom text message during the execution of a function. This decorator wraps any function to show a spinner using Streamlit's st.spinner during the function call, indicating that an operation is in progress. The spinner is displayed with a user-defined text message. Args: text (str): The message to display next to the spinner. Returns: function: A decorator that takes a function and wraps it in a spinner context. """ def _spinner(func): """A decorator function that takes another function and wraps it to show a spinner during its execution. Args: func (function): The function to wrap. Returns: function: The wrapped function with a spinner displayed during its execution. """ def wrapper_func(*args, **kwargs): """The wrapper function that actually executes the wrapped function within the spinner context. Args: *args: Positional arguments passed to the wrapped function. **kwargs: Keyword arguments passed to the wrapped function. """ with st.spinner(text=text): func(*args, **kwargs) return wrapper_func return _spinner @st.cache_data def load_config(): """Loads config.""" return get_config() conf = load_config() class RAGApp: """Application class to manage a Retrieval-Augmented Generation (RAG) model interface. Attributes: app: The main application or server instance hosting the RAG model. conf: Configuration settings or parameters for the application. """ def __init__(self, app, conf): """Initializes the RAGApp with a given application and configuration. Args: app: The main application or framework instance that this class will interact with. conf: A configuration object or dictionary containing settings for the application. """ self.app = app self.conf = conf def render_sidebar(self): """Renders the sidebar in the application interface with model selection and parameters.""" with st.sidebar: st.title("GRAG") st.subheader("Models and parameters") st.sidebar.selectbox( "Choose a model", [ "Llama-2-13b-chat", "Llama-2-7b-chat", "Mixtral-8x7B-Instruct-v0.1", "gemma-7b-it", ], key="selected_model", ) st.sidebar.slider( "Temperature", min_value=0.1, max_value=1.0, value=0.1, step=0.1, key="temperature", ) st.sidebar.slider( "Top-k", min_value=1, max_value=5, value=3, step=1, key="top_k" ) st.button("Load Model", on_click=self.load_rag) st.checkbox("Show sources", key="show_sources") @spinner(text="Loading model...") def load_rag(self): """Loads the specified RAG model based on the user's selection and settings in the sidebar.""" if "rag" in st.session_state: del st.session_state["rag"] llm_kwargs = { "temperature": st.session_state["temperature"], } if st.session_state["selected_model"] == "Mixtral-8x7B-Instruct-v0.1": llm_kwargs["n_gpu_layers"] = 16 llm_kwargs["quantization"] = "Q4_K_M" elif st.session_state["selected_model"] == "gemma-7b-it": llm_kwargs["n_gpu_layers"] = 18 llm_kwargs["quantization"] = "f16" retriever_kwargs = { "client_kwargs": { "read_only": True, }, "top_k": st.session_state["top_k"], } client = DeepLakeClient(collection_name="usc", read_only=True) retriever = Retriever(vectordb=client) st.session_state["rag"] = BasicRAG( model_name=st.session_state["selected_model"], stream=True, llm_kwargs=llm_kwargs, retriever=retriever, retriever_kwargs=retriever_kwargs, ) st.success( f"""Model Loaded !!! Model Name: {st.session_state['selected_model']} Temperature: {st.session_state['temperature']} Top-k : {st.session_state['top_k']}""" ) def clear_cache(self): """Clears the cached data within the application.""" st.cache_data.clear() def render_main(self): """Renders the main chat interface for user interaction with the loaded RAG model.""" st.title(":us: US Constitution Expert! :mortar_board:") if "rag" not in st.session_state: st.warning("You have not loaded any model") else: user_input = st.chat_input("Ask me anything about the US Constitution.") if user_input: with st.chat_message("user"): st.write(user_input) with st.chat_message("assistant"): _ = st.write_stream(st.session_state["rag"](user_input)[0]) if st.session_state["show_sources"]: retrieved_docs = st.session_state["rag"].retriever.get_chunk( user_input ) for index, doc in enumerate(retrieved_docs): with st.expander(f"Source {index + 1}"): st.markdown( f"**{index + 1}. {doc.metadata['source']}**" ) # if st.session_state['show_content']: st.text(f"**{doc.page_content}**") def render(self): """Orchestrates the rendering of both main and sidebar components of the application.""" self.render_main() self.render_sidebar() if __name__ == "__main__": app = RAGApp(st, conf) app.render() .. _sphx_glr_download_auto_examples_RAG-GUI_app.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: app.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: app.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_