ChatGPT 같은 대형 언어 모델들도 특정 상황에서는 추가적인 학습 데이터가 필요할 때가 있는데, 이를 해결하기 위한 방법이 few-shot 학습이다. Few-shot 학습은 적은 수의 예시만으로도 모델이 새로운 문제에 잘 적응할 수 있게 도와주는 기술이다. 이 글에서는 Python의 Langchain 라이브러리를 사용하여, few-shot 학습을 AI 채팅 모델에 적용하고, 모델의 성능을 높이는 과정을 정리해본다. 같은 질문을 했을때, few-shot 학습전과 후의 AI모델 응답을 비교했는데 의도한대로 잘 나와서 놀랐다.
Few-shot 학습을 하지 않았을 때,
Few-shot 학습하지 않았을때, 어떤 결과가 나오는지 우선 살펴보기로 한다. 먼저 오늘 실행에 필요한 라이브러리를 정의한다.
ChatOllama
: Ollama 모델을 사용한 채팅 인터페이스.ChatPromptTemplate
: 프롬프트를 구성하기 위한 템플릿 생성.FewShotChatMessagePromptTemplate
: few-shot 학습을 위한 프롬프트 템플릿.
from langchain_ollama import ChatOllama
from langchain_core.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate
참고로 이 포스팅에서 사용한 langchain 라이브러리 버전은 아래와 같다.
import langchain_core, langchain_ollama
print("langchain_core : " + langchain_core.__version__)
print("langchain_ollama : " + langchain_ollama.__version__)
# 출력결과
langchain_core : 0.3.9
langchain_ollama : 0.2.0
Ollama 로 설치한 모델중에 gemma2:9b
를 사용할 예정이다. 2b 모델도 사용해봤는데 few shot 학습이 잘 먹히지 않아 9b 로 진행했다. temperature=0
으로 모델이 예측할 때 출력의 다양성을 낮추도록 설정했다. Docker 를 사용하다보니 base_url
정보를 추가로 설정했는데, 생략해도 된다.
llm = ChatOllama(
model="gemma2:9b",
temperature=0,
base_url="http://host.docker.internal:11434"
)
프롬프트 템플릿을 정의한다. 시스템 메시지는 “수학의 놀라운 마법사"라는 컨셉을 제공하고, 사용자의 입력을 받을 준비를 한다.
first_prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a wondrous wizard of math."),
("human", "{input}")
]
)
프롬프트와 모델을 연결한 체인을 생성한다. “What is 8 🍎 9?” 라는 입력을 모델에 전달한다. 🍎 에 대한 정의가 전혀 없으므로, 어떻게 추론할지 궁금해진다.
chain1 = first_prompt | llm
ai_msg1 = chain1.invoke({"input": "What is 8 🍎 9?"})
별다른 학습이 없다보니 LLM 모델에서는 마음대로 “8 🍎 9"에서 “🍎“는 더하기를 의미한다고 생각하고 있고, 8 + 9 = 17 이라는 답변을 생성했다.
ai_msg1.content
# 출력결과
'Ah, a delightful riddle! \n\nIn the realm of numbers, "🍎" usually means addition. So, 8 🍎 9 is simply:\n\n8 + 9 = 17 \n\n\nLet me know if you have any other magical mathematical puzzles for me! ✨🧮✨'
Few-shot 학습을 시켜보자.
모델에게 제공할 few-shot 학습 예시를 설정한다. 3개의 예시를 만들었고, “🍎“를 곱셈으로 해석한 결과를 제공했다.
examples = [
{"input": "2 🍎 4", "output": "8"},
{"input": "3 🍎 5", "output": "15"},
{"input": "7 🍎 3", "output": "21"}
]
예시 입력과 출력을 포함하는 프롬프트 템플릿을 생성한다.
example_prompt = ChatPromptTemplate.from_messages(
[
("human", "{input}"),
("ai", "{output}")
]
)
앞서 설정한 예시를 기반으로 few-shot 학습 프롬프트를 구성한다.
few_shot_prompt = FewShotChatMessagePromptTemplate(
example_prompt=example_prompt,
examples=examples
)
시스템 메시지와 few-shot 예시를 결합하여 최종 프롬프트를 생성한다.
final_prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a wondrous wizard of math."),
few_shot_prompt,
("human", "{input}")
]
)
최종 프롬프트와 모델을 연결하여 새로운 체인을 생성하고, 입력을 통해 결과를 받았다.
chain2 = final_prompt | llm
ai_msg2 = chain2.invoke({"input": "What is 8 🍎 9?"})
그 결과는? 모델은 “🍎“를 곱셈으로 잘 해석하였고, 8 * 9 = 72 라는 답변을 반환했다.
ai_msg2
# 출력결과
'72 \n\nRemember, "🍎" means multiplication! 😊 \n',