رفتن به نوشته‌ها

چگونه در تنسورفلو شی گرا کد بزنیم؟

یکی از چیزهایی که معمولا در کارهای علوم داده (data science) نادیده گرفته می‌شود مباحث مربوط به مهندسی نرم‌افزار است. بیشتر افراد به فکر این می‌کنند که نتایج کاری که میکنن درست باشه و به همین خاطر به تمیزی و ساختارمند بودن کد کمتر توجه می‌شود. این موضوع شاید دلایل مختلفی داشته باشه و یکی‌اش شاید به این دلیل باشه که زبان‌هایی که اغلب باهاشون کار میشه اسکریپتی هستند (مثل R و پایتون) و خیلی راحت بدون کلاس و پکیج و … امکان کد زدن را به شما می‌د‌هند یا اینکه هدف اصلی گرفتن نتایج بهتر است تا تمیزی کد!!!
احتمالا با اینکه چه مزایایی شی‌گرایی کد زدن داره آشنا هستید ولی به نظرم در بحث علوم داده مهمترین مزیتی که برای شما به ارمغان میاره سهولت فرآیند خطایابی و مدیریت مدل‌های مختلف است.

در این پست می‌خوام  مدل ورد۲وک رو به صورت مرحله به مرحله و با رعایت (برخی) قواعد شی‌گرایی پیاده‌سازی کنم. اون چیزی که میگم شاید بهترین نحوه پیاده‌سازی نباشه ولی در مواردی که من باهاش سر و کله زدم اغلب جواب داده است.

به طور کلی میشه گفت یک مدل در تنسورفلو از قسمت‌های زیر تشکیل شده است:

  • خواندن داده (با کمک tf.data یا placeholder)
  • تعریف وزن‌ها و پارامترهای شبکه
  • تعریف مدل
  • تعریف تابع خطا (loss function)
  • تعریف بهینه‌ساز (optimizer)

اگر همین ۵تا قسمت رو به صورت توابع مختلف در یک کلاس پایتون برای مدل Skipgram بنویسیم چیزی که خواهیم داشت شبیه زیر خواهد بود:

قسمت اول: ورودی‌های مدل

اینکه مدل ورودی‌تون چه چیزهایی رو به عنوان ورودی دریافت کنه (ورودی‌های کانستراکتور) خیلی بستگی به مدل و شبکه‌ موردنظرتون داره. مثلا در ورد۲وک اون چیزهایی که اغلب به عنوان هایپرپارامتر در نظر گرفته میشوند شامل: نرخ یادگیری (learning rate)، اندازه پنجره (window size)، اندازه بردارها (embedding length)، تعداد نمونه‌های منفی (negative samples)، تعداد کلمات (vocabulary size) و متن (پیکره) ورودی است. پس در مثال ما، این قسمت چیزی شبیه به زیر خواهد بود:

قسمت دوم: خواندن داده‌

همانطور که در اینجا توضیح دادم بهترین راه‌حل برای خواندن داده‌ها در تنسورفلو استفاده از tf.data است. همانطور که میدونید در مدل skipgram ما میخواستیم با استفاده از کلمه وسط (center) کلمات اطراف (target) رو حدس بزنیم. پس تابع get_next باید این دو مقدار را برای ما برگرداند. همچنین به ازای توابع اصلی مون یک tf.name_scope هم تعریف میکنیم تا خواندن گراف محاسباتی در تنسورفلو راحت‌تر باشد. پس این قسمت مشابه زیر خواهد شد:

قسمت سوم: تعریف وزن‌ها و مدل شبکه

از اونجایی که ورد۲وک یک مدل خیلی ساده محسوب میشه و فقط یک لایه داره وزن‌ها (weights) و پارامترهای زیادی نداره بخاطر همین میتونیم در این حالت خاص این قسمت‌های ۲ و ۳ (تعریف وزن‌ها و تعریف مدل) رو یک تابع کنیم و چیزی شبیه قسمت پایین داشته باشیم:

در تابع بالا، embed_matrix ماتریس کل کلمات است و اندازه آن به طول کل کلمات پیکره در طول بردارهای هر کلمه است (همان ماتریس C در مقاله آقای bengio و یا ماتریس W در مقاله ورد۲وک). تابع embedding_lookup یک تابع کمکی است که به صورت بهینه با گرفتن id کلمات بردارهای نهفته آنها را از embed_matrix واکشی می‌کند. center_words قبلا توسط _import_data ایجاد شده بودند پس کافی است آنها را به تابع embedding_lookup پاس بدهیم.

اتفاقی که میافتد چیزی شبیه شکل زیر:

id کلماتی که لازم داریم (center_words) از embed_matrix واکشی میشود.

قسمت چهارم: تعریف خطا

اگر خاطرتون باشه در ورد۲وک از تکنیک نمونه‌برداری منفی (negative sampling) برای اندازه‌گیری خطا استفاده می‌شد. خوشبختانه این تابع در تنسورفلو از قبل پیاده‌سازی شده است و میتوانیم از آن استفاده کنیم. ولی این تابع خطا، اندکی با توابع خطای دیگر (مثل cross_entropy یا mse) متفاوت است. در این تابع خطا لازم است که وزن (weight) و بایاس (bias) به تابع خطا پاس داده شود (چرا؟). علاوه‌ بر این، تعداد کلاس‌ها (در مورد ما میشود اندازه کلمات (vocab size)) و تعداد نمونه‌های منفی (negative samples) باید به آن فرستاده شود. پس در نتیجه تابع خطای ما مشابه زیر خواهد شد:

قسمت پنجم: تعریف بهینه‌ساز

