I usually create chat tools for work, but I've always believed it's necessary to have knowledge about unmanned chatbots powered by AI. So recently, I've been experimenting with LangChain.
An article titled "Mitoyo City and Matsuo Lab's Half-Year Journey: Why They 'Gave Up' on ChatGPT for Business Efficiency Despite a 94% Accuracy Rate – Couldn't Trust AI for Garbage Disposal Guidance" mentioned an improvement where they enabled "real-time (streaming) display of responses." Inspired by this, I decided to implement a similar feature.I referenced the following course and added my own code to it.
Implementation Overview
In this implementation, we use LangChain to process and display LLM-based responses in real-time. The main components include:
-
ChatOpenAI
: A class to utilize OpenAI's models. -
ChatPromptTemplate
: A template for formatting chat inputs. -
LLMChain
: A chain for processing using LLM. -
StreamingHandler
: A callback handler for processing responses from LLM in real-time. - Threads and queues: Elements for asynchronous processing and data exchange.
Design
To achieve real-time display, the design includes the following elements:
Asynchronous Processing: Using Python threads, we execute LLM processes asynchronously. This ensures that the UI (not implemented in this case) does not freeze while waiting for responses from LLM, allowing other processes to continue.
Queues for Data Consistency: Using queues for data transfer between threads prevents conflicts when multiple threads access data simultaneously, maintaining data consistency. This is crucial in real-time systems to avoid data inconsistencies and unexpected errors, thereby enhancing the overall system reliability.
Streaming Processing: By obtaining and processing responses from LLM token by token, we provide real-time feedback to users.
Implementation
I'll explain the code (available on GitHub). If you have obtained the OPEN_API_KEY
, you can also check its behavior with Docker.
Basic Setup
Import necessary libraries and call load_dotenv()
to load environment variables. This safely manages external configuration information (OPENAI_API_KEY
).
import logging
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.chains import LLMChain
from langchain.callbacks.base import BaseCallbackHandler
from dotenv import load_dotenv
from queue import Queue
from threading import Thread
load_dotenv()
StreamingHandler
The StreamingHandler
class is a custom handler for processing LLM responses in real-time. It is called whenever new tokens are received from LLM or in case of an error.
-
on_llm_new_token
: Adds new tokens to the queue. -
on_llm_end
: Notifies the queue when processing ends. -
on_llm_error
: Logs errors and notifies the queue when processing ends.
class StreamingHandler(BaseCallbackHandler):
def __init__(self, queue):
self.queue = queue
def on_llm_new_token(self, token, **kwargs):
self.queue.put(token)
def on_llm_end(self, response, **kwargs):
self.queue.put(None)
def on_llm_error(self, error, **kwargs):
logging.error(f"Error in LLM: {error}")
self.queue.put(None)
StreamingChain
The StreamingChain
class is the main class for streaming data from LLM. It uses threads and queues to process LLM responses in real-time.
-
stream
method: Initiates LLM based on input and starts the result-generating process, which runs on a separate thread. The main thread continues to retrieve tokens from the queue. -
cleanup
method: Waits for the thread to finish if it is still running after streaming ends.
class StreamingChain:
def __init__(self, llm, prompt):
self.llm_chain = LLMChain(llm=llm, prompt=prompt)
self.thread = None
def stream(self, input):
queue = Queue()
handler = StreamingHandler(queue)
def task():
self.llm_chain(input, callbacks=[handler])
self.thread = Thread(target=task)
self.thread.start()
try:
while True:
token = queue.get()
if token is None:
break
yield token
finally:
self.cleanup()
def cleanup(self):
if self.thread and self.thread.is_alive():
self.thread.join()
Usage Example
An example of using StreamingChain
to obtain and display LLM responses in real-time based on user input. Here, we show streaming responses for the user input "Explain Pokémon in 100 characters."
chain = StreamingChain(llm=chat, prompt=prompt)
for output in chain.stream(input={"content": "Explain Pokémon in 100 characters."}):
print(output)
Top comments (0)