AI 모델 정확도 높이기: Langchain과 Few-shot 학습으로 모델 개선하기

ChatGPT 같은 대형 언어 모델들도 특정 상황에서는 추가적인 학습 데이터가 필요할 때가 있는데, 이를 해결하기 위한 방법이 few-shot 학습이다. Few-shot 학습은 적은 수의 예시만으로도 모델이 새로운 문제에 잘 적응할 수 있게 도와주는 기술이다. 이 글에서는 Python의 Langchain 라이브러리를 사용하여, few-shot 학습을 AI 채팅 모델에 적용하고, 모델의 성능을 높이는 과정을 정리해본다. 같은 질문을 했을때, few-shot 학습전과 후의 AI모델 응답을 비교했는데 의도한대로 잘 나와서 놀랐다.

Few-shot 학습을 하지 않았을 때,

Few-shot 학습하지 않았을때, 어떤 결과가 나오는지 우선 살펴보기로 한다. 먼저 오늘 실행에 필요한 라이브러리를 정의한다.

  • ChatOllama: Ollama 모델을 사용한 채팅 인터페이스.
  • ChatPromptTemplate: 프롬프트를 구성하기 위한 템플릿 생성.
  • FewShotChatMessagePromptTemplate: few-shot 학습을 위한 프롬프트 템플릿.
from langchain_ollama import ChatOllama
from langchain_core.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate

참고로 이 포스팅에서 사용한 langchain 라이브러리 버전은 아래와 같다.

import langchain_core, langchain_ollama
print("langchain_core : " + langchain_core.__version__)
print("langchain_ollama : " + langchain_ollama.__version__)

# 출력결과 
langchain_core : 0.3.9
langchain_ollama : 0.2.0

Ollama 로 설치한 모델중에 gemma2:9b 를 사용할 예정이다. 2b 모델도 사용해봤는데 few shot 학습이 잘 먹히지 않아 9b 로 진행했다. temperature=0 으로 모델이 예측할 때 출력의 다양성을 낮추도록 설정했다. Docker 를 사용하다보니 base_url 정보를 추가로 설정했는데, 생략해도 된다.

llm = ChatOllama(
    model="gemma2:9b",
    temperature=0,
    base_url="http://host.docker.internal:11434"
)

프롬프트 템플릿을 정의한다. 시스템 메시지는 “수학의 놀라운 마법사"라는 컨셉을 제공하고, 사용자의 입력을 받을 준비를 한다.

first_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a wondrous wizard of math."),
        ("human", "{input}")
    ]
)

프롬프트와 모델을 연결한 체인을 생성한다. “What is 8 🍎 9?” 라는 입력을 모델에 전달한다. 🍎 에 대한 정의가 전혀 없으므로, 어떻게 추론할지 궁금해진다.

chain1 = first_prompt | llm
ai_msg1 = chain1.invoke({"input": "What is 8 🍎 9?"})

별다른 학습이 없다보니 LLM 모델에서는 마음대로 “8 🍎 9"에서 “🍎“는 더하기를 의미한다고 생각하고 있고, 8 + 9 = 17 이라는 답변을 생성했다.

ai_msg1.content

# 출력결과
'Ah, a delightful riddle!  \n\nIn the realm of numbers, "🍎" usually means addition. So, 8 🍎 9 is simply:\n\n8 + 9 = 17 \n\n\nLet me know if you have any other magical mathematical puzzles for me! ✨🧮✨'

Few-shot 학습을 시켜보자.

모델에게 제공할 few-shot 학습 예시를 설정한다. 3개의 예시를 만들었고, “🍎“를 곱셈으로 해석한 결과를 제공했다.

examples = [
    {"input": "2 🍎 4", "output": "8"},
    {"input": "3 🍎 5", "output": "15"},
    {"input": "7 🍎 3", "output": "21"}
]

예시 입력과 출력을 포함하는 프롬프트 템플릿을 생성한다.

example_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "{input}"),
        ("ai", "{output}")
    ]
)

앞서 설정한 예시를 기반으로 few-shot 학습 프롬프트를 구성한다.

few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    examples=examples
)

시스템 메시지와 few-shot 예시를 결합하여 최종 프롬프트를 생성한다.

final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a wondrous wizard of math."),
        few_shot_prompt,
        ("human", "{input}")
    ]
)

최종 프롬프트와 모델을 연결하여 새로운 체인을 생성하고, 입력을 통해 결과를 받았다.

chain2 = final_prompt | llm
ai_msg2 = chain2.invoke({"input": "What is 8 🍎 9?"})

그 결과는? 모델은 “🍎“를 곱셈으로 잘 해석하였고, 8 * 9 = 72 라는 답변을 반환했다.

ai_msg2

# 출력결과
'72 \n\nRemember, "🍎" means multiplication!  😊  \n',

더 보면 좋을 글들