RAG 을 활용하여 LLM 만들어보기

Chat History 추가와 Streaming 구현

몽자비루 2025. 8. 17. 14:49

이번에는 create_history_aware_retriever 을 활용해서 chat history 를 받을 수 있도록 구조를 수정해보려고 한다.

 

최종 코드는 최하단에 추가해둘 예정이며, 전체적인 흐름은 아래와 같다.

  • user_question 으로 사용자에게 질문 받기
  • get_dictionary_chain() 및 invoke 로 정규화된 질문  normalized_q 생성
  • RunnableWithMessageHistory(rag_chain) 과
    get_session_history() 으로 세션별 대화 히스토리 자동 관리
  • create_stuff_documents_chain 를 통해 context 문서와 질문을 통해 LLM 답변을 생성
  • LLM답변 중 pick("answer").stream(..) 으 로 streamlit 에 스트리밍 출력.

1. user_question 으로 사용자에게 질문 받기

나는 streamlit 을 구성하는 부분과 LLM을 호출하는 부분을 분리해서 만들었다.
먼저 streamlit 을 통해서 user_question 을 받아온 뒤에 get_dictionary_chain() 으로 질문을 정규화한다.

chat.py
import streamlit as st
from llm_with_History import get_ai_message

st.set_page_config(page_title="Chat Application", page_icon="📈")

st.title("소득세 챗봇")
st.caption("소득세 관련 질문을 해보세요!")

if "messages" not in st.session_state:
    st.session_state.messages = []

# 세션에 저장된 메시지를 반복하고 각 메시지 역할에 따라 메시지 표시
print(f"before: {st.session_state.messages}")
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.write(message["content"])


if user_question := st.chat_input("소득세와 관련된 궁금한 내용들을 말씀해주세요!"):
    # 채팅이 입력되었을 때 상단에 메시지가 노출되도록 함.
    # "user", "assistant", "ai", "human" 가 있는데 그중 user 을 사용함.
    with st.chat_message("user"):
        st.write(user_question)

    # 채팅 메시지를 Session State에 저장
    st.session_state.messages.append({"role": "user", "content": user_question})


    # 로딩중인 부분을 보여주기
    with st.spinner("AI 응답을 생성하는 중..."):
        ai_message = get_ai_message(user_question)

        # with st.chat_message("ai"):
        #     st.write(ai_message)
        #     st.session_state.messages.append({"role": "ai", "content": ai_message})
        with st.chat_message("ai"):
            final_text = st.write_stream(ai_message)
            # 채팅 메시지를 Session State에 저장
    st.session_state.messages.append({"role": "ai", "content": final_text})

2. get_dictionary_chain() 및 invoke 를 통해 정규화된 질문을 normalized_q 변수에 입력하기

llm_with_History.py
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_upstage import UpstageEmbeddings, ChatUpstage
from langchain_pinecone import PineconeVectorStore

# 자동 히스토리 관리용
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from dotenv import load_dotenv
load_dotenv() 

def get_llm():
    llm = ChatUpstage()
    return llm


def get_dictionary_chain():
    llm = get_llm()
    dictionary = ["사람을 나타내는 표현은 모두 거주자로 변경해주세요"]

    prompt = ChatPromptTemplate.from_template(
        f"""
        사용자의 질문을 보고, 우리의 사전을 참고해서 사용자의 질문을 변경해주세요.
        변경할 필요가 없다고 판단되면, 질문 원문을 그대로 반환하세요.

        추가로 초과 금액은 단순하게 사용자가 제시한 금액에서 기준 금액을 뺀 값으로 계산해주세요.

        사전: {dictionary}
        질문: {{question}}
        """
        )

    return prompt | llm | StrOutputParser()


def get_ai_message(user_question: str, session_id: str = "abc123"):
    """
    - 사전 변환 → 히스토리 인지형 RAG 스트리밍
    - 문자열 제너레이터를 반환 (Streamlit의 st.write_stream에 바로 전달 가능)
    """
    # 1) 질문 정규화
    normalized_q = get_dictionary_chain().invoke({"question": user_question})

3. RunnableWithMessageHistory(rag_chain) 과 get_session_history() 으로 히스토리 저장 및 입력

4. create_stuff_documents_chain 를 통해 context 문서와 질문을 통해 LLM 답변을 생성

get_qa_chain() 에서 정규화된 질문을 히스토리 인지형 RAG 파이프라인을 사용하여 조립하는데
흐름은 아래와 같다.

  • get_history_retrieval() 함수를 통해 create_history_aware_retriever 으로
    채팅 기록과 새 질문을 입력받아 단독 질문으로 재작성하고 문서 검색을 수행한다.
  • answer_prompt 를 통해 최종 답변 프롬프트를 생성한다.
  • create_stuff_documents_chain() 으로 문서를 결합한다.
  • create_retriever_chain() 으로 최종 RAG 체인을 조립한다.
  • RunnableWithMessageHistory 를 활용하여 채팅 메시지 기록을 관리한다.
  • 이후 rag_with_history 에 최종 RAG 와 get_session_history 를 활용하여 결과를 받아옴.
