src/main.rs
331 lines 11 kB view raw
1//! # Sentiment Classification with BERT 2//! 3//! This application performs sentiment analysis on text using a fine-tuned BERT model. 4//! It loads a pre-trained model from Hugging Face Hub and classifies input text 5//! into sentiment categories (e.g., positive, negative). 6 7/* 8Tangled strings don't support multiple files yet, so throw the following bit into Cargo.toml: 9 10[package] 11name = "bmini" 12version = "0.1.0" 13edition = "2024" 14authors = ["Nick Gerakines @ngerakines.me"] 15description = "A sentiment classification tool using BERT models for text analysis" 16license = "MIT" 17 18[dependencies] 19anyhow = "1.0.100" 20candle-core = "0.9.1" 21candle-nn = "0.9.1" 22candle-transformers = "0.9.1" 23hf-hub = "0.4.3" 24serde = { version = "1.0", features = ["derive"] } 25serde_json = "1.0" 26tokenizers = "0.22.1" 27*/ 28 29use anyhow::{Context, Result}; 30use candle_core::{Device, IndexOp, Tensor}; 31use candle_nn::{linear, Module, VarBuilder}; 32use candle_transformers::models::bert::{BertModel, Config}; 33use hf_hub::{api::sync::Api, Repo, RepoType}; 34use std::collections::HashMap; 35use tokenizers::Tokenizer; 36 37/// Model configuration constants 38mod constants { 39 /// Hugging Face model identifier for sentiment classification 40 pub const SENTIMENT_MODEL_ID: &str = "Varnikasiva/sentiment-classification-bert-mini"; 41 42 /// Tokenizer model identifier (using standard BERT tokenizer) 43 pub const TOKENIZER_MODEL_ID: &str = "google-bert/bert-base-uncased"; 44 45 /// Default git revision to use 46 pub const DEFAULT_REVISION: &str = "main"; 47} 48 49/// Represents a sentiment classification result 50#[derive(Debug, Clone)] 51struct SentimentPrediction { 52 /// The input text that was classified 53 text: String, 54 /// Predicted sentiment label (e.g., "positive", "negative") 55 label: String, 56 /// Confidence score for the prediction (0.0 to 1.0) 57 confidence: f32, 58 /// Probability distribution across all classes 59 class_probabilities: Vec<(String, f32)>, 60} 61 62/// Encapsulates a sentiment classification model with BERT backbone 63struct SentimentClassifier { 64 /// The underlying BERT model for feature extraction 65 bert_model: BertModel, 66 /// Linear classification head that maps BERT embeddings to sentiment classes 67 classifier: candle_nn::Linear, 68 /// Tokenizer for processing input text 69 tokenizer: Tokenizer, 70 /// Maps class indices to human-readable labels 71 label_map: HashMap<usize, String>, 72 /// Computation device (CPU or CUDA) 73 device: Device, 74} 75 76impl SentimentClassifier { 77 /// Creates a new sentiment classifier by loading model weights from Hugging Face Hub 78 /// 79 /// # Arguments 80 /// * `device` - The device to run computations on (CPU or CUDA) 81 /// 82 /// # Returns 83 /// A configured sentiment classifier ready for inference 84 fn new(device: Device) -> Result<Self> { 85 // Initialize Hugging Face API client 86 let api = Api::new()?; 87 88 // Download model files 89 let (config, weights_path) = Self::download_model_files(&api) 90 .context("Failed to download model files")?; 91 92 // Download tokenizer 93 let tokenizer = Self::load_tokenizer(&api) 94 .context("Failed to load tokenizer")?; 95 96 // Parse label mapping from config 97 let label_map = Self::extract_label_mapping(&config) 98 .context("Failed to extract label mapping")?; 99 100 let num_labels = label_map.len(); 101 102 // Parse BERT configuration 103 let bert_config: Config = serde_json::from_value(config.clone()) 104 .context("Failed to parse BERT config")?; 105 106 // Load model weights 107 let (bert_model, classifier) = Self::load_model_weights( 108 weights_path, 109 &bert_config, 110 num_labels, 111 &device, 112 ).context("Failed to load model weights")?; 113 114 Ok(Self { 115 bert_model, 116 classifier, 117 tokenizer, 118 label_map, 119 device, 120 }) 121 } 122 123 /// Downloads model configuration and weights from Hugging Face Hub 124 fn download_model_files(api: &Api) -> Result<(serde_json::Value, std::path::PathBuf)> { 125 let repo = api.repo(Repo::with_revision( 126 constants::SENTIMENT_MODEL_ID.to_string(), 127 RepoType::Model, 128 constants::DEFAULT_REVISION.to_string(), 129 )); 130 131 let config_path = repo.get("config.json") 132 .context("Failed to download config.json")?; 133 let weights_path = repo.get("model.safetensors") 134 .context("Failed to download model weights")?; 135 136 let config_content = std::fs::read_to_string(&config_path) 137 .context("Failed to read config file")?; 138 let config: serde_json::Value = serde_json::from_str(&config_content) 139 .context("Failed to parse config JSON")?; 140 141 Ok((config, weights_path)) 142 } 143 144 /// Loads a compatible tokenizer from Hugging Face Hub 145 fn load_tokenizer(api: &Api) -> Result<Tokenizer> { 146 let tokenizer_repo = api.repo(Repo::with_revision( 147 constants::TOKENIZER_MODEL_ID.to_string(), 148 RepoType::Model, 149 constants::DEFAULT_REVISION.to_string(), 150 )); 151 152 let tokenizer_path = tokenizer_repo.get("tokenizer.json") 153 .context("Failed to download tokenizer")?; 154 155 Tokenizer::from_file(tokenizer_path) 156 .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e)) 157 } 158 159 /// Extracts the label mapping from model configuration 160 fn extract_label_mapping(config: &serde_json::Value) -> Result<HashMap<usize, String>> { 161 let mut label_map = HashMap::new(); 162 163 let id2label = config["id2label"] 164 .as_object() 165 .ok_or_else(|| anyhow::anyhow!("Missing id2label in config"))?; 166 167 for (id_str, label_value) in id2label { 168 let id = id_str.parse::<usize>() 169 .context(format!("Invalid label ID: {}", id_str))?; 170 let label = label_value 171 .as_str() 172 .ok_or_else(|| anyhow::anyhow!("Invalid label value for ID {}", id))? 173 .to_string(); 174 175 label_map.insert(id, label); 176 } 177 178 if label_map.is_empty() { 179 anyhow::bail!("No labels found in model configuration"); 180 } 181 182 Ok(label_map) 183 } 184 185 /// Loads BERT model and classifier weights from safetensors file 186 fn load_model_weights( 187 weights_path: std::path::PathBuf, 188 config: &Config, 189 num_labels: usize, 190 device: &Device, 191 ) -> Result<(BertModel, candle_nn::Linear)> { 192 // Load weights from safetensors format 193 let weights = candle_core::safetensors::load(&weights_path, device) 194 .context("Failed to load safetensors")?; 195 196 let vb = VarBuilder::from_tensors(weights, candle_core::DType::F32, device); 197 198 // Initialize BERT backbone 199 let bert_model = BertModel::load(vb.pp("bert"), config) 200 .context("Failed to load BERT model")?; 201 202 // Initialize classification head 203 let classifier = linear(config.hidden_size, num_labels, vb.pp("classifier")) 204 .context("Failed to load classifier layer")?; 205 206 Ok((bert_model, classifier)) 207 } 208 209 /// Performs sentiment classification on input text 210 /// 211 /// # Arguments 212 /// * `text` - The text to classify 213 /// 214 /// # Returns 215 /// A `SentimentPrediction` containing the classification results 216 pub fn predict(&self, text: &str) -> Result<SentimentPrediction> { 217 // Tokenize input text 218 let encoding = self.tokenizer 219 .encode(text, true) 220 .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?; 221 222 // Convert tokens to tensors 223 let input_ids = Tensor::new(encoding.get_ids(), &self.device)? 224 .unsqueeze(0)?; // Add batch dimension 225 226 let attention_mask = Tensor::new(encoding.get_attention_mask(), &self.device)? 227 .unsqueeze(0)?; // Add batch dimension 228 229 // Run BERT forward pass 230 let bert_output = self.bert_model 231 .forward(&input_ids, &attention_mask, None) 232 .context("BERT forward pass failed")?; 233 234 // Extract pooled output ([CLS] token representation) 235 let cls_embedding = bert_output 236 .i((0, 0))? // First sequence, first token 237 .unsqueeze(0)?; // Add batch dimension back 238 239 // Apply classification head 240 let logits = self.classifier 241 .forward(&cls_embedding) 242 .context("Classification forward pass failed")?; 243 244 // Convert logits to probabilities using softmax 245 let probabilities = candle_nn::ops::softmax(&logits, 1)?; 246 let prob_values: Vec<f32> = probabilities.squeeze(0)?.to_vec1()?; 247 248 // Find predicted class (highest probability) 249 let (predicted_idx, &confidence) = prob_values 250 .iter() 251 .enumerate() 252 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) 253 .ok_or_else(|| anyhow::anyhow!("Failed to find max probability"))?; 254 255 // Get predicted label 256 let label = self.label_map 257 .get(&predicted_idx) 258 .cloned() 259 .unwrap_or_else(|| format!("class_{}", predicted_idx)); 260 261 // Build probability distribution for all classes 262 let class_probabilities: Vec<(String, f32)> = prob_values 263 .iter() 264 .enumerate() 265 .map(|(idx, &prob)| { 266 let class_label = self.label_map 267 .get(&idx) 268 .cloned() 269 .unwrap_or_else(|| format!("class_{}", idx)); 270 (class_label, prob) 271 }) 272 .collect(); 273 274 Ok(SentimentPrediction { 275 text: text.to_string(), 276 label, 277 confidence, 278 class_probabilities, 279 }) 280 } 281} 282 283/// Formats and displays a sentiment prediction result 284fn display_prediction(prediction: &SentimentPrediction) { 285 println!("Text: \"{}\"", prediction.text); 286 println!( 287 "Predicted: {} (confidence: {:.1}%)", 288 prediction.label, 289 prediction.confidence * 100.0 290 ); 291 println!("Probabilities:"); 292 for (label, probability) in &prediction.class_probabilities { 293 println!(" {}: {:.1}%", label, probability * 100.0); 294 } 295} 296 297/// Main entry point for the sentiment classification application 298fn main() -> Result<()> { 299 // Initialize computation device (CPU by default, can be modified for GPU) 300 let device = Device::Cpu; 301 println!("Initializing sentiment classifier on {:?}...", device); 302 303 // Load the sentiment classification model 304 let classifier = SentimentClassifier::new(device) 305 .context("Failed to initialize sentiment classifier")?; 306 307 println!("Model loaded successfully!\n"); 308 309 // Define test sentences for demonstration 310 let test_sentences = [ 311 "Nice", 312 "Very cool", 313 "Great work", 314 "I love this!", 315 "I hate this!", 316 ]; 317 318 // Run predictions and display results 319 println!("Sentiment Classification Results:"); 320 println!("{}", "=".repeat(60)); 321 322 for sentence in &test_sentences { 323 let prediction = classifier.predict(sentence) 324 .context(format!("Failed to classify: {}", sentence))?; 325 326 display_prediction(&prediction); 327 println!("{}", "-".repeat(60)); 328 } 329 330 Ok(()) 331}