20 Sep 2024
Introduction
Fine-tuning large models, especially in deep learning, can be computationally expensive and slow. A common challenge when working with large language models (LLMs) or other neural networks is how to adapt or fine-tune these models without updating all the parameters. Fine-tuning large matrices in these models can be resource-intensive due to the sheer size of the weight matrices.
One effective solution to this problem is the LoRA (Low-Rank Adaptation) technique. LoRA allows for a more efficient fine-tuning process by decomposing large weight matrices into smaller, low-rank matrices. This not only reduces the number of trainable parameters but also preserves the integrity of the pretrained model.
In this blog post, we’ll break down the LoRA technique step by step, using a simple Python code example to show how it works. By the end of the article, you’ll have a clear understanding of how LoRA can be used to fine-tune large matrices efficiently.
The Problem with Fine-Tuning Large Models
When you fine-tune a large model, you typically adjust its weight matrices based on new training data. In modern neural networks, these weight matrices can be massive. For example, a simple layer in a model may have a weight matrix of size \(W\), where \(W\) can be hundreds or even thousands of dimensions large.
Fine-tuning this large matrix directly has two major challenges:
-
Memory Constraints: The size of these matrices means that storing and updating them can be expensive in terms of memory.
-
Overfitting Risk: If you update all parameters, especially when your new training dataset is small, you risk overfitting to the new data and losing the general knowledge from the pretrained model.
Enter LoRA: Fine-Tuning with Low-Rank Matrices
LoRA provides an elegant solution to these challenges by freezing the original weight matrix \(W\) and introducing two smaller matrices, \(A\) and \(B\), which have a much lower rank. The idea is that instead of updating \(W\) directly, you approximate its updates using the product of two smaller matrices.
Here’s how LoRA works:
-
Keep the original weight matrix \(W\) fixed. This ensures that the model retains the general knowledge from pretraining.
- Introduce two low-rank matrices:
- \(A\) with dimensions \((d_{\text{out}}, r)\)
- \(B\) with dimensions \((r, d_{\text{in}})\)
Where \(r\) is a small integer that represents the rank. \(r\) is much smaller than either \(d_{\text{out}}\) or \(d_{\text{in}}\), the dimensions of the original weight matrix \(W\).
-
Compute the low-rank update: The update to \(W\) is given by the product of these two matrices:
\(\Delta W = A \times B\)
-
Add the update to the original matrix: The final weight matrix used during training becomes:
\(W_{\text{new}} = W + A \times B\)
This allows the model to adjust its weights for the new task without needing to update all parameters in the original weight matrix.
By learning these two smaller matrices, LoRA drastically reduces the number of trainable parameters, making the fine-tuning process much more efficient.
Step-by-Step LoRA Algorithm
Let’s break down the LoRA algorithm step by step:
Step 1: Initialization
-
Start with a pretrained model: We begin with a model that has already been trained on a large dataset. The weight matrices in this model, like \(W\), are already optimized for general tasks.
-
Choose a low-rank \(r\): The rank \(r\) is a small number that dictates the size of the low-rank matrices \(A\) and \(B\). For example, if \(W\) is a \(6 \times 6\) matrix, we might choose \(r = 2\).
Step 2: LoRA Decomposition
-
Decompose the update: Instead of updating \(W\) directly, we introduce two smaller matrices, \(A\) and \(B\). These matrices are of much lower rank than \(W\), making them much smaller and easier to train.
-
Freeze the original weights: The original matrix \(W\) is frozen during fine-tuning. This means its values are not changed, ensuring that the general knowledge captured during pretraining is preserved.
Step 3: Training
-
Train \(A\) and \(B\): During the fine-tuning process, only the two small matrices, \(A\) and \(B\), are trained. This allows the model to adapt to new data without requiring updates to the entire weight matrix.
-
Apply the low-rank update: The update matrix \(\Delta W = A \times B\) is added to the original matrix \(W\). This small update allows the model to fine-tune itself for the new task.
Step 4: Inference
- Inference with LoRA: At inference time, you can either compute the update \(A \times B\) on the fly or precompute the new matrix \(W_{\text{new}} = W + A \times B\). The latter option is often more efficient, as it allows you to use the updated weight matrix just like a normal layer in the model.
Do We Need \(A\) and \(B\) During Inference?
This brings up an important question: Do we need \(A\) and \(B\) during inference?
The answer depends on whether you want to compute the update matrix \(A \times B\) during inference or precompute it after training.
-
Compute on the fly: If you compute \(A \times B\) during inference, you will need both \(A\) and \(B\) at inference time. This allows you to dynamically adjust the weights as needed.
-
Precompute \(W_{\text{new}}\): Alternatively, you can precompute \(W_{\text{new}} = W + A \times B\) after training is complete. In this case, you no longer need \(A\) and \(B\) during inference, as you’re using the updated weight matrix directly. This method is computationally more efficient during inference.
Python Example: Applying LoRA to a Simple Matrix
Now that we’ve explained how LoRA works, let’s look at a simple Python example that demonstrates the LoRA technique using random matrices. This example shows how to update a large matrix using the product of two smaller, low-rank matrices.
import numpy as np
# Step 1: Initialize the large matrix W (e.g., a 6x6 matrix)
d_out = 6 # Output dimension
d_in = 6 # Input dimension
# Simulate the large matrix W (initialized randomly)
W = np.random.randn(d_out, d_in)
print("Original W matrix:")
print(W)
# Step 2: Initialize low-rank matrices A and B
r = 2 # Rank of the low-rank approximation
# A has dimensions (d_out, r)
A = np.random.randn(d_out, r)
# B has dimensions (r, d_in)
B = np.random.randn(r, d_in)
# Step 3: Compute the update matrix (A * B)
delta_W = np.dot(A, B) # This gives us the low-rank update
# Step 4: Update the original matrix W
W_new = W + delta_W # New matrix after LoRA update
print("\nUpdate matrix (A * B):")
print(delta_W)
print("\nUpdated W matrix:")
print(W_new)
Output:
Original W matrix:
[[ 0.57982682 0.56981892 -0.31648517 0.74758125 -0.31424568 -1.25113497]
[-0.45936964 -1.21938562 0.34957394 0.38277855 0.66722387 0.71315171]
...
Update matrix (A * B):
[[ 0.25153752 -0.31847642 0.27297871 0.16230158 -0.20924836 0.45321087]
[-0.47939022 0.38251002 -0.21547047 -0.27092739 0.32487688 -0.51983357]
...
Updated W matrix:
[[ 0.83136434 0.2513425 -0.04350646 0.90988283 -0.52349405 -0.7979241 ]
[-0.93875986 -0.8368756 0.13410347 0.11185116 0.99210075 0.19331814]
Conclusion
The LoRA technique is a powerful and efficient way to fine-tune large models without needing to update all their parameters. By learning low-rank matrices \(A\) and \(B\), we can significantly reduce the computational and memory costs of fine-tuning while still allowing the model to adapt to new tasks.
05 Sep 2024
Dunder methods (short for “double underscore methods”) are special methods in Python that have double underscores (__
) before and after their names. They define how objects of a class behave when used with built-in operations such as print()
, indexing, comparisons, or mathematical operations. Here’s a breakdown of some common dunder methods and how you can use them to customize your Python objects.
1. __init__
– The Constructor
The __init__
method is the constructor of a class and is called when you create a new instance. This is where you can initialize attributes for the object.
class Person:
def __init__(self, name):
self.name = name
p = Person("Sarah")
print(p.name) # Outputs: Sarah
Explanation: In this example, we define a Person
class with a constructor that takes a name as an argument. When an instance of the Person
class is created, the __init__
method is called to initialize the name
attribute.
2. __str__
– String Representation for Humans
The __str__
method controls how an object is printed or represented as a string using print()
or str()
.
class Person:
def __init__(self, name):
self.name = name
def __str__(self):
return f"Person: {self.name}"
p = Person("Sarah")
print(p) # Outputs: Person: Sarah
Explanation: Here, the __str__
method provides a human-readable representation of the object. When you print the object, it displays "Person: Sarah"
instead of the default representation like <Person object at 0x...>
.
3. __repr__
– String Representation for Developers
The __repr__
method is like __str__
but is meant to provide a detailed, unambiguous string that can be used for debugging.
class Person:
def __init__(self, name):
self.name = name
def __repr__(self):
return f"Person({self.name!r})"
p = Person("Sarah")
print(repr(p)) # Outputs: Person('Sarah')
Explanation: The __repr__
method provides a string that shows exactly how to recreate the object, making it helpful for debugging.
4. __len__
– Define Object Length
If you want to define a length for your object, implement the __len__
method, which is used by the len()
function.
class Group:
def __init__(self, members):
self.members = members
def __len__(self):
return len(self.members)
g = Group(["Sarah", "John", "Alice"])
print(len(g)) # Outputs: 3
Explanation: The __len__
method returns the length of the group, which is the number of members. Here, len(g)
gives us 3.
5. __getitem__
– Access Items with Indexing
The __getitem__
method allows your object to be indexed like a list or dictionary.
class Group:
def __init__(self, members):
self.members = members
def __getitem__(self, index):
return self.members[index]
g = Group(["Sarah", "John", "Alice"])
print(g[1]) # Outputs: John
Explanation: The __getitem__
method defines how the object should behave when accessed with square brackets (e.g., g[1]
).
6. __setitem__
– Set Items with Indexing
The __setitem__
method defines how an object’s item can be updated via indexing.
class Group:
def __init__(self):
self.members = {}
def __setitem__(self, key, value):
self.members[key] = value
g = Group()
g[0] = "Sarah"
print(g.members) # Outputs: {0: 'Sarah'}
Explanation: Here, we can assign values to g
using indexing, and the __setitem__
method updates the members
dictionary.
7. __delitem__
– Delete Items with Indexing
The __delitem__
method allows you to delete an item from the object using del
.
class Group:
def __init__(self, members):
self.members = members
def __delitem__(self, index):
del self.members[index]
g = Group(["Sarah", "John", "Alice"])
del g[1]
print(g.members) # Outputs: ['Sarah', 'Alice']
Explanation: The __delitem__
method is called when you delete an item from the object using del
.
8. __call__
– Make Objects Callable
If you want an object to behave like a function, implement the __call__
method.
class Greet:
def __call__(self, name):
return f"Hello, {name}!"
greet = Greet()
print(greet("Sarah")) # Outputs: Hello, Sarah!
Explanation: Here, the Greet
object can be called like a function because of the __call__
method.
9. __eq__
– Equality Comparison
The __eq__
method defines how two objects should be compared for equality using ==
.
class Point:
def __init__(self, x, y):
self.x = x
self.y = y
def __eq__(self, other):
return self.x == other.x and self.y == other.y
p1 = Point(1, 2)
p2 = Point(1, 2)
print(p1 == p2) # Outputs: True
Explanation: The __eq__
method compares two Point
objects for equality by checking their x
and y
values.
10. __lt__
– Less Than Comparison
The __lt__
method allows you to define how one object should be compared to another using the less-than operator (<
).
class Point:
def __init__(self, x, y):
self.x = x
self.y = y
def __lt__(self, other):
return (self.x + self.y) < (other.x + other.y)
p1 = Point(1, 2)
p2 = Point(2, 3)
print(p1 < p2) # Outputs: True
Explanation: The __lt__
method compares two Point
objects by their combined x
and y
values.
11. __add__
– Define Addition Behavior
The __add__
method allows you to define how objects should behave when added with the +
operator.
class Vector:
def __init__(self, x, y):
self.x = x
self.y = y
def __add__(self, other):
return Vector(self.x + other.x, self.y + other.y)
def __repr__(self):
return f"Vector({self.x}, {self.y})"
v1 = Vector(1, 2)
v2 = Vector(3, 4)
print(v1 + v2) # Outputs: Vector(4, 6)
Explanation: The __add__
method enables the addition of two Vector
objects by adding their x
and y
values together.
12. __enter__
and __exit__
– Context Managers
These methods define the behavior of an object when used in a with
statement, allowing for setup and teardown logic.
class File:
def __init__(self, filename, mode):
self.filename = filename
self.mode = mode
def __enter__(self):
self.file = open(self.filename, self.mode)
return self.file
def __exit__(self, exc_type, exc_val, exc_tb):
self.file.close()
with File("example.txt", "w") as f:
f.write("Hello, World!")
Explanation: The __enter__
method opens the file, and the __exit__
method ensures the file is closed, even if an error occurs.
13. __contains__
– Define in
Keyword Behavior
The __contains__
method defines how the in
keyword works for your object.
class Group:
def __init__(self, members):
self.members = members
def __contains__(self, item):
return item in self.members
g = Group(["Sarah", "John", "Alice"])
print("John" in g) # Outputs: True
Explanation: The __contains__
method allows you to check if an item is in the Group
using the in
keyword.
14. __iter__
– Make Objects Iterable
The __iter__
method makes your object iterable so it can be used in a for
loop.
class Group:
def __init__(self, members):
self.members = members
def __iter__(self):
return iter(self.members)
g = Group(["Sarah", "John", "Alice"])
for member in g:
print(member)
# Outputs: Sarah, John, Alice
Explanation: The __iter__
method allows the Group
object to be used in a loop to iterate over its members.
These are just some of the many dunder methods available in Python. They give you a lot of control over how objects behave and interact with Python’s built-in functionality, making your code more flexible and powerful.
You can execute all the code in this google colab notebook.
05 Sep 2024
Extracting complex data such as one or multiple dates from text can be tedious and time-consuming, especially when dealing with long articles or documents. Fortunately, using language models like Llama, we can automate this process. This blog post will show you how to extract dates from a passage of text using the Llama model with Ollama, and we’ll walk through the code step-by-step.
Step 1: Install and Set Up Ollama
First, you need to install Ollama, a tool that allows you to run large language models, like Llama, locally. Run the following command to install it:
curl -fsSL https://ollama.com/install.sh | sh
After installation, you can run the Llama model locally by executing:
This sets up the environment for using Llama to handle the text extraction task.
Step 2: Define the Date and Dates Classes
Before jumping into the extraction process, we need to structure our data. The Date
class represents individual dates, while the Dates
class holds a list of these dates.
class Date(BaseModel):
"""
Represents a date with day, month, and year.
"""
day: int = Field(description="Day of the month")
month: int = Field(description="Month of the year")
year: int = Field(description="Year")
class Dates(BaseModel):
"""
Represents a list of dates (Dates class).
"""
dates: List[Date] = Field(..., description="List of dates.")
These classes allow us to store each date with the day, month, and year clearly defined.
Next, we create the DatesExtractor
class, which handles the process of extracting dates from the text. The class initializes a prompt, which instructs the Llama model to find and return the dates in a structured JSON format. Here’s how the class is structured:
class DatesExtractor:
"""
Extracts dates from a given passage and returns them in JSON format.
"""
def __init__(self, model_name: str = "llama3.1", temperature: float = 0):
self.prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful AI assistant. Extract dates from the given passage and return them in JSON format.",
),
(
"user",
"Identify the dates from the passage and extract the day, month and year. Return the dates as a JSON object with a 'dates' key containing a list of date objects. Each date object should have 'day', 'month', and 'year' keys. Return the number of month, not name. Here is the text: {passage}",
),
]
)
self.llm = ChatOllama(model=model_name, temperature=temperature)
self.output_parser = JsonOutputParser()
self.chain = self.prompt | self.llm | self.output_parser
In this section, we define the AI’s behavior through a prompt. The AI is instructed to extract dates from the text, ensuring that each date is formatted with day, month (as a number), and year, and returned as a JSON object.
Once the DatesExtractor
class is set up, we define the extract
method, which takes a passage of text as input, processes it through the Llama model, and returns the extracted dates.
def extract(self, passage: str) -> Dates:
try:
response = self.chain.invoke({"passage": passage})
return Dates(**response)
except Exception as e:
logger.error(f"Error in extracting dates: {e}")
return None
The extract
method calls the AI model with the given text and captures the result. It returns the dates in a structured format, or logs an error if something goes wrong during processing.
To see the code in action, let’s extract dates from a passage that contains various date formats:
if __name__ == "__main__":
passage = "The history of the automobile is marked by key milestones. Karl Benz built the first car on 29/01/1886, revolutionizing transportation. Later, on June 8th, 1948, Porsche unveiled its iconic 356 model. In 1964-04-17, Ford introduced the Mustang, forever changing the sports car industry. Electric cars gained momentum, with Tesla's Model S launched on July 22, 2012."
extractor = DatesExtractor()
dates: Dates = extractor.extract(passage)
for date in dates.dates:
print(date)
In this passage, we have dates in different formats such as dd/mm/yyyy
, dd/mmm/yyyy
, and yyyy-mm-dd
. When we run the extractor, the AI will return each of these dates in a standardized JSON format.
The output of this code looks like:
day=29 month=1 year=1886
day=8 month=6 year=1948
day=17 month=4 year=1964
day=22 month=7 year=2012
Conclusion
Using the combination of Ollama and Llama, extracting dates from text is simple and efficient. By structuring the data using Pydantic models and leveraging powerful AI tools, we can easily handle date extraction tasks in various formats. Whether you’re working with historical data, articles, or other documents, this approach allows for a smooth and scalable solution to date extraction.
Code
The full code is available here
13 Mar 2024
Brissy: The Australian Slang Chatbot
Ever wanted to chat like a true Aussie? Now you can with Brissy, the Australian slang chatbot. Powered by ChatGPT
, Brissy provides answers on everything from Australian culture to general queries—all in authentic Aussie slang. Whether you’re curious about Aussie traditions or simply want to experience a casual, fun conversation in local lingo, Brissy has you covered.
This project demonstrates the versatility of ChatGPT
in creating engaging, user-friendly chatbots. Developed with LangChain
, Brissy uses the ChatGPT API to deliver responses in Australian slang. It’s an open-source project, and the code is available on GitHub. The chatbot is hosted on Hugging Face Spaces, making it accessible to anyone interested in Australian slang.
Give it a try Brissy.
TabgpT: A Chrome Extension to Interact with Your Tabs
TabgpT is a Chrome extension that lets you interact with your open browser tabs through natural language queries. Powered by ChatGPT
, TabgpT can answer questions based on the content of your tabs, making it easier than ever to find the information you need without navigating between pages.
Built using LangChain
, TabgpT uses the ChatGPT API to extract information from open tabs. The extension leverages AWS Lambda and ECR in the backend, along with the Retrieval Augmented Generation (RAG) model to answer questions. TabgpT stores data in a Chrome vectorstore database using OpenAI embeddings. When you ask a question, relevant content chunks are retrieved with Maximum Marginal Relevance (MMR) search and passed to OpenAI’s API to generate an accurate response.
The user interface is designed with Twitter Bootstrap and jQuery for a seamless experience. Available on the Chrome Web Store, TabgpT is your go-to tool for browsing more efficiently.
Give it a try TabgpT.
06 Oct 2023
Demo https://qmaruf-talk-to-data.hf.space
Have you ever contemplated the possibility of engaging in a conversation with your data? Imagine conversing with a chatbot that possesses comprehensive knowledge about your dataset. This intriguing problem is the focus of this note, where we explore how to achieve it using the ChatGPT API.
To illustrate this concept, we will employ the first book of the Harry Potter series, “Harry Potter and the Philosopher’s Stone,” as our dataset. Our goal is to engage in a conversation with the content of the book. To facilitate this, we will utilize LangChain, a powerful tool for parsing and interacting with data, in conjunction with the ChatGPT API.
Our first step is to load the relevant data from the book. We will use the TextLoader from LangChain to achieve this:
from langchain.document_loaders import TextLoader
book_txt = 'docs/potter1.txt'
loader = TextLoader(book_txt)
docs = loader.load()
Next, we need to break down the text into manageable chunks. This allows us to work with smaller sections of the book at a time. We define the chunk_size and chunk_overlap parameters for this purpose:
chunk_size = 1000
chunk_overlap = 250
In the next step, we will create a text splitter based on RecursiveCharacterTextSplitter
. This text splitter is ideal for generic content and uses a character parameter list for segmentation. It sequentially applies these characters to divide text into appropriately sized chunks. It ensures that paragraphs, sentences, and words stay together for maximum semantic coherence.
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len
)
splits = text_splitter.split_documents(docs)
Here length_function
measures the length of given chunks. We will use len
as our length_function
for this example.
The split_documents
method will return a list of Document
objects. For example, here is the first Document
object from the list. We can print this content using the print method: print(splits[0])
Document(page_content="Harry Potter and the Sorcerer's Stone\n\n\nCHAPTER ONE\n\nTHE BOY WHO LIVED\n\nMr. and Mrs. Dursley, of number four, Privet Drive, were proud to say\nthat they were perfectly normal, thank you very much. They were the last\npeople you'd expect to be involved in anything strange or mysterious,\nbecause they just didn't hold with such nonsense.\n\nMr. Dursley was the director of a firm called Grunnings, which made\ndrills. He was a big, beefy man with hardly any neck, although he did\nhave a very large mustache. Mrs. Dursley was thin and blonde and had\nnearly twice the usual amount of neck, which came in very useful as she\nspent so much of her time craning over garden fences, spying on the\nneighbors. The Dursleys had a small son called Dudley and in their\nopinion there was no finer boy anywhere.", metadata={'source': 'docs/potter1.txt'})
In this step, we will create a vector database to store the embeddings of the chunks. For any query, we search the vector database and extract the most similar chunks to the query. We will use Chroma vector db for this example.
from langchain.vectorstores import Chroma
from langchain.embeddings.openai import OpenAIEmbeddings
embedding = OpenAIEmbeddings()
persist_directory = 'docs/chroma'
vectordb = Chroma.from_documents(
documents=splits,
embedding=embedding,
persist_directory=persist_directory
)
Here embedding
is the OpenAI embedding function. persist_directory
is the directory to store the embeddings.
We can now search the vector database using vectordb.max_marginal_relevance_search
(MMR) function. MMR returns chunks selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query and diversity among selected documents. It takes the query string and returns the k
most similar chunks. We will use k=5
for this example.
query = "Write the names of all of Harry Potter's teachers."
answers = vectordb.max_marginal_relevance_search(query, k=10)
Here answers
contain the k
most similar chunks to the query.
We can check all the chunks using a for
loop. Here is the first answer from the list.
Professor Flitwick, the Charms teacher, was a tiny little wizard who had to stand on a pile of books to see over his desk. At the start of their first class, he took the roll call, and when he reached Harry's name, he gave an excited squeak and toppled out of sight. Professor McGonagall was again different. Harry had been quite right to think she wasn't a teacher to cross. Strict and clever, she gave them a talking-to the moment they sat down in her first class.
Up to this point, we have created the vector database and searched the database using a query for relevant documents. Now we will use the ChatGPT API to chat with the content of the book. We will use answers
as chat context.
Now using langchain
, we will create a ChatOpenAI
model to interact with the ChatGPT API.
from langchain.chat_models import ChatOpenAI
llm = ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0)
Here model_name
is the name of the model. We will use gpt-3.5-turbo
for this example. temperature
will control the randomness of the chatbot response.
We will also define a retriever
to extract the most similar chunks to the query.
retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={'k': 10, 'fetch_k': 50})
Now we will create a RetrievalQA
chain using the llm
and retriever
.
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever
)
Here qa_chain
combines a retriever and a language model to retrieve relevant documents for a query and answer questions based on those documents.
Let’s check how qa_chain
performs for a query. We will use the same query that we used earlier.
response = qa_chain({"query": query})
print(response['result'])
Professor Flitwick, Professor McGonagall, Professor Sprout, Professor Binns, Professor Snape, Madam Hooch
We can use a custom prompt to tell qa_chain
how we want the answer. Here it is:
from langchain.prompts import PromptTemplate
template = """Use only the following context to answer the question at the end. Always say "thanks for asking!" at the end of the answer. Write a summary of the question and then give the answer.
Context: {context}
Question: {question}
Answer:
{context}
Question: {question}
Answer:"""
qa_chain_prompt = PromptTemplate.from_template(template)
We will now fit the template into the qa_chain
and check the result.
qa_chain = RetrievalQA.from_chain_type(
llm,
retriever=retriever,
chain_type_kwargs={"prompt": qa_chain_prompt}
)
response = qa_chain({"query": query})
print(response['result'])
The names of all of Harry Potter's teachers are Professor Flitwick, Professor McGonagall, Professor Binns, Professor Snape, Madam Hooch, and Hagrid. Thanks for asking!
qa_chain
is able to understand the context of the query and give a reasonable answer.
Until now, qa_chain
has no memory. That means we can’t ask any question based on the previous answer. We will use ConversationBufferMemory
to create a new type of chain that can remember the previous conversation. Let’s define memory
as an instance of ConversationBufferMemory
and use it to create a new chain named ConversationalRetrievalChain
.
from langchain.memory import ConversationBufferMemory
memory = ConversationBufferMemory(
memory_key='chat_history',
return_messages=True
)
Here’s how to create ConversationalRetrievalChain
using memory
, vectordb
, and llm
.
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=vectordb.as_retriever(),
memory=memory
)
We will ask three related questions to qa_chain
and check the result.
q1 = "Write the names of all of Harry Potters Teachers."
q2 = "sort the name of the teacher based on how frequently they are mentioned"
q3 = "tell me more about this professor"
for q in [q1, q2, q3]:
response = qa({'question': q})
print (f'Q: {q}')
print (f'A: {response["answer"]}')
print ('\n')
Q: Write the names of all of Harry Potters Teachers.
A: The names of Harry Potter's teachers mentioned in the given context are:
1. Professor Flitwick (Charms teacher)
2. Professor McGonagall (unknown subject)
3. Professor Sprout (Herbology teacher)
4. Professor Binns (History of Magic teacher)
5. Professor Snape (Potions teacher)
6. Madam Hooch (unknown subject)
7. Professor Quirrell (unknown subject)
Please note that there may be other teachers at Hogwarts that are not mentioned in
this context.
Q: sort the name of the teacher based on how frequently they are mentioned
A: Professor McGonagall is mentioned most frequently in the given context.
Q: tell me more about this professor
A: Professor McGonagall is described as strict and clever. She is a teacher at
Hogwarts School of Witchcraft and Wizardry and teaches Transfiguration, which
is described as complex and dangerous magic. She gives the students a talking-to
in her first class, emphasizing the importance of taking her class seriously.
She is also shown to be observant and recognizes Harry's talent as a Seeker in
Quidditch, recommending him to the Gryffindor team captain. Additionally, she
is a member of the staff and is seen interacting with other professors, such
as Professor Flitwick.
Well, it looks like qa_chain
is able to remember the previous conversation and answer the questions based on the previous conversation.
This is the end of this note. We have seen how to use langchain
to create a chatbot that can chat with the content of a book. We have also seen how to create a chatbot that can remember the previous conversation.