llm_with_History.py
# 세션 히스토리 저장소
store = {}


def get_llm():
    llm = ChatUpstage()
    return llm

def get_session_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]

def get_dictionary_chain():
    llm = get_llm()
    dictionary = ["사람을 나타내는 표현은 모두 거주자로 변경해주세요"]

    prompt = ChatPromptTemplate.from_template(
        f"""
        사용자의 질문을 보고, 우리의 사전을 참고해서 사용자의 질문을 변경해주세요.
        변경할 필요가 없다고 판단되면, 질문 원문을 그대로 반환하세요.

        추가로 초과 금액은 단순하게 사용자가 제시한 금액에서 기준 금액을 뺀 값으로 계산해주세요.

        사전: {dictionary}
        질문: {{question}}
        """
        )
    return prompt | llm | StrOutputParser()


def get_history_retriever():
    llm = get_llm()
    retriever = get_retriever()
    
    contextualize_q_system_prompt = (
        "Given a chat history and the latest user question "
        "which might reference context in the chat history, "
        "formulate a standalone question which can be understood "
        "without the chat history. Do NOT answer the question, "
        "just reformulate it if needed and otherwise return it as is."
    )

    contextualize_q_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", contextualize_q_system_prompt),
            MessagesPlaceholder("chat_history"),
            ("human", "{input}"),
        ]
    )
    
    history_aware_retriever = create_history_aware_retriever(
        llm, retriever, contextualize_q_prompt
    )
    return history_aware_retriever


def get_qa_chain():
    llm = get_llm()
    history_aware_retriever = get_history_retriever()

    # 2) 최종 답변 프롬프트 (stuff) – {context} + {input}
    answer_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                "너는 소득세 도우미야. 아래 컨텍스트만 사용해 간결하고 정확히 답해.\n"
                "<context>\n{context}\n</context>",
            ),
            MessagesPlaceholder("chat_history"),
            ("user", "{input}"),
        ]
    )
    combine_docs_chain = create_stuff_documents_chain(llm=llm, prompt=answer_prompt)

    # 3) 최종 RAG 체인  (⚠️ 인자명: combine_docs_chain)
    rag_chain = create_retrieval_chain(
        retriever=history_aware_retriever,
        combine_docs_chain=combine_docs_chain,
    )
    return rag_chain


def get_ai_message(user_question: str, session_id: str = "abc123"):
    """
    - 사전 변환 → 히스토리 인지형 RAG 스트리밍
    - 문자열 제너레이터를 반환 (Streamlit의 st.write_stream에 바로 전달 가능)
    """
    # 1) 질문 정규화
    normalized_q = get_dictionary_chain().invoke({"question": user_question})

    # 2) 체인 + 자동 히스토리 래핑
    rag_chain = get_qa_chain()
    rag_with_history = RunnableWithMessageHistory(
        rag_chain,
        get_session_history,
        input_messages_key="input",
        history_messages_key="chat_history",
        output_messages_key="answer",
    )

5. LLM답변 중 pick("answer").stream(..) 으 로 streamlit 에 스트리밍 출력.

마지막으로 이렇게 완성한 값 중 "answer"을 pick 하여 stream 으로 streamlit 에 retrurn 한다.

llm_with_History.py
def get_ai_message(user_question: str, session_id: str = "abc123"):
    """
    - 사전 변환 → 히스토리 인지형 RAG 스트리밍
    - 문자열 제너레이터를 반환 (Streamlit의 st.write_stream에 바로 전달 가능)
    """
    # 1) 질문 정규화
    normalized_q = get_dictionary_chain().invoke({"question": user_question})

    # 2) 체인 + 자동 히스토리 래핑
    rag_chain = get_qa_chain()
    rag_with_history = RunnableWithMessageHistory(
        rag_chain,
        get_session_history,
        input_messages_key="input",
        history_messages_key="chat_history",
        output_messages_key="answer",
    )

    # 3) 'answer'만 스트리밍
    return rag_with_history.pick("answer").stream(
        {"input": normalized_q},
        config={"configurable": {"session_id": session_id}},
    )

 

아래와 같이 두번째 질문에서 이전의 대화 내용을 참고하여 1억 직장인의 소득세에 대해서 답변하는 것을 볼 수 있다.

 

6. 최종 코드

chat.py

더보기
import streamlit as st
from llm_with_History import get_ai_message

st.set_page_config(page_title="Chat Application", page_icon="📈")

st.title("소득세 챗봇")
st.caption("소득세 관련 질문을 해보세요!")

if "messages" not in st.session_state:
    st.session_state.messages = []

# 세션에 저장된 메시지를 반복하고 각 메시지 역할에 따라 메시지 표시
print(f"before: {st.session_state.messages}")
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.write(message["content"])


