WarriorsSami commited on
Commit
1c9f6e9
·
1 Parent(s): fe3f1e2

perf: replace metal with wgpu support and add http file for predict request

Browse files
model/src/inference.rs CHANGED
@@ -55,7 +55,7 @@ impl<B: Backend, D: TextClassificationDataset + 'static> SamModel<B, D> {
55
  println!("Loading weights ...");
56
  let record = CompactRecorder::new()
57
  .load(format!("{artifact_dir}/model").into(), &device)
58
- .expect("Trained model weights tb");
59
 
60
  // Create model using loaded weights
61
  println!("Creating model ...");
 
55
  println!("Loading weights ...");
56
  let record = CompactRecorder::new()
57
  .load(format!("{artifact_dir}/model").into(), &device)
58
+ .expect("Trained model weights to be present");
59
 
60
  // Create model using loaded weights
61
  println!("Creating model ...");
model/src/lib.rs CHANGED
@@ -1,3 +1,3 @@
1
  pub mod data;
2
  pub mod inference;
3
- pub mod model;
 
1
  pub mod data;
2
  pub mod inference;
3
+ pub mod model;
requests/predict.http ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ### POST predict sentiment
2
+ POST http://127.0.0.1:8000/predict
3
+ Accept: application/json
4
+ Content-Type: application/json
5
+
6
+ {
7
+ "text": "I love programming!"
8
+ }
9
+
10
+ ###
server/src/api.rs CHANGED
@@ -1,5 +1,5 @@
1
  use crate::dtos::{SentimentRequest, SentimentResponse};
2
- use burn::backend::Metal;
3
  use model::data::TweetSentimentDataset;
4
  use model::model::SamModel;
5
  use rocket::serde::json::Json;
@@ -7,7 +7,13 @@ use rocket::tokio::sync::Mutex;
7
  use rocket::{State, post};
8
  use std::sync::Arc;
9
 
10
- pub type BackendImpl = Metal<f32, i32>;
 
 
 
 
 
 
11
 
12
  // ---- API Route ---- //
13
  #[post("/predict", format = "json", data = "<input>")]
 
1
  use crate::dtos::{SentimentRequest, SentimentResponse};
2
+ use burn::backend::Wgpu;
3
  use model::data::TweetSentimentDataset;
4
  use model::model::SamModel;
5
  use rocket::serde::json::Json;
 
7
  use rocket::{State, post};
8
  use std::sync::Arc;
9
 
10
+ #[cfg(not(feature = "f16"))]
11
+ #[allow(dead_code)]
12
+ type ElemType = f32;
13
+ #[cfg(feature = "f16")]
14
+ type ElemType = burn::tensor::f16;
15
+
16
+ pub type BackendImpl = Wgpu<ElemType, i32>;
17
 
18
  // ---- API Route ---- //
19
  #[post("/predict", format = "json", data = "<input>")]
server/src/main.rs CHANGED
@@ -1,4 +1,5 @@
1
  #![recursion_limit = "256"]
 
2
 
3
  use crate::api::{AppState, BackendImpl, predict};
4
  use model::data::TweetSentimentDataset;
 
1
  #![recursion_limit = "256"]
2
+ #![allow(unexpected_cfgs)]
3
 
4
  use crate::api::{AppState, BackendImpl, predict};
5
  use model::data::TweetSentimentDataset;
trainer/src/main.rs CHANGED
@@ -11,18 +11,10 @@ fn main() {
11
  inference_runner::run();
12
  }
13
 
14
- #[cfg(not(any(feature = "f16", feature = "flex32")))]
15
- #[allow(unused)]
16
- type ElemType = f32;
17
- #[cfg(feature = "f16")]
18
- type ElemType = burn::tensor::f16;
19
- #[cfg(feature = "flex32")]
20
- type ElemType = burn::tensor::flex32;
21
-
22
  #[cfg(feature = "training")]
23
  mod training_runner {
24
- use crate::{ElemType, training};
25
- use burn::backend::{Autodiff, Metal};
26
  use burn::nn::transformer::TransformerEncoderConfig;
27
  use burn::optim::AdamConfig;
28
  use burn::optim::decay::WeightDecayConfig;
@@ -30,6 +22,14 @@ mod training_runner {
30
  use model::data::TweetSentimentDataset;
31
  use model::inference::ExperimentConfig;
32
 
 
 
 
 
 
 
 
 
33
  pub fn launch<B: AutodiffBackend>(devices: Vec<B::Device>) {
34
  let config = ExperimentConfig::new(
35
  TransformerEncoderConfig::new(256, 1024, 8, 4)
@@ -48,18 +48,24 @@ mod training_runner {
48
  }
49
 
50
  pub fn run() {
51
- launch::<Autodiff<Metal<ElemType, i32>>>(vec![Default::default()]);
52
  }
53
  }
54
 
55
  #[cfg(feature = "inference")]
56
  mod inference_runner {
57
- use crate::ElemType;
58
- use burn::backend::Metal;
59
  use burn::prelude::Backend;
60
  use model::data::TweetSentimentDataset;
61
  use model::inference;
62
 
 
 
 
 
 
 
63
  pub fn launch<B: Backend>(device: B::Device) {
64
  inference::infer::<B, TweetSentimentDataset>(
65
  device,
@@ -80,6 +86,6 @@ mod inference_runner {
80
  }
81
 
82
  pub fn run() {
83
- launch::<Metal<ElemType, i32>>(Default::default());
84
  }
85
  }
 
11
  inference_runner::run();
12
  }
13
 
 
 
 
 
 
 
 
 
14
  #[cfg(feature = "training")]
15
  mod training_runner {
16
+ use crate::training;
17
+ use burn::backend::{Autodiff, Wgpu};
18
  use burn::nn::transformer::TransformerEncoderConfig;
19
  use burn::optim::AdamConfig;
20
  use burn::optim::decay::WeightDecayConfig;
 
22
  use model::data::TweetSentimentDataset;
23
  use model::inference::ExperimentConfig;
24
 
25
+ #[cfg(not(any(feature = "f16", feature = "flex32")))]
26
+ #[allow(unused)]
27
+ pub type ElemType = f32;
28
+ #[cfg(feature = "f16")]
29
+ pub type ElemType = burn::tensor::f16;
30
+ #[cfg(feature = "flex32")]
31
+ pub type ElemType = burn::tensor::flex32;
32
+
33
  pub fn launch<B: AutodiffBackend>(devices: Vec<B::Device>) {
34
  let config = ExperimentConfig::new(
35
  TransformerEncoderConfig::new(256, 1024, 8, 4)
 
48
  }
49
 
50
  pub fn run() {
51
+ launch::<Autodiff<Wgpu<ElemType, i32>>>(vec![Default::default()]);
52
  }
53
  }
54
 
55
  #[cfg(feature = "inference")]
56
  mod inference_runner {
57
+ use burn::backend::Wgpu;
58
+ use burn::backend::wgpu::WgpuDevice;
59
  use burn::prelude::Backend;
60
  use model::data::TweetSentimentDataset;
61
  use model::inference;
62
 
63
+ #[cfg(not(feature = "f16"))]
64
+ #[allow(dead_code)]
65
+ type ElemType = f32;
66
+ #[cfg(feature = "f16")]
67
+ type ElemType = burn::tensor::f16;
68
+
69
  pub fn launch<B: Backend>(device: B::Device) {
70
  inference::infer::<B, TweetSentimentDataset>(
71
  device,
 
86
  }
87
 
88
  pub fn run() {
89
+ launch::<Wgpu<ElemType, i32>>(WgpuDevice::default());
90
  }
91
  }