LM Format Enforcer

0.11.3 · active · verified Thu Apr 09

LM Format Enforcer is a Python library designed to constrain the output of large language models (LLMs) to specific formats like JSON Schema or Regular Expressions. It integrates with popular LLM frameworks such as Hugging Face Transformers and vLLM. The current version is 0.11.3, and it typically releases minor updates frequently to support new integrations or fix compatibility issues.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to enforce a JSON Schema output using a Hugging Face Transformers model. It initializes a tokenizer and model, defines a JSON schema, creates a `JsonSchemaParser`, and then uses `build_transformers_prefix_allowed_tokens_fn` to generate text that strictly adheres to the defined format.

from transformers import AutoTokenizer, AutoModelForCausalLM
from lm_format_enforcer.json_schema_parser import JsonSchemaParser
from lm_format_enforcer.integrations.transformers import build_transformers_prefix_allowed_tokens_fn
import torch

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

# Define the JSON schema
json_schema = {
    "type": "object",
    "properties": {
        "name": {"type": "string"},
        "age": {"type": "integer", "minimum": 0},
        "isStudent": {"type": "boolean"}
    },
    "required": ["name", "age", "isStudent"]
}

# Create the parser
json_parser = JsonSchemaParser(json_schema)

# Build the prefix_allowed_tokens_fn for transformers integration
prefix_allowed_tokens_fn = build_transformers_prefix_allowed_tokens_fn(tokenizer, json_parser)

prompt = "Please generate a JSON object describing a person with name, age, and student status:\n"

# Encode the prompt
input_ids = tokenizer.encode(prompt, return_tensors="pt")

# Generate text with format enforcement
# GPT2 might not perfectly follow instructions but the *format* will be enforced.
output = model.generate(
    input_ids,
    max_new_tokens=100,
    prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
    pad_token_id=tokenizer.eos_token_id,
    do_sample=False, # For deterministic generation where possible
    num_beams=1
)

# Decode and print the result
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(generated_text)
# Example output (format enforced):
# {"name": "Alice", "age": 25, "isStudent": true}

view raw JSON →