Commit
·
1c9f6e9
1
Parent(s):
fe3f1e2
perf: replace metal with wgpu support and add http file for predict request
Browse files- model/src/inference.rs +1 -1
- model/src/lib.rs +1 -1
- requests/predict.http +10 -0
- server/src/api.rs +8 -2
- server/src/main.rs +1 -0
- trainer/src/main.rs +20 -14
model/src/inference.rs
CHANGED
@@ -55,7 +55,7 @@ impl<B: Backend, D: TextClassificationDataset + 'static> SamModel<B, D> {
|
|
55 |
println!("Loading weights ...");
|
56 |
let record = CompactRecorder::new()
|
57 |
.load(format!("{artifact_dir}/model").into(), &device)
|
58 |
-
.expect("Trained model weights
|
59 |
|
60 |
// Create model using loaded weights
|
61 |
println!("Creating model ...");
|
|
|
55 |
println!("Loading weights ...");
|
56 |
let record = CompactRecorder::new()
|
57 |
.load(format!("{artifact_dir}/model").into(), &device)
|
58 |
+
.expect("Trained model weights to be present");
|
59 |
|
60 |
// Create model using loaded weights
|
61 |
println!("Creating model ...");
|
model/src/lib.rs
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
pub mod data;
|
2 |
pub mod inference;
|
3 |
-
pub mod model;
|
|
|
1 |
pub mod data;
|
2 |
pub mod inference;
|
3 |
+
pub mod model;
|
requests/predict.http
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### POST predict sentiment
|
2 |
+
POST http://127.0.0.1:8000/predict
|
3 |
+
Accept: application/json
|
4 |
+
Content-Type: application/json
|
5 |
+
|
6 |
+
{
|
7 |
+
"text": "I love programming!"
|
8 |
+
}
|
9 |
+
|
10 |
+
###
|
server/src/api.rs
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
use crate::dtos::{SentimentRequest, SentimentResponse};
|
2 |
-
use burn::backend::
|
3 |
use model::data::TweetSentimentDataset;
|
4 |
use model::model::SamModel;
|
5 |
use rocket::serde::json::Json;
|
@@ -7,7 +7,13 @@ use rocket::tokio::sync::Mutex;
|
|
7 |
use rocket::{State, post};
|
8 |
use std::sync::Arc;
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
// ---- API Route ---- //
|
13 |
#[post("/predict", format = "json", data = "<input>")]
|
|
|
1 |
use crate::dtos::{SentimentRequest, SentimentResponse};
|
2 |
+
use burn::backend::Wgpu;
|
3 |
use model::data::TweetSentimentDataset;
|
4 |
use model::model::SamModel;
|
5 |
use rocket::serde::json::Json;
|
|
|
7 |
use rocket::{State, post};
|
8 |
use std::sync::Arc;
|
9 |
|
10 |
+
#[cfg(not(feature = "f16"))]
|
11 |
+
#[allow(dead_code)]
|
12 |
+
type ElemType = f32;
|
13 |
+
#[cfg(feature = "f16")]
|
14 |
+
type ElemType = burn::tensor::f16;
|
15 |
+
|
16 |
+
pub type BackendImpl = Wgpu<ElemType, i32>;
|
17 |
|
18 |
// ---- API Route ---- //
|
19 |
#[post("/predict", format = "json", data = "<input>")]
|
server/src/main.rs
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
#![recursion_limit = "256"]
|
|
|
2 |
|
3 |
use crate::api::{AppState, BackendImpl, predict};
|
4 |
use model::data::TweetSentimentDataset;
|
|
|
1 |
#![recursion_limit = "256"]
|
2 |
+
#![allow(unexpected_cfgs)]
|
3 |
|
4 |
use crate::api::{AppState, BackendImpl, predict};
|
5 |
use model::data::TweetSentimentDataset;
|
trainer/src/main.rs
CHANGED
@@ -11,18 +11,10 @@ fn main() {
|
|
11 |
inference_runner::run();
|
12 |
}
|
13 |
|
14 |
-
#[cfg(not(any(feature = "f16", feature = "flex32")))]
|
15 |
-
#[allow(unused)]
|
16 |
-
type ElemType = f32;
|
17 |
-
#[cfg(feature = "f16")]
|
18 |
-
type ElemType = burn::tensor::f16;
|
19 |
-
#[cfg(feature = "flex32")]
|
20 |
-
type ElemType = burn::tensor::flex32;
|
21 |
-
|
22 |
#[cfg(feature = "training")]
|
23 |
mod training_runner {
|
24 |
-
use crate::
|
25 |
-
use burn::backend::{Autodiff,
|
26 |
use burn::nn::transformer::TransformerEncoderConfig;
|
27 |
use burn::optim::AdamConfig;
|
28 |
use burn::optim::decay::WeightDecayConfig;
|
@@ -30,6 +22,14 @@ mod training_runner {
|
|
30 |
use model::data::TweetSentimentDataset;
|
31 |
use model::inference::ExperimentConfig;
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
pub fn launch<B: AutodiffBackend>(devices: Vec<B::Device>) {
|
34 |
let config = ExperimentConfig::new(
|
35 |
TransformerEncoderConfig::new(256, 1024, 8, 4)
|
@@ -48,18 +48,24 @@ mod training_runner {
|
|
48 |
}
|
49 |
|
50 |
pub fn run() {
|
51 |
-
launch::<Autodiff<
|
52 |
}
|
53 |
}
|
54 |
|
55 |
#[cfg(feature = "inference")]
|
56 |
mod inference_runner {
|
57 |
-
use
|
58 |
-
use burn::backend::
|
59 |
use burn::prelude::Backend;
|
60 |
use model::data::TweetSentimentDataset;
|
61 |
use model::inference;
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
pub fn launch<B: Backend>(device: B::Device) {
|
64 |
inference::infer::<B, TweetSentimentDataset>(
|
65 |
device,
|
@@ -80,6 +86,6 @@ mod inference_runner {
|
|
80 |
}
|
81 |
|
82 |
pub fn run() {
|
83 |
-
launch::<
|
84 |
}
|
85 |
}
|
|
|
11 |
inference_runner::run();
|
12 |
}
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
#[cfg(feature = "training")]
|
15 |
mod training_runner {
|
16 |
+
use crate::training;
|
17 |
+
use burn::backend::{Autodiff, Wgpu};
|
18 |
use burn::nn::transformer::TransformerEncoderConfig;
|
19 |
use burn::optim::AdamConfig;
|
20 |
use burn::optim::decay::WeightDecayConfig;
|
|
|
22 |
use model::data::TweetSentimentDataset;
|
23 |
use model::inference::ExperimentConfig;
|
24 |
|
25 |
+
#[cfg(not(any(feature = "f16", feature = "flex32")))]
|
26 |
+
#[allow(unused)]
|
27 |
+
pub type ElemType = f32;
|
28 |
+
#[cfg(feature = "f16")]
|
29 |
+
pub type ElemType = burn::tensor::f16;
|
30 |
+
#[cfg(feature = "flex32")]
|
31 |
+
pub type ElemType = burn::tensor::flex32;
|
32 |
+
|
33 |
pub fn launch<B: AutodiffBackend>(devices: Vec<B::Device>) {
|
34 |
let config = ExperimentConfig::new(
|
35 |
TransformerEncoderConfig::new(256, 1024, 8, 4)
|
|
|
48 |
}
|
49 |
|
50 |
pub fn run() {
|
51 |
+
launch::<Autodiff<Wgpu<ElemType, i32>>>(vec![Default::default()]);
|
52 |
}
|
53 |
}
|
54 |
|
55 |
#[cfg(feature = "inference")]
|
56 |
mod inference_runner {
|
57 |
+
use burn::backend::Wgpu;
|
58 |
+
use burn::backend::wgpu::WgpuDevice;
|
59 |
use burn::prelude::Backend;
|
60 |
use model::data::TweetSentimentDataset;
|
61 |
use model::inference;
|
62 |
|
63 |
+
#[cfg(not(feature = "f16"))]
|
64 |
+
#[allow(dead_code)]
|
65 |
+
type ElemType = f32;
|
66 |
+
#[cfg(feature = "f16")]
|
67 |
+
type ElemType = burn::tensor::f16;
|
68 |
+
|
69 |
pub fn launch<B: Backend>(device: B::Device) {
|
70 |
inference::infer::<B, TweetSentimentDataset>(
|
71 |
device,
|
|
|
86 |
}
|
87 |
|
88 |
pub fn run() {
|
89 |
+
launch::<Wgpu<ElemType, i32>>(WgpuDevice::default());
|
90 |
}
|
91 |
}
|