Transformers Stream Generator

0.0.5 · active · verified Thu Apr 16

This is a text generation method which returns a generator, streaming out each token in real-time during inference, based on Huggingface/Transformers. It provides a simple way to enable token-by-token streaming for Hugging Face `transformers` models, often used for large language models (LLMs). The library is currently at version 0.0.5 and appears to be in an early development stage with updates released as features or fixes are integrated.

Common errors

Warnings

Install

Imports

Quickstart

This example demonstrates how to set up and use `transformers-stream-generator` with a Hugging Face model. First, `init_stream_support()` is called to patch the generation methods. Then, `model.generate()` is called with `do_stream=True` and `do_sample=True` (and usually `num_beams=1`) to get a generator that yields tokens in real-time.

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_stream_generator import init_stream_support
import os

# Initialize streaming support
init_stream_support()

# Load model and tokenizer (e.g., a small GPT-2 for demonstration)
# Replace with your desired model
model_name = os.environ.get('TRANSFORMERS_MODEL', 'gpt2')
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Encode input
input_text = "Hello, I am a language model and I can"
input_ids = tokenizer.encode(input_text, return_tensors='pt')

# Generate text with streaming enabled
# do_stream=True requires do_sample=True and typically num_beams=1
print(f"Generating with {model_name} in streaming mode...")
generator = model.generate(
    input_ids,
    max_new_tokens=50,
    do_stream=True,
    do_sample=True, # Required for do_stream=True
    temperature=0.7,
    top_k=50,
    top_p=0.95,
    num_beams=1 # Streaming generally works best with num_beams=1
)

# Iterate and print tokens as they are generated
print(input_text, end="")
for token_id in generator:
    word = tokenizer.decode(token_id, skip_special_tokens=True)
    print(word, end="", flush=True)
print("\n\nGeneration complete.")

view raw JSON →