//! # Sentiment Classification with BERT //! //! This application performs sentiment analysis on text using a fine-tuned BERT model. //! It loads a pre-trained model from Hugging Face Hub and classifies input text //! into sentiment categories (e.g., positive, negative). /* Tangled strings don't support multiple files yet, so throw the following bit into Cargo.toml: [package] name = "bmini" version = "0.1.0" edition = "2024" authors = ["Nick Gerakines @ngerakines.me"] description = "A sentiment classification tool using BERT models for text analysis" license = "MIT" [dependencies] anyhow = "1.0.100" candle-core = "0.9.1" candle-nn = "0.9.1" candle-transformers = "0.9.1" hf-hub = "0.4.3" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" tokenizers = "0.22.1" */ use anyhow::{Context, Result}; use candle_core::{Device, IndexOp, Tensor}; use candle_nn::{linear, Module, VarBuilder}; use candle_transformers::models::bert::{BertModel, Config}; use hf_hub::{api::sync::Api, Repo, RepoType}; use std::collections::HashMap; use tokenizers::Tokenizer; /// Model configuration constants mod constants { /// Hugging Face model identifier for sentiment classification pub const SENTIMENT_MODEL_ID: &str = "Varnikasiva/sentiment-classification-bert-mini"; /// Tokenizer model identifier (using standard BERT tokenizer) pub const TOKENIZER_MODEL_ID: &str = "google-bert/bert-base-uncased"; /// Default git revision to use pub const DEFAULT_REVISION: &str = "main"; } /// Represents a sentiment classification result #[derive(Debug, Clone)] struct SentimentPrediction { /// The input text that was classified text: String, /// Predicted sentiment label (e.g., "positive", "negative") label: String, /// Confidence score for the prediction (0.0 to 1.0) confidence: f32, /// Probability distribution across all classes class_probabilities: Vec<(String, f32)>, } /// Encapsulates a sentiment classification model with BERT backbone struct SentimentClassifier { /// The underlying BERT model for feature extraction bert_model: BertModel, /// Linear classification head that maps BERT embeddings to sentiment classes classifier: candle_nn::Linear, /// Tokenizer for processing input text tokenizer: Tokenizer, /// Maps class indices to human-readable labels label_map: HashMap, /// Computation device (CPU or CUDA) device: Device, } impl SentimentClassifier { /// Creates a new sentiment classifier by loading model weights from Hugging Face Hub /// /// # Arguments /// * `device` - The device to run computations on (CPU or CUDA) /// /// # Returns /// A configured sentiment classifier ready for inference fn new(device: Device) -> Result { // Initialize Hugging Face API client let api = Api::new()?; // Download model files let (config, weights_path) = Self::download_model_files(&api) .context("Failed to download model files")?; // Download tokenizer let tokenizer = Self::load_tokenizer(&api) .context("Failed to load tokenizer")?; // Parse label mapping from config let label_map = Self::extract_label_mapping(&config) .context("Failed to extract label mapping")?; let num_labels = label_map.len(); // Parse BERT configuration let bert_config: Config = serde_json::from_value(config.clone()) .context("Failed to parse BERT config")?; // Load model weights let (bert_model, classifier) = Self::load_model_weights( weights_path, &bert_config, num_labels, &device, ).context("Failed to load model weights")?; Ok(Self { bert_model, classifier, tokenizer, label_map, device, }) } /// Downloads model configuration and weights from Hugging Face Hub fn download_model_files(api: &Api) -> Result<(serde_json::Value, std::path::PathBuf)> { let repo = api.repo(Repo::with_revision( constants::SENTIMENT_MODEL_ID.to_string(), RepoType::Model, constants::DEFAULT_REVISION.to_string(), )); let config_path = repo.get("config.json") .context("Failed to download config.json")?; let weights_path = repo.get("model.safetensors") .context("Failed to download model weights")?; let config_content = std::fs::read_to_string(&config_path) .context("Failed to read config file")?; let config: serde_json::Value = serde_json::from_str(&config_content) .context("Failed to parse config JSON")?; Ok((config, weights_path)) } /// Loads a compatible tokenizer from Hugging Face Hub fn load_tokenizer(api: &Api) -> Result { let tokenizer_repo = api.repo(Repo::with_revision( constants::TOKENIZER_MODEL_ID.to_string(), RepoType::Model, constants::DEFAULT_REVISION.to_string(), )); let tokenizer_path = tokenizer_repo.get("tokenizer.json") .context("Failed to download tokenizer")?; Tokenizer::from_file(tokenizer_path) .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e)) } /// Extracts the label mapping from model configuration fn extract_label_mapping(config: &serde_json::Value) -> Result> { let mut label_map = HashMap::new(); let id2label = config["id2label"] .as_object() .ok_or_else(|| anyhow::anyhow!("Missing id2label in config"))?; for (id_str, label_value) in id2label { let id = id_str.parse::() .context(format!("Invalid label ID: {}", id_str))?; let label = label_value .as_str() .ok_or_else(|| anyhow::anyhow!("Invalid label value for ID {}", id))? .to_string(); label_map.insert(id, label); } if label_map.is_empty() { anyhow::bail!("No labels found in model configuration"); } Ok(label_map) } /// Loads BERT model and classifier weights from safetensors file fn load_model_weights( weights_path: std::path::PathBuf, config: &Config, num_labels: usize, device: &Device, ) -> Result<(BertModel, candle_nn::Linear)> { // Load weights from safetensors format let weights = candle_core::safetensors::load(&weights_path, device) .context("Failed to load safetensors")?; let vb = VarBuilder::from_tensors(weights, candle_core::DType::F32, device); // Initialize BERT backbone let bert_model = BertModel::load(vb.pp("bert"), config) .context("Failed to load BERT model")?; // Initialize classification head let classifier = linear(config.hidden_size, num_labels, vb.pp("classifier")) .context("Failed to load classifier layer")?; Ok((bert_model, classifier)) } /// Performs sentiment classification on input text /// /// # Arguments /// * `text` - The text to classify /// /// # Returns /// A `SentimentPrediction` containing the classification results pub fn predict(&self, text: &str) -> Result { // Tokenize input text let encoding = self.tokenizer .encode(text, true) .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?; // Convert tokens to tensors let input_ids = Tensor::new(encoding.get_ids(), &self.device)? .unsqueeze(0)?; // Add batch dimension let attention_mask = Tensor::new(encoding.get_attention_mask(), &self.device)? .unsqueeze(0)?; // Add batch dimension // Run BERT forward pass let bert_output = self.bert_model .forward(&input_ids, &attention_mask, None) .context("BERT forward pass failed")?; // Extract pooled output ([CLS] token representation) let cls_embedding = bert_output .i((0, 0))? // First sequence, first token .unsqueeze(0)?; // Add batch dimension back // Apply classification head let logits = self.classifier .forward(&cls_embedding) .context("Classification forward pass failed")?; // Convert logits to probabilities using softmax let probabilities = candle_nn::ops::softmax(&logits, 1)?; let prob_values: Vec = probabilities.squeeze(0)?.to_vec1()?; // Find predicted class (highest probability) let (predicted_idx, &confidence) = prob_values .iter() .enumerate() .max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) .ok_or_else(|| anyhow::anyhow!("Failed to find max probability"))?; // Get predicted label let label = self.label_map .get(&predicted_idx) .cloned() .unwrap_or_else(|| format!("class_{}", predicted_idx)); // Build probability distribution for all classes let class_probabilities: Vec<(String, f32)> = prob_values .iter() .enumerate() .map(|(idx, &prob)| { let class_label = self.label_map .get(&idx) .cloned() .unwrap_or_else(|| format!("class_{}", idx)); (class_label, prob) }) .collect(); Ok(SentimentPrediction { text: text.to_string(), label, confidence, class_probabilities, }) } } /// Formats and displays a sentiment prediction result fn display_prediction(prediction: &SentimentPrediction) { println!("Text: \"{}\"", prediction.text); println!( "Predicted: {} (confidence: {:.1}%)", prediction.label, prediction.confidence * 100.0 ); println!("Probabilities:"); for (label, probability) in &prediction.class_probabilities { println!(" {}: {:.1}%", label, probability * 100.0); } } /// Main entry point for the sentiment classification application fn main() -> Result<()> { // Initialize computation device (CPU by default, can be modified for GPU) let device = Device::Cpu; println!("Initializing sentiment classifier on {:?}...", device); // Load the sentiment classification model let classifier = SentimentClassifier::new(device) .context("Failed to initialize sentiment classifier")?; println!("Model loaded successfully!\n"); // Define test sentences for demonstration let test_sentences = [ "Nice", "Very cool", "Great work", "I love this!", "I hate this!", ]; // Run predictions and display results println!("Sentiment Classification Results:"); println!("{}", "=".repeat(60)); for sentence in &test_sentences { let prediction = classifier.predict(sentence) .context(format!("Failed to classify: {}", sentence))?; display_prediction(&prediction); println!("{}", "-".repeat(60)); } Ok(()) }