src/main.rs
1//! # Sentiment Classification with BERT
2//!
3//! This application performs sentiment analysis on text using a fine-tuned BERT model.
4//! It loads a pre-trained model from Hugging Face Hub and classifies input text
5//! into sentiment categories (e.g., positive, negative).
6
7/*
8Tangled strings don't support multiple files yet, so throw the following bit into Cargo.toml:
9
10[package]
11name = "bmini"
12version = "0.1.0"
13edition = "2024"
14authors = ["Nick Gerakines @ngerakines.me"]
15description = "A sentiment classification tool using BERT models for text analysis"
16license = "MIT"
17
18[dependencies]
19anyhow = "1.0.100"
20candle-core = "0.9.1"
21candle-nn = "0.9.1"
22candle-transformers = "0.9.1"
23hf-hub = "0.4.3"
24serde = { version = "1.0", features = ["derive"] }
25serde_json = "1.0"
26tokenizers = "0.22.1"
27*/
28
29use anyhow::{Context, Result};
30use candle_core::{Device, IndexOp, Tensor};
31use candle_nn::{linear, Module, VarBuilder};
32use candle_transformers::models::bert::{BertModel, Config};
33use hf_hub::{api::sync::Api, Repo, RepoType};
34use std::collections::HashMap;
35use tokenizers::Tokenizer;
36
37/// Model configuration constants
38mod constants {
39 /// Hugging Face model identifier for sentiment classification
40 pub const SENTIMENT_MODEL_ID: &str = "Varnikasiva/sentiment-classification-bert-mini";
41
42 /// Tokenizer model identifier (using standard BERT tokenizer)
43 pub const TOKENIZER_MODEL_ID: &str = "google-bert/bert-base-uncased";
44
45 /// Default git revision to use
46 pub const DEFAULT_REVISION: &str = "main";
47}
48
49/// Represents a sentiment classification result
50#[derive(Debug, Clone)]
51struct SentimentPrediction {
52 /// The input text that was classified
53 text: String,
54 /// Predicted sentiment label (e.g., "positive", "negative")
55 label: String,
56 /// Confidence score for the prediction (0.0 to 1.0)
57 confidence: f32,
58 /// Probability distribution across all classes
59 class_probabilities: Vec<(String, f32)>,
60}
61
62/// Encapsulates a sentiment classification model with BERT backbone
63struct SentimentClassifier {
64 /// The underlying BERT model for feature extraction
65 bert_model: BertModel,
66 /// Linear classification head that maps BERT embeddings to sentiment classes
67 classifier: candle_nn::Linear,
68 /// Tokenizer for processing input text
69 tokenizer: Tokenizer,
70 /// Maps class indices to human-readable labels
71 label_map: HashMap<usize, String>,
72 /// Computation device (CPU or CUDA)
73 device: Device,
74}
75
76impl SentimentClassifier {
77 /// Creates a new sentiment classifier by loading model weights from Hugging Face Hub
78 ///
79 /// # Arguments
80 /// * `device` - The device to run computations on (CPU or CUDA)
81 ///
82 /// # Returns
83 /// A configured sentiment classifier ready for inference
84 fn new(device: Device) -> Result<Self> {
85 // Initialize Hugging Face API client
86 let api = Api::new()?;
87
88 // Download model files
89 let (config, weights_path) = Self::download_model_files(&api)
90 .context("Failed to download model files")?;
91
92 // Download tokenizer
93 let tokenizer = Self::load_tokenizer(&api)
94 .context("Failed to load tokenizer")?;
95
96 // Parse label mapping from config
97 let label_map = Self::extract_label_mapping(&config)
98 .context("Failed to extract label mapping")?;
99
100 let num_labels = label_map.len();
101
102 // Parse BERT configuration
103 let bert_config: Config = serde_json::from_value(config.clone())
104 .context("Failed to parse BERT config")?;
105
106 // Load model weights
107 let (bert_model, classifier) = Self::load_model_weights(
108 weights_path,
109 &bert_config,
110 num_labels,
111 &device,
112 ).context("Failed to load model weights")?;
113
114 Ok(Self {
115 bert_model,
116 classifier,
117 tokenizer,
118 label_map,
119 device,
120 })
121 }
122
123 /// Downloads model configuration and weights from Hugging Face Hub
124 fn download_model_files(api: &Api) -> Result<(serde_json::Value, std::path::PathBuf)> {
125 let repo = api.repo(Repo::with_revision(
126 constants::SENTIMENT_MODEL_ID.to_string(),
127 RepoType::Model,
128 constants::DEFAULT_REVISION.to_string(),
129 ));
130
131 let config_path = repo.get("config.json")
132 .context("Failed to download config.json")?;
133 let weights_path = repo.get("model.safetensors")
134 .context("Failed to download model weights")?;
135
136 let config_content = std::fs::read_to_string(&config_path)
137 .context("Failed to read config file")?;
138 let config: serde_json::Value = serde_json::from_str(&config_content)
139 .context("Failed to parse config JSON")?;
140
141 Ok((config, weights_path))
142 }
143
144 /// Loads a compatible tokenizer from Hugging Face Hub
145 fn load_tokenizer(api: &Api) -> Result<Tokenizer> {
146 let tokenizer_repo = api.repo(Repo::with_revision(
147 constants::TOKENIZER_MODEL_ID.to_string(),
148 RepoType::Model,
149 constants::DEFAULT_REVISION.to_string(),
150 ));
151
152 let tokenizer_path = tokenizer_repo.get("tokenizer.json")
153 .context("Failed to download tokenizer")?;
154
155 Tokenizer::from_file(tokenizer_path)
156 .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))
157 }
158
159 /// Extracts the label mapping from model configuration
160 fn extract_label_mapping(config: &serde_json::Value) -> Result<HashMap<usize, String>> {
161 let mut label_map = HashMap::new();
162
163 let id2label = config["id2label"]
164 .as_object()
165 .ok_or_else(|| anyhow::anyhow!("Missing id2label in config"))?;
166
167 for (id_str, label_value) in id2label {
168 let id = id_str.parse::<usize>()
169 .context(format!("Invalid label ID: {}", id_str))?;
170 let label = label_value
171 .as_str()
172 .ok_or_else(|| anyhow::anyhow!("Invalid label value for ID {}", id))?
173 .to_string();
174
175 label_map.insert(id, label);
176 }
177
178 if label_map.is_empty() {
179 anyhow::bail!("No labels found in model configuration");
180 }
181
182 Ok(label_map)
183 }
184
185 /// Loads BERT model and classifier weights from safetensors file
186 fn load_model_weights(
187 weights_path: std::path::PathBuf,
188 config: &Config,
189 num_labels: usize,
190 device: &Device,
191 ) -> Result<(BertModel, candle_nn::Linear)> {
192 // Load weights from safetensors format
193 let weights = candle_core::safetensors::load(&weights_path, device)
194 .context("Failed to load safetensors")?;
195
196 let vb = VarBuilder::from_tensors(weights, candle_core::DType::F32, device);
197
198 // Initialize BERT backbone
199 let bert_model = BertModel::load(vb.pp("bert"), config)
200 .context("Failed to load BERT model")?;
201
202 // Initialize classification head
203 let classifier = linear(config.hidden_size, num_labels, vb.pp("classifier"))
204 .context("Failed to load classifier layer")?;
205
206 Ok((bert_model, classifier))
207 }
208
209 /// Performs sentiment classification on input text
210 ///
211 /// # Arguments
212 /// * `text` - The text to classify
213 ///
214 /// # Returns
215 /// A `SentimentPrediction` containing the classification results
216 pub fn predict(&self, text: &str) -> Result<SentimentPrediction> {
217 // Tokenize input text
218 let encoding = self.tokenizer
219 .encode(text, true)
220 .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
221
222 // Convert tokens to tensors
223 let input_ids = Tensor::new(encoding.get_ids(), &self.device)?
224 .unsqueeze(0)?; // Add batch dimension
225
226 let attention_mask = Tensor::new(encoding.get_attention_mask(), &self.device)?
227 .unsqueeze(0)?; // Add batch dimension
228
229 // Run BERT forward pass
230 let bert_output = self.bert_model
231 .forward(&input_ids, &attention_mask, None)
232 .context("BERT forward pass failed")?;
233
234 // Extract pooled output ([CLS] token representation)
235 let cls_embedding = bert_output
236 .i((0, 0))? // First sequence, first token
237 .unsqueeze(0)?; // Add batch dimension back
238
239 // Apply classification head
240 let logits = self.classifier
241 .forward(&cls_embedding)
242 .context("Classification forward pass failed")?;
243
244 // Convert logits to probabilities using softmax
245 let probabilities = candle_nn::ops::softmax(&logits, 1)?;
246 let prob_values: Vec<f32> = probabilities.squeeze(0)?.to_vec1()?;
247
248 // Find predicted class (highest probability)
249 let (predicted_idx, &confidence) = prob_values
250 .iter()
251 .enumerate()
252 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
253 .ok_or_else(|| anyhow::anyhow!("Failed to find max probability"))?;
254
255 // Get predicted label
256 let label = self.label_map
257 .get(&predicted_idx)
258 .cloned()
259 .unwrap_or_else(|| format!("class_{}", predicted_idx));
260
261 // Build probability distribution for all classes
262 let class_probabilities: Vec<(String, f32)> = prob_values
263 .iter()
264 .enumerate()
265 .map(|(idx, &prob)| {
266 let class_label = self.label_map
267 .get(&idx)
268 .cloned()
269 .unwrap_or_else(|| format!("class_{}", idx));
270 (class_label, prob)
271 })
272 .collect();
273
274 Ok(SentimentPrediction {
275 text: text.to_string(),
276 label,
277 confidence,
278 class_probabilities,
279 })
280 }
281}
282
283/// Formats and displays a sentiment prediction result
284fn display_prediction(prediction: &SentimentPrediction) {
285 println!("Text: \"{}\"", prediction.text);
286 println!(
287 "Predicted: {} (confidence: {:.1}%)",
288 prediction.label,
289 prediction.confidence * 100.0
290 );
291 println!("Probabilities:");
292 for (label, probability) in &prediction.class_probabilities {
293 println!(" {}: {:.1}%", label, probability * 100.0);
294 }
295}
296
297/// Main entry point for the sentiment classification application
298fn main() -> Result<()> {
299 // Initialize computation device (CPU by default, can be modified for GPU)
300 let device = Device::Cpu;
301 println!("Initializing sentiment classifier on {:?}...", device);
302
303 // Load the sentiment classification model
304 let classifier = SentimentClassifier::new(device)
305 .context("Failed to initialize sentiment classifier")?;
306
307 println!("Model loaded successfully!\n");
308
309 // Define test sentences for demonstration
310 let test_sentences = [
311 "Nice",
312 "Very cool",
313 "Great work",
314 "I love this!",
315 "I hate this!",
316 ];
317
318 // Run predictions and display results
319 println!("Sentiment Classification Results:");
320 println!("{}", "=".repeat(60));
321
322 for sentence in &test_sentences {
323 let prediction = classifier.predict(sentence)
324 .context(format!("Failed to classify: {}", sentence))?;
325
326 display_prediction(&prediction);
327 println!("{}", "-".repeat(60));
328 }
329
330 Ok(())
331}