aboutsummaryrefslogblamecommitdiffstats
path: root/src/main.rs
blob: 5e0fd4ab67299f5c81aad5401856aeb01deea67d (plain) (tree)
1
2
3
4
5
6
7
8
9
                   





                                              

                                                         

 
 
                        

                                                                                          
 
                                 
                                                                                                              
     





                                     

                                                                                                                    

         



























                                                      

 
                    
                                 
                        


                                         
                                                                     



                                                               











                                                                       
 
              
                         


            
 

                 
                        










                                                                                     

       

   
                       














                                           
               

                                                                    
                                  





                                                                                    





                                                            

 
                                                          






                                                              




                                  
 















                                                                  
                                              




















                                                                         
                                                           



                                                                                              

                     


                                                                                     
 
                                                                



                                                                                 
 



                                                                        
 

                         




          
use std::io::Write;
use dumb_cgi::{Request, EmptyResponse, Query};
use postgres::{Client, NoTls};
use std::process::exit;
use regex::Regex;
use dotenv;
use url::Url;
use rand::{thread_rng, Rng, distributions::Alphanumeric};




// Get URL from database
fn get_url(db:&mut Client, short:&str, update:bool) -> Result<String, postgres::Error> {
    let row = db.query_one("SELECT url FROM shorts WHERE short = $1 LIMIT 1", &[&short])?;

    if row.len() == 1 && update {
        db.execute("UPDATE shorts SET count = count + 1, last_visited = now() WHERE short = $1;", &[&short])?;
    }

    let url: String = row.get("url");

    Ok(url)
}

fn insert_short(db:&mut Client, url: &str, short: &str, user: &str) -> Result<u64, postgres::Error> {
    let n = db.execute("INSERT INTO shorts (url, short, created_by) VALUES ($1, $2, $3);", &[&url, &short, &user])?;
    Ok(n)
}

fn status<'a>(code:u16) -> &'a str {
    // I've only implemented statuscodes I *might* use
    return match code {
        200 => "OK",
        201 => "Created",
        202 => "Accepted",
        204 => "No Content",
        301 => "Moved Permanently",
        302 => "Found",
        304 => "Not Modified",
        307 => "Temporary Redirect",
        308 => "Permanent Redirect",
        400 => "Bad Request",
        401 => "Unauthorized",
        402 => "Payment Required",
        403 => "Forbidden",
        404 => "Not Found",
        405 => "Method Not Allowed",
        406 => "Not Acceptableo",
        410 => "Gone",
        414 => "URI Too Long",
        500 => "Internal Server Error",
        501 => "Not Implemented",
        503 => "Service Unavailable",
        508 => "Loop Detected",
        _   => "Not Found",
    };
}

// send cgi-response
fn respond(code:u16, body:&str) {
    // initiate response
    let mut r = EmptyResponse::new(code)
        .with_content_type("text/plain");

    // is 300 <= status < 400, and body is a url, perform redirection
    if (300..399).contains(&code) && Url::parse(body).is_ok() {
        r.add_header("Location", body);
    }

    // if status >= 300, we want to include status text in the response
    if code >= 300 {
        write!(&mut r, "{} {}\n", code, status(code)).unwrap();
        if body != "" {
            write!(&mut r, "\n").unwrap();
        }
    }

    if body != "" {
        // write body
        write!(&mut r, "{}\n", body).unwrap();
    }

    // respond
    r.respond().unwrap();
    exit(0);
}


// Print form
fn print_form() {
    let body = r#"<html>
<head>
    <title>PURL</title>
</head>
<body>
    <form method="post" action="/purl.cgi" enctype="multipart/form-data" name="main">
    Username: $user<br>
    URL to shorten: <input type="text" name="url" size="50"><br>
    Custom short: <input type="text" name="short"><br>
    <input type="hidden" name="form" value="html">
    <input type="submit" value="Submit">
    </form>
</body>
</html>
"#;

    respond(200, body);
}


// Generate random short
fn gen_short() -> String {
    let rand_string: String = thread_rng()
        .sample_iter(&Alphanumeric)
        .take(thread_rng().gen_range(2..6))
        .map(char::from)
        .collect();

    return rand_string;
}


// Do the dirty
// return std::io::Result<()> so we can use ? -
// https://doc.rust-lang.org/std/result/#the-question-mark-operator-
fn main() -> std::io::Result<()> {

    // short = ID for shortened url
    // url = url to be shortened / that has been shortened
    // shorturl = full shortened url
    // shortprefix = part of url that comes before short id
    // docuri = DOCUMENT_URI from http request. This is everything after the domain.
    //          Example: "/hey" in "https://example.com/hey"


    // Connect to db
    let dburl:String = dotenv::var("DATABASE_URL").unwrap();
    let mut db = Client::connect(&dburl, NoTls).unwrap();


    let shortprefix = dotenv::var("SHORTPREFIX").unwrap();

    // Gather all request data from the environment and stdin.
    let req = Request::new().unwrap();

    // Find out what action we're performing
    let action:&str = req.var("ACTION").unwrap_or("none");

    if action == "form" {
        print_form();
    }

    else if action == "redirect" {

        // get DOCUMENT_URI
        let docuri:&str = req.var("DOCUMENT_URI").unwrap_or("");
        if docuri == "" {
            respond(400, "No short provided");
        }

        // Trim / and + from start
        let docuri = docuri.trim_start_matches(&['/', '+']);

        // Make SURE docuri is a valid short
        let shortregex = Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap();
        if !shortregex.is_match(docuri) {
            respond(400, "Not a valid short");
        }

        // Fetch URL from postgres, and redirect or return 404
        match get_url(&mut db, docuri, true) {
            Ok(url) => respond(301, &url),
            Err(_e) => respond(404, ""),
        };
    }

    else if action == "create" {
        match req.query() {
            Query::None => respond(303, "/purl.cgi"),
            Query::Some(map) => {
                if map.iter().count() != 1 {
                    respond(400, "Incorrect number of query items");
                }
                let url:&str = &map["url"];
                if !Url::parse(url).is_ok() {
                    respond(400, "Invalid url");
                }

                let user:&str = req.var("REMOTE_USER").unwrap_or("none");
                let mut short:String = gen_short();

                for i in 1..5 {
                    match get_url(&mut db, &short, false) {
                        Ok(_url) => short = gen_short(), // If Ok, then short is already in
                                                         // use. Try a new one.
                        Err(_e) => break, // we assume that we found a unique short if get_url
                                          // returns an error
                    }

                    // Throw error if we couldn't create a unique short in fire tries
                    if i == 5 { respond(500, "Could not find unique short"); }
                }

                match insert_short(&mut db, url, &short, user) {
                    Ok(_v) => respond(200, &format!("{}{}", shortprefix, short)),
                    Err(_e) => respond(400, "looool"),
                };
                exit(0);

            },
            Query::Err(_e) => respond(400,"Error reading query string"),
        }
    }

    else {
        respond(400, "");
    }


    Ok(())
}