tons of changes, mostly type checking and prompt clean up
This commit is contained in:
293
src/main.rs
293
src/main.rs
@@ -1,9 +1,10 @@
|
||||
use std::borrow::Cow;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use chat::Message;
|
||||
use clap::Parser;
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crossterm::event::{self, DisableMouseCapture, EnableMouseCapture, Event, KeyCode};
|
||||
use crossterm::terminal::{
|
||||
@@ -18,13 +19,16 @@ use ratatui::{
|
||||
text::{Line, Span},
|
||||
widgets::{Block, Borders, Paragraph},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
mod chat;
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(
|
||||
short,
|
||||
long,
|
||||
default_value = "mixtral:8x7b-instruct-v0.1-q5_K_M",
|
||||
default_value = "mistral:latest",
|
||||
help = "Model name to use"
|
||||
)]
|
||||
model: String,
|
||||
@@ -35,60 +39,29 @@ struct Args {
|
||||
help = "Should the response be streamed from ollama or sent all at once"
|
||||
)]
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct StreamChunk {
|
||||
message: StreamMessage,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct StreamMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct Prompt<'a> {
|
||||
role: &'a str,
|
||||
content: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ChatRequest<'a> {
|
||||
model: &'a str,
|
||||
messages: Vec<Prompt<'a>>,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ChatResponse {
|
||||
model: String,
|
||||
created_at: String,
|
||||
message: Message,
|
||||
done_reason: String,
|
||||
done: bool,
|
||||
total_duration: u64,
|
||||
eval_count: u64,
|
||||
eval_duration: u64,
|
||||
prompt_eval_count: u64,
|
||||
prompt_eval_duration: u64,
|
||||
#[arg(short, long, help = "Show statistics in non-stream mode?")]
|
||||
nerd_stats: bool,
|
||||
}
|
||||
|
||||
struct App {
|
||||
args: Args,
|
||||
prompt: String,
|
||||
messages: Vec<String>,
|
||||
waiting: bool
|
||||
messages: Vec<Message>,
|
||||
waiting: bool,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
// parse arguments
|
||||
let args = match Args::try_parse() {
|
||||
Ok(args) => args,
|
||||
Err(e) => {
|
||||
e.print().expect("Error writing clap error");
|
||||
std::process::exit(0);
|
||||
}
|
||||
};
|
||||
|
||||
// setup crossterm
|
||||
enable_raw_mode()?;
|
||||
let mut stdout = std::io::stdout();
|
||||
@@ -97,19 +70,54 @@ async fn main() -> anyhow::Result<()> {
|
||||
let mut terminal = Terminal::new(backend)?;
|
||||
|
||||
let mut app = App {
|
||||
args,
|
||||
prompt: String::new(),
|
||||
messages: vec![String::from("Welcome to the OxiAI TUI Interface!")],
|
||||
waiting: false
|
||||
messages: vec![],
|
||||
waiting: false,
|
||||
};
|
||||
|
||||
// parse arguments
|
||||
let args = Args::parse();
|
||||
|
||||
let client = Client::new();
|
||||
let model_name = &args.model;
|
||||
|
||||
let system_prompt =
|
||||
"[INST]You are a helpful, logical and extremely technical AI assistant.[INST]";
|
||||
let header_prompt =
|
||||
r#"SYSTEM: You are "OxiAI", a logical, personal assistant that answers *only* via JSON"#;
|
||||
|
||||
let tools_list = include_str!("tool/tools_list.json")
|
||||
.parse::<serde_json::Value>()?
|
||||
.to_string();
|
||||
|
||||
let rules_prompt = r#"Rules:
|
||||
1. Think silently, Never reveal your chain-of-thought.
|
||||
2. To use a tool: {"action":"<tool>","arguments":{...}}
|
||||
3. To reply directly: {"action":"chat","arguments":{"response":"..."}
|
||||
4. If a question is vague, comparative, descriptive, or about ideas rather than specifics: use the web_search tool.
|
||||
5. If a question clearly names a specific object, animal, person, place: use the wiki_search tool.
|
||||
6. Base claims strictly on provided data or tool results. If unsure, say so.
|
||||
7. Perform a JSON Self-check to ensure valid, minified, UTF-8 JSON.
|
||||
8. Finish with a coherent sentence; if you reach four consecutive newlines: **STOP.**"#;
|
||||
|
||||
let example_prompt = format!(
|
||||
"Example 1:{user_q_1}\n{assistant_tool_request_1}\n{tool_result_1}\n{assistant_a_1}",
|
||||
user_q_1 = r#"user: {"action":"chat", "arguments":{"response":"Provide a summary of the American Crow.", "source":"user"}}"#,
|
||||
assistant_tool_request_1 = format!(
|
||||
"assistant: {{ \"action\":\"wiki_search\",\"arguments\":{{\"query\":\"American Crow\"}} }}"
|
||||
),
|
||||
tool_result_1 = format!(
|
||||
"tool: {{ \"action\":\"wiki_search\",\"arguments\":{{\"result\":\"{search_data}\"}} }}",
|
||||
search_data = include_str!("tool/american_crow_wikipedia.md").to_string()
|
||||
),
|
||||
assistant_a_1 = format!(
|
||||
"assistant: {{ \"action\":\"chat\",\"arguments\":{{\"response\":\"{example1_assistant_message}\"}} }}",
|
||||
example1_assistant_message =
|
||||
include_str!("tool/american_crow_example1_message.md").to_string()
|
||||
)
|
||||
);
|
||||
|
||||
//let user_info_prompt = r#""#;
|
||||
let system_prompt = format!(
|
||||
"{header_prompt}\n
|
||||
{tools_list}\n\n
|
||||
{rules_prompt}\n"
|
||||
);
|
||||
|
||||
loop {
|
||||
terminal.draw(|f| chat_ui(f, &app))?;
|
||||
@@ -123,36 +131,49 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
KeyCode::Enter => {
|
||||
//TODO: refactor to a parser function to take the contents of the app.prompt vec and do fancy stuff with it (like commands)
|
||||
let prompt = app.prompt.clone();
|
||||
app.messages.push(format!("[INST]{}[INST]", prompt));
|
||||
let message_args = args_builder! {
|
||||
"response" => app.prompt.clone(),
|
||||
};
|
||||
app.prompt.clear();
|
||||
|
||||
let user_prompt = app.messages.pop()
|
||||
.expect("No user prompt received (empty user_prompt)");
|
||||
app.messages.push(chat::Message::new(
|
||||
chat::MessageRoles::User,
|
||||
chat::Action::ChatMessage,
|
||||
message_args));
|
||||
|
||||
let req = ChatRequest {
|
||||
model: model_name,
|
||||
stream: args.stream,
|
||||
messages: vec![
|
||||
Prompt {
|
||||
role: "system",
|
||||
content: system_prompt,
|
||||
},
|
||||
Prompt {
|
||||
role: "user",
|
||||
content: &user_prompt,
|
||||
},
|
||||
],
|
||||
let mut prompts = vec![chat::Prompt {
|
||||
role: Cow::Borrowed("system"),
|
||||
content: Cow::Borrowed(&system_prompt),
|
||||
}];
|
||||
prompts.extend(
|
||||
app.messages
|
||||
.iter()
|
||||
.map(|msg| chat::Prompt::from(msg.clone())),
|
||||
);
|
||||
|
||||
let req = chat::ChatRequest {
|
||||
model: &app.args.model.clone(),
|
||||
stream: app.args.stream,
|
||||
format: "json",
|
||||
stop: vec!["\n\n\n\n"],
|
||||
options: Some(chat::ChatOptions {
|
||||
temperature: Some(0.3),
|
||||
top_p: Some(0.92),
|
||||
top_k: Some(50),
|
||||
repeat_penalty: Some(1.1),
|
||||
seed: None,
|
||||
}),
|
||||
messages: prompts,
|
||||
};
|
||||
|
||||
match args.stream {
|
||||
app.waiting = true;
|
||||
match app.args.stream {
|
||||
true => {
|
||||
stream_ollama_response(&mut app, client.clone(), req)
|
||||
.await?;
|
||||
todo!();
|
||||
stream_ollama_response(&mut app, client.clone(), req).await?;
|
||||
}
|
||||
false => {
|
||||
ollama_response(&mut app, client.clone(), req)
|
||||
.await?;
|
||||
batch_ollama_response(&mut app, client.clone(), req).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -175,12 +196,12 @@ async fn main() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
//FIXME: streaming replies are harder to work with for now, save this for the future
|
||||
async fn stream_ollama_response(
|
||||
app: &mut App,
|
||||
client: Client,
|
||||
req: ChatRequest<'_>,
|
||||
req: chat::ChatRequest<'_>,
|
||||
) -> anyhow::Result<()> {
|
||||
app.waiting = true;
|
||||
let mut resp = client
|
||||
.post("http://localhost:11434/api/chat")
|
||||
.json(&req)
|
||||
@@ -198,60 +219,95 @@ async fn stream_ollama_response(
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let parsed: serde_json::Result<StreamChunk> = serde_json::from_slice(line);
|
||||
let parsed: serde_json::Result<chat::StreamChunk> = serde_json::from_slice(line);
|
||||
if let Ok(parsed) = parsed {
|
||||
assistant_line.push_str(&parsed.message.content);
|
||||
assistant_line.push_str(&parsed.message.content.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
app.messages.push(assistant_line);
|
||||
app.waiting = false;
|
||||
//FIXME: fix this later
|
||||
//app.messages.push(assistant_line);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ollama_response<'a>(
|
||||
async fn batch_ollama_response<'a>(
|
||||
app: &mut App,
|
||||
client: Client,
|
||||
req: ChatRequest<'a>,
|
||||
req: chat::ChatRequest<'a>,
|
||||
) -> anyhow::Result<()> {
|
||||
|
||||
app.waiting = true;
|
||||
|
||||
let start = Instant::now();
|
||||
let resp: ChatResponse = client
|
||||
let resp = client
|
||||
.post("http://localhost:11434/api/chat")
|
||||
.json(&req)
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
let status = resp.status();
|
||||
let headers = resp.headers().clone();
|
||||
let body_bytes = resp.bytes().await?;
|
||||
|
||||
match serde_json::from_slice::<chat::ChatResponse>(&body_bytes) {
|
||||
Ok(r) => app.messages.push(r.message),
|
||||
Err(e) => {
|
||||
println!("Failed to parse JSON: {}", e);
|
||||
println!("Status: {}", status);
|
||||
println!("Headers: {:#?}", headers);
|
||||
// Try to print the body as text for debugging
|
||||
if let Ok(body_text) = std::str::from_utf8(&body_bytes) {
|
||||
println!("Body text: {}", body_text);
|
||||
} else {
|
||||
println!("Body was not valid UTF-8");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
app.messages.push(format!("{} : {}", resp.message.role, resp.message.content));
|
||||
app.messages.push(format!("System : Response generated via {} model with timestamp {}",
|
||||
resp.model, resp.created_at
|
||||
));
|
||||
/*
|
||||
if app.args.nerd_stats {
|
||||
app.messages.push(format!(
|
||||
"System : Response generated via {} model with timestamp {}",
|
||||
resp.model, resp.created_at
|
||||
));
|
||||
|
||||
app.messages.push(format!("System : done_reason = {}, done = {}",
|
||||
resp.done_reason, resp.done
|
||||
));
|
||||
app.messages.push(format!(
|
||||
"System : done_reason = {}, done = {}",
|
||||
resp.done_reason, resp.done
|
||||
));
|
||||
|
||||
app.messages.push(format!("System : Response timing statistics..."));
|
||||
app.messages
|
||||
.push(format!("System : Response timing statistics..."));
|
||||
|
||||
app.messages.push(format!("System : Total elapsed wall time: {:.2?}", elapsed));
|
||||
app.messages.push(format!("System : Prompt tokens: {}", resp.prompt_eval_count));
|
||||
app.messages.push(format!("System : Prompt eval duration: {} ns", resp.prompt_eval_duration));
|
||||
app.messages.push(format!("System : Output tokens: {}", resp.eval_count));
|
||||
app.messages.push(format!("System : Output eval duration: {} ns", resp.eval_duration));
|
||||
app.messages.push(format!("System : Model 'warm up' time {}", (resp.total_duration - (resp.prompt_eval_duration + resp.eval_duration))));
|
||||
|
||||
let token_speed = resp.eval_count as f64 / (resp.eval_duration as f64 / 1_000_000_000.0);
|
||||
app.messages.push(format!("System > Output generation speed: {:.2} tokens/sec", token_speed));
|
||||
app.messages
|
||||
.push(format!("System : Total elapsed wall time: {:.2?}", elapsed));
|
||||
app.messages.push(format!(
|
||||
"System : Prompt tokens: {}",
|
||||
resp.prompt_eval_count
|
||||
));
|
||||
app.messages.push(format!(
|
||||
"System : Prompt eval duration: {} ns",
|
||||
resp.prompt_eval_duration
|
||||
));
|
||||
app.messages
|
||||
.push(format!("System : Output tokens: {}", resp.eval_count));
|
||||
app.messages.push(format!(
|
||||
"System : Output eval duration: {} ns",
|
||||
resp.eval_duration
|
||||
));
|
||||
app.messages.push(format!(
|
||||
"System : Model 'warm up' time {}",
|
||||
(resp.total_duration - (resp.prompt_eval_duration + resp.eval_duration))
|
||||
));
|
||||
|
||||
let token_speed = resp.eval_count as f64 / (resp.eval_duration as f64 / 1_000_000_000.0);
|
||||
app.messages.push(format!(
|
||||
"System > Output generation speed: {:.2} tokens/sec",
|
||||
token_speed
|
||||
));
|
||||
}
|
||||
*/
|
||||
app.waiting = false;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -265,7 +321,7 @@ fn chat_ui(f: &mut ratatui::Frame, app: &App) {
|
||||
let messages: Vec<Line> = app
|
||||
.messages
|
||||
.iter()
|
||||
.map(|m| Line::from(Span::raw(m.clone())))
|
||||
.map(|m| Line::from(Span::raw(m.to_string())))
|
||||
.collect();
|
||||
|
||||
let messages_block = Paragraph::new(ratatui::text::Text::from(messages))
|
||||
@@ -279,7 +335,7 @@ fn chat_ui(f: &mut ratatui::Frame, app: &App) {
|
||||
));
|
||||
|
||||
f.render_widget(messages_block, chunks[0]);
|
||||
|
||||
|
||||
let input_text = if app.waiting {
|
||||
format!("> {} (waiting...)", &app.prompt)
|
||||
} else {
|
||||
@@ -292,13 +348,10 @@ fn chat_ui(f: &mut ratatui::Frame, app: &App) {
|
||||
f.render_widget(input, chunks[1]);
|
||||
|
||||
use ratatui::layout::Position;
|
||||
f.set_cursor_position(
|
||||
Position::new(
|
||||
// the +3 comes from the 3 'characters' of space between the terminal edge and the text location
|
||||
// this places the text cursor after the last entered character
|
||||
chunks[1].x + app.prompt.len() as u16 + 3,
|
||||
chunks[1].y + 1,
|
||||
)
|
||||
);
|
||||
f.set_cursor_position(Position::new(
|
||||
// the +3 comes from the 3 'characters' of space between the terminal edge and the text location
|
||||
// this places the text cursor after the last entered character
|
||||
chunks[1].x + app.prompt.len() as u16 + 3,
|
||||
chunks[1].y + 1,
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user