شاید راحت‌ترین قسمت، تعریف بهینه‌ساز یا Optimizer باشد در انتخاب بهینه‌ساز دست‌مون باز هست و تقریبا میتونیم از هر نوع بهینه‌سازی استفاده کنیم که در اینجا برای سادگی از گرادیان نزولی (GradientDecent) استفاده میکنیم:

خب قسمت‌های اصلی مدل‌مون رو ساختیم. اگر دقت کرده باشید تمام توابع بالا خصوصی (private) هستند و از بیرون قابل صدا زدن نیستند. پس باید یک تابع کمکی درست کنید که یکی یکی گراف رو برای ما بسازد. این تابع رو اصلاحا build_graph نام گذاری میکنیم و به صورت زیر مینویسم:

بعد از ساختن گراف،‌ حالا نوبت به آموزش (train) مدل میرسد. برای آموزش تنها کافی است تعداد iterationها را داشته باشیم و از بیرون آن را دریافت کنیم و سپس آن را صدا بزنیم. به این ترتیب تابع train خواهد شد:

خب تا الان مدل رو ساختید و آموزش دادید و اگر اجرا کنید میبینید که loss کاهش میابد و به خوبی کار میکند ولی چطوری بیاییم از این مدل استفاده کنیم! یا اصن چطوری ببینیم؟ برای اینکه بخواهید بردارهاتون رو ببینید باید از tensorboard کمک بگیریم. ولی متاسفانه هیچی رو لاگ نکردیم که بتونیم مشاهده کنیم. خوب پس بیایید اندکی تغییر در کدمون ایجاد کنیم. در تابع build_graph یک تابع خصوصی دیگر به نام _create_summary اضافه میکنیم که محتوای آن به صورت زیر است:

این تابع امکان این را به ما اجازه میدهد تغییرات loss رو مشاهده کنیم و در صورت وجود مشکل در مدل راحت آن را برطرف کنیم.

همانطور که میدانید در تابع بالا summary_op یک operation است و باید محاسبه شود در نتیجه باید تغییراتی در کد train ایجاد کنیم. علاوه بر این، بیایید وزن‌های مدل هم ذخیره کنیم تا بعدا اگر مدل رو لازم داشتیم دیگه لازم نباشه آن را از صفر آموزش بدهیم. پس تابع train با تغییرات بالا میشود چیزی شبیه زیر:

همچنین یک تابع عمومی (public) دیگر به اسم visualize درست میکنیم که به صورت زیر تعریف می‌شود که برای نمایش بردارها در tensorboard از آن استفاده میکنیم.

اولین نکته‌ای که در تابع بالا هست و من قبلا راجع اش صحبت نکردم word2vec_utils است کد کامل آن را میتونید از اینجا ملاحضه کنید. تابعی most_common_words، کلمات پربسامد در دادگان ما را واکشی میکند و برای نمایش در تنسوربورد استفاده می‌کند. سپس مدلی که ذخیره کردیم بازیابی (restore) میکنیم و سپس بردارهای مورد نظر خود را (بردار وزن‌ها یا همان ورد۲وک) را به دست میاوریم و یک سری اطلاعات ذخیره میکنیم تا از آن در تسنوربورد استفاده کنیم.

خب کارمون تقریبا تموم شد. یک تابع main ایجاد میکنیم و یکی یکی موارد بالا رو صدا میزنیم همچنین برای اینکه پارامترهای شبکه رو بتونیم با کامند لاین تعریف کنیم یک سری FLAGS هم اضافه میکنیم در نتیجه کد نهایی مشابه زیر خواهد شد:

تنها قسمتی که توضیح ندادم بخش tf.data و تولید جفت x و y برای آموزش مدل است. من این کارو به کمک یک تابع کمکی از بیرون کلاس SkipGram انجام دادم تا کارهایی که منطقشون جداست از یکدیگر هم جدا باشند. در skipgram هدف ما حدس کلمات اطراف (target) به کمک کلمه وسط (center) است. پس باید یک سری batch درست کنیم تا این کارو برای ما انجام بدهد. Xهای ما (ورودی) میشود همان کلمات center و Yهای ما میشود کلمات اطراف آن (target). از طرفی چون در هر لحظه یک کلمه را پیش‌بینی میکنیم پس اندازه Yها میشود [batch_size,1].

پیشنهاد میکنم برای اینکه این پیاده‌سازی براتون تمرین بشه مدل CBow رو خودتون پیاده‌سازی کنید. کل کد این آموزش رو میتونید از اینجا ملاحضه بفرمایید. برای اینکه بیشتر شی‌‌گرایی رو دخیل کنید میتونید یک base_model درست کنید و کلاس‌های CBow و Skipgram ازش ارث‌بری کنند و قسمت‌های مشترک رو داخل آن قرار بدید به این ترتیب حجم کدی که میزنید بسیار کمتر خواهد شد. همچنین همه رو در یک پکیج قرار بدید تا این قسمت از کدتون کاملا مجزا از سایر قسمت‌ها باشد.

منتشر شده در آموزشپردازش زبان طبیعیتنسورفلومتفرقه

نظر

  1. ساره ساره

    سلام خسته نباشید
    ببخشید من سوالات زیادی دارم چظور میتونم لحظاتی وقتتون رو بگیرم تا در مورد پیاده سازی یک فایل متنی تکست به فرمت word2vec کمکم کنید

  2. محسن محسن

    سلام. خط زیر در فایل word2vec_utils.py خطا می دهد.
    from helper.utils import download_one_file, safe_mkdir
    با تشکر
    #win10

دیدگاهتان را بنویسید

نشانی ایمیل شما منتشر نخواهد شد. بخش‌های موردنیاز علامت‌گذاری شده‌اند *