if user_question := st.chat_input("소득세와 관련된 궁금한 내용들을 말씀해주세요!"):
    # 채팅이 입력되었을 때 상단에 메시지가 노출되도록 함.
    # "user", "assistant", "ai", "human" 가 있는데 그중 user 을 사용함.
    with st.chat_message("user"):
        st.write(user_question)

    # 채팅 메시지를 Session State에 저장
    st.session_state.messages.append({"role": "user", "content": user_question})


    # 로딩중인 부분을 보여주기
    with st.spinner("AI 응답을 생성하는 중..."):
        ai_message = get_ai_message(user_question)

        # with st.chat_message("ai"):
        #     st.write(ai_message)
        #     st.session_state.messages.append({"role": "ai", "content": ai_message})
        with st.chat_message("ai"):
            final_text = st.write_stream(ai_message)
            # 채팅 메시지를 Session State에 저장
    st.session_state.messages.append({"role": "ai", "content": final_text})

 

llm_with_History.py

더보기
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_upstage import UpstageEmbeddings, ChatUpstage
from langchain_pinecone import PineconeVectorStore

# (선택) 자동 히스토리 관리용
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from dotenv import load_dotenv

load_dotenv() 
# 세션 히스토리 저장소
store = {}


def get_session_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]


def get_retriever():
    embeddings = UpstageEmbeddings(model="solar-embedding-1-large")
    index_name = "tax-index-markdown"
    database = PineconeVectorStore.from_existing_index(
        index_name=index_name,
        embedding=embeddings,
    )
    # 필요 시 k 조정
    retriever = database.as_retriever(search_kwargs={"k": 1})
    return retriever


def get_llm():
    llm = ChatUpstage()
    return llm


def get_dictionary_chain():
    llm = get_llm()
    dictionary = ["사람을 나타내는 표현은 모두 거주자로 변경해주세요"]

    prompt = ChatPromptTemplate.from_template(
        f"""
        사용자의 질문을 보고, 우리의 사전을 참고해서 사용자의 질문을 변경해주세요.
        변경할 필요가 없다고 판단되면, 질문 원문을 그대로 반환하세요.

        추가로 초과 금액은 단순하게 사용자가 제시한 금액에서 기준 금액을 뺀 값으로 계산해주세요.

        사전: {dictionary}
        질문: {{question}}
        """
        )

    return prompt | llm | StrOutputParser()

def get_history_retriever():
    llm = get_llm()
    retriever = get_retriever()
    
    contextualize_q_system_prompt = (
        "Given a chat history and the latest user question "
        "which might reference context in the chat history, "
        "formulate a standalone question which can be understood "
        "without the chat history. Do NOT answer the question, "
        "just reformulate it if needed and otherwise return it as is."
    )

    contextualize_q_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", contextualize_q_system_prompt),
            MessagesPlaceholder("chat_history"),
            ("human", "{input}"),
        ]
    )
    
    history_aware_retriever = create_history_aware_retriever(
        llm, retriever, contextualize_q_prompt
    )
    return history_aware_retriever


def get_qa_chain():
    llm = get_llm()
    history_aware_retriever = get_history_retriever()

    # 2) 최종 답변 프롬프트 (stuff) – {context} + {input}
    answer_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                "너는 소득세 도우미야. 아래 컨텍스트만 사용해 간결하고 정확히 답해.\n"
                "<context>\n{context}\n</context>",
            ),
            MessagesPlaceholder("chat_history"),
            ("user", "{input}"),
        ]
    )
    combine_docs_chain = create_stuff_documents_chain(llm=llm, prompt=answer_prompt)

    # 3) 최종 RAG 체인  (⚠️ 인자명: combine_docs_chain)
    rag_chain = create_retrieval_chain(
        retriever=history_aware_retriever,
        combine_docs_chain=combine_docs_chain,
    )
    return rag_chain


def get_ai_message(user_question: str, session_id: str = "abc123"):
    """
    - 사전 변환 → 히스토리 인지형 RAG 스트리밍
    - 문자열 제너레이터를 반환 (Streamlit의 st.write_stream에 바로 전달 가능)
    """
    # 1) 질문 정규화
    normalized_q = get_dictionary_chain().invoke({"question": user_question})

    # 2) 체인 + 자동 히스토리 래핑
    rag_chain = get_qa_chain()
    rag_with_history = RunnableWithMessageHistory(
        rag_chain,                            # 실제 실행할 체인 (여기서는 RAG 체인)
        get_session_history,                  # 세션별 히스토리 불러오는 함수
        input_messages_key="input",           # 입력 프롬프트에서 사용자가 친 질문이 담기는 키
        history_messages_key="chat_history",  # 대화 히스토리(대화 로그)를 넘길 때 쓰는 키
        output_messages_key="answer",         # 체인 실행 후 나오는 답변이 저장되는 키
    )


    # 3) 'answer'만 스트리밍
    return rag_with_history.pick("answer").stream(
        {"input": normalized_q},
        config={"configurable": {"session_id": session_id}},
    )