/* purl-rs
*
* Author : Dennis Eriksen <d@ennis.no>
* Project : purl-rs
* File : src/main.rs
* Created : 2023-06-27
*
*/
// standard imports
use std::collections::HashMap;
use std::io::Write;
use std::process::exit;
// imports from external crates
use dotenv;
use postgres::{Client, NoTls};
use dumb_cgi::{Request, EmptyResponse, Query};
use pledge::pledge_promises;
use rand::{thread_rng, Rng, distributions::Alphanumeric};
use regex::Regex;
use unveil::unveil;
use url::Url;
// Do the dirty
fn main() {
// Let's drop some privileges before we do anything else
drop_privs();
// Get variables from dotenv, or use defaults
let dburl:&str = &dotenv::var("DATABASE_URL").unwrap_or("postgresql://localhost/purl-rs".to_string());
let create_uri:&str = &dotenv::var("CREATE_URI").unwrap_or("/create".to_string());
let form_uri:&str = &dotenv::var("FORM_URI").unwrap_or("/form".to_string());
let short_uri_prefix:&str = &dotenv::var("SHORT_URI_PREFIX").unwrap_or("/+".to_string());
// Connect to db
let mut db = Client::connect(dburl, NoTls).unwrap();
// TODO: Close connection when done.
// Gather all request data from the environment and stdin.
let req = Request::new().unwrap();
// Get some variable from the current request
let scheme:&str = req.var("REQUEST_SCHEME").unwrap_or("https");
let server_name:&str = req.var("SERVER_NAME").unwrap_or("example.com");
let docuri:&str = req.var("DOCUMENT_URI").unwrap_or("no docuri");
// println!("Debug:");
// println!("dburl: {}", dburl);
// println!("create_uri: {}", create_uri);
// println!("form_uri: {}", form_uri);
// println!("short_uri_prefix: {}", short_uri_prefix);
// println!("scheme: {}", scheme);
// println!("server_name: {}", server_name);
// println!("docuri: {}", docuri);
//
// Form
//
if docuri == form_uri {
print_form();
}
//
// Create
//
else if docuri == create_uri {
// temporary empty map in case there is something wrong with the query
let empty:&HashMap<String, String> = &HashMap::new();
let query = match req.query() {
Query::Some(map) => map,
_ => empty, // empty map used here, if no map is found, or there is an error
};
// Check for url in query. It is obligatory.
if ! query.contains_key("url") {
respond(400, "Error, no url in query");
}
// Get url from query
let url:&str = &query["url"];
// Get short from query. Use supplied short if it exists, else generate random.
// Use mut String, since we will have to change it if it already exists
let mut short:String = if query.contains_key("short") {
query["short"].to_string()
} else {
gen_short()
};
// Get user from proxy, if user is logged in. Else use "nobody".
let user:&str = req.var("REMOTE_USER").unwrap_or("nobody");
// Check that url and short is valid
if ! Url::parse(url).is_ok() {
respond(400, "Invalid url");
} else if ! check_short(&short) {
respond(400, "Invalid short");
}
// Check if short already exists.
// If it does exist, set new random short. Try this five times. If we can not find a unique
// random short i five tries, abort.
for i in 1..5 {
let sql = "SELECT url FROM shorts WHERE short = $1 LIMIT 1";
match db.query_opt(sql, &[&short]).unwrap() {
Some(_row) => short = gen_short(), // If a row was returned, the short was not unique. Continue loop
None => break, // If nothing was returned, the short IS unique. Break out of loop
}
// Throw error if we couldn't create a unique short in fire tries
if i == 5 { respond(500, "Could not find unique short"); }
}
let sql = "INSERT INTO shorts (url, short, created_by) VALUES ($1, $2, $3);";
match db.execute(sql, &[&url, &short, &user]) {
Ok(_v) => respond(200, &format!("{}://{}{}{}", scheme, server_name, short_uri_prefix, short)),
Err(_e) => respond(500, "Could not save shortened url to database"),
};
exit(0);
}
//
// Redirect
//
else if docuri.starts_with(short_uri_prefix) {
// Trim / and + from start
let short = docuri.trim_start_matches(&['/', '+']);
// Make sure short is valid
if ! check_short(&short) { respond(400, "Invalid short"); }
// Fetch URL from postgres and redirect, or return 404
let sql = "SELECT url FROM shorts WHERE short = $1 LIMIT 1";
match db.query_opt(sql, &[&short]).unwrap() {
Some(row) => {
let sql = "UPDATE shorts SET count = count + 1, last_visited = now() WHERE short = $1;";
db.execute(sql, &[&short]).unwrap();
respond(301, row.get("url"));
},
None => respond(404, ""),
}
}
//
// Else? Oops.
//
else {
respond(500, "Reached end of script. There is probably a config-error somewhere.");
}
}
//
// send cgi-response
//
fn respond(code:u16, body:&str) {
// initiate response
let mut r = EmptyResponse::new(code)
.with_content_type("text/plain");
// is 300 <= status < 400, and body is a url, perform redirection
if (300..399).contains(&code) && Url::parse(body).is_ok() {
r.add_header("Location", body);
}
// if status >= 300, we want to include http-status in the response
if code >= 300 {
let mut toptext:String = format!("{} {}\n", code, status(code));
if ! body.is_empty() { toptext.push_str("\n"); } // Add extra linebreak before error-text
write!(&mut r, "{toptext}").unwrap();
}
// write body, unless there is no body
if ! body.is_empty() {
write!(&mut r, "{}\n", body).unwrap();
}
// respond and exit
r.respond().unwrap();
exit(0);
}
//
// HTTP status codes
//
fn status<'a>(code:u16) -> &'a str {
// I've only implemented statuscodes I *might* use
return match code {
200 => "OK",
201 => "Created",
202 => "Accepted",
204 => "No Content",
301 => "Moved Permanently",
302 => "Found",
304 => "Not Modified",
307 => "Temporary Redirect",
308 => "Permanent Redirect",
400 => "Bad Request",
401 => "Unauthorized",
402 => "Payment Required",
403 => "Forbidden",
404 => "Not Found",
405 => "Method Not Allowed",
406 => "Not Acceptableo",
410 => "Gone",
414 => "URI Too Long",
500 => "Internal Server Error",
501 => "Not Implemented",
503 => "Service Unavailable",
508 => "Loop Detected",
_ => "Not Found",
};
}
//
// Print form
//
fn print_form() {
let body = r#"<html>
<head>
<title>PURL</title>
</head>
<body>
<form method="post" action="/purl.cgi" enctype="multipart/form-data" name="main">
Username: $user<br>
URL to shorten: <input type="text" name="url" size="50"><br>
Custom short: <input type="text" name="short"><br>
<input type="hidden" name="form" value="html">
<input type="submit" value="Submit">
</form>
</body>
</html>
"#;
respond(200, body);
}
//
// Generate random short
//
fn gen_short() -> String {
let rand_string: String = thread_rng()
.sample_iter(&Alphanumeric)
.take(thread_rng().gen_range(2..6))
.map(char::from)
.collect();
return rand_string;
}
//
// Check if short is valid
//
fn check_short(short:&str) -> bool {
// Set regex for valid shorts
let shortregex = Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap();
if ! shortregex.is_match(short) {
return false; // short contains invalid characters
} else if short.chars().count() > 128 {
return false; // short is too long
}
return true;
}
//
// Drop privileges
//
fn drop_privs() {
// Restrict what files we can access. See unveil(2)
unveil(".env", "r")
.or_else(unveil::Error::ignore_platform)
.unwrap();
unveil("", "")
.or_else(unveil::Error::ignore_platform)
.unwrap();
// Restrict what system calls we can access. See pledge(2)
pledge_promises![Stdio Rpath Inet Dns]
.or_else(pledge::Error::ignore_platform)
.unwrap();
}
// end of file